# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import math
from abc import abstractmethod
from copy import deepcopy
from typing import Union
import numpy as np
from .. import functional as F
from ..core.tensor.dtype import QuantDtypeMeta, _builtin_quant_dtypes
from ..distributed import WORLD, get_rank, is_distributed
from ..functional.distributed import all_reduce_max, all_reduce_min
from ..logger import get_logger
from ..module import Module
from ..tensor import Tensor
from .utils import QParams, QParamsModuleMixin, QuantMode, create_qparams
logger = get_logger(__name__)
[文档]class Observer(Module, QParamsModuleMixin):
r"""A base class for Observer Module. Used to record input tensor's statistics for
quantization.
Args:
dtype: a string indicating which dtype to collect scale and zero_point of.
"""
def __init__(self, dtype: Union[str, QuantDtypeMeta], **kwargs):
super().__init__()
if isinstance(dtype, str):
if not dtype in _builtin_quant_dtypes:
raise ValueError(
"unknown dtype: {}, only support {}".format(
dtype, _builtin_quant_dtypes.keys()
)
)
dtype = _builtin_quant_dtypes[dtype]
if "narrow_range" in kwargs:
del kwargs["narrow_range"]
logger.warning(
"FakeQuantize currently has no narrow_range param "
"so it is ignored here",
exc_info=DeprecationWarning,
)
self.dtype = dtype
self.qmin = dtype.qmin
self.qmax = dtype.qmax
self.enabled = True
[文档] def enable(self):
self.enabled = True
[文档] def disable(self):
self.enabled = False
[文档] def train(self, mode: bool = True, recursive: bool = True) -> None:
super().train(mode, recursive)
if mode:
self.enable()
else:
self.disable()
[文档] @abstractmethod
def forward(self, x):
pass
[文档]class MinMaxObserver(Observer):
r"""A Observer Module records input tensor's running min and max values to calc scale.
Args:
mode: set quantization mode.
eps: a initial maximum value to avoid division by zero problem.
dtype: a string indicating which dtype to collect scale and zero_point of.
"""
def __init__(
self,
mode: QuantMode = QuantMode.SYMMERTIC,
eps: float = 0.00001,
dtype: Union[str, QuantDtypeMeta] = "qint8",
**kwargs
):
super().__init__(dtype, **kwargs)
self.mode = mode
self.min_val = Tensor(np.finfo(np.float32).max, dtype=np.float32)
self.max_val = Tensor(np.finfo(np.float32).min, dtype=np.float32)
self.scale_limit = eps
def _calculate_qparams(self, inp_min_val, inp_max_val):
min_val = F.minimum(0.0, inp_min_val)
max_val = F.maximum(0.0, inp_max_val)
if self.mode == QuantMode.SYMMERTIC:
symmetric_max_vals = F.maximum(-min_val, max_val)
# use maximun to avoid scale too small at the begin
scale = F.maximum(
symmetric_max_vals / ((self.qmax - self.qmin) / 2), self.scale_limit
)
zero_point = None
else:
# use maximun to avoid scale too small at the begin
scale = F.maximum(
(max_val - min_val) / (self.qmax - self.qmin), self.scale_limit
)
# caculate zero_point
zero_point = self.qmin - F.round((min_val / scale))
return create_qparams(self.mode, self.dtype, scale=scale, zero_point=zero_point)
[文档] def get_qparams(self):
return self._calculate_qparams(self.min_val, self.max_val)
[文档] def forward(self, x_orig):
if self.enabled:
# stop gradient
x = x_orig.detach()
# find max and min
self.min_val[...] = F.minimum(self.min_val, x.min())
self.max_val[...] = F.maximum(self.max_val, x.max())
return x_orig
[文档]class SyncMinMaxObserver(MinMaxObserver):
r"""A distributed version of :class:`~.MinMaxObserver`.
Args:
mode: set quantization mode.
eps: a initial maximum value to avoid division by zero problem.
dtype: a string indicating which dtype to collect scale and zero_point of.
"""
[文档] def forward(self, x_orig):
if self.enable:
x = x_orig.detach()
if is_distributed():
min_x = all_reduce_min(x.min(), WORLD)
max_x = all_reduce_max(x.max(), WORLD)
else:
min_x = x.min()
max_x = x.max()
self.min_val[...] = F.minimum(self.min_val, min_x)
self.max_val[...] = F.maximum(self.max_val, max_x)
return x_orig
[文档]class ExponentialMovingAverageObserver(MinMaxObserver):
r"""A :class:`~.MinMaxObserver` with momentum support for min/max updating.
Args:
momentum: momentum ratio for min/max updating.
mode: set quantization mode.
eps: a initial maximum value to avoid division by zero problem.
dtype: a string indicating which dtype to collect scale and zero_point of.
"""
def __init__(
self,
momentum: float = 0.9,
mode: QuantMode = QuantMode.SYMMERTIC,
eps: float = 0.00001,
dtype: Union[str, QuantDtypeMeta] = "qint8",
**kwargs
):
super().__init__(mode, eps, dtype, **kwargs)
self.momentum = Tensor(momentum, dtype="float32")
# used to avoid if-clauses in the first forward which is not supported
# in trace mode.
self.runtime_momentum = Tensor(0.0)
[文档] def set_momentum(self, momentum):
self.momentum = Tensor(momentum, dtype="float32")
[文档] def forward(self, x_orig):
if self.enabled:
# stop gradient
x = x_orig.detach()
# Exponential Moving Average
self.min_val[...] = (
self.min_val * self.runtime_momentum
+ (1 - self.runtime_momentum) * x.min()
)
self.max_val[...] = (
self.max_val * self.runtime_momentum
+ (1 - self.runtime_momentum) * x.max()
)
self.runtime_momentum[...] = self.momentum
return x_orig
[文档]class SyncExponentialMovingAverageObserver(ExponentialMovingAverageObserver):
r"""A distributed version of :class:`~.ExponentialMovingAverageObserver`.
Args:
momentum: momentum ratio for min/max updating.
mode: set quantization mode.
eps: a initial maximum value to avoid division by zero problem.
dtype: a string indicating which dtype to collect scale and zero_point of.
"""
[文档] def forward(self, x_orig):
if self.enabled:
x = x_orig.detach()
if is_distributed:
min_x = all_reduce_min(x.min(), WORLD)
max_x = all_reduce_max(x.max(), WORLD)
else:
min_x = x.min()
max_x = x.max()
self.min_val[...] = (
self.min_val * self.runtime_momentum
+ (1 - self.runtime_momentum) * min_x
)
self.max_val[...] = (
self.max_val * self.runtime_momentum
+ (1 - self.runtime_momentum) * max_x
)
self.runtime_momentum[...] = self.momentum
return x_orig
[文档]class HistogramObserver(MinMaxObserver):
r"""A :class:`~.MinMaxObserver` using running histogram of tensor values
for min/max updating. Usually used for calibration quantization.
Args:
bins: number of bins to use for the histogram.
upsample_rate: which ratio to interpolate histograms in.
mode: set quantization mode.
eps: a initial maximum value to avoid division by zero problem.
dtype: a string indicating which dtype to collect scale and zero_point of.
"""
def __init__(
self,
bins: int = 2048,
upsample_rate: int = 128,
mode: QuantMode = QuantMode.SYMMERTIC,
eps: float = 0.00001,
dtype: Union[str, QuantDtypeMeta] = "qint8",
**kwargs
):
super().__init__(mode, eps, dtype, **kwargs)
self.bins = bins
self.upsample_rate = upsample_rate
self.dst_nbins = (
_builtin_quant_dtypes[dtype].qmax - _builtin_quant_dtypes[dtype].qmin + 1
)
self.histogram = Tensor([-1] + [0.0] * (bins - 1), dtype="float32")
def _non_linear_param_search(self):
r"""Non-linear parameter search.
An approximation for L2 error minimization for selecting min/max.
By selecting new min/max, we filter out outliers in input distribution.
"""
np_min_val = self.min_val.numpy()
np_max_val = self.max_val.numpy()
np_histogram = self.histogram.numpy()
assert len(np_histogram) == self.bins, "bins mistmatch"
bin_width = (np_max_val - np_min_val) / self.bins
def _get_norm(delta_begin, delta_end, density, norm_type):
r"""Compute the norm of the values uniformaly distributed between
delta_begin and delta_end.
norm = density * (integral_{begin, end} x^2)
= density * (end^3 - begin^3) / 3
"""
assert norm_type == "L2", "Only L2 norms are currently supported"
norm = 0.0
if norm_type == "L2":
norm = (
delta_end * delta_end * delta_end
- delta_begin * delta_begin * delta_begin
) / 3
return density * norm
def _compute_quantization_error(next_start_bin, next_end_bin, norm_type):
r"""Compute the quantization error if we use start_bin to end_bin as the
min and max to do the quantization.
"""
norm = 0.0
dst_bin_width = (
bin_width * (next_end_bin - next_start_bin + 1) / self.dst_nbins
)
if dst_bin_width == 0.0:
return 0.0
for src_bin in range(self.bins):
# distances from the beginning of first dst_bin to the beginning and
# end of src_bin
src_bin_begin = (src_bin - next_start_bin) * bin_width
src_bin_end = src_bin_begin + bin_width
# which dst_bins the beginning and end of src_bin belong to?
dst_bin_of_begin = min(
self.dst_nbins - 1,
max(0.0, math.floor(src_bin_begin / dst_bin_width)),
)
dst_bin_of_end = min(
self.dst_nbins - 1,
max(0.0, math.floor(src_bin_end / dst_bin_width)),
)
dst_bin_of_begin_center = (
dst_bin_of_begin * dst_bin_width + dst_bin_width / 2
)
density = np_histogram[src_bin] / bin_width
if dst_bin_of_begin == dst_bin_of_end:
# if src_bin is entirely within 1 dst_bin
delta_begin = src_bin_begin - dst_bin_of_begin_center
delta_end = src_bin_end - dst_bin_of_begin_center
norm = norm + _get_norm(delta_begin, delta_end, density, norm_type)
else:
delta_begin = src_bin_begin - dst_bin_of_begin_center
delta_end = dst_bin_width / 2
norm = norm + _get_norm(delta_begin, delta_end, density, norm_type)
norm = norm + (dst_bin_of_end - dst_bin_of_begin - 1) * _get_norm(
-dst_bin_width / 2, dst_bin_width / 2, density, norm_type
)
dst_bin_of_end_center = (
dst_bin_of_end * dst_bin_width + dst_bin_width / 2
)
delta_begin = -dst_bin_width / 2
delta_end = src_bin_end - dst_bin_of_end_center
norm = norm + _get_norm(delta_begin, delta_end, density, norm_type)
return norm
# cumulative sum
total = sum(np_histogram)
cSum = np.cumsum(np_histogram, axis=0)
stepsize = 1e-5 # granularity
alpha = 0.0 # lower bound
beta = 1.0 # upper bound
start_bin = 0
end_bin = self.bins - 1
norm_min = float("inf")
while alpha < beta:
# Find the next step
next_alpha = alpha + stepsize
next_beta = beta - stepsize
# find the left and right bins between the quantile bounds
l = start_bin
r = end_bin
while l < end_bin and cSum[l] < next_alpha * total:
l = l + 1
while r > start_bin and cSum[r] > next_beta * total:
r = r - 1
# decide the next move
next_start_bin = start_bin
next_end_bin = end_bin
if (l - start_bin) > (end_bin - r):
# move the start bin
next_start_bin = l
alpha = next_alpha
else:
# move the end bin
next_end_bin = r
beta = next_beta
if next_start_bin == start_bin and next_end_bin == end_bin:
continue
# calculate the quantization error using next_start_bin and next_end_bin
norm = _compute_quantization_error(next_start_bin, next_end_bin, "L2")
if norm > norm_min:
break
norm_min = norm
start_bin = next_start_bin
end_bin = next_end_bin
new_min = self.min_val + Tensor(bin_width * start_bin, dtype=np.float32)
new_max = self.min_val + Tensor(bin_width * (end_bin + 1), dtype=np.float32)
return new_min, new_max
[文档] def get_qparams(self):
new_min, new_max = self._non_linear_param_search()
return self._calculate_qparams(new_min, new_max)
def _combine_histograms(
self, orig_hist, new_hist, upsample_rate, downsample_rate, start_idx, Nbins
):
# First up-sample the histogram with new data by a factor of L
# This creates an approximate probability density thats piecwise constant
upsampled_histogram = new_hist.repeat(upsample_rate)
# Now insert the upsampled histogram into the output
# histogram, which is initialized with zeros.
# The offset at which the histogram is introduced is determined
# by the start index as the output histogram can cover a wider range
histogram_with_output_range = np.zeros((Nbins * downsample_rate))
histogram_with_output_range[
start_idx : Nbins * upsample_rate + start_idx
] = upsampled_histogram
# Compute integral histogram, double precision is needed to ensure
# that there are no overflows
integral_histogram = np.cumsum(histogram_with_output_range, 0)[
downsample_rate - 1 :: downsample_rate
]
# Finally perform interpolation
shifted_integral_histogram = np.zeros((Nbins))
shifted_integral_histogram[1:Nbins] = integral_histogram[0:-1]
interpolated_histogram = (
integral_histogram - shifted_integral_histogram
) / upsample_rate
orig_hist = orig_hist + interpolated_histogram
return orig_hist
def _adjust_min_max(self, combined_min, combined_max, upsample_rate):
# We ensure that:
# (combined_max - combined_min)/(downsample_rate*Nbins) = (max - min)/(upsample_rate*Nbins)
# This allows us to have a common grid of resolution s, where we can align
# the input histogram
# start_idx maps min_val to the histogram bin index.
np_min_val = self.min_val.numpy()
np_max_val = self.max_val.numpy()
hist_bin_width = (np_max_val - np_min_val) / (self.bins * upsample_rate)
downsample_rate = int(
np.ceil((combined_max - combined_min) / (self.bins * hist_bin_width))
)
e = downsample_rate * (self.bins * hist_bin_width) - (
combined_max - combined_min
)
combined_max = combined_max + e / 2
combined_min = combined_min - e / 2
start_idx = int(np.round((np_min_val - combined_min) / hist_bin_width))
return combined_min, combined_max, downsample_rate, start_idx
[文档] def sideeffect_forward(self, x_orig):
x = x_orig.numpy()
min_val = self.min_val.numpy()
max_val = self.max_val.numpy()
histogram = self.histogram.numpy()
new_min = x.min()
new_max = x.max()
if histogram[0] == -1:
new_histogram, _ = np.histogram(x, self.bins, (new_min, new_max))
else:
new_min = min(new_min, min_val)
new_max = max(new_max, max_val)
# combine the existing histogram and new histogram into 1 histogram
# We do this by first upsampling the histogram to a dense grid
# and then downsampling the histogram efficiently
(new_min, new_max, downsample_rate, start_idx) = self._adjust_min_max(
new_min, new_max, self.upsample_rate
)
new_histogram, _ = np.histogram(x, self.bins, (new_min, new_max))
new_histogram = new_histogram.astype(np.float64)
if new_min == min_val and new_max == max_val:
new_histogram += histogram
else:
new_histogram = self._combine_histograms(
new_histogram,
histogram,
self.upsample_rate,
downsample_rate,
start_idx,
self.bins,
)
self.histogram = Tensor(new_histogram, dtype="float32")
self.min_val = Tensor(new_min, dtype="float32")
self.max_val = Tensor(new_max, dtype="float32")
[文档] def forward(self, x_orig):
self.sideeffect_forward(x_orig)
return x_orig
[文档]class PassiveObserver(Observer):
r"""An Observer that supports setting :attr:`scale` directly."""
def __init__(self, dtype: Union[str, QuantDtypeMeta], **kwargs):
super().__init__(dtype, **kwargs)
self.qparams = None
self.orig_scale = None
@property
def scale(self):
return self.qparams.scale
@scale.setter
def scale(self, value: np.ndarray):
assert np.all(value > 0)
self.qparams.scale[...] = Tensor(value)
[文档] def get_qparams(self):
return self.qparams
[文档] def set_qparams(self, qparams: QParams):
r"""set the ``qparams``.
Args:
qparams: used to set initial scale.
"""
self.qparams = deepcopy(qparams)
if qparams.scale is None:
raise AssertionError("Can not get an initialized scale")
if qparams.dtype_meta is None:
qparams.dtype_meta = self.dtype
else:
assert (
qparams.dtype_meta is self.dtype
), "input qparams' dtype is not equal to self.dtype.\nqparams.dtype_meta={}\nself.dtype={}".format(
qparams.dtype_meta, self.dtype
)
self.orig_scale = qparams.scale.numpy()
[文档] def forward(self, x):
r"""Just return input because :attr:`qparams` is set by :func:`~.apply_easy_quant`."""
return x