megengine.amp.autocast¶
- class autocast(enabled=True, low_prec_dtype='float16', high_prec_dtype='float32')[源代码]¶
A class to control autocast mode for amp as a context manager or a decorator.
- 参数
enabled (
bool
) – Whether autocast mode is enabled.low_prec_dtype (
str
) – Set amp autocast mode’s lower precision dtype. It will change the target dtype in tensor casting for better speed and memory. Default: float16.high_prec_dtype (
str
) – Set amp autocast mode’s higher precision dtype. It will change the target dtype in tensor casting for better precision. Default: float32.
实际案例
# used as decorator @autocast() def train_step(image, label): with gm: logits = model(image) loss = F.nn.cross_entropy(logits, label) gm.backward(loss) opt.step().clear_grad() return loss # used as context manager def train_step(image, label): with autocast(): with gm: logits = model(image) loss = F.nn.cross_entropy(logits, label) gm.backward(loss) opt.step().clear_grad() return loss
Methods