autocast

class autocast(enabled=True, low_prec_dtype='float16', high_prec_dtype='float32')[source]

A class to control autocast mode for amp as a context manager or a decorator.

Parameters
  • enabled (bool) – whether autocast mode is enabled.

  • low_prec_dtype (str) – set amp autocast mode’s lower precision dtype. It will change the target dtype in tensor casting for better speed and memory. Default: float16.

  • high_prec_dtype (str) – set amp autocast mode’s higher precision dtype. It will change the target dtype in tensor casting for better precision. Default: float32.

Returns

None

Examples

# used as decorator
@autocast()
def train_step(image, label):
    with gm:
        logits = model(image)
        loss = F.nn.cross_entropy(logits, label)
        gm.backward(loss)
    opt.step().clear_grad()
    return loss

# used as context manager
def train_step(image, label):
    with autocast():
        with gm:
            logits = model(image)
            loss = F.nn.cross_entropy(logits, label)
            gm.backward(loss)
    opt.step().clear_grad()
    return loss