megengine.functional.flatten

flatten(inp, start_axis=0, end_axis=- 1)[source]

Reshapes the tensor by flattening the sub-tensor from dimension start_axis to dimension end_axis.

Parameters
  • inp (Tensor) – input tensor.

  • start_axis (int) – start dimension that the sub-tensor to be flattened. Default: 0

  • end_axis (int) – end dimension that the sub-tensor to be flattened. Default: -1

Return type

Tensor

Returns

output tensor.

Examples

>>> import numpy as np
>>> inp_shape = (2, 2, 3, 3)
>>> x = Tensor(
...     np.arange(36, dtype=np.int32).reshape(inp_shape),
... )
>>> out = F.flatten(x, 2)
>>> x.numpy().shape
(2, 2, 3, 3)
>>> out.numpy().shape
(2, 2, 9)