megengine.module.BatchNorm2d¶
- class BatchNorm2d(num_features, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True, freeze=False, compute_mode='default', **kwargs)[源代码]¶
在四维张量上进行批标准化(Batch Normalization)。
\[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta\]The mean and standard-deviation are calculated per-dimension over the mini-batches and \(\gamma\) and \(\beta\) are learnable parameter vectors.
默认在训练过程中,该层持续估计其计算的均值和方差,随后在评估过程中用于标准化。运行中的估计保持默认的
momentum
值 0.9。如果
track_running_stats
设置为False
,此层将不会累计估计统计数据,而是把当前batch的统计数据应用在评估阶段。因为批标准化是在 C 维进行的,(N, H, W) 切片上进行统计计算,所以通常将此方法称作空域批正则(Spatial Batch Normalization)。
- 参数
num_features – 通常是形状为 \((N, C, H, W)\) 输入数据的 \(C\) 或者维度低于四维的输入的最高维。
eps – 添加到分母的单个值,增加数值稳定性。默认:1e-5
momentum – 用于计算
running_mean
和running_var
的值。默认:0.9affine – 单个布尔值,当设置为
True
,那么这个模块具有可学习的仿射(affine)参数。默认:Truetrack_running_stats – 当设置为 True,则这个模块跟踪运行时的不同batch的均值和方差。当设置为 False,该模块不跟踪这样的统计数据并在训练和eval模式下始终使用当前批统计数据。默认: True
freeze – 设置为True时,此模块不会更新运行运行时平均值和运行时方差,使用运行时平均值和方差而不是批次均值和批次方差来标准化输入。这个参数产生作用仅在使用 track_running_stats=True 初始化模块时有效并且模块处于训练模式。默认:False
实际案例
import numpy as np import megengine as mge import megengine.module as M # With Learnable Parameters m = M.BatchNorm2d(4) inp = mge.tensor(np.random.rand(1, 4, 3, 3).astype("float32")) oup = m(inp) print(m.weight.numpy().flatten(), m.bias.numpy().flatten()) # Without L`e`arnable Parameters m = M.BatchNorm2d(4, affine=False) oup = m(inp) print(m.weight, m.bias)
输出:
[1. 1. 1. 1.] [0. 0. 0. 0.] None None
方法
apply
(fn)对当前模块中的所有模块应用函数
fn
,包括当前模块本身。buffers
([recursive])返回该模块中对于buffers的一个可迭代对象。
children
(**kwargs)返回一个可迭代对象,可遍历所有属于当前模块的直接属性的子模块。
disable_quantize
([value])设置
module
的quantize_diabled
属性,并返回module
。eval
()当前模块中所有模块的
training
属性(包括自身)置为False
,并将其切换为推理模式。forward
(inp)load_state_dict
(state_dict[, strict])加载一个参数字典,这个字典通常使用
state_dict
得到。modules
(**kwargs)返回一个可迭代对象,可以遍历当前模块中的所有模块,包括其本身。
named_buffers
([prefix, recursive])返回可遍历模块中 key 与 buffer 的键值对的可迭代对象,其中
key
为从该模块至 buffer 的点路径(dotted path)。named_children
(**kwargs)返回可迭代对象,可以遍历属于当前模块的直接属性的所有子模块(submodule)与键(key)组成的”key-submodule”对,其中'key'是子模块对应的属性名。
named_modules
([prefix])返回可迭代对象,可以遍历当前模块包括自身在内的所有其内部模块所组成的key-module键-模块对,其中'key'是从当前模块到各子模块的点路径(dotted path)。
named_parameters
([prefix, recursive])返回一个可迭代对象,可以遍历当前模块中key与
Parameter
组成的键值对。其中key
是从模块到Parameter
的点路径(dotted path)。named_tensors
([prefix, recursive])Returns an iterable for key tensor pairs of the module, where
key
is the dotted path from this module to the tensor.parameters
([recursive])返回一个可迭代对象,遍历当前模块中的所有
Parameter
register_forward_hook
(hook)给模块输出注册一个回调函数。
给模块输入注册一个回调函数。
replace_param
(params, start_pos[, seen])Replaces module's parameters with
params
, used byParamPack
to- rtype
- rtype
state_dict
([rst, prefix, keep_var])tensors
([recursive])Returns an iterable for the
Tensor
of the module.train
([mode, recursive])当前模块中所有模块的
training
属性(包括自身)置为mode
。将所有参数的梯度置0。