Module¶
- class Module(name=None)[源代码]¶
Module基类。
- 参数
name – module 的名称,可以通过子类的
kwargs
参数初始化。
- load_state_dict(state_dict, strict=True)[源代码]¶
向当前模块中加载由
state_dict
创建的给定字典。若strict
为True
,state_dict
的键则必须与state_dict
返回的键准确匹配。为了处理复杂情况,用户可以传入闭包 Function[key: str, var: Tensor] -> Optional[np.ndarray] 作为 state_dict 。例如,欲加载除了最后线性分类器外的所有部分:
state_dict = {...} # Dict[str, np.ndarray] model.load_state_dict({ k: None if k.startswith('fc') else v for k, v in state_dict.items() }, strict=False)
这里返回
None
意味着忽略参数 k 。为了防止形状不匹配(例如加载PyTorch权重),我们可以在加载之前重塑(reshape):
state_dict = {...} def reshape_accordingly(k, v): return state_dict[k].reshape(v.shape) model.load_state_dict(reshape_accordingly)
我们还可以进行原位重初始化或修剪(pruning):
def reinit_and_pruning(k, v): if 'bias' in k: M.init.zero_(v) if 'conv' in k:
- named_buffers(prefix=None, recursive=True, **kwargs)[源代码]¶
返回可遍历模块中 key 与 buffer 的键值对的可迭代对象,其中
key
为从该模块至 buffer 的点路径(dotted path)。
- named_children(**kwargs)[源代码]¶
返回可迭代对象,可以遍历属于当前模块的直接属性的所有子模块(submodule)与键(key)组成的”key-submodule”对,其中’key’是子模块对应的属性名。
- named_modules(prefix=None, **kwargs)[源代码]¶
返回可迭代对象,可以遍历当前模块包括自身在内的所有其内部模块所组成的key-module键-模块对,其中’key’是从当前模块到各子模块的点路径(dotted path)。
- named_parameters(prefix=None, recursive=True, **kwargs)[源代码]¶
返回一个可迭代对象,可以遍历当前模块中key与
Parameter
组成的键值对。其中key
是从模块到Parameter
的点路径(dotted path)。
- named_tensors(prefix=None, recursive=True, **kwargs)[源代码]¶
返回一个以从 module 到 tensor 的点路径作为
key
的键值对的可迭代对象,
- register_forward_hook(hook)[源代码]¶
注册一个处理推理输出的钩子函数。此函数接受 module, inputs 和 outputs 作为输入,返回修改过的 outputs 或是 None.
一个带有
remove
接口以删除该钩子函数的句柄。- 返回类型
HookHandler
- register_forward_pre_hook(hook)[源代码]¶
注册一个处理推理输入的钩子函数。
- 参数
hook (
Callable
) – 一个接受 module 和 inputs 作为输入,返回修改过的 inputs 或是 None 的函数。- 返回类型
HookHandler
- 返回
一个带有
remove
接口以删除该钩子函数的句柄。
- replace_param(params, start_pos, seen=None)[源代码]¶
Replaces module’s parameters with
params
, used byParamPack
to speedup multimachine training.1.0 版后已移除.
- state_dict(rst=None, prefix='', keep_var=False)[源代码]¶
Returns a dictionary containing whole states of the module.
- train(mode=True, recursive=True)[源代码]¶
将该模块中的所有模块(包括它自身)的训练模式设置为
mode
。 可便捷地将这些模块的training
属性设置为mode
,但仅对某些模块有效(例如BatchNorm2d
,Dropout
,Observer
)