autocast¶
- class autocast(enabled=True, low_prec_dtype='float16', high_prec_dtype='float32')[源代码]¶
作为上下文管理器或装饰器来控制amp的自动转换模式的类。
- 参数
enabled (
bool
) – whether autocast mode is enabled.low_prec_dtype (
str
) – set amp autocast mode’s lower precision dtype. It will change the target dtype in tensor casting for better speed and memory. Default: float16.high_prec_dtype (
str
) – set amp autocast mode’s higher precision dtype. It will change the target dtype in tensor casting for better precision. Default: float32.
- 返回
None
实际案例
# used as decorator @autocast() def train_step(image, label): with gm: logits = model(image) loss = F.nn.cross_entropy(logits, label) gm.backward(loss) opt.step().clear_grad() return loss # used as context manager def train_step(image, label): with autocast(): with gm: logits = model(image) loss = F.nn.cross_entropy(logits, label) gm.backward(loss) opt.step().clear_grad() return loss