megengine.functional.distributed.all_to_all¶
- all_to_all(inp, group=WORLD, device=None, split_axis=0, concat_axis=0)[源代码]¶
每个进程将输入张量分散到所有进程,并返回收集的张量。
- 参数
- 返回类型
- 返回
结果张量
实际案例
input = Tensor([0 1]) + rank*2 # Rank 0 # input: Tensor([0 1]) # Rank 1 # input: Tensor([2 3]) output = all_to_all(input) # Rank 0 # output: Tensor([0 2]) # Rank 1 # output: Tensor([1 3]) input = Tensor([0 1]) + rank*2 group = Group([1, 0]) output = all_to_all(input, group) # Rank 0 # output: Tensor([0 3]) # Rank 1 # output: Tensor([2 1])