Module¶
- class Module(name=None)[source]¶
Base Module class.
- Parameters
name – module’s name, can be initialized by the
kwargs
parameter of child class.
- children(**kwargs)[source]¶
Returns an iterable for all the submodules that are direct attributes of this module.
- disable_quantize(value=True)[source]¶
Sets
module
’squantize_disabled
attribute and returnmodule
. Could be used as a decorator.
- eval()[source]¶
Sets training mode of all the modules within this module (including itself) to
False
. Seetrain
for details.- Return type
- load_state_dict(state_dict, strict=True)[source]¶
Loads a given dictionary created by
state_dict
into this module. Ifstrict
isTrue
, the keys ofstate_dict
must exactly match the keys returned bystate_dict
.Users can also pass a closure:
Function[key: str, var: Tensor] -> Optional[np.ndarray]
as a state_dict, in order to handle complex situations. For example, load everything except for the final linear classifier:state_dict = {...} # Dict[str, np.ndarray] model.load_state_dict({ k: None if k.startswith('fc') else v for k, v in state_dict.items() }, strict=False)
Here returning
None
means skipping parameterk
.To prevent shape mismatch (e.g. load PyTorch weights), we can reshape before loading:
state_dict = {...} def reshape_accordingly(k, v): return state_dict[k].reshape(v.shape) model.load_state_dict(reshape_accordingly)
We can also perform inplace re-initialization or pruning:
def reinit_and_pruning(k, v): if 'bias' in k: M.init.zero_(v) if 'conv' in k:
- modules(**kwargs)[source]¶
Returns an iterable for all the modules within this module, including itself.
- named_buffers(prefix=None, recursive=True, **kwargs)[source]¶
Returns an iterable for key buffer pairs of the module, where
key
is the dotted path from this module to the buffer.
- named_children(**kwargs)[source]¶
Returns an iterable of key-submodule pairs for all the submodules that are direct attributes of this module, where ‘key’ is the attribute name of submodules.
- named_modules(prefix=None, **kwargs)[source]¶
Returns an iterable of key-module pairs for all the modules within this module, including itself, where ‘key’ is the dotted path from this module to the submodules.
- named_parameters(prefix=None, recursive=True, **kwargs)[source]¶
Returns an iterable for key
Parameter
pairs of the module, wherekey
is the dotted path from this module to theParameter
.
- named_tensors(prefix=None, recursive=True, **kwargs)[source]¶
Returns an iterable for key tensor pairs of the module, where
key
is the dotted path from this module to the tensor.
- register_forward_hook(hook)[source]¶
Registers a hook to handle forward results. hook should be a function that receive module, inputs and outputs, then return a modified outputs or None.
This method return a handler with
remove
interface to delete the hook.- Return type
HookHandler
- register_forward_pre_hook(hook)[source]¶
Registers a hook to handle forward inputs. hook should be a function.
- Parameters
hook (
Callable
) – a function that receive module and inputs, then return a modified inputs or None.- Return type
HookHandler
- Returns
a handler with
remove
interface to delete the hook.
- replace_param(params, start_pos, seen=None)[source]¶
Replaces module’s parameters with
params
, used byParamPack
to speedup multimachine training.Deprecated since version 1.0.
- state_dict(rst=None, prefix='', keep_var=False)[source]¶
Returns a dictionary containing whole states of the module.
- train(mode=True, recursive=True)[source]¶
Sets training mode of all the modules within this module (including itself) to
mode
. This effectively sets thetraining
attributes of those modules tomode
, but only has effect on certain modules (e.g.BatchNorm2d
,Dropout
,Observer
)