# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import abc
from enum import Enum
from functools import partial, update_wrapper, wraps
from typing import Union
import numpy as np
from .. import functional as F
from ..autodiff import Function
from ..core._imperative_rt.core2 import apply
from ..core.ops import builtin
from ..core.tensor.dtype import (
QuantDtypeMeta,
_builtin_quant_dtypes,
create_quantized_dtype,
)
from ..tensor import Tensor
class Round(Function):
r"""The functional round have no grad and can not use for quantization-aware-training.
We use Function and STE(Straight-Through Estimator) to implement backward propagation.
"""
def forward(self, x):
return F.round(x)
def backward(self, output_grads):
return output_grads
def tqt_forward(qmin, qmax, inp, scale):
op = builtin.TQT(qmin=qmin, qmax=qmax)
(output,) = apply(op, inp, scale)
return output
def lsq_forward(qmin, qmax, inp, step_size, zero_point=None, scale_grad=None):
if zero_point is None:
zero_point = Tensor([0.0], dtype=np.float32)
if scale_grad is None:
scale_grad = Tensor([1.0], dtype=np.float32)
op = builtin.LSQ(qmin=qmin, qmax=qmax)
(output,) = apply(op, inp, step_size, zero_point, scale_grad)
return output
def register_method_to_class(cls):
def decorator(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
return func(self, *args, **kwargs)
if isinstance(func, partial):
update_wrapper(func, func.func)
setattr(cls, func.__name__, wrapper)
return func
return decorator
[文档]class QuantMode(Enum):
r"""Quantization mode enumerate class."""
SYMMERTIC = 1
ASYMMERTIC = 2
[文档]class QParams:
r"""To standardize FakeQuant, Observer and Tensor's qparams format. If custom
qparams is needed, inherit this class and add custom ``__slots__``.
"""
__slots__ = "mode", "dtype_meta", "scale", "zero_point"
def __init__(
self,
mode: QuantMode,
dtype_meta: QuantDtypeMeta,
scale: Tensor,
zero_point: Tensor,
):
self.mode = mode
self.dtype_meta = dtype_meta
self.scale = scale
self.zero_point = zero_point
[文档] def update(self, qparams: "QParams"):
for key in self.__slots__:
setattr(self, key, getattr(qparams, key))
def __eq__(self, other):
if len(self.__slots__) != len(other.__slots__):
return False
for key in self.__slots__:
if not hasattr(other, key) or getattr(self, key) != getattr(other, key):
return False
return True
def __repr__(self):
content = ", ".join(
["{}={}".format(key, getattr(self, key)) for key in self.__slots__]
)
return "QParams({})".format(content)
class LSQParams:
r"""To standardize LSQ's qparams format. If custom
qparams is needed, inherit this class and add custom ``__slots__``.
"""
__slots__ = "mode", "dtype_meta", "scale", "zero_point", "grad_scale"
def __init__(
self,
mode: QuantMode,
dtype_meta: QuantDtypeMeta,
scale: Tensor,
zero_point: Tensor,
grad_scale: Tensor,
):
self.mode = mode
self.dtype_meta = dtype_meta
self.scale = scale
self.zero_point = zero_point
self.grad_scale = grad_scale
def update(self, lsqparams: "LSQParams"):
for key in self.__slots__:
setattr(self, key, getattr(lsqparams, key))
def __eq__(self, other):
if len(self.__slots__) != len(other.__slots__):
return False
for key in self.__slots__:
if not hasattr(other, key) or getattr(self, key) != getattr(other, key):
return False
return True
def __repr__(self):
content = ", ".join(
["{}={}".format(key, getattr(self, key)) for key in self.__slots__]
)
return "LSQParams({})".format(content)
class QParamsModuleMixin(abc.ABC):
def get_quantized_dtype(self):
qparams = self.get_qparams()
dtype = qparams.dtype_meta
scale = float(qparams.scale.numpy()) if qparams.scale is not None else None
zero_point = (
int(qparams.zero_point.numpy()) if qparams.zero_point is not None else None
)
return create_quantized_dtype(dtype, scale, zero_point)
@abc.abstractmethod
def get_qparams(self) -> QParams:
pass
_builtin_qparams = {
QuantMode.SYMMERTIC: partial(QParams, mode=QuantMode.SYMMERTIC),
QuantMode.ASYMMERTIC: partial(QParams, mode=QuantMode.ASYMMERTIC),
}
[文档]def create_qparams(
mode: QuantMode = QuantMode.SYMMERTIC,
dtype_meta: Union[str, QuantDtypeMeta] = None,
scale: Tensor = None,
zero_point: Tensor = None,
):
r"""
Args:
mode: QuantMode:
dtype_meta: Union[str:
QuantDtypeMeta]:
scale: Tensor:
zero_point: Tensor:
"""
if isinstance(dtype_meta, str):
dtype_meta = _builtin_quant_dtypes[dtype_meta]
if mode is None:
return QParams(mode, dtype_meta, scale, zero_point)
assert isinstance(mode, QuantMode)
return _builtin_qparams[mode](
dtype_meta=dtype_meta, scale=scale, zero_point=zero_point
)
[文档]def fake_quant_tensor(inp: Tensor, qparams: QParams) -> Tensor:
"""Apply fake quantization to the inp tensor.
Args:
inp: the input tensor which need to be faked.
qparams: to get mode, qmin, qmax, scale and zero_point from.
"""
scale = qparams.scale
if qparams.mode == QuantMode.ASYMMERTIC:
zero_point = qparams.zero_point
else:
zero_point = Tensor([0.0], dtype=np.float32)
qmin = qparams.dtype_meta.qmin
qmax = qparams.dtype_meta.qmax
op = builtin.FakeQuant(qmin=qmin, qmax=qmax)
return apply(op, inp, scale, zero_point)[0]
[文档]def fake_quant_bias(bias: Tensor, inp: Tensor, w_qat: Tensor) -> Tensor:
"""Apply fake quantization to bias, with the special scale from input tensor
and weight tensor, the quantized type set to qint32 also.
Args:
bias: the bias tensor which need to be faked.
inp: the input tensor which contain the quantization parameters.
w_qat: the weight tensor which contain the quantization parameters.
Warning:
Only work for symmetric quantization method now.
"""
b_qat = bias
if (
getattr(inp, "qparams", None) is not None
and getattr(w_qat, "qparams", None) is not None
and bias is not None
):
inp_params = inp.qparams
w_params = w_qat.qparams
if inp_params.scale is not None and w_params.scale is not None:
assert inp_params.mode == w_params.mode, "incompatible QuantMode"
# TODO: support quint8 dtype.
assert (
inp_params.dtype_meta.np_dtype_str == "int8"
and w_params.dtype_meta.np_dtype_str == "int8"
), "fake_quant_bias only support int8 like dtype now"
# use the same mode with weight.
# TODO: avoid hardcode
b_dtype = _builtin_quant_dtypes["qint32"]
b_param = create_qparams(
w_params.mode, b_dtype, scale=inp_params.scale * w_params.scale
)
b_qat = fake_quant_tensor(bias, b_param)
b_qat.qparams.update(b_param)
return b_qat