Source code for ablkit.utils.logger

"""
Copyright (c) OpenMMLab. All rights reserved.
Modified from
https://github.com/open-mmlab/mmengine/blob/main/mmengine/logging/logger.py
"""

import logging
import os
import os.path as osp
import sys
from logging import Logger, LogRecord
from typing import Optional, Union

from termcolor import colored

from .manager import ManagerMixin, _accquire_lock, _release_lock


class FilterDuplicateWarning(logging.Filter):
    """
    Filter for eliminating repeated warning messages in logging.

    This filter checks for duplicate warning messages and allows only the first occurrence of
    each message to be logged, filtering out subsequent duplicates.

    Parameters
    ----------
    name : str, optional
        The name of the filter. Defaults to "abl".
    """

    def __init__(self, name: Optional[str] = "abl"):
        super().__init__(name)
        self.seen: set = set()

    def filter(self, record: LogRecord) -> bool:
        """Filter the repeated warning message.

        Args:
            record (LogRecord): The log record.

        Returns:
            bool: Whether to output the log record.
        """
        if record.levelno != logging.WARNING:
            return True

        if record.msg not in self.seen:
            self.seen.add(record.msg)
            return True
        return False


class ABLFormatter(logging.Formatter):
    """
    Colorful format for ABLLogger. If the log level is error, the logger will
    additionally output the location of the code.

    Parameters
    ----------
    color : bool, optional
        Whether to use colorful format. filehandler is not
        allowed to use color format, otherwise it will be garbled.
        Defaults to True.
    blink : bool, optional
        Whether to blink the ``INFO`` and ``DEBUG`` logging
        level. Defaults to False.
    kwargs : dict
        Keyword arguments passed to
        :meth:``logging.Formatter.__init__``.
    """

    _color_mapping: dict = dict(ERROR="red", WARNING="yellow", INFO="white", DEBUG="green")

    def __init__(self, color: bool = True, blink: bool = False, **kwargs):
        super().__init__(**kwargs)
        assert not (not color and blink), "blink should only be available when color is True"
        # Get prefix format according to color.
        error_prefix = self._get_prefix("ERROR", color, blink=True)
        warn_prefix = self._get_prefix("WARNING", color, blink=True)
        info_prefix = self._get_prefix("INFO", color, blink)
        debug_prefix = self._get_prefix("DEBUG", color, blink)

        # Config output format.
        self.err_format = (
            f"%(asctime)s - %(name)s - {error_prefix} - "
            "%(pathname)s - %(funcName)s - %(lineno)d - "
            "%(message)s"
        )
        self.warn_format = f"%(asctime)s - %(name)s - {warn_prefix} - %(" "message)s"
        self.info_format = f"%(asctime)s - %(name)s - {info_prefix} - %(" "message)s"
        self.debug_format = f"%(asctime)s - %(name)s - {debug_prefix} - %(" "message)s"

    def _get_prefix(self, level: str, color: bool, blink: bool = False) -> str:
        """
        Get the prefix of the target log level.

        Parameters
        ----------
        level : str
            Log level.
        color : bool
            Whether to get a colorful prefix.
        blink : bool, optional
            Whether the prefix will blink. Defaults to False.

        Returns
        -------
        str
            The plain or colorful prefix.
        """
        if color:
            attrs = ["underline"]
            if blink:
                attrs.append("blink")
            prefix = colored(level, self._color_mapping[level], attrs=attrs)
        else:
            prefix = level
        return prefix

    def format(self, record: LogRecord) -> str:
        """
        Override the ``logging.Formatter.format`` method. Output the
        message according to the specified log level.

        Parameters
        ----------
        record : LogRecord
            A LogRecord instance representing an event being logged.

        Returns
        -------
        str
            Formatted result.
        """
        if record.levelno == logging.ERROR:
            self._style._fmt = self.err_format  # pylint: disable=protected-access
        elif record.levelno == logging.WARNING:
            self._style._fmt = self.warn_format  # pylint: disable=protected-access
        elif record.levelno == logging.INFO:
            self._style._fmt = self.info_format  # pylint: disable=protected-access
        elif record.levelno == logging.DEBUG:
            self._style._fmt = self.debug_format  # pylint: disable=protected-access

        result = logging.Formatter.format(self, record)
        return result


[docs] class ABLLogger(Logger, ManagerMixin): """ Formatted logger used to record messages with different log levels and features. ``ABLLogger`` provides a formatted logger that can log messages with different log levels. It allows the creation of logger instances in a similar manner to ``ManagerMixin``. The logger has features like distributed log storage and colored terminal output for different log levels. Parameters ---------- name : str Global instance name. logger_name : str, optional ``name`` attribute of ``logging.Logger`` instance. Defaults to 'abl'. log_file : str, optional The log filename. If specified, a ``FileHandler`` will be added to the logger. Defaults to None. log_level : Union[int, str], optional The log level of the handler. Defaults to 'INFO'. If log level is 'DEBUG', distributed logs will be saved during distributed training. file_mode : str, optional The file mode used to open log file. Defaults to 'w'. Notes ----- - The ``name`` of the logger and the ``instance_name`` of ``ABLLogger`` could be different. ``ABLLogger`` instances are retrieved using ``ABLLogger.get_instance``, not ``logging.getLogger``. This ensures ``ABLLogger`` is not influenced by third-party logging configurations. - Unlike ``logging.Logger``, ``ABLLogger`` will not log warning or error messages without ``Handler``. Examples -------- >>> logger = ABLLogger.get_instance(name='ABLLogger', logger_name='Logger') >>> # Although logger has a name attribute like ``logging.Logger`` >>> # We cannot get logger instance by ``logging.getLogger``. >>> assert logger.name == 'Logger' >>> assert logger.instance_name == 'ABLLogger' >>> assert id(logger) != id(logging.getLogger('Logger')) >>> # Get logger that does not store logs. >>> logger1 = ABLLogger.get_instance('logger1') >>> # Get logger only save rank0 logs. >>> logger2 = ABLLogger.get_instance('logger2', log_file='out.log') >>> # Get logger only save multiple ranks logs. >>> logger3 = ABLLogger.get_instance('logger3', log_file='out.log', distributed=True) """ def __init__( self, name: str, logger_name="abl", log_file: Optional[str] = None, log_level: Union[int, str] = "INFO", file_mode: str = "w", ): Logger.__init__(self, logger_name) ManagerMixin.__init__(self, name) if isinstance(log_level, str): log_level = logging._nameToLevel[log_level] stream_handler = logging.StreamHandler(stream=sys.stdout) # ``StreamHandler`` record month, day, hour, minute, and second # timestamp. stream_handler.setFormatter(ABLFormatter(color=True, datefmt="%m/%d %H:%M:%S")) stream_handler.setLevel(log_level) stream_handler.addFilter(FilterDuplicateWarning(logger_name)) self.handlers.append(stream_handler) if log_file is None: import time # pylint: disable=import-outside-toplevel local_time = time.strftime("%Y%m%d_%H_%M_%S", time.localtime()) _log_dir = os.path.join("results", local_time) self._log_dir = _log_dir if not os.path.exists(_log_dir): os.makedirs(_log_dir) log_file = osp.join(_log_dir, local_time + ".log") file_handler = logging.FileHandler(log_file, file_mode) file_handler.setFormatter(ABLFormatter(color=False, datefmt="%Y/%m/%d %H:%M:%S")) file_handler.setLevel(log_level) file_handler.addFilter(FilterDuplicateWarning(logger_name)) self.handlers.append(file_handler) self._log_file = log_file @property def log_file(self): """Get the file path of the log. Returns: str: Path of the log. """ return self._log_file @property def log_dir(self): """Get the directory where the log is stored. Returns: str: Directory where the log is stored. """ return self._log_dir
[docs] @classmethod def get_current_instance(cls) -> "ABLLogger": """ Get the latest created ``ABLLogger`` instance. Returns ------- ABLLogger The latest created ``ABLLogger`` instance. If no instance has been created, returns a logger with the instance name "abl". """ if not cls._instance_dict: cls.get_instance("abl") return super().get_current_instance()
[docs] def callHandlers(self, record: LogRecord) -> None: """ Pass a record to all relevant handlers. Override the ``callHandlers`` method in ``logging.Logger`` to avoid multiple warning messages in DDP mode. This method loops through all handlers of the logger instance and its parents in the logger hierarchy. Parameters ---------- record : LogRecord A ``LogRecord`` instance containing the logged message. """ for handler in self.handlers: if record.levelno >= handler.level: handler.handle(record)
[docs] def setLevel(self, level): """ Set the logging level of this logger. Override the ``setLevel`` method to clear caches of all ``ABLLogger`` instances managed by ``ManagerMixin``. The level must be an int or a str. Parameters ---------- level : Union[int, str] The logging level to set. """ self.level = logging._checkLevel(level) # pylint: disable=protected-access _accquire_lock() # The same logic as ``logging.Manager._clear_cache``. for logger in ABLLogger._instance_dict.values(): logger._cache.clear() # pylint: disable=protected-access _release_lock()