megengine.functional.distributed.all_to_all

all_to_all(inp, group=WORLD, device=None, split_axis=0, concat_axis=0)[源代码]

Each process scatter input tensor to all processes and return gathered tensor.

参数
  • inp (Tensor) – Input tensor.

  • group (Optional[Group]) – The process group to work on. The default group is WORLD which means all processes available. You can use a list of process ranks to create new group to work on it, e.g. [1, 3, 5].

  • device (Optional[str]) – The specific device to execute this operator. None default device means the device of inp will be used. Specify “gpu0:1” to execute this operator on diffrent cuda stream, 1 is stream id, and default stream id is 0.

  • split_axis (int) – The axis that collectivecomm will split data the default axis is 0

返回类型

Tensor

返回

Result tensor.

实际案例

input = Tensor([0 1]) + rank*2
# Rank 0 # input: Tensor([0 1])
# Rank 1 # input: Tensor([2 3])
output = all_to_all(input)
# Rank 0 # output: Tensor([0 2])
# Rank 1 # output: Tensor([1 3])

input = Tensor([0 1]) + rank*2
group = Group([1, 0])
output = all_to_all(input, group)
# Rank 0 # output: Tensor([0 3])
# Rank 1 # output: Tensor([2 1])