megengine.functional.where¶
- where(mask, x=None, y=None)[源代码]¶
根据mask选出张量x或张量y中的元素。
\[\textrm{out}_i = x_i \textrm{ if } \textrm{mask}_i \textrm{ is True else } y_i\]- 参数
- 返回类型
- 返回
输出张量。
实际案例
>>> import numpy as np >>> mask = Tensor(np.array([[True, False], [False, True]], dtype=np.bool)) >>> x = Tensor(np.array([[1, np.inf], [np.nan, 4]], ... dtype=np.float32)) >>> y = Tensor(np.array([[5, 6], [7, 8]], dtype=np.float32)) >>> out = F.where(mask, x, y) >>> out.numpy() array([[1., 6.], [7., 4.]], dtype=float32)