# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import json
import os
import re
from contextlib import ContextDecorator, contextmanager
from functools import wraps
from typing import List
from weakref import WeakSet
from .. import _atexit
from ..core._imperative_rt.core2 import (
pop_scope,
push_scope,
start_profile,
stop_profile,
sync,
)
from ..logger import get_logger
_running_profiler = None
_living_profilers = WeakSet()
[文档]class Profiler(ContextDecorator):
r"""Profile graph execution in imperative mode.
Args:
path: default path prefix for profiler to dump.
Examples:
.. code-block::
import megengine as mge
import megengine.module as M
from megengine.utils.profiler import Profiler
# With Learnable Parameters
profiler = Profiler()
for iter in range(0, 10):
# Only profile record of last iter would be saved
with profiler:
# your code here
# Then open the profile file in chrome timeline window
"""
CHROME_TIMELINE = "chrome_timeline.json"
valid_options = {"sample_rate": 0, "profile_device": 1, "num_tensor_watch": 10}
valid_formats = {"chrome_timeline.json", "memory_flow.svg"}
def __init__(
self,
path: str = "profile",
format: str = "chrome_timeline.json",
formats: List[str] = None,
**kwargs
) -> None:
if not formats:
formats = [format]
assert not isinstance(formats, str), "formats excepts list, got str"
for format in formats:
assert format in Profiler.valid_formats, "unsupported format {}".format(
format
)
self._path = path
self._formats = formats
self._options = {}
for opt, optval in Profiler.valid_options.items():
self._options[opt] = int(kwargs.pop(opt, optval))
self._pid = "<PID>"
self._dump_callback = None
@property
def path(self):
if len(self._formats) == 0:
format = "<FORMAT>"
elif len(self._formats) == 1:
format = self._formats[0]
else:
format = "{" + ",".join(self._formats) + "}"
return self.format_path(self._path, self._pid, format)
@property
def directory(self):
return self._path
@property
def formats(self):
return list(self._formats)
[文档] def start(self):
global _running_profiler
assert _running_profiler is None
_running_profiler = self
self._pid = os.getpid()
start_profile(self._options)
return self
[文档] def stop(self):
global _running_profiler
assert _running_profiler is self
_running_profiler = None
sync()
self._dump_callback = stop_profile()
self._pid = os.getpid()
_living_profilers.add(self)
[文档] def dump(self):
if self._dump_callback is not None:
if not os.path.exists(self._path):
os.makedirs(self._path)
if not os.path.isdir(self._path):
get_logger().warning(
"{} is not a directory, cannot write profiling results".format(
self._path
)
)
return
for format in self._formats:
path = self.format_path(self._path, self._pid, format)
get_logger().info("process {} generating {}".format(self._pid, format))
self._dump_callback(path, format)
get_logger().info("profiling results written to {}".format(path))
self._dump_callback = None
_living_profilers.remove(self)
def __enter__(self):
self.start()
def __exit__(self, val, tp, trace):
self.stop()
def __call__(self, func):
func = super().__call__(func)
func.__profiler__ = self
return func
def __del__(self):
self.dump()
@contextmanager
def scope(name):
push_scope(name)
yield
pop_scope(name)
def profile(*args, **kwargs):
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
return Profiler()(args[0])
return Profiler(*args, **kwargs)
def merge_trace_events(directory: str):
names = filter(
lambda x: re.match(r"\d+\.chrome_timeline\.json", x), os.listdir(directory)
)
def load_trace_events(name):
with open(os.path.join(directory, name), "r", encoding="utf-8") as f:
return json.load(f)
def find_metadata(content):
if isinstance(content, dict):
assert "traceEvents" in content
content = content["traceEvents"]
if len(content) == 0:
return None
assert content[0]["name"] == "Metadata"
return content[0]["args"]
contents = list(map(load_trace_events, names))
metadata_list = list(map(find_metadata, contents))
min_local_time = min(
map(lambda x: x["localTime"], filter(lambda x: x is not None, metadata_list))
)
events = []
for content, metadata in zip(contents, metadata_list):
local_events = content["traceEvents"]
if len(local_events) == 0:
continue
local_time = metadata["localTime"]
time_shift = local_time - min_local_time
for event in local_events:
if "ts" in event:
event["ts"] = int(event["ts"] + time_shift)
events.extend(filter(lambda x: x["name"] != "Metadata", local_events))
result = {
"traceEvents": events,
}
path = os.path.join(directory, "merge.chrome_timeline.json")
with open(path, "w") as f:
json.dump(result, f, ensure_ascii=False, separators=(",", ":"))
get_logger().info("profiling results written to {}".format(path))
def is_profiling():
return _running_profiler is not None
def _stop_current_profiler():
global _running_profiler
if _running_profiler is not None:
_running_profiler.stop()
living_profilers = [*_living_profilers]
for profiler in living_profilers:
profiler.dump()
_atexit(_stop_current_profiler)