megengine.module.module 源代码

from abc import ABCMeta, abstractmethod
from collections import OrderedDict
from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union

import numpy as np

from ..core.tensor.utils import make_shape_tuple
from ..logger import get_logger
from ..tensor import Parameter, Tensor
from ..utils.deprecation import deprecated
from ..utils.hook import HookHandler
from ..utils.naming import AutoNaming

logger = get_logger(__name__)


def _expand_structure(prefix, obj):
    if isinstance(obj, (Tensor, Module)):
        return [(prefix, obj)]
    elif isinstance(obj, (list, tuple, dict)):
        ret = []
        if isinstance(obj, dict):
            targets = ((k, obj[k]) for k in sorted(obj))
        else:
            targets = ((str(k), v) for k, v in enumerate(obj))
        for k, o in targets:
            sub_ret = _expand_structure(k, o)
            if sub_ret and not isinstance(k, str):
                raise AssertionError(
                    "keys for Tensor and Module must be str, error key: {}".format(k)
                )
            for kt, vt in sub_ret:
                ret.extend([(prefix + "." + kt, vt)])
        return ret
    else:
        return []


def _access_structure(obj, key, callback=None):
    key_list = key.split(".")
    cur = obj
    parent = None
    for k in key_list:
        parent = cur
        if isinstance(cur, (list, tuple)):
            k = int(k)
            cur = cur[k]
        elif isinstance(cur, dict):
            cur = cur[k]
        else:
            cur = getattr(cur, k)
    if callable is None:
        return cur
    return callback(parent, k, cur)


def _is_parameter(obj):
    return isinstance(obj, Parameter)


def _is_tensor(obj):
    return isinstance(obj, Tensor)


def _is_buffer(obj):
    return isinstance(obj, Tensor) and not isinstance(obj, Parameter)


def _is_module(obj):
    return isinstance(obj, Module)


def _get_XNorm_typeclass():
    from .batchnorm import _BatchNorm
    from .normalization import GroupNorm, InstanceNorm, LayerNorm, GeneralNorm

    XNorm_types = (_BatchNorm, GroupNorm, LayerNorm, InstanceNorm)
    return XNorm_types


[文档]class Module(metaclass=ABCMeta): r"""Base Module class. Args: name: module's name, can be initialized by the ``kwargs`` parameter of child class. """ def __init__(self, name=None): self._modules = [] if name is not None: assert ( isinstance(name, str) and name.strip() ), "Module's name must be a non-empty string" self.name = name # runtime attributes self.training = True self.quantize_disabled = False # hooks self._forward_pre_hooks = OrderedDict() self._forward_hooks = OrderedDict() # used for profiler and automatic naming self._name = None self._short_name = None @abstractmethod def forward(self, inputs): pass
[文档] def register_forward_pre_hook(self, hook: Callable) -> HookHandler: """Registers a hook to handle forward inputs. `hook` should be a function. Args: hook: a function that receive `module` and `inputs`, then return a modified `inputs` or `None`. Returns: a handler with :meth:`~.HookHandler.remove` interface to delete the hook. """ return HookHandler(self._forward_pre_hooks, hook)
[文档] def register_forward_hook(self, hook: Callable) -> HookHandler: """Registers a hook to handle forward results. `hook` should be a function that receive `module`, `inputs` and `outputs`, then return a modified `outputs` or `None`. This method return a handler with :meth:`~.HookHandler.remove` interface to delete the hook. """ return HookHandler(self._forward_hooks, hook)
def __call__(self, *inputs, **kwargs): AutoNaming.push_scope(self.name if self.name is not None else self._short_name) for hook in self._forward_pre_hooks.values(): modified_inputs = hook(self, inputs) if modified_inputs is not None: if not isinstance(modified_inputs, tuple): modified_inputs = (modified_inputs,) inputs = modified_inputs outputs = self.forward(*inputs, **kwargs) for hook in self._forward_hooks.values(): modified_outputs = hook(self, inputs, outputs) if modified_outputs is not None: outputs = modified_outputs AutoNaming.pop_scope() return outputs def _flatten( self, *, recursive: bool = True, with_key: bool = False, with_parent: bool = False, prefix: Optional[str] = None, predicate: Callable[[Any], bool] = lambda _: True, seen: Optional[Set[int]] = None ) -> Union[Iterable[Any], Iterable[Tuple[str, Any]]]: """Scans the module object and returns an iterable for the :class:`~.Tensor` and :class:`~.Module` attributes that agree with the ``predicate``. For multiple calls of this function with same arguments, the order of objects within the returned iterable is guaranteed to be identical, as long as all the involved module objects' ``__dict__`` does not change thoughout those calls. Args: recursive: whether to recursively scan all the submodules. with_key: whether to yield keys along with yielded objects. with_parent: whether to yield ``self`` along with yielded objects. prefix: prefix appended to the yielded keys. predicate: the predication function applied to scanned objects. seen: a dict that records whether a module has been traversed yet. """ if seen is None: seen = set([id(self)]) module_dict = vars(self) _prefix = "" if prefix is None else prefix + "." for key in sorted(module_dict): for expanded_key, leaf in _expand_structure(key, module_dict[key]): leaf_id = id(leaf) if leaf_id in seen: continue seen.add(leaf_id) if predicate(leaf): if with_key and with_parent: yield _prefix + expanded_key, leaf, self elif with_key: yield _prefix + expanded_key, leaf elif with_parent: yield leaf, self else: yield leaf if recursive and isinstance(leaf, Module): yield from leaf._flatten( recursive=recursive, with_key=with_key, with_parent=with_parent, prefix=_prefix + expanded_key if with_key else None, predicate=predicate, seen=seen, )
[文档] def parameters(self, recursive: bool = True, **kwargs) -> Iterable[Parameter]: r"""Returns an iterable for the :class:`~.Parameter` of the module. Args: recursive: If ``True``, returns all :class:`~.Parameter` within this module, else only returns :class:`~.Parameter` that are direct attributes of this module. """ if "requires_grad" in kwargs: del kwargs["requires_grad"] logger.warning( "Tensor currently has no requires_grad attribute " "so requires_grad argument is ignored here" ) def predicate(obj) -> bool: return _is_parameter(obj) yield from self._flatten( with_key=False, predicate=predicate, recursive=recursive, **kwargs )
[文档] def named_parameters( self, prefix: Optional[str] = None, recursive: bool = True, **kwargs ) -> Iterable[Tuple[str, Parameter]]: r"""Returns an iterable for key :class:`~.Parameter` pairs of the module, where ``key`` is the dotted path from this module to the :class:`~.Parameter`. Args: prefix: prefix prepended to the keys. recursive: if ``True``, returns all :class:`~.Parameter` within this module, else only returns :class:`~.Parameter` that are direct attributes of this module. """ if "requires_grad" in kwargs: del kwargs["requires_grad"] logger.warning( "Tensor currently has no requires_grad attribute " "so requires_grad argument is ignored here" ) def predicate(obj) -> bool: return _is_parameter(obj) yield from self._flatten( with_key=True, prefix=prefix, predicate=predicate, recursive=recursive, **kwargs, )
[文档] def buffers(self, recursive: bool = True, **kwargs) -> Iterable[Tensor]: r"""Returns an iterable for the buffers of the module. Buffer is defined to be :class:`~.Tensor` excluding :class:`~.Parameter`. Args: recursive: if ``True``, returns all buffers within this module, else only returns buffers that are direct attributes """ yield from self._flatten( with_key=False, predicate=_is_buffer, recursive=recursive, **kwargs )
[文档] def named_buffers( self, prefix: Optional[str] = None, recursive: bool = True, **kwargs ) -> Iterable[Tuple[str, Tensor]]: r"""Returns an iterable for key buffer pairs of the module, where ``key`` is the dotted path from this module to the buffer. Buffer is defined to be :class:`~.Tensor` excluding :class:`~.Parameter`. Args: prefix: prefix prepended to the keys. recursive: if ``True``, returns all buffers within this module, else only returns buffers that are direct attributes of this module. prefix: Optional[str]: """ yield from self._flatten( with_key=True, prefix=prefix, predicate=_is_buffer, recursive=recursive, **kwargs, )
[文档] def tensors(self, recursive: bool = True, **kwargs) -> Iterable[Parameter]: r""" Returns an iterable for the :class:`~.Tensor` of the module. :param recursive: If ``True``, returns all :class:`~.Tensor` within this module, else only returns :class:`~.Tensor` that are direct attributes of this module. """ yield from self._flatten( with_key=False, predicate=_is_tensor, recursive=recursive, **kwargs )
[文档] def named_tensors( self, prefix: Optional[str] = None, recursive: bool = True, **kwargs ) -> Iterable[Tuple[str, Tensor]]: """ Returns an iterable for key tensor pairs of the module, where ``key`` is the dotted path from this module to the tensor. :param prefix: prefix prepended to the keys. :param recursive: if ``True``, returns all tensors within this module, else only returns tensors that are direct attributes of this module. """ yield from self._flatten( with_key=True, prefix=prefix, predicate=_is_tensor, recursive=recursive, **kwargs, )
[文档] def children(self, **kwargs) -> "Iterable[Module]": r"""Returns an iterable for all the submodules that are direct attributes of this module. """ yield from self._flatten( with_key=False, predicate=_is_module, recursive=False, **kwargs )
[文档] def named_children(self, **kwargs) -> "Iterable[Tuple[str, Module]]": r"""Returns an iterable of key-submodule pairs for all the submodules that are direct attributes of this module, where 'key' is the attribute name of submodules. """ yield from self._flatten( with_key=True, predicate=_is_module, recursive=False, **kwargs )
[文档] def modules(self, **kwargs) -> "Iterable[Module]": r"""Returns an iterable for all the modules within this module, including itself.""" if "with_parent" in kwargs and kwargs["with_parent"]: yield self, None else: yield self yield from self._flatten(with_key=False, predicate=_is_module, **kwargs)
[文档] def named_modules( self, prefix: Optional[str] = None, **kwargs ) -> "Iterable[Tuple[str, Module]]": r"""Returns an iterable of key-module pairs for all the modules within this module, including itself, where 'key' is the dotted path from this module to the submodules. Args: prefix: prefix prepended to the path. """ if "with_parent" in kwargs and kwargs["with_parent"]: yield ("" if prefix is None else prefix), self, None else: yield ("" if prefix is None else prefix), self yield from self._flatten( with_key=True, prefix=prefix, predicate=_is_module, **kwargs )
[文档] def apply(self, fn: "Callable[[Module], Any]") -> None: r"""Applies function ``fn`` to all the modules within this module, including itself. Args: fn: the function to be applied on modules. """ for it in self.modules(): fn(it)
[文档] @deprecated(version="1.0") def zero_grad(self) -> None: r"""Sets all parameters' grads to zero""" for param in self.parameters(): if param.grad is not None: param.grad.reset_zero()
[文档] def train(self, mode: bool = True, recursive: bool = True) -> None: r"""Sets training mode of all the modules within this module (including itself) to ``mode``. This effectively sets the ``training`` attributes of those modules to ``mode``, but only has effect on certain modules (e.g. :class:`~.BatchNorm2d`, :class:`~.Dropout`, :class:`~.Observer`) Args: mode: the training mode to be set on modules. recursive: whether to recursively call submodules' ``train()``. """ if not recursive: self.training = mode return def fn(module: Module) -> None: module.train(mode, recursive=False) self.apply(fn)
[文档] def eval(self) -> None: r"""Sets training mode of all the modules within this module (including itself) to ``False``. See :meth:`~.Module.train` for details. """ self.train(False)
[文档] def disable_quantize(self, value=True): r"""Sets ``module``'s ``quantize_disabled`` attribute and return ``module``. Could be used as a decorator. """ def fn(module: Module) -> None: module.quantize_disabled = value self.apply(fn)
[文档] @deprecated(version="1.0") def replace_param( self, params: dict, start_pos: int, seen: Optional[Set[int]] = None ): r"""Replaces module's parameters with ``params``, used by :class:`~.ParamPack` to speedup multimachine training. """ offset = 0 if seen is None: seen = set([id(self)]) module_dict = vars(self) for key in sorted(module_dict): hash_id = id(module_dict[key]) if hash_id in seen: continue seen.add(hash_id) if isinstance(module_dict[key], Parameter): if start_pos + offset in params: assert make_shape_tuple(module_dict[key].shape) == make_shape_tuple( params[start_pos + offset].shape ) module_dict[key] = params[start_pos + offset] offset += 1 if isinstance(module_dict[key], Module): offset += module_dict[key].replace_param( params, start_pos + offset, seen ) return offset
[文档] def state_dict(self, rst=None, prefix="", keep_var=False): r"""Returns a dictionary containing whole states of the module.""" _rst = self._state_dict(rst=rst, prefix=prefix, keep_var=keep_var) rst = OrderedDict() XNorm_typeclass = _get_XNorm_typeclass() for (module_type, k), v in _rst.items(): # for performance reasons, parameters in XNorm (e.g., BatchNorm2d) are 4-dim tensors, # however they will be reshaped to 1-dim tensors before returned by `statr_dict()` if issubclass(module_type, XNorm_typeclass): v = v.reshape(-1) rst[k] = v return rst
def _state_dict(self, rst=None, prefix="", keep_var=False): r"""Returns a dictionary containing whole states of the module.""" def is_state(obj): return _is_parameter(obj) or _is_buffer(obj) module_type = self.__class__ if rst is None: rst = OrderedDict() for k, v in self._flatten(recursive=False, with_key=True, predicate=is_state): assert prefix + k not in rst, "duplicated state: {}".format(k) if keep_var: rst[(module_type, prefix + k)] = v else: rst[(module_type, prefix + k)] = v.numpy() for k, submodule in self._flatten( recursive=False, with_key=True, predicate=lambda obj: isinstance(obj, Module), ): submodule.state_dict(rst, prefix + k + ".", keep_var) return rst
[文档] def load_state_dict( self, state_dict: Union[dict, Callable[[str, Tensor], Optional[np.ndarray]]], strict=True, ): r"""Loads a given dictionary created by :func:`state_dict` into this module. If ``strict`` is ``True``, the keys of :func:`state_dict` must exactly match the keys returned by :func:`state_dict`. Users can also pass a closure: ``Function[key: str, var: Tensor] -> Optional[np.ndarray]`` as a `state_dict`, in order to handle complex situations. For example, load everything except for the final linear classifier: .. code-block:: state_dict = {...} # Dict[str, np.ndarray] model.load_state_dict({ k: None if k.startswith('fc') else v for k, v in state_dict.items() }, strict=False) Here returning ``None`` means skipping parameter ``k``. To prevent shape mismatch (e.g. load PyTorch weights), we can reshape before loading: .. code-block:: state_dict = {...} def reshape_accordingly(k, v): return state_dict[k].reshape(v.shape) model.load_state_dict(reshape_accordingly) We can also perform inplace re-initialization or pruning: .. code-block:: def reinit_and_pruning(k, v): if 'bias' in k: M.init.zero_(v) if 'conv' in k: """ unused = [] if isinstance(state_dict, dict): unused = state_dict.keys() def closure(k, _): # var unused return state_dict[k] if k in state_dict else None elif callable(state_dict): closure = state_dict else: raise ValueError( "`state_dict` must load a dict or callable, got {}".format( type(state_dict) ) ) loaded, skipped = self._load_state_dict_with_closure(closure) unused = set(unused) - loaded if len(unused) != 0: if strict: raise KeyError( "Unused params violate `strict=True`, unused={}".format(unused) ) else: logger.warning( "Unused params in `strict=False` mode, unused={}".format(unused) ) if len(skipped) != 0: if strict: raise KeyError( "Missing params violate `strict=True`, missing={}".format(skipped) ) else: logger.warning( "Missing params in `strict=False` mode, missing={}".format(skipped) )
def _load_state_dict_with_closure(self, closure): r"""Advance state_dict load through callable ``closure`` whose signature is ``closure(key: str, var: Tensor) -> Union[np.ndarry, None]`` """ XNorm_typeclass = _get_XNorm_typeclass() assert callable(closure), "closure must be a function" loaded = [] skipped = [] local_state_dict = self._state_dict(keep_var=True) for (module_type, k), var in local_state_dict.items(): to_be_load = closure(k, var) if to_be_load is None: skipped.append(k) continue assert isinstance( to_be_load, np.ndarray ), "closure should return a `np.ndarray`, now `{}` get {}".format( k, to_be_load ) var_shape = make_shape_tuple(var.shape) to_be_load_shape = make_shape_tuple(to_be_load.shape) if var_shape != to_be_load_shape: # weight and bias in BatchNorm1d, BatchNorm2d and SyncBatchNorm are 1-dim tensors in v1.0, and # since v1.1 they are 4-dim tensors. The following special rule for these modules preserves the # backward compatibility. if issubclass(module_type, XNorm_typeclass): if np.prod(var_shape) == np.prod(to_be_load_shape): to_be_load = to_be_load.reshape(var_shape) else: raise ValueError( "param `{}` size mismatch, should be {}, get {}".format( k, np.prod(var_shape), np.prod(to_be_load_shape) ) ) else: raise ValueError( "param `{}` shape mismatch, should be {}, get {}".format( k, var_shape, to_be_load_shape ) ) var._reset( type(var)( to_be_load, dtype=to_be_load.dtype, device=var.device, no_cache=True ) ) loaded.append(k) return set(loaded), set(skipped) def __setattr__(self, name: str, value): is_module_like = _is_module(value) or isinstance(value, (list, tuple, dict)) if name != "_modules": modules = self.__dict__.get("_modules") if modules is None and is_module_like: raise AttributeError( "cannot assign module before Module.__init__() call" ) if is_module_like: if name not in modules: modules.append(name) else: if modules is not None and name in modules: modules.remove(name) def append_name(prefix, name): if prefix is None or prefix == "": return name return prefix + "." + name def set_name(parent, prefix, name, obj): if isinstance(obj, Tensor): assert obj.name is not None if obj.name != "": name = obj.name full_name = append_name(prefix, name) if obj._short_name and obj._short_name != name: logger.warning( "try setting the submodule `{}` to `{}`'s new attribute `{}`, its name `{}` will remain unchanged".format( obj._short_name, type(parent), name, obj._short_name ) ) return if isinstance(obj, Tensor): obj._prefix = prefix obj._name = full_name obj._short_name = name obj._set_name(obj._name) return obj._name elif isinstance(obj, Module): obj._name = full_name obj._short_name = name for k, v in obj._flatten(recursive=False, with_key=True): set_name(obj, full_name, k, v) return obj._name else: assert False for k, v in _expand_structure(name, value): prefix = self._name if self._name else self.name set_name(self, prefix, k, v) super().__setattr__(name, value) def __setstate__(self, state): if "_short_name" not in state: state["_short_name"] = state["_name"] state["_name"] = None self.__dict__.update(state) def __delattr__(self, name: str): if name in self.__dict__ and _is_module(self.__dict__[name]): modules = self.__dict__.get("_modules") if name in modules: modules.remove(name) super().__delattr__(name) def _module_info_string(self) -> str: r"""Set the extra representation of the module.""" return "" def __repr__(self): def add_indent(repr_str, num_spaces): s = repr_str.split("\n") # don't do anything for single-line stuff if len(s) == 1: return repr_str first = s.pop(0) s = [(num_spaces * " ") + line for line in s] s = "\n".join(s) s = first + "\n" + s return s extra_lines = [] extra_repr = self._module_info_string() if extra_repr: extra_lines = extra_repr.split("\n") child_lines = [] for name in self._modules: if _is_module(self.__dict__[name]): child_lines.append( "(" + name + "): " + add_indent(repr(self.__dict__[name]), 2) ) else: for k, v in _expand_structure(name, self.__dict__[name]): if _is_module(v): child_lines.append("(" + k + "): " + add_indent(repr(v), 2)) lines = extra_lines + child_lines main_str = self.__class__.__name__ + "(" if lines: # simple one-liner info, which most builtin Modules will use if len(extra_lines) == 1 and not child_lines: main_str += extra_lines[0] else: main_str += "\n " + "\n ".join(lines) + "\n" main_str += ")" return main_str