megengine.functional.distributed.reduce_scatter_sum¶
- reduce_scatter_sum(inp, group=WORLD, device=None, axis=0)[源代码]¶
通过求和规约指定组中的张量,并在第一维度将其拆分。
- 参数
- 返回类型
- 返回
分割张量
实际案例
input = Tensor([0 1]) # Rank 0 # input: Tensor([0 1]) # Rank 1 # input: Tensor([0 1]) output = reduce_scatter_sum(input) # Rank 0 # output: Tensor([0]) # Rank 1 # output: Tensor([2]) input = Tensor([0 1]) group = Group([1, 0]) output = reduce_scatter_sum(input, group) # Rank 0 # output: Tensor([2]) # Rank 1 # output: Tensor([0])