xla_trace¶
- class xla_trace(*args, **kwargs)[source]¶
- Wraps a callable, and provides accelerated evaluation compiled by xla.
Currently it is an experimental feature. Refer to
trace
for more information.
Examples
import numpy as np from basecls.models.resnet import resnet18 from megengine.autodiff.grad_manager import GradManager from megengine.jit import xla_trace from megengine.optimizer import Adam model = resnet18() gm = GradManager() opt = Adam(model.parameters(), lr=1e-4) gm.attach(model.parameters()) # Only tensors in wrapped func args/kwargs will be treated as graph inputs, # and other tensors will be captured as const value. # Module, optimizer, and train data/label should be arguments of the wrapped function. @xla_trace(capture_as_const=True) def train_step(model, opt, data, label): with gm: pred = model(data) loss = F.loss.cross_entropy(pred, label) gm.backward(loss) opt.step().clear_grad() return loss