autocast¶
- class autocast(enabled=True, low_prec_dtype='float16', high_prec_dtype='float32')[source]¶
A class to control autocast mode for amp as a context manager or a decorator.
- Parameters
enabled (
bool
) – whether autocast mode is enabled.low_prec_dtype (
str
) – set amp autocast mode’s lower precision dtype. It will change the target dtype in tensor casting for better speed and memory. Default: float16.high_prec_dtype (
str
) – set amp autocast mode’s higher precision dtype. It will change the target dtype in tensor casting for better precision. Default: float32.
- Returns
None
Examples
# used as decorator @autocast() def train_step(image, label): with gm: logits = model(image) loss = F.nn.cross_entropy(logits, label) gm.backward(loss) opt.step().clear_grad() return loss # used as context manager def train_step(image, label): with autocast(): with gm: logits = model(image) loss = F.nn.cross_entropy(logits, label) gm.backward(loss) opt.step().clear_grad() return loss