# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections
import json
import os
import weakref
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
from .. import _imperative_rt
from .._imperative_rt import GraphOptimizeOptions
from .._imperative_rt.core2 import apply, set_cpp_apply_backward_varnode
from .._wrap import as_device
from ..ops.builtin import OpDef
[文档]def set_priority_to_id(dest_vars):
r"""For all oprs in the subgraph constructed by dest_vars,
sets its priority to id if its original priority is zero.
Args:
dest_vars: target vars representing the graph.
"""
dest_vec = []
for i in dest_vars:
assert isinstance(i, _imperative_rt.VarNode)
dest_vec.append(i)
_imperative_rt.graph._set_priority_to_id(dest_vec)
[文档]class Graph(_imperative_rt.ComputingGraph):
def __init__(self):
super().__init__()
self._var_cache = weakref.WeakKeyDictionary()
self._op_cache = weakref.WeakKeyDictionary()
self._executor = ThreadPoolExecutor(1)
self._function = None
self._future = None
def _wrap(self, obj):
if type(obj) is _imperative_rt.VarNode:
wrapper, cache = VarNode, self._var_cache
elif type(obj) is _imperative_rt.OperatorNode:
wrapper, cache = OpNode, self._op_cache
else:
raise TypeError(type(obj))
if obj not in cache:
cache[obj] = wrapper(obj)
return cache[obj]
def _set_priority_to_id(self, dest_vars):
set_priority_to_id(_unwrap(dest_vars))
[文档] def compile(self, *args):
self._function = super().compile(_unwrap(args))
return self
[文档] def execute(self, *args):
assert self._future is None
def wrapped(*args):
try:
self._function.execute(*args)
except Exception as exc:
for i in self._function._all_rendezvous:
i.set_exception(str(exc))
raise exc
self._future = self._executor.submit(wrapped, *args)
[文档] def wait(self):
assert self._future is not None
self._future.exception()
self._function.wait()
try:
return self._future.result()
finally:
self._future = None
def __call__(self, *args):
self.execute(*args)
return self.wait()
def _make_const_for_backward(self, data):
device = as_device(data.comp_node).to_c()
data = data.numpy()
return self._wrap(_imperative_rt.make_const(self, data, device, data.dtype))
[文档] def make_const(self, data, dtype=None, device=None, name=None):
if isinstance(data, _imperative_rt.DeviceTensorND):
assert dtype is None and device is None
return self._wrap(_imperative_rt.make_shared(self, data))
else:
data = np.asarray(data, dtype=dtype)
if data.dtype == np.float64:
data = data.astype(np.float32)
elif data.dtype == np.int64:
data = data.astype(np.int32)
device = as_device(device).to_c()
return self._wrap(
_imperative_rt.make_const(self, data, device, dtype, name)
)
[文档] def make_h2d(self, *, dtype, device, shape=None, name=None):
device = as_device(device).to_c()
return self._wrap(_imperative_rt.make_h2d(self, device, dtype, shape, name))
def _to_json(self, filename):
# debug interface
if self._function:
js = json.loads(self._function._to_json())
json.dump(js, open(filename, "w"))
else:
print("this function should be called after compilation.")
[文档]class VarNode:
def __init__(self, node: _imperative_rt.VarNode, isscalar=False):
self._node = node
self._isscalar = isscalar
if hasattr(self.graph, "_var_cache"):
self.graph._var_cache[node] = self
@property
def graph(self) -> Graph:
return self._node.graph
@property
def op(self):
if hasattr(self.graph, "_wrap"):
return self.graph._wrap(self._node.owner)
else:
return self._node.owner
@property
def name(self):
return self._node.name
@property
def id(self):
return self._node.id
@name.setter
def name(self, name):
self._node.name = name
@property
def dtype(self):
return self._node.dtype
@property
def device(self):
return as_device(self._node.comp_node)
@property
def shape(self):
return self._node.shape
@property
def value(self):
return self._node.value
[文档]class OpNode:
def __init__(self, node: _imperative_rt.OperatorNode):
self._node = node
if hasattr(self.graph, "_op_cache"):
self.graph._op_cache[node] = self
@property
def graph(self) -> Graph:
return self._node.graph
@property
def name(self):
return self._node.name
@property
def id(self):
return self._node.id
@name.setter
def name(self, name):
self._node.name = name
@property
def inputs(self):
if hasattr(self.graph, "_wrap"):
return tuple(map(self.graph._wrap, self._node.inputs))
else:
return self._node.inputs
@property
def outputs(self):
if hasattr(self.graph, "_wrap"):
return tuple(map(self.graph._wrap, self._node.outputs))
else:
return self._node.outputs
@property
def params(self):
return json.loads(self._node.params)
@property
def type(self):
return self._node.type
[文档]def optimize_for_inference(dest_vars, **kwargs):
r"""Applies optimize_for_inference pass for computing graph.
Args:
dest_vars: list of output vars in the computing graph
Keyword Arguments:
* enable_io16xc32 --
whether to use float16 for I/O between oprs and use
float32 as internal computation precision. Note the output var would be
changed to float16.
* enable_ioc16 --
whether to use float16 for both I/O and computation
precision.
* enable_hwcd4 --
whether to use NHWCD4 data layout. This is faster on some
OpenCL backend.
* enable_nchw88 --
whether to use NCHW88 data layout, currently
used in X86 AVX backend.
* enable_nchw44 --
whether to use NCHW44 data layout, currently
used in arm backend.
* enable_nchw44_dot --
whether to use NCHW44_dot data layout, currently
used in armv8.2+dotprod backend.
* enable_nchw4 --
whether to use NCHW4 data layout, currently
used in nvidia backend(based on cudnn).
* enable_nchw32 --
whether to use NCHW32 data layout, currently
used in nvidia backend with tensorcore(based on cudnn).
* enable_chwn4 --
whether to use CHWN4 data layout, currently
used in nvidia backend with tensorcore.
* enable_nchw64 --
whether to use NCHW64 data layout, used for fast int4
support on Nvidia GPU.
* enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty
into one opr.
* enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z
input for inference on nvidia backend(this optimization pass will
result in mismatch of the precision of output of training and
inference)
"""
inference_options = GraphOptimizeOptions()
inference_optimize_layout_transform_map = {
"enable_hwcd4": GraphOptimizeOptions.LayoutTransform.NHWCD4,
"enable_nchw4": GraphOptimizeOptions.LayoutTransform.NCHW4,
"enable_nchw88": GraphOptimizeOptions.LayoutTransform.NCHW88,
"enable_nchw32": GraphOptimizeOptions.LayoutTransform.NCHW32,
"enable_nchw44": GraphOptimizeOptions.LayoutTransform.NCHW44,
"enable_nchw44_dot": GraphOptimizeOptions.LayoutTransform.NCHW44_DOT,
"enable_chwn4": GraphOptimizeOptions.LayoutTransform.CHWN4,
"enable_nchw64": GraphOptimizeOptions.LayoutTransform.NCHW64,
}
for k, v in inference_optimize_layout_transform_map.items():
if kwargs.pop(k, False):
inference_options.layout_transform = v
if kwargs.pop("enable_io16xc32", False):
inference_options.f16_io_f32_comp = True
if kwargs.pop("enable_ioc16", False):
inference_options.f16_io_comp = True
if kwargs.pop("enable_fuse_conv_bias_nonlinearity", False):
inference_options.fuse_conv_bias_nonlinearity = True
if kwargs.pop("enable_fuse_conv_bias_with_z", False):
inference_options.fuse_conv_bias_with_z = True
if kwargs.pop("enable_fuse_preprocess", False):
inference_options.fuse_preprocess = True
if kwargs:
raise ValueError("unknown options: %s" % list(kwargs))
dest_vars = _unwrap(dest_vars)
res_vars = _imperative_rt.optimize_for_inference(dest_vars, inference_options)
return _wrap(res_vars), inference_options.serialize()
def deserialize_infer_option(x: int) -> Dict[str, bool]:
r"""Deserailize optimize options generated by ``imperative_rt.GraphOptimizeOptions``.
Args:
x: inference options represented by int.
Returns:
inference options represented by dict.
"""
inference_options = GraphOptimizeOptions.deserialize(x)
inference_optimize_layout_transform_map = {
GraphOptimizeOptions.LayoutTransform.NHWCD4: "enable_hwcd4",
GraphOptimizeOptions.LayoutTransform.NCHW4: "enable_nchw4",
GraphOptimizeOptions.LayoutTransform.NCHW88: "enable_nchw88",
GraphOptimizeOptions.LayoutTransform.NCHW32: "enable_nchw32",
GraphOptimizeOptions.LayoutTransform.NCHW44: "enable_nchw44",
GraphOptimizeOptions.LayoutTransform.NCHW44_DOT: "enable_nchw44_dot",
GraphOptimizeOptions.LayoutTransform.CHWN4: "enable_chwn4",
GraphOptimizeOptions.LayoutTransform.NCHW64: "enable_nchw64",
}
ret = dict()
layout = inference_options.layout_transform
if layout != GraphOptimizeOptions.LayoutTransform.DEFAULT:
ret[inference_optimize_layout_transform_map[layout]] = True
if inference_options.f16_io_f32_comp:
ret["enable_io16xc32"] = True
if inference_options.f16_io_comp:
ret["enable_ioc16"] = True
if inference_options.fuse_conv_bias_nonlinearity:
ret["enable_fuse_conv_bias_nonlinearity"] = True
if inference_options.fuse_conv_bias_with_z:
ret["enable_fuse_conv_bias_with_z"] = True
if inference_options.fuse_preprocess:
ret["enable_fuse_preprocess"] = True
return ret
[文档]def modify_opr_algo_strategy_inplace(dest_vars, strategy: str):
r"""C++ graph version of :func:`~.set_execution_strategy`. Used to inplacely modify
dumped graph's fast-run strategy.
Args:
dest_vars: list of output vars in the computing graph.
strategy: fast-run algorithms strategy.
"""
dest_vars = _unwrap(dest_vars)
_imperative_rt.modify_opr_algo_strategy_inplace(dest_vars, strategy)
CompGraphDumpResult = collections.namedtuple(
"CompGraphDumpResult",
[
"nr_opr",
"tot_bytes",
"tensor_value_bytes",
"content_hash",
"inputs",
"outputs",
"params",
],
)
[文档]def dump_graph(
output_vars: Union[Dict[str, VarNode], List[VarNode]],
*,
keep_var_name: int = 1,
keep_opr_name: bool = False,
keep_param_name: bool = False,
keep_opr_priority: bool = False,
strip_info_file=None,
append_json=False,
metadata=None
) -> Tuple[bytes, CompGraphDumpResult]:
r"""serialize the computing graph of `output_vars` and get byte result.
Args:
output_vars: output variables which are the graph's end point.
keep_var_name: level for keeping variable names:
* 0: none of the names are kept
* 1: (default)keep names of output vars
* 2: keep names of all (output and internal) vars
keep_opr_name: whether to keep operator names.
keep_param_name: whether to keep param names, so param values can be
easily manipulated after loading model
keep_opr_priority: whether to keep priority setting for operators
strip_info_file: a string for path or a file handler. if is not None,
then the dump information for code strip would be written to ``strip_info_file``
append_json: will be check when `strip_info_file` is not None. if set
true, the information for code strip will be append to strip_info_file.
if set false, will rewrite strip_info_file
Note:
The underlying C++ API only accepts a var list. If a dict is given,
the vars would be renamed to the given names.
Returns:
dump result as byte string, and an instance of namedtuple
:class:`CompGraphDumpResult`, whose fields are:
* ``nr_opr`` number of operators dumped
* ``tot_bytes`` total bytes for the whole graph
* ``tensor_value_bytes`` bytes consumed for dumping tensor values
* ``inputs`` names of input tensors
* ``params`` list of names of dumped params
* ``outputs`` names of output vars
"""
if isinstance(output_vars, dict):
used_vars = set()
for name, var in output_vars.items():
assert var.id not in used_vars, (
"var name is associated with a var object, so we can not have "
"two names given to the same var: {}".format(var)
)
used_vars.add(var.id)
var.name = name
output_vars = list(output_vars.values())
else:
output_vars = list(output_vars)
ov = _unwrap(output_vars)
stat = []
inputs = []
outputs = []
params = []
dump_content = _imperative_rt.dump_graph(
ov,
keep_var_name,
keep_opr_name,
keep_param_name,
keep_opr_priority,
metadata,
stat,
inputs,
outputs,
params,
)
dump_info = CompGraphDumpResult(*stat, inputs, outputs, params)
if strip_info_file is not None:
if isinstance(strip_info_file, str):
if not os.path.exists(strip_info_file):
os.mknod(strip_info_file)
strip_info_file = open(strip_info_file, "r+")
new_strip_dict = json.loads(_imperative_rt.get_info_for_strip(ov))
ori_strip_dict = new_strip_dict
json_content = strip_info_file.read()
if append_json and len(json_content) != 0:
# if there are contents in json file. Read them first and then append new information
ori_strip_dict = json.loads(json_content)
for k in ori_strip_dict:
new_strip_dict_v = new_strip_dict.get(k)
if new_strip_dict_v is not None:
for value in new_strip_dict_v:
if not value in ori_strip_dict[k]:
ori_strip_dict[k].append(value)
ori_strip_dict["hash"] = dump_info.content_hash
strip_info_file.seek(0)
strip_info_file.truncate()
json.dump(ori_strip_dict, strip_info_file)
return dump_content, dump_info
CompGraphLoadResult = collections.namedtuple(
"CompGraphLoadResult", ["graph", "output_vars_dict", "output_vars_list", "metadata"]
)
[文档]def load_graph(fpath) -> CompGraphLoadResult:
r"""Load a serialized computing graph from file.
Args:
fpath: Path or Handle of the input file
Returns:
An instance of namedtuple :class:`CompGraphLoadResult`,
whose fields are:
* ``graph`` loaded CompGraph
* ``output_vars_dict`` A Python dict, mapping name to output SymbolVar
* ``output_vars_list`` A Python list, containing output vars in the
order passed to serialize_comp_graph_to_file
"""
output_vars_map = []
output_vars_list = []
if isinstance(fpath, str):
buf = open(fpath, "rb").read()
else:
buf = fpath.read()
cg, metadata = _imperative_rt.load_graph(buf, output_vars_map, output_vars_list)
return CompGraphLoadResult(cg, dict(output_vars_map), output_vars_list, metadata)
def _wrap(x):
if isinstance(x, collections.abc.Sequence):
return type(x)(map(_wrap, x))
if hasattr(x.graph, "_wrap"):
return x.graph._wrap(x)
else:
return x
def _unwrap(x):
if isinstance(x, collections.abc.Sequence):
return type(x)(map(_unwrap, x))
if isinstance(x, VarNode):
return x._node
return x
[文档]def apply_normal_varnode(op: OpDef, *args: VarNode):
# for PyOp like RemoteSend/Recv
if getattr(op, "op", None):
op = op.op
outputs = _imperative_rt.invoke_op(op, _unwrap(args))
return _wrap(outputs)
[文档]def output_callback(callback, var, *args):
args = (var,) + args
dummy = _imperative_rt.output_callback(callback, _unwrap(args))
return _wrap(dummy)
[文档]class OutputNode(OpNode):
def __init__(self, var, *args):
args = (var,) + args
r = _imperative_rt.DeviceTensorNDRendezvous()
dummy = _imperative_rt.output_callback(r, _unwrap(args))
super().__init__(dummy.owner)
self._rendezvous = r
[文档] def get_value(self):
return self._rendezvous.get()
[文档] def drop_value(self):
self._rendezvous.drop()
[文档] def reset(self):
self._rendezvous.reset()
[文档]class ValueOutputNode(OpNode):
def __init__(self, var, *args):
args = (var,) + args
r = _imperative_rt.HostTensorNDRendezvous()
dummy = _imperative_rt.value_output_callback(r, _unwrap(args))
super().__init__(dummy.owner)
self._rendezvous = r
[文档] def get_value(self):
hostnd, event = self._rendezvous.get()
event.wait()
return hostnd.numpy()
[文档] def drop_value(self):
self._rendezvous.drop()
[文档] def reset(self):
self._rendezvous.reset()
[文档]class TensorAttr:
def __init__(self, shape, dtype, device):
self.shape = shape
self.dtype = dtype
self.device = device
[文档]class AttrOutputNode(OpNode):
def __init__(self, var, *args):
args = (var,) + args
r = _imperative_rt.TensorAttrRendezvous()
dummy = _imperative_rt.attr_output_callback(r, _unwrap(args))
super().__init__(dummy.owner)
self._rendezvous = r
[文档] def get_value(self):
attr = self._rendezvous.get()
return TensorAttr(attr.shape, attr.dtype, as_device(attr.comp_node))
[文档] def drop_value(self):
self._rendezvous.drop()
[文档] def reset(self):
self._rendezvous.reset()
[文档]class VirtualDepNode(OpNode):
def __init__(self, vars, device=""):
out = _imperative_rt.virtual_dep(_unwrap(vars), device)
super().__init__(out)