Module

class Module(name=None)[源代码]

Module基类。

参数

name – module 的名称,可以通过子类的 kwargs 参数初始化。

apply(fn)[源代码]

对当前模块中的所有模块应用函数 fn,包括当前模块本身。

参数

fn (Callable[[Module], Any]) – 多个模块上要应用的函数。

返回类型

None

buffers(recursive=True, **kwargs)[源代码]

返回该模块中对于buffers的一个可迭代对象。

Buffer被定义为是 Tensor 且不是 Parameter

参数

recursive (bool) – 如果设置为 True,返回此 module 及其所有子 module 的 buffer,否则仅返回作为此 module 属性的 buffer。

返回类型

Iterable[Tensor]

children(**kwargs)[源代码]

返回一个可迭代对象,可遍历所有属于当前模块的直接属性的子模块。

返回类型

Iterable[Module]

disable_quantize(value=True)[源代码]

设置 modulequantize_diabled 属性,并返回 module 。可以作为装饰器使用。

eval()[源代码]

当前模块中所有模块的 training 属性(包括自身)置为 False ,并将其切换为推理模式。请参阅 train 了解详情。

返回类型

None

load_state_dict(state_dict, strict=True)[源代码]

向当前模块中加载由 state_dict 创建的给定字典。若 strictTruestate_dict 的键则必须与 state_dict 返回的键准确匹配。

为了处理复杂情况,用户可以传入闭包 Function[key: str, var: Tensor] -> Optional[np.ndarray] 作为 state_dict 。例如,欲加载除了最后线性分类器外的所有部分:

state_dict = {...}  #  Dict[str, np.ndarray]
model.load_state_dict({
    k: None if k.startswith('fc') else v
    for k, v in state_dict.items()
}, strict=False)

这里返回 None 意味着忽略参数 k

为了防止形状不匹配(例如加载PyTorch权重),我们可以在加载之前重塑(reshape):

state_dict = {...}
def reshape_accordingly(k, v):
    return state_dict[k].reshape(v.shape)
model.load_state_dict(reshape_accordingly)

我们还可以进行原位重初始化或修剪(pruning):

def reinit_and_pruning(k, v):
    if 'bias' in k:
        M.init.zero_(v)
    if 'conv' in k:
modules(**kwargs)[源代码]

返回一个可迭代对象,可以遍历当前模块中的所有模块,包括其本身。

返回类型

Iterable[Module]

named_buffers(prefix=None, recursive=True, **kwargs)[源代码]

返回可遍历模块中 key 与 buffer 的键值对的可迭代对象,其中 key 为从该模块至 buffer 的点路径(dotted path)。

Buffer被定义为是 Tensor 且不是 Parameter

参数
  • prefix (Optional[str]) – 加在每个键(key)前的前缀。

  • recursive (bool) – 如果为 True ,则返回所有当前模块中的buffer,否则只返回属于该模块的直接属性。

  • prefix – Optional[str]:

返回类型

Iterable[Tuple[str, Tensor]]

named_children(**kwargs)[源代码]

返回可迭代对象,可以遍历属于当前模块的直接属性的所有子模块(submodule)与键(key)组成的”key-submodule”对,其中’key’是子模块对应的属性名。

返回类型

Iterable[Tuple[str, Module]]

named_modules(prefix=None, **kwargs)[源代码]

返回可迭代对象,可以遍历当前模块包括自身在内的所有其内部模块所组成的key-module键-模块对,其中’key’是从当前模块到各子模块的点路径(dotted path)。

参数

prefix (Optional[str]) – 加在路径前的前缀。

返回类型

Iterable[Tuple[str, Module]]

named_parameters(prefix=None, recursive=True, **kwargs)[源代码]

返回一个可迭代对象,可以遍历当前模块中key与 Parameter 组成的键值对。其中 key 是从模块到 Parameter 的点路径(dotted path)。

参数
  • prefix (Optional[str]) – 加在每个键(key)前的前缀。

  • recursive (bool) – 如果为 True , 则返回在此模块内的所有 Parameter ; 否则,只返回属于当前模块直接属性的 Parameter

返回类型

Iterable[Tuple[str, Parameter]]

named_tensors(prefix=None, recursive=True, **kwargs)[源代码]

返回一个以从 module 到 tensor 的点路径作为 key 的键值对的可迭代对象,

参数
  • prefix (Optional[str]) – 加在每个键(key)前的前缀。

  • recursive (bool) – 如果设置为 True, 返回此 module 及其所有子 module 的 tensor, 否则仅返回作为此 module 属性的 tensor.

返回类型

Iterable[Tuple[str, Tensor]]

parameters(recursive=True, **kwargs)[源代码]

返回一个可迭代对象,遍历当前模块中的所有 Parameter

参数

recursive (bool) – 如果为 True , 则返回在此模块内的所有 Parameter ; 否则,只返回属于当前模块直接属性的 Parameter

返回类型

Iterable[Parameter]

register_forward_hook(hook)[源代码]

注册一个处理推理输出的钩子函数。此函数接受 module, inputsoutputs 作为输入,返回修改过的 outputs 或是 None.

一个带有 remove 接口以删除该钩子函数的句柄。

返回类型

HookHandler

register_forward_pre_hook(hook)[源代码]

注册一个处理推理输入的钩子函数。

参数

hook (Callable) – 一个接受 moduleinputs 作为输入,返回修改过的 inputs 或是 None 的函数。

返回类型

HookHandler

返回

一个带有 remove 接口以删除该钩子函数的句柄。

replace_param(params, start_pos, seen=None)[源代码]

Replaces module’s parameters with params, used by ParamPack to speedup multimachine training.

1.0 版后已移除.

state_dict(rst=None, prefix='', keep_var=False)[源代码]

Returns a dictionary containing whole states of the module.

tensors(recursive=True, **kwargs)[源代码]

返回一个此 module 的 Tensor 的可迭代对象。

参数

recursive (bool) – 如果设置为 True, 则返回此 module 以及其子 module 的所有 Tensor, 否则仅返回此 module 本身的 Tensor.

返回类型

Iterable[Parameter]

train(mode=True, recursive=True)[源代码]

将该模块中的所有模块(包括它自身)的训练模式设置为 mode 。 可便捷地将这些模块的 training 属性设置为 mode ,但仅对某些模块有效(例如 BatchNorm2d, Dropout, Observer)

参数
  • mode (bool) – 为模块设置的训练模式。

  • recursive (bool) – 是否要递归调用子模块的 train()

返回类型

None

zero_grad()[源代码]

将所有参数的梯度置0。

1.0 版后已移除.

返回类型

None