megengine.quantization.LSQ

class LSQ(dtype, enable=True, eps=1e-05, **kwargs)[源代码]

LSQ: https://arxiv.org/pdf/1902.08153.pdf Estimating and scaling the task loss gradient at each weight and activation layer’s quantizer step size

参数
  • dtype (Union[str, QuantDtypeMeta]) – a string or QuantDtypeMeta indicating the target quantization dtype of input.

  • enable (bool) – whether do normal_forward or fake_quant_forward.

  • eps (float) – a small value to avoid division by zero. Default: 1e-5

Methods

apply(fn)

Applies function fn to all the modules within this module, including itself.

buffers([recursive])

Returns an iterable for the buffers of the module.

children(**kwargs)

Returns an iterable for all the submodules that are direct attributes of this module.

disable()

disable_quantize([value])

Sets module's quantize_disabled attribute and return module.

enable()

eval()

Sets training mode of all the modules within this module (including itself) to False.

fake_quant_forward(inp[, qparams])

forward(inp[, qparams])

get_qparams()

get_quantized_dtype()

load_state_dict(state_dict[, strict])

Loads a given dictionary created by state_dict into this module.

modules(**kwargs)

Returns an iterable for all the modules within this module, including itself.

named_buffers([prefix, recursive])

Returns an iterable for key buffer pairs of the module, where key is the dotted path from this module to the buffer.

named_children(**kwargs)

Returns an iterable of key-submodule pairs for all the submodules that are direct attributes of this module, where 'key' is the attribute name of submodules.

named_modules([prefix])

Returns an iterable of key-module pairs for all the modules within this module, including itself, where 'key' is the dotted path from this module to the submodules.

named_parameters([prefix, recursive])

Returns an iterable for key Parameter pairs of the module, where key is the dotted path from this module to the Parameter.

named_tensors([prefix, recursive])

Returns an iterable for key tensor pairs of the module, where key is the dotted path from this module to the tensor.

normal_forward(inp[, qparams])

parameters([recursive])

Returns an iterable for the Parameter of the module.

register_forward_hook(hook)

Registers a hook to handle forward results.

register_forward_pre_hook(hook)

Registers a hook to handle forward inputs.

replace_param(params, start_pos[, seen])

Replaces module's parameters with params, used by ParamPack to

set_qparams(qparams)

state_dict([rst, prefix, keep_var])

tensors([recursive])

Returns an iterable for the Tensor of the module.

train([mode, recursive])

Sets training mode of all the modules within this module (including itself) to mode.

zero_grad()

Sets all parameters' grads to zero