Observer#

class Observer(dtype, **kwargs)[源代码]#

所有 Observer 的基类。用于记录输入 Tensor 的统计信息以进行量化。

参数:

dtype (Union[str, QuantDtypeMeta]) – 字符串,表明按何种dtype来收集scale和zero_point。

train(mode=True, recursive=True)[源代码]#

将该模块中的所有模块(包括它自身)的训练模式设置为 mode 。 可便捷地将这些模块的 training 属性设置为 mode ,但仅对某些模块有效(例如 BatchNorm2d, Dropout, Observer)

参数:
  • mode (bool) – 为模块设置的训练模式。

  • recursive (bool) – 是否要递归调用子模块的 train()

返回类型:

None