BatchNorm2d¶
- class BatchNorm2d(num_features, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True, freeze=False, **kwargs)[源代码]¶
在四维张量上进行批标准化(Batch Normalization)。
\[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta\]公式中均值和标准差是每个 mini-batch 维度分别计算的,而 \(\gamma\) 和 \(\beta\) 则是可学习的参数。
默认在训练过程中,该层持续估计其计算的均值和方差,随后在评估过程中用于标准化。运行中的估计保持默认的
momentum
值 0.9。如果
track_running_stats
设置为False
,此层将不会累计估计统计数据,而是把当前batch的统计数据应用在评估阶段。因为批标准化是在 C 维进行的,(N, H, W) 切片上进行统计计算,所以通常将此方法称作空域批正则(Spatial Batch Normalization)。
注解
running_mean
和running_var
的更新公式如下(以running_mean
为例):\[\textrm{running_mean} = \textrm{momentum} \times \textrm{running_mean} + (1 - \textrm{momentum}) \times \textrm{batch_mean}\]可能细节上和其它框架部完全一致。值得注意的是,在 PyTorch 中的
momentum
如果为 0.1,则在 MegEngine 中对应 0.9。- 参数
num_features – 通常是形状为 \((N, C, H, W)\) 输入数据的 \(C\) 或者维度低于四维的输入的最高维。
eps – 添加到分母的单个值,增加数值稳定性。默认:1e-5
momentum – 用于计算
running_mean
和running_var
的值。默认:0.9affine – 单个布尔值,当设置为
True
,那么这个模块具有可学习的仿射(affine)参数。默认:Truetrack_running_stats – 当设置为 True,则这个模块跟踪运行时的不同batch的均值和方差。当设置为 False,该模块不跟踪这样的统计数据并在训练和eval模式下始终使用当前批统计数据。默认: True
freeze – 设置为True时,此模块不会更新运行运行时平均值和运行时方差,使用运行时平均值和方差而不是批次均值和批次方差来标准化输入。这个参数产生作用仅在使用 track_running_stats=True 初始化模块时有效并且模块处于训练模式。默认:False
- Shape:
Input: \((N, C, H, W)\)
Output: \((N, C, H, W)\) (same shape as input)
实际案例
>>> import numpy as np >>> # With Learnable Parameters >>> m = M.BatchNorm2d(4) >>> inp = mge.tensor(np.random.rand(1, 4, 3, 3).astype("float32")) >>> oup = m(inp) >>> print(m.weight.numpy().flatten(), m.bias.numpy().flatten()) [1. 1. 1. 1.] [0. 0. 0. 0.] >>> # Without Learnable Parameters >>> m = M.BatchNorm2d(4, affine=False) >>> oup = m(inp) >>> print(m.weight, m.bias) None None