BatchNorm 差异对比¶
torch.nn.BatchNorm2d
torch.nn.BatchNorm2d(
num_features,
eps=1e-05,
momentum=0.1,
affine=True,
track_running_stats=True,
device=None,
dtype=None
)
更多请查看 torch.nn.BatchNorm2d
.
megengine.module.BatchNorm2d
megengine.module.BatchNorm2d(
num_features,
eps=1e-05,
momentum=0.9,
affine=True,
track_running_stats=True,
freeze=False,
**kwargs
)
更多请查看 megengine.module.BatchNorm2d
.
功能差异¶
momentum 差异¶
Warning
MegEngine 的 momentum
参数默认值为 0.9, 而 PyTorch 的默认值为 0.1.
在实际计算时效果一致,这表明该参数在 MegEngine 中的含义与 PyTorch 中的含义不同。
running_mean 和 running_var 的计算方式
MegEngine 中 running_mean
和 running_var
的更新公式如下:
\[ \begin{align}\begin{aligned}\begin{aligned}\\\begin{split}\textrm{running_mean} = &\textrm{momentum} \times \textrm{running_mean} \\
&+ (1 - \textrm{momentum}) \times \textrm{batch_mean}\end{split}\\\end{aligned}\end{aligned}\end{align} \]
running_var
的更新过程与上同理,MegEngine 的 momentum
的含义更符合其惯性的实际语义,
而 PyTorch 的 momentum
参数是指数加权平均的衰减率,即 1 - momentum
(此处指 MegEngine 中的 momentum
)。
冻结参数¶
megengine.module.BatchNorm2d
支持 freeze
参数,用于冻结 BN 层的参数。
在 Pytorch 中可能需要使用类似的方法进行冻结:
for child in model.children():
if isinstance (child, torch.nn.BatchNorm2d):
for param in child.parameters():
param.requires_grad = False
而在 MegEngine 中,只需要将 freeze
参数设置为 True
即可。