# -*- coding: utf-8 -*-
from typing import Optional
import numpy as np
from ..distributed.group import WORLD, Group
from ..functional.nn import batch_norm, sync_batch_norm
from ..tensor import Parameter, Tensor
from . import init
from .module import Module
class _BatchNorm(Module):
def __init__(
self,
num_features,
eps=1e-5,
momentum=0.9,
affine=True,
track_running_stats=True,
freeze=False,
**kwargs
):
super(_BatchNorm, self).__init__(**kwargs)
self.num_features = num_features
self.eps = eps
self.momentum = momentum
self.affine = affine
self.track_running_stats = track_running_stats
self._track_running_stats_saved = track_running_stats
self.freeze = freeze
if self.freeze:
assert (
self._track_running_stats_saved
), "track_running_stats must be initilized to True if freeze is True"
tshape = (1, self.num_features, 1, 1)
if self.affine:
self.weight = Parameter(np.ones(tshape, dtype=np.float32))
self.bias = Parameter(np.zeros(tshape, dtype=np.float32))
else:
self.weight = None
self.bias = None
if self.track_running_stats:
self.running_mean = Tensor(np.zeros(tshape, dtype=np.float32))
self.running_var = Tensor(np.ones(tshape, dtype=np.float32))
else:
self.running_mean = None
self.running_var = None
def reset_running_stats(self) -> None:
if self.track_running_stats:
init.zeros_(self.running_mean)
init.ones_(self.running_var)
def reset_parameters(self) -> None:
self.reset_running_stats()
if self.affine:
init.ones_(self.weight)
init.zeros_(self.bias)
def _check_input_ndim(self, inp):
raise NotImplementedError
def forward(self, inp):
self._check_input_ndim(inp)
if self._track_running_stats_saved == False:
assert (
self.track_running_stats == False
), "track_running_stats can not be initilized to False and changed to True later"
_weight = self.weight
_bias = self.bias
if self.freeze:
if _weight is not None:
_weight = _weight.detach()
if _bias is not None:
_bias = _bias.detach()
# fastpath excution for freeze
scale = (self.running_var + self.eps) ** (-0.5)
if _weight is not None:
scale *= _weight
bias = -self.running_mean * scale
if _bias is not None:
bias += _bias
return inp * scale + bias
if self.training and self.track_running_stats:
exponential_average_factor = self.momentum
else:
exponential_average_factor = 0.0 # useless
output = batch_norm(
inp,
self.running_mean if self.track_running_stats else None,
self.running_var if self.track_running_stats else None,
_weight,
_bias,
training=self.training
or ((self.running_mean is None) and (self.running_var is None)),
momentum=exponential_average_factor,
eps=self.eps,
)
return output
def _module_info_string(self) -> str:
s = (
"{num_features}, eps={eps}, momentum={momentum}, affine={affine}, "
"track_running_stats={track_running_stats}"
)
return s.format(**self.__dict__)
[文档]class SyncBatchNorm(_BatchNorm):
r"""Applies Synchronized Batch Normalization for distributed training.
Args:
num_features: usually :math:`C` from an input of shape
:math:`(N, C, H, W)` or the highest ranked dimension of an input
less than 4D.
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the ``running_mean`` and ``running_var`` computation.
Default: 0.9
affine: a boolean value that when set to True, this module has
learnable affine parameters. Default: True
track_running_stats: when set to True, this module tracks the
running mean and variance. When set to False, this module does not
track such statistics and always uses batch statistics in both training
and eval modes. Default: True
freeze: when set to True, this module does not update the
running mean and variance, and uses the running mean and variance instead of
the batch mean and batch variance to normalize the input. The parameter takes effect
only when the module is initilized with track_running_stats as True.
Default: False
group: communication group, caculate mean and variance between this group.
Default: :obj:`~.distributed.WORLD`
"""
def __init__(
self,
num_features,
eps=1e-5,
momentum=0.9,
affine=True,
track_running_stats=True,
freeze=False,
group: Optional[Group] = WORLD,
**kwargs
) -> None:
super().__init__(
num_features, eps, momentum, affine, track_running_stats, freeze, **kwargs
)
self.group = group
def _check_input_ndim(self, inp):
if len(inp.shape) not in {2, 3, 4}:
raise ValueError(
"expected 2D, 3D or 4D input (got {}D input)".format(len(inp.shape))
)
def forward(self, inp):
self._check_input_ndim(inp)
inp_shape = inp.shape
_ndims = len(inp_shape)
if _ndims != 4:
new_shape = Tensor([1, 1, 1, 1], device=inp.device)
origin_shape = inp_shape
if _ndims == 2:
new_shape[:2] = origin_shape[:2]
elif _ndims == 3:
new_shape[:3] = origin_shape[:3]
else:
raise ValueError(
"expected 2D, 3D or 4D input (got {}D input)".format(len(inp_shape))
)
inp = inp.reshape(new_shape)
if self.training and self.track_running_stats:
exponential_average_factor = self.momentum
else:
exponential_average_factor = 0.0 # useless
_weight = self.weight
_bias = self.bias
if self.freeze:
if _weight is not None:
_weight = _weight.detach()
if _bias is not None:
_bias = _bias.detach()
output = sync_batch_norm(
inp,
self.running_mean,
self.running_var,
_weight,
_bias,
training=(self.training and not self.freeze)
or ((self.running_mean is None) and (self.running_var is None)),
momentum=exponential_average_factor,
eps=self.eps,
group=self.group,
)
if _ndims != 4:
output = output.reshape(origin_shape)
return output
[文档]class BatchNorm1d(_BatchNorm):
r"""Applies Batch Normalization over a 2D or 3D input.
.. math::
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
The mean and standard-deviation are calculated per-dimension over
the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
of size `C` (where `C` is the number of features or channels of the input). By default, the
elements of :math:`\gamma` are set to 1 and the elements of :math:`\beta` are set to 0. The
standard-deviation is calculated via the biased estimator, equivalent to `torch.var(input, unbiased=False)`.
By default, during training this layer keeps running estimates of its
computed mean and variance, which are then used for normalization during
evaluation. The running estimates are kept with a default :attr:`momentum`
of 0.9.
If :attr:`track_running_stats` is set to ``False``, this layer then does not
keep running estimates, and batch statistics are instead used during
evaluation time as well.
Because the Batch Normalization is done over the `C` dimension, computing statistics
on `(N, L)` slices, it's common terminology to call this Temporal Batch Normalization.
.. note::
The update formula for ``running_mean`` and ``running_var`` (taking ``running_mean`` as an example) is
.. math::
\textrm{running_mean} = \textrm{momentum} \times \textrm{running_mean} + (1 - \textrm{momentum}) \times \textrm{batch_mean}
which could be defined differently in other frameworks. Most notably, ``momentum`` of 0.1 in PyTorch
is equivalent to ``mementum`` of 0.9 here.
Shape:
- Input: :math:`(N, C)` or :math:`(N, C, L)`, where :math:`N` is the batch size,
:math:`C` is the number of features or channels, and :math:`L` is the sequence length
- Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
"""
def _check_input_ndim(self, inp):
if len(inp.shape) not in {2, 3}:
raise ValueError(
"expected 2D or 3D input (got {}D input)".format(len(inp.shape))
)
[文档]class BatchNorm2d(_BatchNorm):
r"""Applies Batch Normalization over a 4D tensor.
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
The mean and standard-deviation are calculated per-dimension over
the mini-batches and :math:`\gamma` and :math:`\beta` are learnable
parameter vectors.
By default, during training this layer keeps running estimates of its
computed mean and variance, which are then used for normalization during
evaluation. The running estimates are kept with a default :attr:`momentum`
of 0.9.
If :attr:`track_running_stats` is set to ``False``, this layer will not
keep running estimates, batch statistics is used during
evaluation time instead.
Because the Batch Normalization is done over the `C` dimension, computing
statistics on `(N, H, W)` slices, it's common terminology to call this
Spatial Batch Normalization.
.. note::
The update formula for ``running_mean`` and ``running_var`` (taking ``running_mean`` as an example) is
.. math::
\textrm{running_mean} = \textrm{momentum} \times \textrm{running_mean} + (1 - \textrm{momentum}) \times \textrm{batch_mean}
which could be defined differently in other frameworks. Most notably, ``momentum`` of 0.1 in PyTorch
is equivalent to ``mementum`` of 0.9 here.
Args:
num_features: usually :math:`C` from an input of shape
:math:`(N, C, H, W)` or the highest ranked dimension of an input
less than 4D.
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the ``running_mean`` and ``running_var`` computation.
Default: 0.9
affine: a boolean value that when set to True, this module has
learnable affine parameters. Default: True
track_running_stats: when set to True, this module tracks the
running mean and variance. When set to False, this module does not
track such statistics and always uses batch statistics in both training
and eval modes. Default: True
freeze: when set to True, this module does not update the
running mean and variance, and uses the running mean and variance instead of
the batch mean and batch variance to normalize the input. The parameter takes effect
only when the module is initilized with track_running_stats as True.
Default: False
Shape:
- Input: :math:`(N, C, H, W)`
- Output: :math:`(N, C, H, W)` (same shape as input)
Examples:
>>> import numpy as np
>>> # With Learnable Parameters
>>> m = M.BatchNorm2d(4)
>>> inp = mge.tensor(np.random.rand(1, 4, 3, 3).astype("float32"))
>>> oup = m(inp)
>>> print(m.weight.numpy().flatten(), m.bias.numpy().flatten())
[1. 1. 1. 1.] [0. 0. 0. 0.]
>>> # Without Learnable Parameters
>>> m = M.BatchNorm2d(4, affine=False)
>>> oup = m(inp)
>>> print(m.weight, m.bias)
None None
"""
def _check_input_ndim(self, inp):
if len(inp.shape) != 4:
raise ValueError("expected 4D input (got {}D input)".format(len(inp.shape)))