# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections
import contextlib
import functools
import itertools
import json
import os
import pickle
from typing import Any
import numpy as np
from ..core._imperative_rt import GraphProfiler, GraphProfiler2, SerializationMetadata
from ..core._imperative_rt.core2 import Tensor as RawTensor
from ..core._imperative_rt.core2 import (
TensorWeakRef,
apply,
set_tracing,
skip_tracing,
unset_tracing,
)
from ..core._imperative_rt.ops import (
AssertEqual,
CollectiveComm,
ExternOpr,
RemoteRecv,
RemoteSend,
)
from ..core._trace_option import set_symbolic_shape
from ..core._wrap import as_device
from ..core.ops.builtin import BatchNorm, OpDef
from ..core.tensor import megbrain_graph as G
from ..core.tensor.utils import setscalar
from ..utils.naming import AutoNaming
from ..utils.profiler import is_profiling
from .dtr_config import DTRConfig
from .graph_opt_config import GraphOptimizationConfig
from .sublinear_memory_config import SublinearMemoryConfig
def _input_node_use_static_shape():
return os.environ.get("MEGENGINE_INPUT_NODE_USE_STATIC_SHAPE") is not None
class TraceMismatchError(RuntimeError):
pass
active_trace = None
def is_tracing():
if active_trace is None:
return False
else:
return not skip_tracing
@contextlib.contextmanager
def exclude_from_trace():
global skip_tracing
if skip_tracing or (active_trace is None):
yield
return
try:
skip_tracing = True
unset_tracing()
if active_trace is not None:
active_trace._begin_excluded_region()
yield
finally:
skip_tracing = False
set_tracing()
class TensorInfo:
__slots__ = (
# collected attributes
"name",
"external",
"data_read",
"shape_read",
"value_read",
"exported",
"device",
"dtype",
"shape",
"is_const",
"bound_data",
# resources for execution
"varnode",
"data_setter",
"shape_reader",
"value_reader",
"data_reader",
)
def __init__(self):
self.name = None
self.exported = None
self.data_read = None
self.shape_read = None
self.value_read = None
self.bound_data = None
self.data_setter = None
self.shape_reader = None
self.value_reader = None
self.data_reader = None
_io_op_types = {AssertEqual, CollectiveComm, RemoteSend, RemoteRecv}
[文档]class trace:
"""Wraps a callable and provide:
* tracing via :meth:`.trace` and :meth:`.dump`
* accelerated evalutaion via :meth:`.__call__`
Args:
function: the function will be traced.
symbolic: whether to apply symbolic execution for tracing. Default: False
capture_as_const: capture global vars or closures as const value. Default: False
record_only: if True, won't run even if call the function. Default: False
sublinear_memory_config: configuration for sublinear memory optimization.
If not None, it enables sublinear memory optimization with given setting.
profiling: whether to profile compiled trace. Default: False
opt_level: optimization level for compiling trace. Default: 2
graph_opt_config: configuration for graph optimization. Default: None
symbolic_shape: whether to use symbolic shape for tracing. Default: True
"""
def __new__(cls, *args, **kwargs):
if not args:
return functools.partial(cls, **kwargs)
return super().__new__(cls)
[文档] def __init__(
self,
function,
symbolic=False,
capture_as_const=False,
record_only=False,
sublinear_memory_config: SublinearMemoryConfig = None,
dtr_config: DTRConfig = None,
profiling: bool = False,
opt_level: int = 2,
graph_opt_config: GraphOptimizationConfig = None,
symbolic_shape: bool = True,
):
self.__wrapped__ = function
self._symbolic = symbolic or record_only
self._capture_as_const = capture_as_const or record_only
self._record_only = record_only
self._sublinear_memory_config = sublinear_memory_config
self._dtr_config = dtr_config
self._profiling = profiling
self._profiler = None
self._profiler2 = None
self._graph_opt_level = opt_level
self._graph_opt_config = graph_opt_config
self._symbolic_shape = symbolic_shape
self._output_handles = set()
self._reset()
def _reset(self):
self._untraced = True
self._tinfo = [] # handle -> TensorInfo
self._seq = []
self._pc = 0
self._graph = None
self._need_reset_nodes = None
self._lazy_eval_graph = None
self._lazy_eval_tensors = set()
self._lazy_eval_links = None
self._active_tensors = set()
self._tensor_remaps = None
self._inputs_to_restore = None
self._arg_bindings = None
self._kwarg_bindings = None
self._output_bindings = None
self._output_names = None
def _new_handle(self):
handle = len(self._tinfo)
info = TensorInfo()
self._tinfo.append(info)
return handle, info
def _apply_op(self, op, args):
assert not self._untraced
# check against trace
if self._pc >= len(self._seq):
raise TraceMismatchError("trace should end here, but more op observed")
record = self._seq[self._pc]
op_, ihandles, ohandles = record
if (isinstance(op_, str) and op_ == "Const") or (op != op_):
raise TraceMismatchError("op different from last time")
if len(ihandles) != len(args):
raise TraceMismatchError("op input size different from last time")
# check all inputs of crrent op
for h, x in zip(ihandles, args):
info = self._tinfo[h]
if info.external:
if (
x._compiled_info is not None
and not self._tinfo[x._mixin_handle].exported
):
raise TraceMismatchError(
"failed to capture: input was an external tensor "
"last time, got an internal tensor this time"
)
if info.bound_data:
if x._compiled_info is not None:
raise TraceMismatchError(
"const capture violated: was an external tensor "
"last time, got an internal tensor this time"
)
if x._handle != info.bound_data._handle:
if not np.array_equal(x.numpy(), info.bound_data.numpy()):
raise TraceMismatchError(
"const capture violated: got "
"a different tensor this time"
)
else:
if info.dtype != x.dtype:
raise TraceMismatchError(
"failed to capture: different dtype from last time"
)
if info.device != x.device:
raise TraceMismatchError(
"failed to capture: different device from last time"
)
info.data_setter.set_value(x._dev_tensor())
else:
if x._mixin_handle == -1:
if x._handle not in self._tensor_remaps:
raise TraceMismatchError(
"unexpected capture: trying to use an external tensor as "
"input, but that input was an internal tensor last time"
)
else:
x._mixin_handle = self._tensor_remaps[
x._handle
]._CompiledTensorProxy__handle
if x._mixin_handle != h:
raise TraceMismatchError(
"mis-wiring: input edge to an data flow "
"graph node is different from last time"
)
self._pc += 1
outputs = []
for h in ohandles:
info = self._tinfo[h]
# generate output tensor and create compied info
y = RawTensor(info.varnode)
y._compiled_info = CompiledTensorProxy(h)
y._mixin_handle = h
outputs += [y]
self._active_tensors.add(TensorWeakRef(y))
self._output_handles.update(ohandles)
return outputs
def _apply_const(self, value, dtype, device):
assert not self._untraced
# check against trace
if self._pc >= len(self._seq):
raise TraceMismatchError("trace should end here, but more op observed")
record = self._seq[self._pc]
op_, ihandles, ohandles = record
# Const op is represented by a str
assert isinstance(op_, str) and op_ == "Const"
expected = self._tinfo[ohandles[0]].bound_data.numpy()
shape = value.shape
if shape != expected.shape or dtype != expected.dtype:
eq = False
elif shape == ():
eq = expected.item() == value.item()
elif shape == (1,):
eq = expected[0] == value[0]
else:
eq = np.all(value == expected)
if not eq:
raise TraceMismatchError(
"const tensor violated: got a different tensor this time"
)
self._pc += 1
(h,) = ohandles
outputs = [self._tinfo[h].bound_data]
return outputs
# run in first step, record information for trace
def _record_op(self, op, inputs, outputs):
if skip_tracing:
for x in inputs:
h = getattr(x, "_mixin_handle", -1)
if h >= 0:
self._tinfo[h].data = True
return
ihandles = []
for x in inputs:
h = getattr(x, "_mixin_handle", -1)
if h < 0 or (not self._capture_as_const and self._tinfo[h].exported):
h, info = self._new_handle()
name = AutoNaming.gen_name(x)
info.name = name
info.external = True
info.device = x.device
info.dtype = x.dtype
info.shape = x.shape
if self._capture_as_const:
info.bound_data = RawTensor(
x.numpy(), x.dtype, x.device, False, name
)
ihandles.append(h)
ohandles = []
for x in outputs:
h, info = self._new_handle()
ohandles.append(h)
info.external = False
x._mixin_handle = h
x._recording = True
x._trace_mixin_info = info
self._active_tensors.add(TensorWeakRef(x))
if self._symbolic:
self._lazy_eval_tensors.add(TensorWeakRef(x))
self._seq.append((op, tuple(ihandles), tuple(ohandles)))
def _record_const(self, outputs):
if skip_tracing:
(x,) = outputs
h = getattr(x, "_mixin_handle", -1)
if h >= 0:
self._tinfo[h].data_read = True
return
(x,) = outputs
h, info = self._new_handle()
ohandles = [h]
info.external = True
info.device = x.device
info.dtype = x.dtype
info.shape = x.shape
info.bound_data = x
info.is_const = True
x._mixin_handle = h
x._recording = True
x._trace_mixin_info = info
if self._symbolic:
self._lazy_eval_tensors.add(TensorWeakRef(x))
self._seq.append(("Const", tuple(), tuple(ohandles)))
def _set_active(self, active: bool):
global active_trace
if active:
if active_trace:
raise NotImplementedError("sorry, not implemented: nested trace")
active_trace = self
else:
assert active_trace is self
active_trace = None
def _init_trace(self, symbolic: bool):
if symbolic:
self._lazy_eval_graph = G.Graph()
self._apply_graph_options(self._lazy_eval_graph)
self._lazy_eval_links = ()
def _take_escaped_tensors(self):
escaped_tensors = tuple(filter(lambda x: x() is not None, self._active_tensors))
self._active_tensors.clear()
return escaped_tensors
def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors, lazy_eval_links):
lazy_eval_tensors = [x() for x in lazy_eval_tensors]
lazy_eval_tensors = [x for x in lazy_eval_tensors if x is not None]
readers = [G.OutputNode(x._varnode).outputs[0] for x in lazy_eval_tensors]
self._apply_graph_options(lazy_eval_graph)
lazy_eval_graph.options.graph_opt_level = self._graph_opt_level
lazy_eval_graph._set_priority_to_id([*lazy_eval_links, *readers])
lazy_eval_graph.compile(*lazy_eval_links, *readers)
self._execute_graph(lazy_eval_graph)
lazy_eval_graph.wait()
for r, x in zip(readers, lazy_eval_tensors):
# get values from lazy_eval_graph and assign to lazy_eval tensor
x._handle = RawTensor(r.op.get_value())._handle
x._reset_varnode()
@contextlib.contextmanager
def _setup(self):
interrupted = False
def do_enter():
set_tracing()
self._save_symbolic_shape = set_symbolic_shape(self._symbolic_shape)
self._set_active(True)
if self._untraced:
self._init_trace(self._symbolic)
else:
if self._graph is None:
self._compile()
self._execute_graph(self._graph)
def do_finalize():
escaped_tensors = self._take_escaped_tensors()
if self._untraced:
if self._record_only:
self._lazy_eval_graph = None
self._lazy_eval_tensors = None
self._lazy_eval_links = None
else:
for x in escaped_tensors:
if x():
info = self._tinfo[x()._mixin_handle]
info.data_read = True
x()._mixin_handle = -1
x()._recording = False
if self._inputs_to_restore:
for x in self._inputs_to_restore:
x._mixin_handle = -1
x._recording = False
if self._symbolic and (
self._lazy_eval_tensors or self._lazy_eval_links
):
# eval lazy eval tensors
self._lazy_eval(
self._lazy_eval_graph,
self._lazy_eval_tensors,
self._lazy_eval_links,
)
self._lazy_eval_graph = None
self._lazy_eval_tensors = None
self._lazy_eval_links = None
self._untraced = False
else:
# compiled_tensor leaks
if self._pc == len(self._seq):
for x in escaped_tensors:
try:
x().__init__(RawTensor(x()._dev_tensor()))
except RuntimeError:
# TraceMismatchError thrown in do_exit
pass
self._graph.wait()
self._reset_exec_env()
# reset status
self._pc = 0
self._tensor_remaps = None
self._set_active(False)
set_symbolic_shape(self._save_symbolic_shape)
unset_tracing()
def do_exit():
unset_tracing()
if not self._untraced and self._pc != len(self._seq):
raise TraceMismatchError("premature end")
if not self._symbolic or not self._untraced:
# reset output tensors
for x in self._active_tensors.copy():
strong_x = x()
if strong_x is not None:
strong_x._dev_tensor()
strong_x._reset_varnode()
strong_x._mixin_handle = -1
strong_x._recording = False
strong_x._trace_mixin_info = None
try:
do_enter()
yield
do_exit()
except:
interrupted = True
raise
finally:
do_finalize()
if interrupted:
self._reset()
def _begin_excluded_region(self):
if self._capture_as_const:
raise RuntimeError(
"exclude_from_trace cannot be used with capture_as_const"
)
if self._untraced:
# conditionally reading a compiled tensor in excluded region
# is permitted, so we have to assume every tensor might be read
for x in self._active_tensors:
strong_x = x()
if strong_x:
info = self._tinfo[strong_x._mixin_handle]
info.exported = True
info.data_read = True
else:
for x in self._active_tensors:
strong_x = x()
if strong_x:
strong_x._dev_tensor()
def _apply_graph_options(self, graph):
graph.options.no_force_inplace = True
graph.options.seq_opt.enable_seq_comp_node_opt = False
graph.options.graph_opt_level = self._graph_opt_level
if self._dtr_config is not None:
graph.options.enable_dtr_memory_opt = True
graph.options.dtr_config.eviction_threshold = (
self._dtr_config.eviction_threshold
)
graph.options.dtr_config.evictee_minimum_size = (
self._dtr_config.evictee_minimum_size
)
graph.options.dtr_config.recomp_memory_factor = (
self._dtr_config.recomp_memory_factor
)
graph.options.dtr_config.recomp_time_factor = (
self._dtr_config.recomp_time_factor
)
# graph optimization
if self._graph_opt_config is not None:
mapping = {None: 0, False: 1, True: 2}
jit_config = graph.options.graph_opt.jit_config
jit_config.fuse_dimshuffle = mapping[
self._graph_opt_config.jit_fuse_dimshuffle
]
jit_config.fuse_reduce = mapping[self._graph_opt_config.jit_fuse_reduce]
# sublinear
if self._sublinear_memory_config is not None:
graph.options.enable_sublinear_memory_opt = True
sublinear_config = graph.options.sublinear_mem_config
sublinear_config.lb_memory_mb = self._sublinear_memory_config.lb_memory_mb
sublinear_config.genetic_nr_iter = (
self._sublinear_memory_config.genetic_nr_iter
)
sublinear_config.genetic_pool_size = (
self._sublinear_memory_config.genetic_pool_size
)
sublinear_config.thresh_nr_try = self._sublinear_memory_config.thresh_nr_try
sublinear_config.num_worker = self._sublinear_memory_config.num_worker
# profile
if self._profiling:
self._profiler = GraphProfiler(graph)
self._profiler2 = None
if int(os.getenv("MEGENGINE_INPLACE_UPDATE", "0")):
graph.options.var_sanity_check_first_run = False
def _execute_graph(self, graph: G.Graph, *args):
if is_profiling() and (self._profiler2 is None):
self._profiler2 = GraphProfiler2(graph)
elif not is_profiling() and (self._profiler2 is not None):
self._profiler2 = None
graph.execute(*args)
def _compile(self):
graph = self._graph = G.Graph()
graph.options.async_exec_level = 0b100
self._apply_graph_options(graph)
need_reset_nodes = self._need_reset_nodes = []
# links enforce ordering of I/O nodes
in_out_links = ()
io_links = ()
readers = []
if self._capture_as_const:
for h in itertools.chain(self._arg_bindings, self._kwarg_bindings.values()):
info = self._tinfo[h]
opnode = info.data_setter = G.InputNode(
device=info.device,
dtype=info.dtype,
shape=info.shape or (1,),
graph=graph,
use_static_shape=_input_node_use_static_shape(),
)
need_reset_nodes.append(opnode)
info.varnode = opnode.outputs[0]
in_out_links += opnode.outputs[1:]
for op, ihandles, ohandles in self._seq:
if isinstance(op, str) and op == "Const":
assert len(ihandles) == 0
(h,) = ohandles
info = self._tinfo[h]
if not hasattr(info, "varnode"):
assert info.external
assert info.bound_data
info.varnode = graph.make_const(
info.bound_data.numpy(),
info.bound_data.dtype,
info.bound_data.device,
)
continue
require_links = type(op) in _io_op_types
ivars = []
for i, h in enumerate(ihandles):
info = self._tinfo[h]
if not hasattr(info, "varnode"):
assert info.external
if info.bound_data:
if getattr(info, "is_const", False):
info.varnode = graph.make_const(
info.bound_data.numpy(),
info.bound_data.dtype,
info.bound_data.device,
)
else:
info.varnode = graph.make_const(
info.bound_data._dev_tensor()
# info.bound_data.numpy()
)
else:
opnode = info.data_setter = G.InputNode(
*in_out_links,
device=info.device,
dtype=info.dtype,
shape=info.shape or (1,),
graph=graph,
use_static_shape=_input_node_use_static_shape(),
)
need_reset_nodes.append(opnode)
info.varnode, *in_out_links = opnode.outputs
if require_links and i == 0 and len(io_links) > 0:
opnode = G.VirtualDepNode(
[info.varnode, *io_links], str(io_links[0].device)
)
info.varnode = opnode.outputs[0]
io_links = (info.varnode,)
ivars.append(info.varnode)
ovars = G.apply_normal_varnode(op, *ivars)
if require_links and len(ovars) > 0:
io_links = (ovars[0],)
assert len(ovars) == len(ohandles)
for h, v in zip(ohandles, ovars):
info = self._tinfo[h]
info.varnode = v
def add_reader(opnode):
nonlocal in_out_links
need_reset_nodes.append(opnode)
readers.append(opnode.outputs[0])
in_out_links = opnode.outputs
if info.data_read:
# Shape can be obtained from data so doesn't need its own
# output node. On the other hand, value is read separately
# to leverage eager h2d copy
info.shape_read = False
opnode = info.data_reader = G.OutputNode(v, *in_out_links)
add_reader(opnode)
if info.value_read:
opnode = info.value_reader = G.ValueOutputNode(v, *in_out_links)
add_reader(opnode)
if info.shape_read:
opnode = info.shape_reader = G.AttrOutputNode(v, *in_out_links)
add_reader(opnode)
graph.options.graph_opt_level = self._graph_opt_level
graph._set_priority_to_id([*readers, *in_out_links, *io_links])
graph.compile(*readers, *in_out_links, *io_links)
def _reset_exec_env(self):
for opnode in self._need_reset_nodes:
opnode.reset()
def __call__(self, *args, **kwargs):
with self._setup():
if self._capture_as_const:
self._process_inputs(*args, **kwargs)
outputs = self.__wrapped__(*args, **kwargs)
if self._capture_as_const:
self._process_outputs(outputs)
return outputs
def dump(
self,
file,
*,
arg_names=None,
output_names=None,
append=False,
keep_var_name: int = 1,
keep_opr_name: bool = False,
keep_param_name: bool = False,
keep_opr_priority: bool = False,
strip_info_file=None,
append_json=False,
optimize_for_inference=True,
user_info: Any = None,
enable_metadata: bool = True,
**kwargs
):
r"""Serializes trace to file system.
Args:
file: output file, could be file object or filename.
arg_names: names of the input tensors in the traced function.
output_names: names of the output tensors in the traced function,
use the default name if not specified.
append: whether output is appended to ``file``.
Only works when ``file`` is str.
keep_var_name: level for keeping variable names:
* 0: none of the names are kept
* 1: (default)keep names of output vars
* 2: keep names of all (output and internal) vars
keep_opr_name: whether to keep operator names.
keep_param_name: whether to keep param names, so param values can be
easily manipulated after loading model
keep_opr_priority: whether to keep priority setting for operators
strip_info_file: a string for path or a file handler. if is not None,
then the dump information for code strip would be written to ``strip_info_file``
append_json: will be check when `strip_info_file` is not None. if set
true, the information for code strip will be append to strip_info_file.
if set false, will rewrite strip_info_file
optimize_for_inference: enbale optmizations,
will skip all optimize options if this is False. Default: True
user_info: any type object, which will be pickled to bytes.
enable_metadata: whether to save metadata into output file.
Keyword Arguments:
* enable_io16xc32 --
whether to use float16 for I/O between oprs and use
float32 as internal computation precision. Note the output var would be
changed to float16.
* enable_ioc16 --
whether to use float16 for both I/O and computation
precision.
* enable_hwcd4 --
whether to use NHWCD4 data layout. This is faster on some
OpenCL backend.
* enable_nchw88 --
whether to use NCHW88 data layout, currently
used in X86 AVX backend.
* enable_nchw44 --
whether to use NCHW44 data layout, currently
used in arm backend.
* enable_nchw44_dot --
whether to use NCHW44_dot data layout, currently
used in armv8.2+dotprod backend.
* enable_nchw4 --
whether to use NCHW4 data layout, currently
used in nvidia backend(based on cudnn).
* enable_nchw32 --
whether to use NCHW32 data layout, currently
used in nvidia backend with tensorcore(based on cudnn).
* enable_chwn4 --
whether to use CHWN4 data layout, currently
used in nvidia backend with tensorcore.
* enable_nchw64 --
whether to use NCHW64 data layout, used for fast int4
support on Nvidia GPU.
* enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty
into one opr.
* enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z
input for inference on nvidia backend(this optimization pass will
result in mismatch of the precision of output of training and
inference)
"""
if not self._capture_as_const:
raise ValueError(
"you must specify capture_as_const=True at __init__ to use dump"
)
if self._untraced and len(self._seq) == 0:
raise RuntimeError("should do record first before dump")
if self._output_names and output_names:
raise TypeError(
"cannot specify output_names when output is already in dict format"
)
if output_names and not isinstance(output_names, collections.abc.Sequence):
output_names = (output_names,)
if output_names and len(output_names) != len(self._output_bindings):
raise ValueError(
"wrong number of output_names, should be {} values".format(
len(self._output_bindings)
)
)
without_arg_names = arg_names is None
if without_arg_names:
arg_names = ["arg_%d" % i for i in range(len(self._arg_bindings))]
if arg_names and not isinstance(arg_names, collections.abc.Sequence):
arg_names = (arg_names,)
if arg_names and len(arg_names) != len(self._arg_bindings):
raise ValueError(
"wrong number of arg_names, should be {} values".format(
len(self._arg_bindings)
)
)
output_names = output_names or self._output_names
def dumped_device(info):
device_name = info.device.logical_name
if device_name[:3] in ("cpu", "gpu", "xpu"):
return as_device("xpux")
return info.device
h2v = {}
graph = G.Graph()
# apply graph_opt_level in dump
if self._graph_opt_level is not None:
graph.options.graph_opt_level = self._graph_opt_level
for i, h in enumerate(self._arg_bindings):
info = self._tinfo[h]
h2v[h] = graph.make_h2d(
dtype=info.dtype,
device=dumped_device(info),
shape=info.shape or (1,),
name=info.name if without_arg_names and info.name else arg_names[i],
)
for k, h in self._kwarg_bindings.items():
info = self._tinfo[h]
h2v[h] = graph.make_h2d(
dtype=info.dtype,
device=dumped_device(info),
shape=info.shape or (1,),
name=k,
)
for op, ihandles, ohandles in self._seq:
if isinstance(op, str) and op == "Const":
assert len(ihandles) == 0
(h,) = ohandles
info = self._tinfo[h]
if h not in h2v:
assert info.external
assert info.bound_data
h2v[h] = graph.make_const(
info.bound_data.numpy(),
dtype=info.dtype,
device=dumped_device(info),
name=info.name,
)
continue
ivars = []
for h in ihandles:
info = self._tinfo[h]
if h not in h2v:
assert info.external
assert info.bound_data
h2v[h] = graph.make_const(
info.bound_data.numpy(),
dtype=info.dtype,
device=dumped_device(info),
name=info.name,
)
ivars.append(h2v[h])
if isinstance(op, BatchNorm):
assert (
op.fwd_mode == BatchNorm.FwdMode.INFERENCE
), "can not dump BatchNorm in training mode, maybe you forget to do model.eval()?"
ovars = G.apply_normal_varnode(op, *ivars)
AutoNaming.record_opnode(ovars[0].op)
assert len(ovars) == len(ohandles)
h2v.update(zip(ohandles, ovars))
for i in ohandles:
name = AutoNaming.get_var_name(i)
if name is not None:
h2v[i].name = name
AutoNaming.remove_duplicate_names()
dest_vars = []
for i, h in enumerate(self._output_bindings):
v = h2v[h]
if output_names:
v.name = output_names[i]
dest_vars.append(v)
if optimize_for_inference:
dest_vars, optimize_options = G.optimize_for_inference(dest_vars, **kwargs)
metadata = SerializationMetadata()
if enable_metadata:
metadata.user_info = pickle.dumps(user_info)
metadata.is_valid = True
metadata.graph_modified = False
if optimize_for_inference:
metadata.optimize_options = optimize_options
if isinstance(file, str):
permission = "wb" if append == False else "ab"
file = open(file, permission)
if keep_opr_priority:
graph._set_priority_to_id(dest_vars)
dump_content, dump_info = G.dump_graph(
dest_vars,
keep_var_name=keep_var_name,
keep_opr_name=keep_opr_name,
keep_param_name=keep_param_name,
keep_opr_priority=keep_opr_priority,
strip_info_file=strip_info_file,
append_json=append_json,
metadata=metadata,
)
file.write(dump_content)
return dump_info
def _process_inputs(self, *args, **kwargs):
if self._untraced:
self._inputs_to_restore = []
def record_input(x):
if x is None:
return
h, info = self._new_handle()
info.external = False
info.name = x.c_name
info.device = x.device
info.dtype = x.dtype
info.shape = x.numpy().shape
x._mixin_handle = h
x._recording = True
x._trace_mixin_info = info
self._inputs_to_restore.append(x)
return h
self._arg_bindings = []
for i, x in enumerate(args):
if not isinstance(x, RawTensor):
raise TypeError(
"positional arguments should all be tensor "
"but args[%d] cannot be recognized as one" % i
)
self._arg_bindings.append(record_input(x))
self._kwarg_bindings = {}
for k, x in kwargs.items():
if isinstance(x, RawTensor):
self._kwarg_bindings[k] = record_input(x)
else:
if len(args) != len(self._arg_bindings):
raise TraceMismatchError("positional argument length mismatch")
self._tensor_remaps = {}
for i, (h, x) in enumerate(zip(self._arg_bindings, args)):
if not isinstance(x, RawTensor):
raise TypeError(
"positional arguments should all be tensor "
"but args[%d] cannot be recognized as one" % i
)
info = self._tinfo[h]
if x.dtype != info.dtype:
raise TypeError("args[%d].dtype different from last time" % i)
if x.device != info.device:
raise TypeError("args[%d].device different from last time" % i)
info.data_setter.set_value(x._dev_tensor())
self._tensor_remaps[x._handle] = CompiledTensorProxy(h)
kwargs_tensors = {}
for k, x in kwargs.items():
if isinstance(x, RawTensor):
kwargs_tensors[k] = x
if set(kwargs_tensors) != set(self._kwarg_bindings):
too_many = set(kwargs_tensors) - set(self._kwarg_bindings)
too_few = set(self._kwarg_bindings) - set(kwargs_tensors)
if too_many:
raise TraceMismatchError(
"keyword arguments found to be tensor this time "
"but were non-tensor previously: %s" % " ".join(too_many)
)
if too_few:
raise TraceMismatchError(
"keyword arguments found to be non-tensor this time "
"but were tensor previously: %s" % " ".join(too_few)
)
for k, h in self._kwarg_bindings.items():
x = kwargs_tensors[k]
info = self._tinfo[h]
if x.dtype != info.dtype:
raise TypeError("kwargs[%s].dtype different from last time" % k)
if x.device != info.device:
raise TypeError("kwargs[%s].device different from last time" % k)
info.data_setter.set_value(x._dev_tensor())
self._tensor_remaps[x._handle] = CompiledTensorProxy(h)
def _process_outputs(self, outputs):
output_names = None
if isinstance(outputs, collections.abc.Mapping):
output_names, outputs = zip(*sorted(outputs.items()))
elif not isinstance(outputs, collections.abc.Sequence):
outputs = (outputs,)
if not self._untraced:
if output_names != self._output_names:
too_many = set(output_names) - set(self._output_names)
too_few = set(self._output_names) - set(output_names)
if too_many:
raise TraceMismatchError(
"output has more keys than last time: %s" % " ".join(too_many)
)
if too_few:
raise TraceMismatchError(
"output has less keys than last time: %s" % " ".join(too_few)
)
if len(outputs) != len(self._output_bindings):
raise TraceMismatchError("output size differs from last time")
else:
self._output_names = output_names
self._output_bindings = []
for i, x in enumerate(outputs):
if not isinstance(x, RawTensor):
raise TypeError("every item of return value should be tensor")
if self._untraced:
h = x._mixin_handle
if h < 0:
raise RuntimeError("output is not computed from inputs")
self._output_bindings.append(h)
else:
h = x._mixin_handle
if h not in self._output_handles:
raise RuntimeError("output is not computed from inputs")
if h != self._output_bindings[i]:
raise TraceMismatchError(
"retval[%s] is a different tensor than last time"
% (output_names and output_names[i] or i)
)
def get_profile(self):
r"""Get profiling result for compiled trace.
Return:
a json compatible object.
"""
if not self._profiler:
raise RuntimeError("trace is not set with profiling=True")
return json.loads(self._profiler.get())
class CompiledTensorProxy:
r"""Duck-typed RawTensor"""
def __init__(self, handle):
self.__handle = handle
self._isscalar = False
self.__info = active_trace._tinfo[handle]
self.__shape = None
self.__data = None
self.__value = None
@property
def dtype(self):
return self.__info.varnode.dtype
@property
def device(self):
return self.__info.varnode.device
@property
def shape(self):
if self._isscalar:
return ()
if self.__shape is None:
if self.__info.shape_read:
self.__shape = self.__info.shape_reader.get_value().shape
elif self.__info.data_read:
self.__shape = self._dev_tensor().shape
else:
# c++ will throw TraceReadError
return None
return self.__shape
def numpy(self):
if self.__value is None:
if self.__info.value_read:
self.__value = self.__info.value_reader.get_value()
elif self.__info.data_read:
self.__value = self._dev_tensor().numpy()
else:
# c++ will throw TraceReadError
return None
# c++ side will handle scalar case
return self.__value
def _dev_tensor(self):
if self.__data is None:
if not self.__info.data_read:
# c++ will throw TraceReadError
return None
self.__data = self.__info.data_reader.get_value()
return self.__data
def __del__(self):
if self.__info.shape_read and self.__shape is not None:
self.__info.shape_reader.drop_value()
if self.__info.value_read and self.__value is not None:
self.__info.value_reader.drop_value()
if self.__info.data_read and self.__data is not None:
self.__info.data_reader.drop_value()
def apply_symbolic_mode(op: OpDef, *args: RawTensor):
graph = active_trace._lazy_eval_graph
ivars = []
for x in args:
var = getattr(x, "_varnode", None)
if var:
ivars.append(var)
else:
data_setter = G.InputNode(
device=x.device,
dtype=x.dtype,
shape=x.numpy().shape or (1,),
graph=graph,
use_static_shape=True,
)
var = data_setter.outputs[0]
ivars.append(var)
data_setter.set_value(x._dev_tensor())
require_links = type(op) in _io_op_types
if require_links and active_trace._lazy_eval_links:
assert len(ivars) > 0, "op should has at least one input"
opnode = G.VirtualDepNode(
[ivars[0], *active_trace._lazy_eval_links],
str(active_trace._lazy_eval_links[0].device),
)
ivars[0] = opnode.outputs[0]
active_trace._lazy_eval_links = (ivars[0],)
ovars = G.apply_normal_varnode(op, *ivars)
outputs = [RawTensor(o) for o in ovars]
if require_links:
active_trace._lazy_eval_links = (G.VarNode(outputs[0]._varnode),)
return outputs
def apply_const_symbolic_mode(value, dtype, device, name):
graph = active_trace._lazy_eval_graph
# don't need to unset tracing
# because varnode construction will ignore tracing flag
ret = RawTensor(graph.make_const(value, dtype=dtype, device=device, name=name))
if np.array(value).ndim == 0:
setscalar(ret)
return (ret,)
def apply_compiled_mode(op: OpDef, *args: RawTensor):
if skip_tracing:
args = [
RawTensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x
for x in args
]
unset_tracing()
ret = apply(op, *args)
set_tracing()
return ret
return active_trace._apply_op(op, args)
def apply_const_compiled_mode(value, dtype, device, is_const, no_cache, name):
if skip_tracing:
unset_tracing()
ret = RawTensor(value, dtype, device, False, name)
set_tracing()
return ret
return active_trace._apply_const(value, dtype, device)
def apply_with_tracing(op: OpDef, *args: RawTensor):
if active_trace._graph:
# if member _graph exits, then is_compiled
return apply_compiled_mode(op, *args)
if hasattr(op, "scope"):
op.scope = AutoNaming.get_scope()
if active_trace._symbolic:
outputs = apply_symbolic_mode(op, *args)
else:
unset_tracing()
outputs = apply(op, *args)
set_tracing()
active_trace._record_op(op, args, outputs)
return list(outputs)
def apply_const_with_tracing(value, dtype, device, is_const, no_cache, name):
if active_trace._graph:
return apply_const_compiled_mode(value, dtype, device, is_const, no_cache, name)
if active_trace._symbolic:
outputs = apply_const_symbolic_mode(value, dtype, device, name)
else:
unset_tracing()
outputs = RawTensor(value, dtype, device, False, name)
if np.array(value).ndim == 0:
setscalar(outputs)
outputs = (outputs,)
set_tracing()
active_trace._record_const(outputs)
return list(outputs)