Source code for megengine.jit.xla_backend

from collections import OrderedDict, defaultdict

import numpy as np

from .. import _full_sync, tensor
from ..core._imperative_rt import CompNode
from ..core._imperative_rt.core2 import Tensor as RawTensor
from ..core._imperative_rt.core2 import (
    is_external_convert,
    set_external_convert,
    set_external_convert_hook,
    set_py_external_type,
    unset_external_convert,
)
from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed
from ..core._trace_option import set_use_xla_backend
from ..device import get_default_device
from ..tensor import Tensor
from ..utils.dlpack import from_dlpack, to_dlpack
from .tracing import trace

try:
    from mge_xlalib.xla_extension import ArrayImpl
    from ..xla.lib import xla_client as xc
except ImportError as e:
    pass

xla_client_compute_stream = None


def apply_external_convert_hook(input, cn):
    stream = xla_client_compute_stream
    assert isinstance(input, ArrayImpl)
    dlpack_capsule = xc._xla.buffer_to_dlpack_managed_tensor(input, take_ownership=True)
    output = from_dlpack(dlpack_capsule, stream).to(cn, _borrow=True)
    return output


[docs]class xla_trace(trace): r"""Wraps a callable, and provides accelerated evaluation compiled by xla. Currently it is an experimental feature. Refer to :class:`~.jit.tracing.trace` for more information. Examples: .. code-block:: python import numpy as np from basecls.models.resnet import resnet18 from megengine.autodiff.grad_manager import GradManager from megengine.jit import xla_trace from megengine.optimizer import Adam model = resnet18() gm = GradManager() opt = Adam(model.parameters(), lr=1e-4) gm.attach(model.parameters()) # Only tensors in wrapped func args/kwargs will be treated as graph inputs, # and other tensors will be captured as const value. # Module, optimizer, and train data/label should be arguments of the wrapped function. @xla_trace(capture_as_const=True) def train_step(model, opt, data, label): with gm: pred = model(data) loss = F.loss.cross_entropy(pred, label) gm.backward(loss) opt.step().clear_grad() return loss """ third_party_backend = True def __init__(self, function, *, without_host=True, symbolic_shape=False, **kwargs): assert without_host, "xla trace only support without host mode" assert not symbolic_shape, "xla doesn't support dynamic shape currently" set_external_convert_hook(apply_external_convert_hook) set_py_external_type(ArrayImpl) set_external_convert() super().__init__( function, without_host=without_host, symbolic_shape=symbolic_shape, **kwargs ) def setup_env(self): self.orig_use_xla = set_use_xla_backend(True) def unset_env(self): set_use_xla_backend(self.orig_use_xla) def convert_params_to_xla(self): from ..utils.module_utils import get_expand_structure from ..tensor import Tensor backend = self.xla_exec.backend devices = backend.local_devices() default_cn = CompNode(get_default_device()) _, device_id, _ = default_cn.physical_locator device_index = ( 0 if len(devices) == 0 else [d.id for d in devices].index(device_id) ) device = devices[device_index] for attr, _ in self.attr_to_key.items(): param = get_expand_structure(attr[0], attr[1]) param._reset(param.to("cpux")) for tensor, _ in self.opt_param_dict.items(): tensor._reset(tensor.to("cpux")) def as_xla_array(tensor, backend, device): np_array = tensor.numpy() if np_array.shape == (): np_array = np_array[np.newaxis] xla_array = backend.buffer_from_pyval(np_array, device) tensor._reset(Tensor(xla_array, device=default_cn)) for attr, _ in self.attr_to_key.items(): param = get_expand_structure(attr[0], attr[1]) as_xla_array(param, backend, device) for tensor, _ in self.opt_param_dict.items(): as_xla_array(tensor, backend, device) def compile(self): from ..xla import build_xla from ..traced_module.pytree import SUPPORTED_LEAF_TYPE, register_supported_type from ..utils.module_utils import get_expand_structure from ..xla.device import get_xla_backend_and_device from ..tensor import Tensor from ..distributed import get_mm_server_addr, is_distributed, get_rank from ..device import coalesce_free_memory assert self.traced coalesce_free_memory() _full_sync() self.xla_exec, self.inp_ids, self.out_ids = build_xla( self, return_with_io=True, return_device_array=True, ip=get_mm_server_addr()[0] if is_distributed() else None, port=get_mm_server_addr()[1] + 1 if is_distributed() else None, ) if self.overall: self.convert_params_to_xla() coalesce_free_memory() _full_sync() id2inpidx = defaultdict(list) id2outidx = defaultdict(list) for idx, id in enumerate(self.inp_ids): id2inpidx[id].append(idx) for idx, id in enumerate(self.out_ids): id2outidx[id].append(idx) self.inpkey2idx = {} self.outkey2idx = {} if self.input_num == len(set(self.inp_ids)) - 1: self.has_randomstate = True default_rng_seed = _get_global_rng_seed() high = default_rng_seed >> 32 low = default_rng_seed & 0xFFFFFFFF self.random_seed = Tensor([[high, low], [low, high]], dtype="int32") else: assert self.input_num == len(set(self.inp_ids)), ( self.input_num, len(self.inp_ids), ) self.has_randomstate = False inpmark2id = dict() outmark2id = dict() for var in self.vars: if var.kind == "external": for mark in var.inp_mark: inpmark2id[mark] = var.id elif var.data_required and var.out_mark: for mark in var.out_mark: outmark2id[mark] = var.id for k, v in inpmark2id.items(): for idx in id2inpidx[v]: self.inpkey2idx[k] = idx for k, v in outmark2id.items(): for idx in id2outidx[v]: self.outkey2idx[k] = idx def prepare_xla_inputs(self, tensors): from ..utils.module_utils import get_expand_structure inp_count = 0 inp_list = [0] * self.input_num for idx, t in enumerate(tensors): inp = self.inpkey2idx[self.arg_list[idx]] inp_list[inp] = t inp_count += 1 if self.overall: for attr, key in self.attr_to_key.items(): param = get_expand_structure(attr[0], attr[1]) inp = self.inpkey2idx[key] inp_list[inp] = param inp_count += 1 for tensor, k in self.opt_param_dict.items(): inp = self.inpkey2idx[k] inp_list[inp] = tensor inp_count += 1 opt_hyper_inps = [] for opt in self.optimizers: opt_hyper_inps.extend([Tensor(pg["lr"]) for pg in opt.param_groups]) for tensor, k in zip(opt_hyper_inps, self.capture_optimizer_hyper_param): inp = self.inpkey2idx[k] inp_list[inp] = tensor inp_count += 1 assert inp_count == self.input_num if self.has_randomstate: inp_list.append(self.random_seed) return inp_list def to_dlpack(self, x, take_ownership: bool = True): from ..xla.lib import xla_client as xc return xc._xla.buffer_to_dlpack_managed_tensor(x, take_ownership=take_ownership) def execute(self, *args, **kwargs): from ..tensor import Tensor from ..optimizer import Optimizer from ..traced_module.pytree import tree_flatten from ..utils.module_utils import get_expand_structure inputs, _ = tree_flatten((args, kwargs)) arrays = [] cn = CompNode(get_default_device()) stream = dict(self.xla_exec.backend.get_compute_compnode()) device_kind, device_id, stream_id = cn.physical_locator xla_stream = stream[device_id] xla_comp_cn = "gpu{}:{}".format(device_id, xla_stream) self.optimizers = [] for t in inputs: if isinstance(t, RawTensor): if not t._is_external_value(): assert cn == t.device arrays.append(t.to(xla_comp_cn, _borrow=True)) else: arrays.append(t) if isinstance(t, Optimizer): self.optimizers.append(t) arrays = self.prepare_xla_inputs(arrays) outputs = self.xla_exec(*arrays) global xla_client_compute_stream xla_client_compute_stream = xla_stream return_vals = [] for i in self.out_list: if i == -1: if not hasattr(self, "outdef"): return_vals.append(None) else: return_vals.append(outputs[self.outkey2idx[i]]) if not self.out_list: return_vals = [ None, ] keeped_features = [] for i in self.keeped_activation: keeped_features.append(tensor(outputs[self.outkey2idx[i]], device=cn)) out_tensors = [] for array in return_vals: if array is not None: t = tensor(array, device=cn) out_tensors.append(t) else: out_tensors.append(array) if self.overall: for attr, key in self.update_param_dict.items(): param = get_expand_structure(attr[0], attr[1]) xla_array = outputs[self.outkey2idx[key]] t = tensor(xla_array, device=cn) param._reset(t) for state, key in self.update_opt_param_dict.items(): xla_array = outputs[self.outkey2idx[key]] t = tensor(xla_array, device=cn) state._reset(t) elif hasattr(self, "input_need_update_dict"): for index, out_mark in self.input_need_update_dict.items(): inputs[index]._reset(outputs[self.outkey2idx[out_mark]]) rst = ( self.outdef.unflatten(out_tensors) if hasattr(self, "outdef") else out_tensors ) if keeped_features: return rst, keeped_features else: return rst