Module¶
- class Module(name=None)[源代码]¶
Module基类。
- #: megengine.module.module.Module megengine.module.module.Module.apply #: megengine.module.module.Module.buffers #: megengine.module.module.Module.named_buffers #: megengine.module.module.Module.named_modules #: megengine.module.module.Module.named_parameters #: megengine.module.module.Module.named_tensors #: megengine.module.module.Module.parameters #: megengine.module.module.Module.register_forward_pre_hook #: megengine.module.module.Module.tensors megengine.module.module.Module.train #: of
name – module 的名称,可以通过子类的
kwargs
参数初始化。
- apply(fn)[源代码]¶
对当前模块中的所有模块应用函数
fn
,包括当前模块本身。- #: megengine.module.module.Module megengine.module.module.Module.apply #: megengine.module.module.Module.buffers #: megengine.module.module.Module.named_buffers #: megengine.module.module.Module.named_modules #: megengine.module.module.Module.named_parameters #: megengine.module.module.Module.named_tensors #: megengine.module.module.Module.parameters #: megengine.module.module.Module.register_forward_pre_hook #: megengine.module.module.Module.tensors megengine.module.module.Module.train #: of
- 返回类型
- buffers(recursive=True, **kwargs)[源代码]¶
返回该模块中对于buffers的一个可迭代对象。
Buffer被定义为是
Tensor
且不是Parameter
- #: megengine.module.module.Module megengine.module.module.Module.apply #: megengine.module.module.Module.buffers #: megengine.module.module.Module.named_buffers #: megengine.module.module.Module.named_modules #: megengine.module.module.Module.named_parameters #: megengine.module.module.Module.named_tensors #: megengine.module.module.Module.parameters #: megengine.module.module.Module.register_forward_pre_hook #: megengine.module.module.Module.tensors megengine.module.module.Module.train #: of
recursive (
bool
) – 如果设置为True
,返回此 module 及其所有子 module 的 buffer,否则仅返回作为此 module 属性的 buffer。- 返回类型
- load_state_dict(state_dict, strict=True)[源代码]¶
向当前模块中加载由
state_dict
创建的给定字典。若strict
为True
,state_dict
的键则必须与state_dict
返回的键准确匹配。为了处理复杂情况,用户可以传入闭包 Function[key: str, var: Tensor] -> Optional[np.ndarray] 作为 state_dict 。例如,欲加载除了最后线性分类器外的所有部分:
state_dict = {...} # Dict[str, np.ndarray] model.load_state_dict({ k: None if k.startswith('fc') else v for k, v in state_dict.items() }, strict=False)
这里返回
None
意味着忽略参数 k 。为了防止形状不匹配(例如加载PyTorch权重),我们可以在加载之前重塑(reshape):
state_dict = {...} def reshape_accordingly(k, v): return state_dict[k].reshape(v.shape) model.load_state_dict(reshape_accordingly)
我们还可以进行原位重初始化或修剪(pruning):
def reinit_and_pruning(k, v): if 'bias' in k: M.init.zero_(v) if 'conv' in k:
- named_buffers(prefix=None, recursive=True, **kwargs)[源代码]¶
返回可遍历模块中 key 与 buffer 的键值对的可迭代对象,其中
key
为从该模块至 buffer 的点路径(dotted path)。Buffer被定义为是
Tensor
且不是Parameter
- #: megengine.module.module.Module megengine.module.module.Module.apply #: megengine.module.module.Module.buffers #: megengine.module.module.Module.named_buffers #: megengine.module.module.Module.named_modules #: megengine.module.module.Module.named_parameters #: megengine.module.module.Module.named_tensors #: megengine.module.module.Module.parameters #: megengine.module.module.Module.register_forward_pre_hook #: megengine.module.module.Module.tensors megengine.module.module.Module.train #: of
- 返回类型
- named_children(**kwargs)[源代码]¶
返回可迭代对象,可以遍历属于当前模块的直接属性的所有子模块(submodule)与键(key)组成的”key-submodule”对,其中’key’是子模块对应的属性名。
- named_modules(prefix=None, **kwargs)[源代码]¶
返回可迭代对象,可以遍历当前模块包括自身在内的所有其内部模块所组成的key-module键-模块对,其中’key’是从当前模块到各子模块的点路径(dotted path)。
- #: megengine.module.module.Module megengine.module.module.Module.apply #: megengine.module.module.Module.buffers #: megengine.module.module.Module.named_buffers #: megengine.module.module.Module.named_modules #: megengine.module.module.Module.named_parameters #: megengine.module.module.Module.named_tensors #: megengine.module.module.Module.parameters #: megengine.module.module.Module.register_forward_pre_hook #: megengine.module.module.Module.tensors megengine.module.module.Module.train #: of
- 返回类型
- named_parameters(prefix=None, recursive=True, **kwargs)[源代码]¶
返回一个可迭代对象,可以遍历当前模块中key与
Parameter
组成的键值对。其中key
是从模块到Parameter
的点路径(dotted path)。- #: megengine.module.module.Module megengine.module.module.Module.apply #: megengine.module.module.Module.buffers #: megengine.module.module.Module.named_buffers #: megengine.module.module.Module.named_modules #: megengine.module.module.Module.named_parameters #: megengine.module.module.Module.named_tensors #: megengine.module.module.Module.parameters #: megengine.module.module.Module.register_forward_pre_hook #: megengine.module.module.Module.tensors megengine.module.module.Module.train #: of
- 返回类型
- named_tensors(prefix=None, recursive=True, **kwargs)[源代码]¶
返回一个以从 module 到 tensor 的点路径作为
key
的键值对的可迭代对象,- #: megengine.module.module.Module megengine.module.module.Module.apply #: megengine.module.module.Module.buffers #: megengine.module.module.Module.named_buffers #: megengine.module.module.Module.named_modules #: megengine.module.module.Module.named_parameters #: megengine.module.module.Module.named_tensors #: megengine.module.module.Module.parameters #: megengine.module.module.Module.register_forward_pre_hook #: megengine.module.module.Module.tensors megengine.module.module.Module.train #: of
- 返回类型
- parameters(recursive=True, **kwargs)[源代码]¶
返回一个可迭代对象,遍历当前模块中的所有
Parameter
- #: megengine.module.module.Module megengine.module.module.Module.apply #: megengine.module.module.Module.buffers #: megengine.module.module.Module.named_buffers #: megengine.module.module.Module.named_modules #: megengine.module.module.Module.named_parameters #: megengine.module.module.Module.named_tensors #: megengine.module.module.Module.parameters #: megengine.module.module.Module.register_forward_pre_hook #: megengine.module.module.Module.tensors megengine.module.module.Module.train #: of
recursive (
bool
) – 如果为True
, 则返回在此模块内的所有Parameter
; 否则,只返回属于当前模块直接属性的Parameter
。- 返回类型
- register_forward_hook(hook)[源代码]¶
注册一个处理推理输出的钩子函数。此函数接受 module, inputs 和 outputs 作为输入,返回修改过的 outputs 或是 None.
一个带有
remove
接口以删除该钩子函数的句柄。- 返回类型
HookHandler
- register_forward_pre_hook(hook)[源代码]¶
注册一个处理推理输入的钩子函数。
- #: megengine.module.module.Module megengine.module.module.Module.apply #: megengine.module.module.Module.buffers #: megengine.module.module.Module.named_buffers #: megengine.module.module.Module.named_modules #: megengine.module.module.Module.named_parameters #: megengine.module.module.Module.named_tensors #: megengine.module.module.Module.parameters #: megengine.module.module.Module.register_forward_pre_hook #: megengine.module.module.Module.tensors megengine.module.module.Module.train #: of
hook (
Callable
) – 一个接受 module 和 inputs 作为输入,返回修改过的 inputs 或是 None 的函数。- 返回类型
HookHandler
- 返回
一个带有
remove
接口以删除该钩子函数的句柄。
- replace_param(params, start_pos, seen=None)[源代码]¶
- 用
params
替代此 module 的各参数,被ParamPack
使用以加速多机训练。 提速多机训练
1.0 版后已移除.
- 用
- tensors(recursive=True, **kwargs)[源代码]¶
返回一个此 module 的
Tensor
的可迭代对象。- #: megengine.module.module.Module megengine.module.module.Module.apply #: megengine.module.module.Module.buffers #: megengine.module.module.Module.named_buffers #: megengine.module.module.Module.named_modules #: megengine.module.module.Module.named_parameters #: megengine.module.module.Module.named_tensors #: megengine.module.module.Module.parameters #: megengine.module.module.Module.register_forward_pre_hook #: megengine.module.module.Module.tensors megengine.module.module.Module.train #: of
recursive (
bool
) – 如果设置为True
, 则返回此 module 以及其子 module 的所有Tensor
, 否则仅返回此 module 本身的Tensor
.- 返回类型
- train(mode=True, recursive=True)[源代码]¶
将该模块中的所有模块(包括它自身)的训练模式设置为
mode
。 可便捷地将这些模块的training
属性设置为mode
,但仅对某些模块有效(例如BatchNorm2d
,Dropout
,Observer
)- #: megengine.module.module.Module megengine.module.module.Module.apply #: megengine.module.module.Module.buffers #: megengine.module.module.Module.named_buffers #: megengine.module.module.Module.named_modules #: megengine.module.module.Module.named_parameters #: megengine.module.module.Module.named_tensors #: megengine.module.module.Module.parameters #: megengine.module.module.Module.register_forward_pre_hook #: megengine.module.module.Module.tensors megengine.module.module.Module.train #: of
- 返回类型