megengine.functional.distributed.all_to_all

all_to_all(inp, group=WORLD, device=None, split_axis=0, concat_axis=0)[源代码]

每个进程将输入张量分散到所有进程,并返回收集的张量。

参数
  • inp (Tensor) – 输入张量

  • group (Optional[Group]) – 需要处理的组,默认为包含所有进程的 WORLD 组。你可以使用进程序号来创建新的组并使用,例如 [1,3,5] 。

  • device (Optional[str]) – 执行此操作的设备。默认为输入张量所在的设备。可以通过指定设备为 ”gpu0:1“ 以在不同的 cuda 流上执行此操作,其中1是 cuda 流的编号,默认 cuda 流编号为0。

  • split_axis (int) – 集合通信拆分数据的默认拆分维度为维度0

返回类型

Tensor

返回

结果张量

实际案例

input = Tensor([0 1]) + rank*2
# Rank 0 # input: Tensor([0 1])
# Rank 1 # input: Tensor([2 3])
output = all_to_all(input)
# Rank 0 # output: Tensor([0 2])
# Rank 1 # output: Tensor([1 3])

input = Tensor([0 1]) + rank*2
group = Group([1, 0])
output = all_to_all(input, group)
# Rank 0 # output: Tensor([0 3])
# Rank 1 # output: Tensor([2 1])