importmathfromtypingimportUnionfrom..importfunctionalasFfrom..core.tensor.dtypeimportQuantDtypeMeta,_builtin_quant_dtypesfrom..loggerimportget_loggerfrom..moduleimportModulefrom..tensorimportParameter,Tensorfrom.utilsimport(LSQParams,QParams,QParamsModuleMixin,QuantMode,create_qparams,fake_quant_tensor,lsq_forward,tqt_forward,)logger=get_logger(__name__)class_FakeQuantize(Module):def__init__(self,dtype:Union[str,QuantDtypeMeta],enable:bool=True,**kwargs):super().__init__()ifisinstance(dtype,str):ifnotdtypein_builtin_quant_dtypes:raiseValueError("unknown dtype: {}, only support {}".format(dtype,_builtin_quant_dtypes.keys()))dtype=_builtin_quant_dtypes[dtype]if"narrow_range"inkwargs:delkwargs["narrow_range"]logger.warning("FakeQuantize currently has no narrow_range param ""so it is ignored here",exc_info=DeprecationWarning,)self.dtype=dtypeself.qmin=dtype.qminself.qmax=dtype.qmaxself.enabled=enabledefenable(self):self.enabled=Truedefdisable(self):self.enabled=Falsedeffake_quant_forward(self,inp,qparams:QParams=None):raiseNotImplementedErrordefnormal_forward(self,inp,qparams:QParams=None):returninpdefforward(self,inp,qparams:QParams=None):ifself.enabled:returnself.fake_quant_forward(inp,qparams=qparams)else:returnself.normal_forward(inp,qparams=qparams)
[docs]classTQT(_FakeQuantize,QParamsModuleMixin):r"""TQT: https://arxiv.org/abs/1903.08066 Trained Quantization Thresholds for Accurate and Efficient Fixed-Point Inference of Deep Neural Networks. Args: dtype: a string or :class:`~.QuantDtypeMeta` indicating the target quantization dtype of input. enable: whether do ``normal_forward`` or ``fake_quant_forward``. """def__init__(self,dtype:Union[str,QuantDtypeMeta],enable:bool=True,**kwargs):super().__init__(dtype,enable,**kwargs)self.scale=Parameter(0.0,dtype="float32")deffake_quant_forward(self,inp,qparams:QParams=None):# when enable, TQT will do fakequant forward, finetune the scalereturntqt_forward(self.qmin,self.qmax,inp,self.scale)defset_qparams(self,qparams:QParams):assert(qparams.mode==QuantMode.SYMMERTIC),"only symmetric quantization is supported by TQT"ifqparams.scaleisNone:raiseAssertionError("Can not get an initialized scale")self.scale[...]=F.log(qparams.scale)/math.log(2)defget_qparams(self):returncreate_qparams(QuantMode.SYMMERTIC,self.dtype,scale=2**self.scale)
[docs]classFakeQuantize(_FakeQuantize):r"""A module to do quant and dequant according to observer's scale and zero_point. Args: dtype: a string or :class:`~.QuantDtypeMeta` indicating the target quantization dtype of input. enable: whether do ``normal_forward`` or ``fake_quant_forward``. """deffake_quant_forward(self,inp,qparams:QParams=None):assert(qparams.dtype_metaisself.dtype),"input qparams' dtype is not equal to self.dtype.\nqparams.dtype_meta={}\nself.dtype={}".format(qparams.dtype_meta,self.dtype)returnfake_quant_tensor(inp,qparams)
[docs]classLSQ(_FakeQuantize,QParamsModuleMixin):r"""LSQ: https://arxiv.org/pdf/1902.08153.pdf Estimating and scaling the task loss gradient at each weight and activation layer's quantizer step size Args: dtype: a string or :class:`~.QuantDtypeMeta` indicating the target quantization dtype of input. enable: whether do ``normal_forward`` or ``fake_quant_forward``. eps: a small value to avoid division by zero. Default: 1e-5 """def__init__(self,dtype:Union[str,QuantDtypeMeta],enable:bool=True,eps:float=1e-5,**kwargs):super().__init__(dtype=dtype,enable=enable,**kwargs)self.eps=Tensor(eps,dtype="float32")self.step_size=Parameter(1.0,dtype="float32")self.mode=Noneself.zero_point=Tensor(0.0,dtype="float32")self.grad_scale=Tensor(1.0,dtype="float32")defset_qparams(self,qparams:QParams):self.mode=qparams.modeifqparams.mode==QuantMode.ASYMMERTIC:self.zero_point=qparams.zero_pointelse:self.zero_point=Tensor(0.0,dtype="float32")ifqparams.scaleisNone:raiseAssertionError("Can not get an initialized scale")init_step_size=qparams.scaleifinit_step_size<self.eps:init_step_size=Tensor(0.0,dtype="float32")else:init_step_size=Tensor(init_step_size-self.eps)self.step_size=Parameter(init_step_size.item(),dtype="float32")ifisinstance(qparams,LSQParams):self.grad_scale=qparams.grad_scaledeffake_quant_forward(self,inp,qparams:LSQParams=None):step_size=F.abs(self.step_size)+self.epsreturnlsq_forward(self.qmin,self.qmax,inp,step_size,self.zero_point,self.grad_scale)defget_qparams(self):returnLSQParams(mode=self.mode,dtype_meta=self.dtype,scale=F.abs(self.step_size.detach())+self.eps,zero_point=self.zero_point,grad_scale=self.grad_scale,)