megengine.functional.where#

where(mask, x, y)[源代码]#

根据mask选出张量x或张量y中的元素。

\[\textrm{out}_i = x_i \textrm{ if } \textrm{mask}_i \textrm{ is True else } y_i\]
参数:
  • mask (Tensor) – 用于选择x或y的 mask。

  • x (Tensor) – 第一个选择。

  • y (Tensor) – 第二个选择。

返回类型:

Tensor

返回:

输出张量。

实际案例

>>> import numpy as np
>>> mask = Tensor(np.array([[True, False], [False, True]], dtype=np.bool))
>>> x = Tensor(np.array([[1, np.inf], [np.nan, 4]],
...     dtype=np.float32))
>>> y = Tensor(np.array([[5, 6], [7, 8]], dtype=np.float32))
>>> out = F.where(mask, x, y)
>>> out.numpy()
array([[1., 6.],
       [7., 4.]], dtype=float32)