megengine.functional.gather¶
- gather(inp, axis, index)[源代码]¶
根据给定的索引从输入 Tensor 中收集数据。
For a 3-D tensor, the output is specified by:
out[i][j][k] = inp[index[i][j][k]][j][k] # if axis == 0 out[i][j][k] = inp[i][index[i][j][k]][k] # if axis == 1 out[i][j][k] = inp[i][j][index[i][j][k]] # if axis == 2
如果
inp
是一个尺寸为 \((x_0,x_1,...,x_{i-1},x_i,x_{i+1},...,x_{n-1})\) 且 axis=i 的 n 维 Tensor 则index
必须是一个尺寸为 \((x_0,x_1,...,x_{i-1},y,x_{i+1},...,x_{n-1})\) 的 n 维 Tensor,这里的 \(y\ge 1\) 和输出的尺寸都必须必须与index
的尺寸相同。实际案例
import megengine.functional as F from megengine import tensor inp = tensor([ [1,2], [3,4], [5,6], ]) index = tensor([[0,2], [1,0]]) oup = F.gather(inp, 0, index) print(oup.numpy())
输出:
[[1 6] [3 2]]