from typing import Tuple, Union
from ..functional import relu
from .batchnorm import BatchNorm2d
from .conv import Conv2d
from .module import Module
class _ConvBnActivation2d(Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
bias: bool = True,
conv_mode: str = "cross_correlation",
compute_mode: str = "default",
eps=1e-5,
momentum=0.9,
affine=True,
track_running_stats=True,
padding_mode: str = "zeros",
**kwargs
):
super().__init__(**kwargs)
self.conv = Conv2d(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
conv_mode,
compute_mode,
padding_mode,
**kwargs,
)
self.bn = BatchNorm2d(out_channels, eps, momentum, affine, track_running_stats)
[docs]class ConvBn2d(_ConvBnActivation2d):
r"""A fused :class:`~.Module` including :class:`~.module.Conv2d` and :class:`~.module.BatchNorm2d`.
Could be replaced with :class:`~.QATModule` version :class:`~.qat.ConvBn2d` using
:func:`~.quantize.quantize_qat`.
"""
def forward(self, inp):
return self.bn(self.conv(inp))
[docs]class ConvBnRelu2d(_ConvBnActivation2d):
r"""A fused :class:`~.Module` including :class:`~.module.Conv2d`, :class:`~.module.BatchNorm2d` and :func:`~.relu`.
Could be replaced with :class:`~.QATModule` version :class:`~.qat.ConvBnRelu2d` using :func:`~.quantize.quantize_qat`.
"""
def forward(self, inp):
return relu(self.bn(self.conv(inp)))