import abc
import copy
import weakref
from importlib import import_module
from typing import Any, Dict, List, Tuple, Type
import numpy
from .. import get_logger
from ..core._imperative_rt.core2 import Tensor as RawTensor
from ..module import Module
from ..quantization.utils import QParams
from ..tensor import Tensor
from .module_tracer import active_module_tracer
from .tm_config import _get_expr_checker
from .utils import _check_obj_attr
logger = get_logger(__name__)
class Node:
r"""``Node`` represents the variables (``Tensor``, ``Module``) used in Module's forward method.
They are inputs/outputs of Expr (the operations on variables).
"""
expr = None # type: Expr
r"""The Expr which produces the Node."""
__total_id = 0 # type: int
_id = None # type: int
_top_graph = None # type: weakref.ReferenceType
_format_spec = "" # type: str
def __init__(self, expr, name: str, qualname: str):
self.expr = expr
self.users = [] # List[Expr]
self._id = Node.__total_id
Node.__total_id += 1
self._name = name
self._qualname = qualname
self.actual_node = [] # type: List[Node]
def __repr__(self):
format_spec = Node._format_spec
return self.__format__(format_spec)
def __format__(self, format_spec: str) -> str:
if not format_spec:
format_spec = Node._format_spec
name = self._name
if name is None:
name = ""
if format_spec in ["i", "p", "ip", "pi"]:
if "p" in format_spec:
prefix_name = self.top_graph._name
name = "{}_{}".format(prefix_name, name)
if "i" in format_spec:
name = "%{}_{}".format(self._id, name)
return name
else:
return name if name else ("%d" % self._id)
@property
def name(self):
r"""Return the name of this Node."""
return self._name
@name.setter
def name(self, new_name: str):
r"""Set a new name to this Node."""
graph = self.top_graph
assert graph is not None, "The parent graph of this Node cannot be None."
assert graph._namespace.used_names.get(new_name, None) is None, (
"The name(%s) is already in use. Please try a different one again."
% (new_name)
)
graph._namespace.unassociate_name_with_obj(self)
self._name = graph._namespace.create_unique_name(new_name, self)
@property
def qualname(self):
r"""Get the `qualname` of this Node. The `qualname` can be used to get the
submodule from the traced Module or Module.
Example:
.. code-block::
import megengine.module as M
import megengine.functional as F
import megengine.traced_module as tm
import megengine as mge
class block(M.Module):
def __init__(self):
super().__init__()
self.param = mge.Tensor([1.])
self.relu = M.ReLU()
def forward(self, x):
x = x + self.param
return self.relu(F.relu(x))
class module(M.Module):
def __init__(self):
super().__init__()
self.block = block()
def forward(self, x):
x = self.block(x)
return x
net = module()
traced_net = tm.trace_module(net, mge.Tensor([0.]))
traced_net = traced_net.flatten()
out_node = traced_net.graph.outputs[0]
# qualname : "module.block.relu.[out]"
qualname = out_node.qualname
# qualname : "block.relu"
qualname = qualname.split(".", 1)[-1].rsplit(".", 1)[0]
assert qualname in list(map(lambda x: x[0], net.named_modules()))
assert qualname in list(map(lambda x: x[0], traced_net.named_modules()))
"""
return self._qualname
@property
def top_graph(self):
r"""Get the parent graph of this Node."""
if self._top_graph:
return self._top_graph()
return None
@classmethod
def _set_format_spec(cls, str):
old_format_spec = cls._format_spec
cls._format_spec = str
return old_format_spec
@classmethod
def _get_next_id(cls):
return cls.__total_id
@classmethod
def _set_next_id(cls, id: int = 0):
assert isinstance(id, int)
cls.__total_id = id
def __copy__(self):
cls = self.__class__
result = cls.__new__(cls)
result.__dict__.update(self.__dict__)
return result
def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
state = {}
memo[id(self)] = result
for k, v in self.__dict__.items():
if not isinstance(v, weakref.ReferenceType) and k != "actual_node":
state[k] = copy.deepcopy(v, memo)
result.__dict__.update(state)
return result
[docs]class ModuleNode(Node):
r"""``ModuleNode`` represents the Module objects."""
module_type = Module # type: Type[Module]
r"""The type of the Module correspending to the ModuleNode."""
_owner = None # type: weakref.ReferenceType
def __init__(self, expr, name: str = None, qualname: str = None):
super().__init__(expr, name, qualname)
def __getstate__(self):
state = {
"expr": self.expr,
"users": self.users,
"_id": self._id,
"_name": self._name,
"_qualname": self._qualname,
"module_type": (self.module_type.__module__, self.module_type.__qualname__),
}
_check_obj_attr(state)
return state
def __setstate__(self, state):
if "_orig_name" in state:
state["_qualname"] = state.pop("_orig_name")
self.__dict__.update(state)
try:
if isinstance(self.module_type, tuple):
mname, classname = self.module_type
mtype = getattr(import_module(mname), classname)
self.module_type = mtype
except Exception:
pass
@property
def owner(self):
r"""Get the ``Module`` corresponding to this ``ModuleNode``.
"""
if self._owner:
return self._owner()
return None
[docs]class TensorNode(Node):
r"""``TensorNode`` represents the Tensor objects."""
_shape = None # type: Tuple[int]
_dtype = None # type: numpy.dtype
_qparams = None # type: QParams
_device = None
_value = None # type: Tensor
def __init__(
self,
expr,
name: str = None,
qualname: str = None,
shape: Tuple[int] = None,
dtype: numpy.dtype = None,
qparams: QParams = None,
):
super().__init__(expr, name, qualname)
self._shape = shape
self._dtype = dtype
self._qparams = qparams
def __getstate__(self):
state = {
"expr": self.expr,
"users": self.users,
"_id": self._id,
"_qparams": self._qparams,
"_shape": self._shape,
"_dtype": self._dtype,
"_device": self._device,
"_name": self._name,
"_qualname": self._qualname,
}
_check_obj_attr(state)
return state
def __setstate__(self, state):
if "_orig_name" in state:
qualname = state.pop("_orig_name")
modulepath, comma, qualname = qualname.rpartition(".")
expr_name = state["expr"].__class__.__name__
if expr_name not in ["GetAttr"]:
qualname = "[{}]".format(qualname)
if comma:
qualname = "{}.{}".format(modulepath, qualname)
state["_qualname"] = qualname
self.__dict__.update(state)
@property
def shape(self):
r"""Get the shape of this Node."""
return self._shape
@shape.setter
def shape(self, shape):
self._shape = shape
@property
def dtype(self):
r"""Get the dtype of this Node."""
return self._dtype
@dtype.setter
def dtype(self, dtype):
self._dtype = dtype
@property
def device(self):
r"""Get the device of this Node pointed Tensor."""
return self._device
@device.setter
def device(self, device):
self._device = device
@property
def qparams(self):
r"""Get the :class:`QParams` of this Node."""
return self._qparams
@qparams.setter
def qparams(self, qparams):
self._qparams = qparams
@property
def value(self):
r"""Get the bound Tensor of this Node."""
return self._value
@value.setter
def value(self, value):
r"""Bind a :class:`Tensor` to this Node."""
if isinstance(value, RawTensor) and NodeMixin.get(value, None) is not None:
setattr(value, "_NodeMixin__node", None)
self._value = value
class NodeMixin(abc.ABC):
__node = None
@abc.abstractmethod
def _record_wrapped_nodes(self, node):
# record the nodes which had been bound to this NodeMixin
pass
@classmethod
def _record_tensornode_property(cls, node, value):
assert isinstance(node, TensorNode)
assert isinstance(value, RawTensor)
if isinstance(value, RawTensor):
try:
node._dtype = value.dtype
except RuntimeError:
node._dtype = None
node._shape = (
value._tuple_shape if isinstance(value, Tensor) else value.shape
)
node._device = value.device
if hasattr(value, "_qparams") and value._qparams is not None:
node._qparams = value.qparams
@classmethod
def wrap(cls, value, node):
if isinstance(value, (NodeMixin, RawTensor)):
if isinstance(node, Node):
if isinstance(value, RawTensor):
cls._record_tensornode_property(node, value)
if isinstance(value, NodeMixin):
value._record_wrapped_nodes(node)
setattr(value, "_NodeMixin__node", node)
if _get_expr_checker():
if isinstance(value, RawTensor):
active_module_tracer().checker.record_node2value(node, value)
if isinstance(value, NodeMixin):
active_module_tracer().checker.record_nodemixin(node, value)
else:
assert callable(node)
n = node()
assert isinstance(n, Node)
if isinstance(value, RawTensor):
cls._record_tensornode_property(n, value)
if isinstance(value, NodeMixin):
value._record_wrapped_nodes(n)
setattr(value, "_NodeMixin__node", n)
if _get_expr_checker():
if isinstance(value, RawTensor):
active_module_tracer().checker.record_node2value(n, value)
if isinstance(value, NodeMixin):
active_module_tracer().checker.record_nodemixin(n, value)
@classmethod
def wrap_safe(cls, value, node):
assert isinstance(value, (NodeMixin, RawTensor))
if isinstance(value, RawTensor):
cls._record_tensornode_property(node, value)
setattr(value, "_NodeMixin__node", node)
if _get_expr_checker():
if isinstance(value, RawTensor):
active_module_tracer().checker.record_node2value(node, value)
if isinstance(value, NodeMixin):
active_module_tracer().checker.record_nodemixin(node, value)
if isinstance(value, NodeMixin):
value._record_wrapped_nodes(node)
@classmethod
def clear_node(cls, value):
if hasattr(value, "_NodeMixin__node"):
delattr(value, "_NodeMixin__node")
@classmethod
def get(cls, value, *default):
return getattr(value, "_NodeMixin__node", *default)
@classmethod
def get_wrapped_type(cls, value):
if isinstance(value, RawTensor):
return TensorNode
if isinstance(value, (Module, NodeMixin)):
return ModuleNode
return Node