Source code for megengine.module.quantized.conv

from typing import Tuple, Union

import numpy as np

from ... import module as Float
from ...core.tensor import dtype
from ...functional.nn import conv_bias_activation, pad
from ...functional.quantized import conv_transpose2d
from ...tensor import Parameter
from ..qat import conv as QAT
from .module import QuantizedModule


[docs]class Conv2d(Float.Conv2d, QuantizedModule): r"""Quantized version of :class:`~.qat.Conv2d`. Applies a 2D convolution over a quantized input tensor, used for inference only. The parameter is same with :class:`~.module.Conv2d`. """ def __init__( self, in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int]], stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, conv_mode: str = "cross_correlation", compute_mode: str = "default", dtype=None, padding_mode: str = "zeros", **kwargs ): super().__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, True, conv_mode, compute_mode, padding_mode, ) self.output_dtype = dtype def calc_conv_quantized(self, inp, nonlinear_mode="identity"): assert self.padding_mode in [ "zeros", "reflect", "replicate", ] inp_scale = dtype.get_scale(inp.dtype) w_scale = dtype.get_scale(self.weight.dtype) bias_scale = inp_scale * w_scale if self.padding_mode != "zeros": return conv_bias_activation( pad(inp, self.get_pad_witdth(), self.padding_mode), self.weight, self.bias.astype(dtype.qint32(bias_scale)), self.output_dtype, self.stride, 0, self.dilation, self.groups, conv_mode=self.conv_mode, compute_mode=self.compute_mode, nonlinear_mode=nonlinear_mode, ) return conv_bias_activation( inp, self.weight, self.bias.astype(dtype.qint32(bias_scale)), self.output_dtype, self.stride, self.padding, self.dilation, self.groups, conv_mode=self.conv_mode, compute_mode=self.compute_mode, nonlinear_mode=nonlinear_mode, )
[docs] @classmethod def from_qat_module(cls, qat_module: QAT.Conv2d): r""" Return a :class:`~.QuantizedModule` instance converted from a :class:`~.QATModule` instance. """ output_dtype = qat_module.get_activation_dtype() qconv = cls( qat_module.in_channels, qat_module.out_channels, qat_module.kernel_size, qat_module.stride, qat_module.padding, qat_module.dilation, qat_module.groups, dtype=output_dtype, padding_mode=qat_module.padding_mode, name=qat_module.name, ) weight = qat_module.weight.astype(qat_module.get_weight_dtype()) qconv.weight = Parameter(weight.numpy(), name=qat_module.weight.name) if qat_module.bias is not None: qconv.bias = Parameter(qat_module.bias.numpy(), name=qat_module.bias.name) else: qconv.bias = Parameter( np.zeros(qat_module._infer_bias_shape(), dtype=np.float32) ) return qconv
def forward(self, inp): return self.calc_conv_quantized(inp, nonlinear_mode="identity")
[docs]class ConvRelu2d(Conv2d): r"""Quantized version of :class:`~.qat.ConvRelu2d`.""" def forward(self, inp): return self.calc_conv_quantized(inp, nonlinear_mode="relu")
[docs]class ConvTranspose2d(Float.ConvTranspose2d, QuantizedModule): r"""Quantized version of :class:`~.qat.ConvTranspose2d`. Applies a 2D transposed convolution over a quantized input tensor, used for inference only. The parameter is same with :class:`~.module.ConvTranspose2d` but dtype. Args: dtype: data type of the output, should be qint8. """ output_padding = 0 def __init__( self, in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int]], stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, output_padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, bias: bool = True, conv_mode: str = "cross_correlation", compute_mode: str = "default", dtype=None, **kwargs ): super().__init__( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding, dilation=dilation, groups=groups, bias=bias, conv_mode=conv_mode, compute_mode=compute_mode, ) self.output_dtype = dtype
[docs] @classmethod def from_qat_module(cls, qat_module: QAT.ConvTranspose2d): r""" return a :class:`~.QuantizedModule` instance converted from a :class:`~.QATModule` instance. """ output_dtype = qat_module.get_activation_dtype() qconv_transpose2d = cls( qat_module.in_channels, qat_module.out_channels, qat_module.kernel_size, qat_module.stride, qat_module.padding, qat_module.output_padding, qat_module.dilation, qat_module.groups, qat_module.bias is not None, qat_module.conv_mode, qat_module.compute_mode, dtype=output_dtype, name=qat_module.name, ) weight = qat_module.weight.astype(qat_module.get_weight_dtype()) qconv_transpose2d.weight = Parameter( weight.numpy(), name=qat_module.weight.name ) qconv_transpose2d.bias = ( Parameter(qat_module.bias.numpy(), name=qat_module.bias.name) if qat_module.bias is not None else None ) return qconv_transpose2d
def calc_conv_transpose2d_quantized(self, inp, nonlinear_mode): assert nonlinear_mode == "identity", "nonlinear_mode shoule be 'identity'" if self.bias is not None: inp_scale = dtype.get_scale(inp.dtype) w_scale = dtype.get_scale(self.weight.dtype) bias_scale = inp_scale * w_scale return conv_transpose2d( inp=inp, weight=self.weight, bias=self.bias.astype(dtype.qint32(bias_scale)) if self.bias is not None else None, dtype=self.output_dtype, stride=self.stride, padding=self.padding, output_padding=self.output_padding, dilation=self.dilation, groups=self.groups, conv_mode=self.conv_mode, compute_mode=self.compute_mode, ) def forward(self, inp): return self.calc_conv_transpose2d_quantized(inp, nonlinear_mode="identity")
class ConvTransposeRelu2d(ConvTranspose2d): r"""Quantized version of :class:`~.qat.ConvTransposeRelu2d`.""" def forward(self, inp): return self.calc_conv_transpose2d_quantized(inp, nonlinear_mode="relu")