megengine.functional.nn.warp_perspective

warp_perspective(inp, mat, out_shape, mat_idx=None, border_mode='replicate', border_val=0.0, format='NCHW', interp_mode='linear')[source]

Applies perspective transformation to batched 2D images. A perspective transformation is a projection of a image onto a new view plane.

The input images are transformed to the output images by the transformation matrix:

\[\text{output}(n, c, h, w) = \text{input} \left( n, c, \frac{M_{00}w + M_{01}h + M_{02}}{M_{20}w + M_{21}h + M_{22}}, \frac{M_{10}w + M_{11}h + M_{12}}{M_{20}w + M_{21}h + M_{22}} \right)\]

Optionally, we can set mat_idx to assign different transformations to the same image, otherwise the input images and transformations should be one-to-one correnspondence.

Parameters
  • inp (Tensor) – input image.

  • mat (Tensor) – (batch, 3, 3) transformation matrix.

  • out_shape (Union[Tuple[int, int], int, Tensor]) – (h, w) size of the output image.

  • mat_idx (Union[Iterable[int], Tensor, None]) – image batch idx assigned to each matrix. Default: None

  • border_mode (str) – pixel extrapolation method. Default: “replicate”. Currently also support “constant”, “reflect”, “reflect_101”, “wrap”.

  • border_val (float) – value used in case of a constant border. Default: 0

  • format (str) – NHWC” is also supported. Default: “NCHW”.

  • interp_mode (str) – interpolation methods. Default: “linear”. Currently only support “linear” mode.

Return type

Tensor

Returns

output tensor.

Note

The transformation matrix is the inverse of that used by cv2.warpPerspective.

Examples

>>> import numpy as np
>>> inp_shape = (1, 1, 4, 4)
>>> x = Tensor(np.arange(16, dtype=np.float32).reshape(inp_shape))
>>> M_shape = (1, 3, 3)
>>> # M defines a translation: dst(1, 1, h, w) = rst(1, 1, h+1, w+1)
>>> M = Tensor(np.array([[1., 0., 1.],
...                      [0., 1., 1.],
...                      [0., 0., 1.]], dtype=np.float32).reshape(M_shape))
>>> out = F.vision.warp_perspective(x, M, (2, 2))
>>> out.numpy()
array([[[[ 5.,  6.],
         [ 9., 10.]]]], dtype=float32)