megengine.distributed.functional.
all_gather
Create all_gather operator for collective communication.
inp (Tensor) – input tensor.
Tensor
group (Optional[Group]) – communication group.
Optional
Group
device (Optional[str]) – execution device.
str
all_reduce_max
Create all_reduce_max operator for collective communication.
all_reduce_min
Create all_reduce_min operator for collective communication.
all_reduce_sum
Create all_reduce_sum operator for collective communication.
all_to_all
Create all_to_all operator for collective communication.
broadcast
Create broadcast operator for collective communication.
gather
Create gather operator for collective communication.
reduce_scatter_sum
Create reduce_scatter_sum operator for collective communication.
reduce_sum
Create reduce_sum operator for collective communication.
remote_recv
Receive a Tensor from a remote process.
src_rank (int) – source process rank.
int
shape (Tuple[int]) – the shape of the tensor to receive.
Tuple
dtype (type) – the data type of the tensor to receive.
type
device (Optional[str]) – the device to place the received tensor.
inp – dummy input to determine recved tensor type
remote_send
Send a Tensor to a remote process.
inp (Tensor) – tensor to send.
dest_rank (int) – destination process rank.
scatter
Create scatter operator for collective communication.
megengine.distributed.group.
Bases: object
object
check
comp_node
key
rank
reset
size
StaticData
backend
client
device
master_ip
mm_server_port
next_stream
proc_rank
py_server_port
server
world_size
get_backend
Get the backend str.
get_client
Get client of python XML RPC server.
Client
get_mm_server_addr
Get master_ip and port of C++ mm_server.
Tuple[str, int]
get_py_server_addr
Get master_ip and port of python XML RPC server.
get_rank
Get the rank of the current process.
get_world_size
Get the total number of processes participating in the job.
group_barrier
Block until all ranks in the group reach this barrier.
None
init_process_group
Initialize the distributed process group and specify the device used in the current process
master_ip (str) – ip address of the master node.
port (int) – port available for all processes to communicate.
world_size (int) – total number of processes participating in the job.
rank (int) – rank of the current process.
device (int) – the GPU device id to bind this process to.
backend (Optional[str]) – communicator backend, currently support ‘nccl’ and ‘ucx’.
is_distributed
Return True if the distributed process group has been initialized.
bool
new_group
Build a subgroup containing certain ranks.
megengine.distributed.helper.
AllreduceCallback
Allreduce Callback with tensor fusion optimization.
reduce_method (str) – the method to reduce gradiants.
group (Group) – communication group.
TensorFuture
Bases: megengine.utils.future.Future
megengine.utils.future.Future
dtype
numpy
shape
apply
bcast_list_
Broadcast tensors between given group.
inps (list) – input tensors.
list
get_device_count_by_fork
Get device count in fork thread. See https://stackoverflow.com/questions/22950047/cuda-initialization-error-after-fork for more information.
get_offsets
make_allreduce_cb
alias of megengine.distributed.helper.AllreduceCallback
megengine.distributed.helper.AllreduceCallback
pack_allreduce_split
param_pack_concat
Returns concated tensor, only used for parampack.
parampack
offsets (Tensor) – device value of offsets.
offsets_val (list) – offsets of inputs, length of 2 * n, format [begin0, end0, begin1, end1].
concated tensor.
Examples:
import numpy as np from megengine import tensor from megengine.distributed.helper import param_pack_concat a = tensor(np.ones((1,), np.int32)) b = tensor(np.ones((3, 3), np.int32)) offsets_val = [0, 1, 1, 10] offsets = tensor(offsets_val, np.int32) c = param_pack_concat([a, b], offsets, offsets_val) print(c.numpy())
Outputs:
[1 1 1 1 1 1 1 1 1 1]
param_pack_split
only used for parampack.
offsets (list) – offsets of outputs, length of 2 * n, while n is tensor nums you want to split, format [begin0, end0, begin1, end1].
shapes (list) – tensor shapes of outputs.
splitted tensors.
import numpy as np from megengine import tensor from megengine.distributed.helper import param_pack_split a = tensor(np.ones((10,), np.int32)) b, c = param_pack_split(a, [0, 1, 1, 10], [(1,), (3, 3)]) print(b.numpy()) print(c.numpy())
[1] [[1 1 1] [1 1 1] [1 1 1]]
synchronized
Decorator. Decorated function will synchronize when finished. Specifically, we use this to prevent data race during hub.load
megengine.distributed.launcher.
launcher
Decorator for launching multiple processes in single-machine multi-gpu training.
func – the function you want to launch in distributed mode.
n_gpus – how many devices each node.
world_size – how many devices totally.
rank_start – start number for rank.
master_ip – ip address for master node (where the rank 0 is).
port – server port for distributed server.
megengine.distributed.server.
Distributed Client for distributed training.
master_ip – ip address of master node.
port – port of server at master node.
check_is_grad
Check whether send/recv need gradiants.
key – key to match send/recv op.
check_remote_tracer
Get tracer dict for send/recv op.
connect
Check connection success.
get_mm_server_port
Get multiple machine server port.
A barrier wait for all group member.
key – group key to match each other.
size – group size.
set_is_grad
Mark send/recv need gradiants by key.
is_grad – whether this op need grad.
set_remote_tracer
Set tracer dict for tracing send/recv op.
tracer_set – valid tracer set.
user_get
Get user defined key-value pairs across processes.
user_set
Set user defined key-value pairs across processes.
Methods
Distributed Server Method. Used for exchange information between distributed nodes.
mm_server_port – multiple machine rpc server port.
Method for checking connection success.
Get multiple machine rpc server port.
Server
Distributed Server for distributed training. Should be running at master node.
port – python server port.
ThreadXMLRPCServer
Bases: socketserver.ThreadingMixIn, xmlrpc.server.SimpleXMLRPCServer
socketserver.ThreadingMixIn
xmlrpc.server.SimpleXMLRPCServer
main
megengine.distributed.util.
get_free_ports
Get one or more free ports.
List[int]
List