megengine.module.SyncBatchNorm

class SyncBatchNorm(num_features, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True, freeze=False, group=WORLD, **kwargs)[源代码]

对于分布式训练执行组同步Batch Normalization操作。

参数
  • num_features – usually \(C\) from an input of shape \((N, C, H, W)\) or the highest ranked dimension of an input less than 4D.

  • eps – a value added to the denominator for numerical stability. Default: 1e-5

  • momentum – the value used for the running_mean and running_var computation. Default: 0.9

  • affine – a boolean value that when set to True, this module has learnable affine parameters. Default: True

  • track_running_stats – when set to True, this module tracks the running mean and variance. When set to False, this module does not track such statistics and always uses batch statistics in both training and eval modes. Default: True

  • freeze – when set to True, this module does not update the running mean and variance, and uses the running mean and variance instead of the batch mean and batch variance to normalize the input. The parameter takes effect only when the module is initilized with track_running_stats as True. Default: False

  • group (Optional[Group]) – communication group, caculate mean and variance between this group. Default: WORLD

方法

apply(fn)

对当前模块中的所有模块应用函数 fn,包括当前模块本身。

buffers([recursive])

返回该模块中对于buffers的一个可迭代对象。

children(**kwargs)

返回一个可迭代对象,可遍历所有属于当前模块的直接属性的子模块。

disable_quantize([value])

设置 modulequantize_diabled 属性,并返回 module

eval()

当前模块中所有模块的 training 属性(包括自身)置为 False ,并将其切换为推理模式。

forward(inp)

load_state_dict(state_dict[, strict])

加载一个参数字典,这个字典通常使用 state_dict 得到。

modules(**kwargs)

返回一个可迭代对象,可以遍历当前模块中的所有模块,包括其本身。

named_buffers([prefix, recursive])

返回可遍历模块中 key 与 buffer 的键值对的可迭代对象,其中 key 为从该模块至 buffer 的点路径(dotted path)。

named_children(**kwargs)

返回可迭代对象,可以遍历属于当前模块的直接属性的所有子模块(submodule)与键(key)组成的”key-submodule”对,其中'key'是子模块对应的属性名。

named_modules([prefix])

返回可迭代对象,可以遍历当前模块包括自身在内的所有其内部模块所组成的key-module键-模块对,其中'key'是从当前模块到各子模块的点路径(dotted path)。

named_parameters([prefix, recursive])

返回一个可迭代对象,可以遍历当前模块中key与 Parameter 组成的键值对。其中 key 是从模块到 Parameter 的点路径(dotted path)。

named_tensors([prefix, recursive])

Returns an iterable for key tensor pairs of the module, where key is the dotted path from this module to the tensor.

parameters([recursive])

返回一个可迭代对象,遍历当前模块中的所有 Parameter

register_forward_hook(hook)

给模块输出注册一个回调函数。

register_forward_pre_hook(hook)

给模块输入注册一个回调函数。

replace_param(params, start_pos[, seen])

Replaces module's parameters with params, used by ParamPack to

reset_parameters()

rtype

None

reset_running_stats()

rtype

None

state_dict([rst, prefix, keep_var])

tensors([recursive])

Returns an iterable for the Tensor of the module.

train([mode, recursive])

当前模块中所有模块的 training 属性(包括自身)置为 mode

zero_grad()

将所有参数的梯度置0。