megengine.functional.where

where(mask, x, y)[源代码]

根据mask选出张量x或张量y中的元素。

\[\textrm{out}_i = x_i \textrm{ if } \textrm{mask}_i \textrm{ is True else } y_i\]
参数
  • mask (Tensor) – 用于选择x或y的 mask。

  • x (Tensor) – 第一个选择。

  • y (Tensor) – 第二个选择。

返回类型

Tensor

返回

输出张量。

例如:

import numpy as np
from megengine import tensor
import megengine.functional as F
mask = tensor(np.array([[True, False], [False, True]], dtype=np.bool))
x = tensor(np.array([[1, np.inf], [np.nan, 4]],
    dtype=np.float32))
y = tensor(np.array([[5, 6], [7, 8]], dtype=np.float32))
out = F.where(mask, x, y)
print(out.numpy())

输出:

[[1. 6.]
 [7. 4.]]