# -*- coding: utf-8 -*-
import time
from contextlib import contextmanager
from typing import List, Optional, Tuple
from mprop import mproperty
from ..device import _sh, set_default_device, what_is_xpu
from ..random import seed
from .server import Client, Server
class StaticData:
server = None
client = None
master_ip = None
py_server_port = None
mm_server_port = None
world_size = None
proc_rank = None
device = None
backend = None
device_type = None
machine_ranks = None
_sd = None
[docs]class Group:
r"""Include ranked nodes running collective communication (See :mod:`~.functional.distributed`).
By default collectives operate on the default group (also called ``WORLD``)
and require all processes to enter the distributed function call.
Args:
proc_ranks: rank list of the group, the first one is root rank.
"""
def __init__(self, proc_ranks):
if len(proc_ranks) == 0: # empty group
self.proc_ranks = None
self.stream = None
else:
self.reset(proc_ranks)
def reset(self, proc_ranks):
self.check(proc_ranks)
self.proc_ranks = proc_ranks
self.is_single_machine_cache = None
self.stream = _sh.get_next()
def check(self, proc_ranks):
assert _sd is not None, "please call init_process_group first"
for rank in proc_ranks:
assert isinstance(rank, int)
assert rank >= 0 and rank < _sd.world_size
assert _sd.proc_rank in proc_ranks
@property
def size(self):
assert len(self.proc_ranks) > 0, "invalid group"
return len(self.proc_ranks)
@property
def key(self):
assert len(self.proc_ranks) > 0, "invalid group"
return ",".join(map(str, self.proc_ranks))
@property
def rank(self):
assert len(self.proc_ranks) > 0, "invalid group"
return self.proc_ranks.index(_sd.proc_rank)
@property
def comp_node(self):
assert len(self.proc_ranks) > 0, "invalid group"
return "{}{}:{}".format(_sd.device_type, _sd.device, self.stream)
@property
def is_single_machine(self):
if self.is_single_machine_cache is not None:
return self.is_single_machine_cache
assert _sd is not None, "please call init_process_group first"
for rank in self.proc_ranks:
if _sd.machine_ranks is None or rank not in _sd.machine_ranks:
self.is_single_machine_cache = False
return False
self.is_single_machine_cache = True
return True
WORLD = Group([])
_devices = {"gpu", "cuda", "rocm"}
_backends = {"nccl", "rccl", "auto"}
[docs]def init_process_group(
master_ip: str,
port: int,
world_size: int,
rank: int,
device: int,
backend: Optional[str] = "auto",
device_type: str = "xpu",
) -> None:
r"""Initialize the distributed process group and specify the device used in the current process
Args:
master_ip: ip address of the master node.
port: port available for all processes to communicate.
world_size: total number of processes participating in the job.
rank: rank of the current process.
device: the GPU device id to bind this process to.
backend: communicator backend, currently support 'nccl' and 'rccl'.
"""
physical_device_type = what_is_xpu() if device_type == "xpu" else device_type
if not isinstance(master_ip, str):
raise TypeError("Expect type str but got {}".format(type(master_ip)))
if not isinstance(port, int):
raise TypeError("Expect type int but got {}".format(type(port)))
if not isinstance(world_size, int):
raise TypeError("Expect type int but got {}".format(type(world_size)))
if not isinstance(rank, int):
raise TypeError("Expect type int but got {}".format(type(rank)))
if not isinstance(device, int):
raise TypeError("Expect type int but got {}".format(type(backend)))
if backend not in _backends:
raise ValueError(
"backend should be one of {} but got {}".format(_backends, backend)
)
if physical_device_type not in _devices:
raise ValueError(
"{} is not a valid distributed device type".format(device_type)
)
global _sd
assert _sd is None, "init_process_group should be called only once"
_sd = StaticData()
assert world_size > 1
assert rank >= 0 and rank < world_size
assert port > 0
_sd.client = Client(master_ip, port)
_sd.master_ip = master_ip
_sd.py_server_port = port
_sd.mm_server_port = _sd.client.get_mm_server_port()
_sd.world_size = world_size
_sd.proc_rank = rank
_sd.device = device
_sd.backend = backend
_sd.device_type = device_type
WORLD.reset(list(range(world_size)))
set_default_device("{}{}".format(device_type, device))
seed(int(time.time()) + rank)
if backend == "nccl":
# init nccl env
from ..core._imperative_rt.common import init_nccl_env
group_barrier()
init_nccl_env(master_ip, _sd.mm_server_port, world_size, rank, 0)
def _set_machine_ranks(ranks) -> None:
global _sd
assert _sd is not None
_sd.machine_ranks = ranks
[docs]@contextmanager
def override_backend(new_backend: str):
r"""Override distributed backend
Args:
new_backend: communicator backend set in this context.
"""
global _sd
assert _sd, "please call init_process_group first"
old_backend = _sd.backend
_sd.backend = new_backend
try:
yield
finally:
_sd.backend = old_backend
[docs]def is_distributed() -> bool:
r"""Return True if the distributed process group has been initialized."""
return _sd is not None
[docs]def get_rank() -> int:
r"""Get the rank of the current process."""
return _sd.proc_rank if _sd is not None else 0
[docs]def get_world_size() -> int:
r"""Get the total number of processes participating in the job."""
return _sd.world_size if _sd is not None else 1
[docs]def get_backend() -> str:
r"""Get the backend str."""
assert _sd is not None, "please call init_process_group first"
return _sd.backend if _sd is not None else None
[docs]def get_py_server_addr() -> Tuple[str, int]:
r"""Get master_ip and port of python XML RPC server."""
assert _sd is not None, "please call init_process_group first"
return _sd.master_ip, _sd.py_server_port
[docs]def get_mm_server_addr() -> Tuple[str, int]:
r"""Get master_ip and port of C++ mm_server."""
assert _sd is not None, "please call init_process_group first"
return _sd.master_ip, _sd.mm_server_port
[docs]def get_client() -> Client:
r"""Get client of python XML RPC server."""
assert _sd is not None, "please call init_process_group first"
return _sd.client
[docs]def new_group(proc_ranks: List[int]) -> Group:
r"""Build a subgroup containing certain ranks."""
return Group(proc_ranks)
[docs]def group_barrier(group: Group = WORLD) -> None:
r"""Block until all ranks in the group reach this barrier."""
# if running with single node, skip it
if _sd is None:
return
assert isinstance(group, Group)
_sd.client.group_barrier(group.key, group.size)