BatchNorm2d

class BatchNorm2d(num_features, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True, freeze=False, **kwargs)[源代码]

在四维张量上进行批标准化(Batch Normalization)。

\[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta\]

公式中均值和标准差是每个 mini-batch 维度分别计算的,而 \(\gamma\)\(\beta\) 则是可学习的参数。

默认在训练过程中,该层持续估计其计算的均值和方差,随后在评估过程中用于标准化。运行中的估计保持默认的 momentum 值 0.9。

如果 track_running_stats 设置为 False ,此层将不会累计估计统计数据,而是把当前batch的统计数据应用在评估阶段。

因为批标准化是在 C 维进行的,(N, H, W) 切片上进行统计计算,所以通常将此方法称作空域批正则(Spatial Batch Normalization)。

注解

running_meanrunning_var 的更新公式如下(以 running_mean 为例):

\[\textrm{running_mean} = \textrm{momentum} \times \textrm{running_mean} + (1 - \textrm{momentum}) \times \textrm{batch_mean}\]

可能细节上和其它框架部完全一致。值得注意的是,在 PyTorch 中的 momentum 如果为 0.1,则在 MegEngine 中对应 0.9。

参数
  • num_features – 通常是形状为 \((N, C, H, W)\) 输入数据的 \(C\) 或者维度低于四维的输入的最高维。

  • eps – 添加到分母的单个值,增加数值稳定性。默认:1e-5

  • momentum – 用于计算 running_meanrunning_var 的值。默认:0.9

  • affine – 单个布尔值,当设置为 True ,那么这个模块具有可学习的仿射(affine)参数。默认:True

  • track_running_stats – 当设置为 True,则这个模块跟踪运行时的不同batch的均值和方差。当设置为 False,该模块不跟踪这样的统计数据并在训练和eval模式下始终使用当前批统计数据。默认: True

  • freeze – 设置为True时,此模块不会更新运行运行时平均值和运行时方差,使用运行时平均值和方差而不是批次均值和批次方差来标准化输入。这个参数产生作用仅在使用 track_running_stats=True 初始化模块时有效并且模块处于训练模式。默认:False

Shape:
  • Input: \((N, C, H, W)\)

  • Output: \((N, C, H, W)\) (same shape as input)

实际案例

>>> import numpy as np
>>> # With Learnable Parameters
>>> m = M.BatchNorm2d(4)
>>> inp = mge.tensor(np.random.rand(1, 4, 3, 3).astype("float32"))
>>> oup = m(inp)
>>> print(m.weight.numpy().flatten(), m.bias.numpy().flatten())
[1. 1. 1. 1.] [0. 0. 0. 0.]
>>> # Without Learnable Parameters
>>> m = M.BatchNorm2d(4, affine=False)
>>> oup = m(inp)
>>> print(m.weight, m.bias)
None None