megengine.functional.cond_take

cond_take(mask, x)[source]

Takes elements from data if specific condition is satisfied on mask. This operator has two outputs: the first is the elements taken, and the second is the indices corresponding to those elements; they are both 1-dimensional. High-dimension input would first be flattened.

Parameters
  • mask (Tensor) – condition param; must be the same shape with data.

  • x (Tensor) – input tensor from which to take elements.

Examples

>>> import numpy as np
>>> mask = Tensor(np.array([[True, False], [False, True]], dtype=np.bool_))
>>> x = Tensor(np.array([[1, np.inf], [np.nan, 4]],
...     dtype=np.float32))
>>> v, index = F.cond_take(mask, x)
>>> print(v.numpy(), index.numpy())
[1. 4.] [0 3]
Return type

Tensor