xla_trace

class xla_trace(*args, **kwargs)[source]
Wraps a callable, and provides accelerated evaluation compiled by xla.

Currently it is an experimental feature. Refer to trace for more information.

Examples

import numpy as np
from basecls.models.resnet import resnet18
from megengine.autodiff.grad_manager import GradManager
from megengine.jit import xla_trace
from megengine.optimizer import Adam

model = resnet18()
gm = GradManager()
opt = Adam(model.parameters(), lr=1e-4)
gm.attach(model.parameters())

# Only tensors in wrapped func args/kwargs will be treated as graph inputs,
# and other tensors will be captured as const value.
# Module, optimizer, and train data/label should be arguments of the wrapped function.
@xla_trace(capture_as_const=True)
def train_step(model, opt, data, label):
    with gm:
        pred = model(data)
        loss = F.loss.cross_entropy(pred, label)
        gm.backward(loss)
    opt.step().clear_grad()
    return loss