megengine.module.conv_bn 源代码

# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Tuple, Union

from ..functional import relu
from .batchnorm import BatchNorm2d
from .conv import Conv2d
from .module import Module


class _ConvBnActivation2d(Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: Union[int, Tuple[int, int]],
        stride: Union[int, Tuple[int, int]] = 1,
        padding: Union[int, Tuple[int, int]] = 0,
        dilation: Union[int, Tuple[int, int]] = 1,
        groups: int = 1,
        bias: bool = True,
        conv_mode: str = "cross_correlation",
        compute_mode: str = "default",
        eps=1e-5,
        momentum=0.9,
        affine=True,
        track_running_stats=True,
        **kwargs
    ):
        super().__init__()
        self.conv = Conv2d(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            dilation,
            groups,
            bias,
            conv_mode,
            compute_mode,
            **kwargs,
        )
        self.bn = BatchNorm2d(out_channels, eps, momentum, affine, track_running_stats)


[文档]class ConvBn2d(_ConvBnActivation2d): r"""A fused :class:`~.Module` including :class:`~.module.Conv2d` and :class:`~.module.BatchNorm2d`. Could be replaced with :class:`~.QATModule` version :class:`~.qat.ConvBn2d` using :func:`~.quantize.quantize_qat`. """
[文档] def forward(self, inp): return self.bn(self.conv(inp))
[文档]class ConvBnRelu2d(_ConvBnActivation2d): r"""A fused :class:`~.Module` including :class:`~.module.Conv2d`, :class:`~.module.BatchNorm2d` and :func:`~.relu`. Could be replaced with :class:`~.QATModule` version :class:`~.qat.ConvBnRelu2d` using :func:`~.quantize.quantize_qat`. """
[文档] def forward(self, inp): return relu(self.bn(self.conv(inp)))