# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import contextlib
import logging
import os
import sys
_all_loggers = []
_default_level_name = os.getenv("MEGENGINE_LOGGING_LEVEL", "INFO")
_default_level = logging.getLevelName(_default_level_name.upper())
[文档]def set_log_file(fout, mode="a"):
    r"""Sets log output file.
    Args:
        fout: file-like object that supports write and flush, or string for the filename
        mode: specify the mode to open log file if *fout* is a string
    """
    if isinstance(fout, str):
        fout = open(fout, mode)
    MegEngineLogFormatter.log_fout = fout 
class MegEngineLogFormatter(logging.Formatter):
    log_fout = None
    date_full = "[%(asctime)s %(lineno)d@%(filename)s:%(name)s] "
    date = "%(asctime)s "
    msg = "%(message)s"
    max_lines = 256
    def _color_exc(self, msg):
        r"""Sets the color of message as the execution type."""
        return "\x1b[34m{}\x1b[0m".format(msg)
    def _color_dbg(self, msg):
        r"""Sets the color of message as the debugging type."""
        return "\x1b[36m{}\x1b[0m".format(msg)
    def _color_warn(self, msg):
        r"""Sets the color of message as the warning type."""
        return "\x1b[1;31m{}\x1b[0m".format(msg)
    def _color_err(self, msg):
        r"""Sets the color of message as the error type."""
        return "\x1b[1;4;31m{}\x1b[0m".format(msg)
    def _color_omitted(self, msg):
        r"""Sets the color of message as the omitted type."""
        return "\x1b[35m{}\x1b[0m".format(msg)
    def _color_normal(self, msg):
        r"""Sets the color of message as the normal type."""
        return msg
    def _color_date(self, msg):
        r"""Sets the color of message the same as date."""
        return "\x1b[32m{}\x1b[0m".format(msg)
    def format(self, record):
        if record.levelno == logging.DEBUG:
            mcl, mtxt = self._color_dbg, "DBG"
        elif record.levelno == logging.WARNING:
            mcl, mtxt = self._color_warn, "WRN"
        elif record.levelno == logging.ERROR:
            mcl, mtxt = self._color_err, "ERR"
        else:
            mcl, mtxt = self._color_normal, ""
        if mtxt:
            mtxt += " "
        if self.log_fout:
            self.__set_fmt(self.date_full + mtxt + self.msg)
            formatted = super(MegEngineLogFormatter, self).format(record)
            nr_line = formatted.count("\n") + 1
            if nr_line >= self.max_lines:
                head, body = formatted.split("\n", 1)
                formatted = "\n".join(
                    [
                        head,
                        "BEGIN_LONG_LOG_{}_LINES{{".format(nr_line - 1),
                        body,
                        "}}END_LONG_LOG_{}_LINES".format(nr_line - 1),
                    ]
                )
            self.log_fout.write(formatted)
            self.log_fout.write("\n")
            self.log_fout.flush()
        self.__set_fmt(self._color_date(self.date) + mcl(mtxt + self.msg))
        formatted = super(MegEngineLogFormatter, self).format(record)
        if record.exc_text or record.exc_info:
            # handle exception format
            b = formatted.find("Traceback ")
            if b != -1:
                s = formatted[b:]
                s = self._color_exc("  " + s.replace("\n", "\n  "))
                formatted = formatted[:b] + s
        nr_line = formatted.count("\n") + 1
        if nr_line >= self.max_lines:
            lines = formatted.split("\n")
            remain = self.max_lines // 2
            removed = len(lines) - remain * 2
            if removed > 0:
                mid_msg = self._color_omitted(
                    "[{} log lines omitted (would be written to output file "
                    "if set_log_file() has been called;\n"
                    " the threshold can be set at "
                    "MegEngineLogFormatter.max_lines)]".format(removed)
                )
                formatted = "\n".join(lines[:remain] + [mid_msg] + lines[-remain:])
        return formatted
    if sys.version_info.major < 3:
        def __set_fmt(self, fmt):
            self._fmt = fmt
    else:
        def __set_fmt(self, fmt):
            self._style._fmt = fmt
[文档]def get_logger(name=None, formatter=MegEngineLogFormatter):
    r"""Gets megengine logger with given name."""
    logger = logging.getLogger(name)
    if getattr(logger, "_init_done__", None):
        return logger
    logger._init_done__ = True
    logger.propagate = False
    logger.setLevel(_default_level)
    handler = logging.StreamHandler()
    handler.setFormatter(formatter(datefmt="%d %H:%M:%S"))
    handler.setLevel(0)
    del logger.handlers[:]
    logger.addHandler(handler)
    _all_loggers.append(logger)
    return logger 
[文档]def set_log_level(level, update_existing=True):
    r"""Sets default logging level.
    Args:
        level: loggin level given by python :mod:`logging` module
        update_existing: whether to update existing loggers
    """
    global _default_level  # pylint: disable=global-statement
    _default_level = level
    if update_existing:
        for i in _all_loggers:
            i.setLevel(level) 
_logger = get_logger(__name__)
try:
    if sys.version_info.major < 3:
        raise ImportError()
    from .core._imperative_rt.utils import Logger as _imperative_rt_logger
    class MegBrainLogFormatter(MegEngineLogFormatter):
        date = "%(asctime)s[mgb] "
        def _color_date(self, msg):
            return "\x1b[33m{}\x1b[0m".format(msg)
    _megbrain_logger = get_logger("megbrain", MegBrainLogFormatter)
    _imperative_rt_logger.set_log_handler(_megbrain_logger)
    def set_mgb_log_level(level):
        r"""Sets megbrain log level
        Args:
            level: new log level
        Returns:
            original log level
        """
        _megbrain_logger.setLevel(level)
        if level == logging.getLevelName("ERROR"):
            rst = _imperative_rt_logger.set_log_level(
                _imperative_rt_logger.LogLevel.Error
            )
        elif level == logging.getLevelName("INFO"):
            rst = _imperative_rt_logger.set_log_level(
                _imperative_rt_logger.LogLevel.Info
            )
        else:
            rst = _imperative_rt_logger.set_log_level(
                _imperative_rt_logger.LogLevel.Debug
            )
        return rst
    set_mgb_log_level(_default_level)
except ImportError as exc:
    def set_mgb_log_level(level):
        raise NotImplementedError("imperative_rt has not been imported")
@contextlib.contextmanager
def replace_mgb_log_level(level):
    r"""Replaces megbrain log level in a block and restore after exiting.
    Args:
        level: new log level
    """
    old = set_mgb_log_level(level)
    try:
        yield
    finally:
        set_mgb_log_level(old)
[文档]def enable_debug_log():
    r"""Sets logging level to debug for all components."""
    set_log_level(logging.DEBUG)
    set_mgb_log_level(logging.DEBUG)