megengine.utils.profiler 源代码

# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import json
import os
import re
from contextlib import ContextDecorator, contextmanager
from functools import wraps
from typing import List
from weakref import WeakSet

from .. import _atexit
from ..core._imperative_rt.core2 import (
    pop_scope,
    push_scope,
    start_profile,
    stop_profile,
    sync,
)
from ..logger import get_logger

_running_profiler = None
_living_profilers = WeakSet()


[文档]class Profiler(ContextDecorator): r"""Profile graph execution in imperative mode. Args: path: default path prefix for profiler to dump. Examples: .. code-block:: import megengine as mge import megengine.module as M from megengine.utils.profiler import Profiler # With Learnable Parameters profiler = Profiler() for iter in range(0, 10): # Only profile record of last iter would be saved with profiler: # your code here # Then open the profile file in chrome timeline window """ CHROME_TIMELINE = "chrome_timeline.json" valid_options = {"sample_rate": 0, "profile_device": 1, "num_tensor_watch": 10} valid_formats = {"chrome_timeline.json", "memory_flow.svg"} def __init__( self, path: str = "profile", format: str = "chrome_timeline.json", formats: List[str] = None, **kwargs ) -> None: if not formats: formats = [format] assert not isinstance(formats, str), "formats excepts list, got str" for format in formats: assert format in Profiler.valid_formats, "unsupported format {}".format( format ) self._path = path self._formats = formats self._options = {} for opt, optval in Profiler.valid_options.items(): self._options[opt] = int(kwargs.pop(opt, optval)) self._pid = "<PID>" self._dump_callback = None @property def path(self): if len(self._formats) == 0: format = "<FORMAT>" elif len(self._formats) == 1: format = self._formats[0] else: format = "{" + ",".join(self._formats) + "}" return self.format_path(self._path, self._pid, format) @property def directory(self): return self._path @property def formats(self): return list(self._formats)
[文档] def start(self): global _running_profiler assert _running_profiler is None _running_profiler = self self._pid = os.getpid() start_profile(self._options) return self
[文档] def stop(self): global _running_profiler assert _running_profiler is self _running_profiler = None sync() self._dump_callback = stop_profile() self._pid = os.getpid() _living_profilers.add(self)
[文档] def dump(self): if self._dump_callback is not None: if not os.path.exists(self._path): os.makedirs(self._path) if not os.path.isdir(self._path): get_logger().warning( "{} is not a directory, cannot write profiling results".format( self._path ) ) return for format in self._formats: path = self.format_path(self._path, self._pid, format) get_logger().info("process {} generating {}".format(self._pid, format)) self._dump_callback(path, format) get_logger().info("profiling results written to {}".format(path)) self._dump_callback = None _living_profilers.remove(self)
[文档] def format_path(self, path, pid, format): return os.path.join(path, "{}.{}".format(pid, format))
def __enter__(self): self.start() def __exit__(self, val, tp, trace): self.stop() def __call__(self, func): func = super().__call__(func) func.__profiler__ = self return func def __del__(self): self.dump()
@contextmanager def scope(name): push_scope(name) yield pop_scope(name) def profile(*args, **kwargs): if len(args) == 1 and len(kwargs) == 0 and callable(args[0]): return Profiler()(args[0]) return Profiler(*args, **kwargs) def merge_trace_events(directory: str): names = filter( lambda x: re.match(r"\d+\.chrome_timeline\.json", x), os.listdir(directory) ) def load_trace_events(name): with open(os.path.join(directory, name), "r", encoding="utf-8") as f: return json.load(f) def find_metadata(content): if isinstance(content, dict): assert "traceEvents" in content content = content["traceEvents"] if len(content) == 0: return None assert content[0]["name"] == "Metadata" return content[0]["args"] contents = list(map(load_trace_events, names)) metadata_list = list(map(find_metadata, contents)) min_local_time = min( map(lambda x: x["localTime"], filter(lambda x: x is not None, metadata_list)) ) events = [] for content, metadata in zip(contents, metadata_list): local_events = content["traceEvents"] if len(local_events) == 0: continue local_time = metadata["localTime"] time_shift = local_time - min_local_time for event in local_events: if "ts" in event: event["ts"] = int(event["ts"] + time_shift) events.extend(filter(lambda x: x["name"] != "Metadata", local_events)) result = { "traceEvents": events, } path = os.path.join(directory, "merge.chrome_timeline.json") with open(path, "w") as f: json.dump(result, f, ensure_ascii=False, separators=(",", ":")) get_logger().info("profiling results written to {}".format(path)) def is_profiling(): return _running_profiler is not None def _stop_current_profiler(): global _running_profiler if _running_profiler is not None: _running_profiler.stop() living_profilers = [*_living_profilers] for profiler in living_profilers: profiler.dump() _atexit(_stop_current_profiler)