import builtins
import collections
import copy
import inspect
import re
import weakref
from importlib import import_module
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Union
from ..core._imperative_rt import OpDef
from ..core._imperative_rt.core2 import Const
from ..core._imperative_rt.core2 import Tensor as RawTensor
from ..core._imperative_rt.core2 import (
apply,
is_tracing_module,
set_module_trace_hook,
set_module_tracing,
unset_module_tracing,
)
from ..core.ops.builtin import FakeQuant
from ..module import Module
from ..tensor import Parameter, Tensor
from ..version import __version__
from .module_tracer import active_module_tracer, module_tracer
from .node import ModuleNode, Node, NodeMixin, TensorNode
from .pytree import ArgsIndex, TreeDef, _is_const_leaf, _is_leaf, tree_flatten
from .serialization import _ModuleState
from .tm_config import _exclude_from_trace, _get_expr_checker
from .utils import _check_builtin_module_attr, _check_obj_attr, _convert_kwargs_to_args
def rstrip(s: str, __chars: str):
__chars = re.escape(__chars)
s = re.sub(r"^(?P<left>.*?)(?:%s)+$" % __chars, "\g<left>", s)
return s
def get_suffix_name(prefix: str, name: str):
if prefix == name:
return ""
matchd = re.compile("^%s\.(.*)" % prefix).match(name)
if matchd is None:
return None
return matchd.group(1)
def is_call_module(expr, module_cls: Module = None):
return (
isinstance(expr, CallMethod)
and isinstance(expr.inputs[0], ModuleNode)
and expr.method == "__call__"
) and (module_cls is None or isinstance(expr.inputs[0].owner, module_cls))
def is_call_tensor_method(expr, method: Iterable[str] = None):
if method and isinstance(method, str):
method = (method,)
return (
isinstance(expr, CallMethod)
and not is_call_module(expr)
and (method is None or any(expr.method == f for f in method))
)
def is_call_function(expr, func: Iterable[Callable] = None):
if func and not isinstance(func, Iterable):
func = (func,)
return isinstance(expr, CallFunction) and (
func is None or any(expr.func == f for f in func)
)
def is_constant(expr):
return isinstance(expr, Constant)
def is_getattr(expr):
return isinstance(expr, GetAttr)
def is_apply_def(expr, opdef=None):
return isinstance(expr, Apply) and (opdef is None or isinstance(expr.opdef, opdef))
def is_input(expr):
return isinstance(expr, Input)
class Expr:
r"""``Expr`` represents the operations (i.e. ``CallMethod``, ``CallFunction``, ``Apply``,
``GetAttr``, ``Input``, ``Constant``) on ``Node``.
"""
inputs = None # type: List[Node]
r"""The input Nodes of this Expr."""
outputs = None # type: List[Node]
r"""The output Nodes of this Expr."""
const_val = None # type: List[Any]
r"""The non-tensor object in the input of the operation."""
arg_def = None # type: TreeDef
r"""The :class:`TreeDef` used to reconstruct the input of the operation."""
out_def = None # type: TreeDef
r"""The :class:`TreeDef` used to reconstruct the output of the operation."""
_top_graph = None # type: weakref.ReferenceType
__total_id = 0
def __init__(self) -> None:
self._id = Expr.__total_id
Expr.__total_id += 1
self._disable_remove = False
def enable_remove(self):
self._disable_remove = False
def disable_remove(self):
self._disable_remove = True
def add_inputs(self, vals):
if not isinstance(vals, collections.abc.Sequence):
vals = (vals,)
for val in vals:
node = NodeMixin.get(val, None)
if isinstance(node, (TensorNode, ModuleNode)):
self.inputs.append(node)
node.users.append(self)
else:
assert node is None
assert not isinstance(val, (Module, RawTensor))
assert _is_leaf(val) and _is_const_leaf(val)
idx = len(self.inputs) + len(self.const_val)
self.const_val.append((idx, val))
def add_outputs(self, outputs):
assert active_module_tracer() is not None
self.outputs = []
if outputs is None:
return
current_graph = active_module_tracer().current_scope()
if not isinstance(outputs, collections.abc.Sequence):
outputs = (outputs,)
for i in outputs:
assert isinstance(i, RawTensor), "The output must be a Tensor"
node = NodeMixin.get_wrapped_type(i)(expr=self, name="", qualname="",)
NodeMixin.wrap_safe(i, node)
self.outputs.append(node)
current_graph._namespace.auto_naming_for_outputs(self)
def unflatten_args(self, inputs):
assert self.arg_def is not None, "{} expr doesn't have args/kwargs".format(
type(self).__name__
)
inputs = list(inputs)
for idx, val in self.const_val:
inputs.insert(idx, val)
args, kwargs = self.arg_def.unflatten(inputs)
return args, kwargs
def replace_inputs(self, repl_dict: Dict[Node, Node]):
r"""Replace the input Nodes of this Expr.
Args:
repl_dict: the map {old_Node: new_Node} that specifies how to replace the input Nodes.
"""
while repl_dict:
node, repl_node = repl_dict.popitem()
assert type(node) == type(repl_node)
assert node in self.inputs, "({}) is not in the ({})".format(node, self)
assert (
repl_node.top_graph == node.top_graph
), "({}) and ({}) are not in the same graph".format(node, repl_node)
graph = self.top_graph
repl_expr_idx = graph._exprs.index(repl_node.expr)
self_idx = graph._exprs.index(self)
assert (
repl_expr_idx < self_idx
), "({}) must be generated before ({})".format(repl_node, self)
idx = self.inputs.index(node)
self.inputs[idx] = repl_node
node.users.remove(self)
repl_node.users.append(self)
@property
def _support_set_args_kwargs(self):
return False
def set_args_kwargs(self, *args, **kwargs):
r""" Set args and kwargs for Expr.
"""
assert (
self._support_set_args_kwargs
), "Doesn't support set args/kwargs for {} expr".format(type(self).__name__)
args, kwargs = _convert_kwargs_to_args(self._get_func(), args, kwargs)
inputs, arg_def = tree_flatten((args, kwargs))
orig_inputs = self.inputs
self.inputs = []
self.const_val = []
for val in inputs:
if isinstance(val, (TensorNode, ModuleNode)):
self.inputs.append(val)
else:
assert _is_leaf(val) and _is_const_leaf(val)
idx = len(self.inputs) + len(self.const_val)
self.const_val.append((idx, val))
for n in orig_inputs:
if n not in self.inputs:
n.users.remove(self)
for n in self.inputs:
if n not in orig_inputs:
n.users.append(self)
self.arg_def = arg_def
@property
def kwargs(self):
r"""Get the keyword arguments of the operation corresponding to this Expr."""
_, kwargs = self.unflatten_args(self.inputs)
return kwargs
@property
def args(self):
r"""Get the positional arguments of the operation corresponding to this Expr."""
args, _ = self.unflatten_args(self.inputs)
return args
def _get_func(self):
# get called function when the expr is interpreted
raise NotImplementedError
@property
def named_args(self):
func = self._get_func()
return inspect.getcallargs(func, *self.args, **self.kwargs)
def set_arg(self, name, val):
func = self._get_func()
if name in self.kwargs:
new_kwargs = self.kwargs
new_kwargs[name] = val
self.set_args_kwargs(*self.args, **new_kwargs)
else:
arg_spec = inspect.getfullargspec(func)
if name in arg_spec.args:
ind = arg_spec.args.index(name)
new_args = list(self.args)
new_args[ind] = val
self.set_args_kwargs(*new_args)
elif name == arg_spec.varargs:
assert arg_spec.varargs is not None
assert len(self.args) >= len(arg_spec.args)
val = (val,) if not isinstance(val, Sequence) else val
self.set_args_kwargs(*self.args[0 : len(arg_spec.args)], *val)
else:
assert (
arg_spec.varkw is not None
), "func {} does't have argument named {}".format(func, name)
new_kwargs = self.kwargs
new_kwargs[name] = val
self.set_args_kwargs(*self.args, **new_kwargs)
@property
def return_val(self):
return self.out_def.unflatten(self.outputs)
@return_val.setter
def return_val(self, new_outputs):
outputs, out_def = tree_flatten(
new_outputs, is_leaf=lambda x: isinstance(x, Node)
)
assert all(
isinstance(o, Node) for o in outputs
), "Return values of expr must be ModuleNode or TensorNode or Container with them"
assert all(
o.expr in (None, self) for o in outputs
), "Some nodes are produced by other expr, can not be output of expr {}".format(
self
)
self.outputs = outputs
self.out_def = out_def
@property
def top_graph(self):
r"""Get the parent graph of this Expr."""
if self._top_graph:
return self._top_graph()
return None
@classmethod
def _get_next_id(cls):
return cls.__total_id
@classmethod
def _set_next_id(cls, id: int = 0):
assert isinstance(id, int)
cls.__total_id = id
def __copy__(self):
cls = self.__class__
result = cls.__new__(cls)
result.__dict__.update(self.__dict__)
return result
def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
state = {}
memo[id(self)] = result
for k, v in self.__dict__.items():
if not isinstance(v, weakref.ReferenceType):
state[k] = copy.deepcopy(v, memo)
result.__dict__.update(state)
return result
# expr: None (i.e. fake expression which is used to mark input)
# expr: outputs = getattr(inputs[0], self.name)
[docs]class GetAttr(Expr):
r"""``Getattr`` represents the fetch of an attribute from the ``Module`` hierarchy."""
name = None
r"""name: the qualified name of the attribute to be retrieved."""
def __init__(
self, module: ModuleNode, type: Union[Node], attr_name: str, name: str = "",
):
super().__init__()
assert isinstance(module, ModuleNode)
assert type in [TensorNode, ModuleNode]
self.inputs = [
module,
]
module.users.append(self)
self.name = attr_name
self.outputs = [
type(self, name=name, qualname="{}.{}".format(module.qualname, attr_name)),
]
@classmethod
def make(cls, *args, **kwargs):
assert active_module_tracer() is not None
current_graph = active_module_tracer().current_scope()
expr = cls(*args, **kwargs)
current_graph._namespace.auto_naming_for_outputs(expr)
current_graph._insert(expr)
return expr.outputs[0]
def interpret(self, *inputs):
mod = inputs[0]
module_path, _, name = self.name.rpartition(".")
if module_path == "":
return (getattr(mod, name),)
module_names = module_path.split(".")
for item in module_names:
mod = getattr(mod, item)
if not isinstance(mod, Module):
raise AttributeError("`{}` is not an Module".format(item))
return (getattr(mod, name),)
def __repr__(self):
out_type = "Tensor"
if isinstance(self.outputs[0], ModuleNode):
m_type = self.outputs[0].module_type
out_type = m_type.__name__ if isinstance(m_type, type) else m_type[1]
return '%{}:\t{} = getattr({}, "{}") -> ({})'.format(
self._id, self.outputs[0], self.inputs[0], self.name, out_type
)
def __getstate__(self):
state = {
"_id": self._id,
"_disable_remove": self._disable_remove,
"inputs": self.inputs,
"outputs": self.outputs,
"name": self.name,
}
_check_obj_attr(state)
return state
# expr: outputs = inputs[0].__call__(*inputs[1:])
[docs]class CallMethod(Expr):
r"""``CallMethod`` represents a call to the ``__call__`` method of ``Module`` or a method of ``Tensor``.
Args:
node: the Node to be called.
method: the method name.
Default: "__call__"
"""
def __init__(self, node, method="__call__"):
super().__init__()
if isinstance(node, type):
assert issubclass(node, Tensor)
cls = Parameter if issubclass(node, Parameter) else Tensor
self.inputs = []
self.const_val = [(0, cls)]
else:
assert isinstance(node, (TensorNode, ModuleNode))
node.users.append(self)
self.inputs = [
node,
]
self.const_val = []
self.arg_def = tree_flatten(((node,), {}))[1]
self.method = method
@classmethod
def make(cls, *args, **kwargs):
assert active_module_tracer() is not None
expr = cls(*args, **kwargs)
active_module_tracer().current_scope()._insert(expr)
return expr
@property
def graph(self):
if isinstance(self.inputs[0], ModuleNode):
m_node = self.inputs[0]
if (
hasattr(m_node.owner, "argdef_graph_map")
and m_node.owner.argdef_graph_map
):
assert self.arg_def in m_node.owner.argdef_graph_map
return m_node.owner.argdef_graph_map[self.arg_def]
return None
def interpret(self, *inputs):
args, kwargs = self.unflatten_args(inputs)
obj = args[0]
meth = getattr(obj, self.method)
if inspect.ismethod(meth):
args = args[1:]
outputs = getattr(obj, self.method)(*args, **kwargs)
if self.method == "__setitem__":
outputs = obj
if outputs is None:
return outputs
outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor))
return outputs
def _get_func(self):
if isinstance(self.args[0], type):
obj_type = self.args[0]
elif isinstance(self.args[0], ModuleNode):
obj_type = self.args[0].module_type
else:
assert isinstance(self.args[0], TensorNode)
obj_type = Tensor
meth = getattr(
obj_type, "forward" if issubclass(obj_type, Module) else self.method
)
return meth
@property
def _support_set_args_kwargs(self):
# only expr call tensor method or builtin module support modify args/kwargs
return (
isinstance(self.args[0], (TensorNode, type))
or self.args[0].module_type is not Module
)
def __repr__(self):
args = ", ".join(str(i) for i in self.args[1:])
kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items())
outputs = self.outputs
if self.out_def:
outputs = self.out_def.unflatten(outputs)
method = ".%s" % self.method
if method == ".__call__":
method = ""
return "%{}:\t{}{}{}({})".format(
self._id,
str(outputs) + " = " if outputs else "",
self.args[0],
method,
", ".join([args, kwargs]),
)
def __getstate__(self):
state = {
"_id": self._id,
"_disable_remove": self._disable_remove,
"inputs": self.inputs,
"const_val": self.const_val,
"method": self.method,
"arg_def": self.arg_def,
"out_def": self.out_def,
"outputs": self.outputs,
"version": __version__,
}
_check_obj_attr(state)
return state
# expr: outputs = apply(self.opdef, *inputs)
[docs]class Apply(Expr):
r"""``Apply`` represents a call to :func:`apply`.
Args:
opdef: the applied :class:`OpDef`.
"""
opdef = None
def __init__(self, opdef):
super().__init__()
assert isinstance(opdef, OpDef)
self.opdef = opdef
self.inputs = []
@classmethod
def make(cls, *args, **kwargs):
assert active_module_tracer() is not None
expr = cls(*args, **kwargs)
active_module_tracer().current_scope()._insert(expr)
return expr
def interpret(self, *inputs):
return apply(self.opdef, *inputs)
def __repr__(self):
return "%{}:\t{} = {}({})".format(
self._id,
", ".join(str(i) for i in self.outputs),
self.opdef,
", ".join(str(i) for i in self.inputs),
)
def __getstate__(self):
opdef_state = self.opdef.__getstate__()
opdef_state["opdef_type"] = type(self.opdef)
state = {
"_id": self._id,
"_disable_remove": self._disable_remove,
"opdef_state": opdef_state,
"inputs": self.inputs,
"outputs": self.outputs,
"version": __version__,
}
_check_obj_attr(state)
return state
def __setstate__(self, state):
# compat with mge 1.6
if "opdef" in state and "opdef_state" not in state:
opdef_state = state.pop("opdef")
opdef_state["opdef_type"] = opdef_state.pop("type")
state["opdef_state"] = opdef_state
self.__dict__.update(state)
assert isinstance(state["opdef_state"], dict)
opdef_state = state["opdef_state"].copy()
opdef_type = opdef_state.pop("opdef_type")
opdef_obj = opdef_type()
opdef_obj.__setstate__(opdef_state)
setattr(self, "opdef", opdef_obj)
@classmethod
def apply_module_trace_hook(cls, opdef, *inputs):
for i in inputs:
node = NodeMixin.get(i, None)
if node is None: # capture as constant
NodeMixin.wrap_safe(i, Constant.make(i))
if isinstance(opdef, FakeQuant):
inp_nodes = [NodeMixin.get(inputs[0])]
for i in inputs[1:]:
node = Constant.make(i)
if _get_expr_checker():
active_module_tracer().checker.record_node2value(node, Tensor(i))
inp_nodes.append(node)
apply_node = cls.make(opdef)
for n in inp_nodes:
n.users.append(apply_node)
apply_node.inputs = inp_nodes
else:
apply_node = cls.make(opdef)
apply_node.add_inputs(inputs)
assert not apply_node.const_val
unset_module_tracing()
outputs = apply(opdef, *inputs)
set_module_tracing()
apply_node.add_outputs(outputs)
for n, v in zip(apply_node.outputs, outputs):
NodeMixin.wrap_safe(v, n)
if _get_expr_checker():
with _exclude_from_trace():
active_module_tracer().checker.check_apply(apply_node, outputs, opdef)
return list(outputs)
[docs]class CallFunction(Expr):
r"""``CallFunction`` represents a call to a built-in function.
Args:
func: a built-in function.
"""
def __init__(self, func):
super().__init__()
assert isinstance(func, Callable)
self.func = func
self.const_val = []
self.inputs = []
@classmethod
def make(cls, *args, **kwargs):
assert active_module_tracer() is not None
expr = cls(*args, **kwargs)
active_module_tracer().current_scope()._insert(expr)
return expr
def interpret(self, *inputs):
args, kwargs = self.unflatten_args(inputs)
func = (
self.func
if not is_tracing_module()
else active_module_tracer().patcher.wrap_fn(self.func)
)
outputs = func(*args, **kwargs)
if outputs is None:
return outputs
outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor))
return outputs
def _get_func(self):
return self.func
@property
def _support_set_args_kwargs(self):
return True
def __repr__(self):
args = ", ".join(str(i) for i in self.args)
kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items())
outputs = self.outputs
if self.out_def:
outputs = self.out_def.unflatten(outputs)
return "%{}:\t{}{}({})".format(
self._id,
str(outputs) + " = " if outputs else "",
self.func.__module__.rsplit(".")[-1] + "." + self.func.__name__,
", ".join([args, kwargs]),
)
def __getstate__(self):
state = {
"_id": self._id,
"_disable_remove": self._disable_remove,
"func": (self.func.__module__, self.func.__qualname__),
"const_val": self.const_val,
"inputs": self.inputs,
"arg_def": self.arg_def,
"out_def": self.out_def,
"outputs": self.outputs,
"version": __version__,
}
_check_obj_attr(state)
return state
def __setstate__(self, state):
self.__dict__.update(state)
try:
if isinstance(self.func, tuple):
mname, fname = self.func
f = import_module(mname)
for i in fname.split("."):
f = getattr(f, i)
self.func = f
except Exception:
pass
# expr outputs = self.value
[docs]class Constant(Expr):
r"""``Constant`` represents a ``Tensor`` or "Module" which is not the attribute of a Module.
Args:
c: a const Tensor or Module.
name: the name of output Node.
"""
value = None
r"""The const Tensor or Module"""
# TODO: constant cache to reduce the size of dumped model
_constant_cache = {}
def __init__(self, c, name: str = "", qualname: str = ""):
super().__init__()
assert isinstance(c, (RawTensor, Module))
if isinstance(c, Module):
assert module_tracer.is_builtin(c) or c.is_qat
if type(c) is RawTensor:
with _exclude_from_trace():
c = Tensor(c)
self.value = c
self.name = name
self.inputs = []
node_cls = NodeMixin.get_wrapped_type(c)
self.outputs = [
node_cls(self, name=name, qualname=qualname),
]
@classmethod
def make(cls, *args, **kwargs):
assert active_module_tracer() is not None
expr = cls(*args, **kwargs)
current_graph = active_module_tracer().current_scope()
current_graph._namespace.auto_naming_for_outputs(expr)
current_graph._insert(expr)
active_module_tracer().current_constant_cache().append(expr.value)
return expr.outputs[0]
def interpret(self, *inputs):
if isinstance(self.value, RawTensor):
return (Const(self.value.numpy(), None, None),)
return (self.value,)
def __repr__(self):
name = self.name
if name is None:
name = type(self.value)
node_type = "Module"
if isinstance(self.outputs[0], TensorNode):
node_type = "Tensor"
return "%{}:\t{} = Constant({}) -> ({})".format(
self._id, self.outputs[0], name, node_type
)
def __getstate__(self):
state = {
"_id": self._id,
"_disable_remove": self._disable_remove,
"value": self.value,
"name": self.name,
"inputs": self.inputs,
"outputs": self.outputs,
}
_check_obj_attr(state)
if isinstance(self.value, RawTensor):
state["value"] = Tensor(self.value)
if isinstance(self.value, Module) and module_tracer.is_builtin(self.value):
_check_builtin_module_attr(self.value)
state["value"] = _ModuleState.get_module_state(self.value)
return state
def __setstate__(self, state):
for k, v in state.items():
if isinstance(v, _ModuleState):
state[k] = v.to_module()
self.__dict__.update(state)
def _module_trace_capture(value):
node = Constant.make(value)
NodeMixin.wrap_safe(value, node)
return node
set_module_trace_hook(Apply.apply_module_trace_hook)