megengine.functional.nn.indexing_one_hot

indexing_one_hot(src, index, axis=1, keepdims=False)[源代码]

对一些轴进行One-hot索引。

参数
  • src (Tensor) – 输入张量。

  • index (Tensor) – 索引张量。

  • axis (int) – 源数据上的轴,索引值为其索引。 默认: 1

  • keepdims – 是否在结果数据中删除该轴。 默认: False

返回类型

Tensor

返回

输出张量。

例如:

import megengine.functional as F
from megengine import tensor

src = tensor([[1.0, 2.0]])
index = tensor([0])
val = F.indexing_one_hot(src, index)
print(val.numpy())

输出:

[1.]