# -*- coding: utf-8 -*-
import functools
import multiprocessing as mp
from collections import defaultdict
from typing import Callable
from weakref import WeakSet
import numpy as np
from megengine.autodiff.grad_manager import GradManager, get_backwarding_grad_manager
from ..core._imperative_rt.core2 import apply
from ..core._trace_option import use_xla_backend
from ..core.ops.builtin import ParamPackConcat, ParamPackSplit
from ..functional.tensor import copy
from ..tensor import Tensor
from ..utils.deprecation import deprecated_func
from ..utils.future import Future
from . import group as _group
from .functional import _bcast_param, all_reduce_sum, broadcast
from .group import WORLD, Group, group_barrier, is_distributed, override_backend
[文档]def param_pack_split(inp: Tensor, offsets: list, shapes: list):
r"""Returns split tensor to list of tensors as offsets and shapes described,
only used for ``parampack``.
Args:
inp: input tensor.
offsets: offsets of outputs, length of ``2 * n``,
where ``n`` is the number of tensor you want to split,
format ``[begin0, end0, begin1, end1]``.
shapes: tensor shapes of outputs.
Returns:
splitted tensors.
Examples:
>>> a = F.ones(10)
>>> b, c = dist.helper.param_pack_split(a, [0, 1, 1, 10], [(1,), (3, 3)])
>>> b
Tensor([1.], device=xpux:0)
>>> c
Tensor([[1. 1. 1.]
[1. 1. 1.]
[1. 1. 1.]], device=xpux:0)
"""
op = ParamPackSplit()
op.offsets = offsets
op.shapes = [s or (1,) for s in shapes]
outputs = apply(op, inp)
return outputs
[文档]def param_pack_concat(inps: list, offsets: Tensor, offsets_val: list):
r"""Returns concated tensor, only used for ``parampack``.
Args:
inps: list of input tensors.
offsets: device value of offsets.
offsets_val: offsets of inputs, length of ``2 * n``,
format ``[begin0, end0, begin1, end1]``.
Returns:
concated tensor.
Examples:
>>> a = F.ones(1)
>>> b = F.ones((3, 3))
>>> offsets_val = [0, 1, 1, 10]
>>> offsets = Tensor(offsets_val)
>>> c = dist.helper.param_pack_concat([a, b], offsets, offsets_val) # doctest: +SKIP
Tensor([1. 1. 1. 1. 1. 1. 1. 1. 1. 1.], device=xpux:0)
"""
op = ParamPackConcat()
op.offsets = offsets_val
return apply(op, *inps, offsets)[0]
def get_offsets(shapes):
offsets = []
offset = 0
for shape in shapes:
offsets.append(offset)
offset += int(np.prod(shape))
offsets.append(offset)
return offsets
_enable_p2p_cache = None
def _check_enable_p2p():
global _enable_p2p_cache
if _enable_p2p_cache is not None:
return _enable_p2p_cache
cmd = ["nvidia-smi", "topo", "-p2p", "w"]
import subprocess
output = subprocess.run(cmd, stdout=subprocess.PIPE).stdout
if output.count(b"OK") > 1:
_enable_p2p_cache = True
return True
else:
_enable_p2p_cache = False
return False
[文档]def pack_allreduce_split(pack_list, shapes, group, reduce_method):
offsets_val = get_offsets(shapes)
offsets = Tensor(offsets_val)
packed_grads = param_pack_concat(pack_list, offsets, offsets_val)
packed_grads = all_reduce_sum(packed_grads, group, group.comp_node)
if reduce_method == "mean":
packed_grads /= group.size
grads = param_pack_split(packed_grads, offsets_val, shapes)
return grads
class TensorFuture(Future):
def device(self):
raise "Sorry, this tensor is not ready"
def numpy(self):
raise "Sorry, this tensor is not ready"
def shape(self):
raise "Sorry, this tensor is not ready"
def dtype(self):
raise "Sorry, this tensor is not ready"
[文档]def synchronized(func: Callable):
r"""Decorator. Decorated function will synchronize when finished.
Specifically, we use this to prevent data race during hub.load
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
if not is_distributed():
return func(*args, **kwargs)
ret = func(*args, **kwargs)
group_barrier()
return ret
return wrapper
def _check_device_initialized(device_type: str, rank: int):
try:
test = Tensor(1, device=(device_type + str(rank)))
del test
except Exception as e:
errmsg = (
"Device initialization check failed, which may be caused "
"by using CUDA before forking the thread. Please review "
"the code to ensure that no CUDA functions or variables "
"are used before forking."
)
raise RuntimeError(errmsg) from e
def _check_interpreter_status():
from ..core._imperative_rt.core2 import get_option
_ = get_option("async_level")
get_device_count_by_fork = deprecated_func(
"1.5", "megengine.device", "get_device_count", False
)
[文档]def bcast_list_(inps: list, group: Group = WORLD):
r"""Broadcast tensors between given group.
Args:
inps(List[Tensor]): input tensors.
group(:attr:`.distributed.group.Group, optional): communication group. Default: WORLD.
"""
for inp in inps:
inp._reset(_bcast_param(inp, group))
[文档]class AllreduceCallback:
r"""Allreduce Callback with tensor fusion optimization.
Args:
reduce_method(str): the method to reduce gradiants. ``reduce_method`` should be "sum" or "mean".
group(:attr:`.distributed.group.Group, optional): communication group. Default: WORLD.
backend(str, optional): override distributed backend in allreduce. If ``backend`` is None, will use the backend set in ``dist.launcher``. Default: None.
Examples:
.. code-block:: python
import megengine as mge
import megengine.autodiff as ad
import megengine.distributed as dist
gm = ad.GradManager()
gm.attach(linear_cls.parameters(), callbacks=[dist.make_allreduce_cb("sum")])
"""
def __init__(self, reduce_method: str, group: Group = WORLD, backend: str = None):
reduce_method = reduce_method.lower()
assert reduce_method in ["sum", "mean"], "reduce_method should be sum or mean"
self._reduce_method = reduce_method
self._group = group
self._marked_gm = WeakSet()
self._param_pack_thd = 10 * 1024 * 1024
self._reset()
if backend is None:
assert _group._sd, "please call init_process_group first"
backend = _group._sd.backend
if backend == "auto":
backend = "nccl"
self._backend = backend
def _reset(self):
self._params = []
self._gradients_dict = dict()
self._futures_dict = dict()
self._packing_list = defaultdict(list)
self._packing_size = defaultdict(int)
self._grad_origin_device = dict()
def _pack(self, dtype):
if len(self._packing_list[dtype]) == 0:
return
grad_list = [self._gradients_dict[p] for p in self._packing_list[dtype]]
shapes = [p._tuple_shape for p in self._packing_list[dtype]]
with override_backend(self._backend):
reduced_grads = pack_allreduce_split(
grad_list, shapes, self._group, self._reduce_method
)
for param, grad in zip(self._packing_list[dtype], reduced_grads):
self._gradients_dict[param] = grad
self._packing_list[dtype] = []
self._packing_size[dtype] = 0
def __call__(self, param, grad):
if use_xla_backend():
self._used_xla = True
grad = all_reduce_sum(grad, self._group)
if self._reduce_method == "mean":
grad /= self._group.size
return grad
# TODO: Integrate the allreduce process of XLA with the allreduce process of imperative into one.
if getattr(self, "_used_xla", False) and grad._is_external_value():
return grad
gm = get_backwarding_grad_manager()
assert isinstance(gm, GradManager)
if gm not in self._marked_gm:
gm._register_after_backward_callback(self._flush)
self._marked_gm.add(gm)
self._params.append(param)
self._futures_dict[param] = TensorFuture(ack=False)
self._gradients_dict[param] = grad
self._grad_origin_device[param] = str(grad.device)
dtype_str = str(np.dtype(param.dtype))
dtype_size = np.dtype(param.dtype).itemsize
self._packing_list[dtype_str].append(param)
self._packing_size[dtype_str] += int(np.prod(param._tuple_shape)) * dtype_size
if self._packing_size[dtype_str] > self._param_pack_thd:
self._pack(dtype_str)
return self._futures_dict[param]
def _flush(self):
for dtype in sorted(self._packing_list.keys()):
self._pack(dtype)
for param in self._params:
grad = self._gradients_dict[param]
grad = copy(grad, self._grad_origin_device[param])
self._futures_dict[param].set(grad)
self._reset()
make_allreduce_cb = AllreduceCallback