megengine.functional.where

where(mask, x=None, y=None)[源代码]

根据mask选出张量x或张量y中的元素。

\[\textrm{out}_i = x_i \textrm{ if } \textrm{mask}_i \textrm{ is True else } y_i\]
参数
返回类型

Tensor

返回

输出张量。

实际案例

>>> import numpy as np
>>> mask = Tensor(np.array([[True, False], [False, True]], dtype=np.bool))
>>> x = Tensor(np.array([[1, np.inf], [np.nan, 4]],
...     dtype=np.float32))
>>> y = Tensor(np.array([[5, 6], [7, 8]], dtype=np.float32))
>>> out = F.where(mask, x, y)
>>> out.numpy()
array([[1., 6.],
       [7., 4.]], dtype=float32)