# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from abc import abstractmethod
# avoid circular reference
from ...quantization.fake_quant import FakeQuantize
from ...quantization.observer import Observer
from ...quantization.qconfig import QConfig
from ...quantization.utils import fake_quant_bias
from ...tensor import Tensor
from ..module import Module
[文档]class QATModule(Module):
r"""Base class of quantized-float related :class:`~.Module`, basically for QAT and Calibration.
Use :meth:`from_float_module` to generate a instance from float :class:`~.Module`.
Or use :func:`~.quantize.quantize_qat` to do it recursively and automatically.
Can also be converted to :class:`~.QuantizedModule` for deployment using
:func:`~.quantize.quantize` further.
"""
with_weight = True
with_act = True
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.weight_observer = None # type: Observer
self.act_observer = None # type: Observer
self.weight_fake_quant = None # type: FakeQuantize
self.act_fake_quant = None # type: FakeQuantize
def __repr__(self):
return "QAT." + super().__repr__()
[文档] def set_qconfig(self, qconfig: QConfig):
r"""Set quantization related configs with ``qconfig``, including
observer and fake_quant for weight and activation.
"""
def safe_call(func):
return func() if func is not None else None
if self.with_act:
self.act_observer = safe_call(qconfig.act_observer)
self.act_fake_quant = safe_call(qconfig.act_fake_quant)
if self.with_weight:
self.weight_observer = safe_call(qconfig.weight_observer)
self.weight_fake_quant = safe_call(qconfig.weight_fake_quant)
def _enable_exec(self, with_module, func, enable):
if not with_module or not func:
return
if enable:
func.enable()
else:
func.disable()
[文档] def set_fake_quant(self, enable):
self._enable_exec(self.with_act, self.act_fake_quant, enable)
self._enable_exec(self.with_weight, self.weight_fake_quant, enable)
[文档] def set_observer(self, enable):
self._enable_exec(self.with_act, self.act_observer, enable)
self._enable_exec(self.with_weight, self.weight_observer, enable)
def _apply_fakequant_with_observer(
self, target: Tensor, fake_quant: FakeQuantize, observer: Observer
):
# do observer
if observer is None:
oup = target
qparams = None
else:
oup = observer(target)
qparams = observer.get_qparams()
# do fake quant
if fake_quant is not None:
oup = fake_quant(oup, qparams)
# use qparams of fake_quant if have.
if hasattr(fake_quant, "get_qparams"):
qparams = fake_quant.get_qparams()
# set to tensor qparams.
if qparams is not None:
oup.qparams.update(qparams)
return oup
[文档] def apply_quant_weight(self, target: Tensor):
r"""Apply weight's observer and fake_quant from ``qconfig`` on ``target``."""
return self._apply_fakequant_with_observer(
target, self.weight_fake_quant, self.weight_observer
)
[文档] def apply_quant_activation(self, target: Tensor):
r"""Apply weight's observer and fake_quant from ``qconfig`` on ``target``."""
return self._apply_fakequant_with_observer(
target, self.act_fake_quant, self.act_observer
)
[文档] def apply_quant_bias(self, target: Tensor, inp: Tensor, w_qat: Tensor):
r"""Use :func:`~.fake_quant_bias` to process ``target``. Only valid when
``act_fake_quant`` and ``weight_fake_quant`` are both enabled.
"""
# bias should have the same dtype as activation, so act_fake_quant can also
# decide whether to do bias fakequant
if (
self.act_fake_quant
and self.act_fake_quant.enabled
and self.weight_fake_quant
and self.weight_fake_quant.enabled
):
b_qat = fake_quant_bias(target, inp, w_qat)
else:
b_qat = target
return b_qat
def _get_method_result(
self, method: str, fake_quant: FakeQuantize, observer: Observer
):
if hasattr(fake_quant, method):
return getattr(fake_quant, method)()
elif hasattr(observer, method):
return getattr(observer, method)()
return None
[文档] def get_weight_dtype(self):
r"""Get weight's quantization dtype as the method from ``qconfig``."""
return self._get_method_result(
"get_quantized_dtype", self.weight_fake_quant, self.weight_observer
)
[文档] def get_activation_dtype(self):
r"""Get activation's quantization dtype as the method from ``qconfig``."""
return self._get_method_result(
"get_quantized_dtype", self.act_fake_quant, self.act_observer
)
[文档] def get_weight_qparams(self):
r"""Get weight's quantization parameters."""
return self._get_method_result(
"get_qparams", self.weight_fake_quant, self.weight_observer
)
[文档] def get_activation_qparams(self):
r"""Get activation's quantization parameters."""
return self._get_method_result(
"get_qparams", self.act_fake_quant, self.act_observer
)
[文档] @classmethod
@abstractmethod
def from_float_module(cls, float_module: Module):
r"""Return a :class:`~.QATModule` instance converted from
a float :class:`~.Module` instance.
"""