Source code for megengine.distributed.group

# -*- coding: utf-8 -*-
import time
from contextlib import contextmanager
from typing import List, Optional, Tuple

from mprop import mproperty

from ..device import _sh, set_default_device, what_is_xpu
from ..random import seed
from .server import Client, Server


class StaticData:
    server = None
    client = None
    master_ip = None
    py_server_port = None
    mm_server_port = None
    world_size = None
    proc_rank = None
    device = None
    backend = None
    device_type = None
    machine_ranks = None


_sd = None


[docs]class Group: r"""Include ranked nodes running collective communication (See :mod:`~.functional.distributed`). By default collectives operate on the default group (also called ``WORLD``) and require all processes to enter the distributed function call. Args: proc_ranks: rank list of the group, the first one is root rank. """ def __init__(self, proc_ranks): if len(proc_ranks) == 0: # empty group self.proc_ranks = None self.stream = None else: self.reset(proc_ranks) def reset(self, proc_ranks): self.check(proc_ranks) self.proc_ranks = proc_ranks self.is_single_machine_cache = None self.stream = _sh.get_next() def check(self, proc_ranks): assert _sd is not None, "please call init_process_group first" for rank in proc_ranks: assert isinstance(rank, int) assert rank >= 0 and rank < _sd.world_size assert _sd.proc_rank in proc_ranks @property def size(self): assert len(self.proc_ranks) > 0, "invalid group" return len(self.proc_ranks) @property def key(self): assert len(self.proc_ranks) > 0, "invalid group" return ",".join(map(str, self.proc_ranks)) @property def rank(self): assert len(self.proc_ranks) > 0, "invalid group" return self.proc_ranks.index(_sd.proc_rank) @property def comp_node(self): assert len(self.proc_ranks) > 0, "invalid group" return "{}{}:{}".format(_sd.device_type, _sd.device, self.stream) @property def is_single_machine(self): if self.is_single_machine_cache is not None: return self.is_single_machine_cache assert _sd is not None, "please call init_process_group first" for rank in self.proc_ranks: if _sd.machine_ranks is None or rank not in _sd.machine_ranks: self.is_single_machine_cache = False return False self.is_single_machine_cache = True return True
WORLD = Group([]) _devices = {"gpu", "cuda", "rocm"} _backends = {"nccl", "rccl", "auto"}
[docs]def init_process_group( master_ip: str, port: int, world_size: int, rank: int, device: int, backend: Optional[str] = "auto", device_type: str = "xpu", ) -> None: r"""Initialize the distributed process group and specify the device used in the current process Args: master_ip: ip address of the master node. port: port available for all processes to communicate. world_size: total number of processes participating in the job. rank: rank of the current process. device: the GPU device id to bind this process to. backend: communicator backend, currently support 'nccl' and 'rccl'. """ physical_device_type = what_is_xpu() if device_type == "xpu" else device_type if not isinstance(master_ip, str): raise TypeError("Expect type str but got {}".format(type(master_ip))) if not isinstance(port, int): raise TypeError("Expect type int but got {}".format(type(port))) if not isinstance(world_size, int): raise TypeError("Expect type int but got {}".format(type(world_size))) if not isinstance(rank, int): raise TypeError("Expect type int but got {}".format(type(rank))) if not isinstance(device, int): raise TypeError("Expect type int but got {}".format(type(backend))) if backend not in _backends: raise ValueError( "backend should be one of {} but got {}".format(_backends, backend) ) if physical_device_type not in _devices: raise ValueError( "{} is not a valid distributed device type".format(device_type) ) global _sd assert _sd is None, "init_process_group should be called only once" _sd = StaticData() assert world_size > 1 assert rank >= 0 and rank < world_size assert port > 0 _sd.client = Client(master_ip, port) _sd.master_ip = master_ip _sd.py_server_port = port _sd.mm_server_port = _sd.client.get_mm_server_port() _sd.world_size = world_size _sd.proc_rank = rank _sd.device = device _sd.backend = backend _sd.device_type = device_type WORLD.reset(list(range(world_size))) set_default_device("{}{}".format(device_type, device)) seed(int(time.time()) + rank) if backend == "nccl": # init nccl env from ..core._imperative_rt.common import init_nccl_env group_barrier() init_nccl_env(master_ip, _sd.mm_server_port, world_size, rank, 0)
def _set_machine_ranks(ranks) -> None: global _sd assert _sd is not None _sd.machine_ranks = ranks
[docs]@contextmanager def override_backend(new_backend: str): r"""Override distributed backend Args: new_backend: communicator backend set in this context. """ global _sd assert _sd, "please call init_process_group first" old_backend = _sd.backend _sd.backend = new_backend try: yield finally: _sd.backend = old_backend
[docs]def is_distributed() -> bool: r"""Return True if the distributed process group has been initialized.""" return _sd is not None
[docs]def get_rank() -> int: r"""Get the rank of the current process.""" return _sd.proc_rank if _sd is not None else 0
[docs]def get_world_size() -> int: r"""Get the total number of processes participating in the job.""" return _sd.world_size if _sd is not None else 1
[docs]def get_backend() -> str: r"""Get the backend str.""" assert _sd is not None, "please call init_process_group first" return _sd.backend if _sd is not None else None
[docs]def get_py_server_addr() -> Tuple[str, int]: r"""Get master_ip and port of python XML RPC server.""" assert _sd is not None, "please call init_process_group first" return _sd.master_ip, _sd.py_server_port
[docs]def get_mm_server_addr() -> Tuple[str, int]: r"""Get master_ip and port of C++ mm_server.""" assert _sd is not None, "please call init_process_group first" return _sd.master_ip, _sd.mm_server_port
[docs]def get_client() -> Client: r"""Get client of python XML RPC server.""" assert _sd is not None, "please call init_process_group first" return _sd.client
[docs]def new_group(proc_ranks: List[int]) -> Group: r"""Build a subgroup containing certain ranks.""" return Group(proc_ranks)
[docs]def group_barrier(group: Group = WORLD) -> None: r"""Block until all ranks in the group reach this barrier.""" # if running with single node, skip it if _sd is None: return assert isinstance(group, Group) _sd.client.group_barrier(group.key, group.size)