与其它框架 API 进行对比

警告

MegEngine 的 API 设计遵循 MEP 3 – Tensor API 设计规范 , 向《数组 API 标准》靠齐。

  • 在同其它框架进行对比时,同样的命名不意味着用法也完全一致;

  • 如果有新的 API 支持需求,可在 GitHub 创建相应的 Issue 或 Pull Request.

注解

  • 你可以利用浏览器的查找功能在当前页面查询对应的 API.

  • 当前页面并非自动生成,如发现有缺失/过时内容,欢迎编辑当前页面。

参见

当前页面更多用于检索,具有其它框架使用经验的用户还可以参考 用户迁移指南

Data Structure

NumPy

Pytorch

MegEngine

Comment

ndarray

Tensor

Tensor

深入理解 Tensor 数据结构

General tensor operations

Trigonometric functions

NumPy

Pytorch

MegEngine

Comment

sin

sin

sin

cos

cos

cos

tan

tan

tan

arcsin

asin

asin

arccos

acos

acos

arctan

atan

atan

Hyperbolic functions

NumPy

Pytorch

MegEngine

Comment

sinh

sinh

sinh

cosh

cosh

cosh

tanh

tanh

tanh

arcsinh

asinh

asinh

arccosh

acosh

acosh

arctanh

atanh

atanh

Bit operations

NumPy

Pytorch

MegEngine

Comment

left_shift

Not Found

left_shift

<< operator

right_shift

Not Found

right_shift

>> operator

Statistical Functions

NumPy

Pytorch

MegEngine

Comment

sum

sum

sum

prod

prod

prod

mean

mean

mean

min

min

min

max

max

max

var

var

var

std

std

std

cumsum

cumsum

cumsum

Linear Algebra Functions

NumPy

Pytorch

MegEngine

Comment

transpose

transpose

transpose

dot

dot

dot

inv

inv

matinv

matmul

matmul

matmul

svd

svd

svd

norm

norm

norm

Indexing Functions

NumPy

Pytorch

MegEngine

Comment

take_along_axis

gather

gather

put_along_axis

scatter

scatter

where

where

where / cond_take

取决于传参情况

Searching Functions

NumPy

Pytorch

MegEngine

Comment

argmin

argmin

argmin

argmax

argmax

argmax

Sorting Functions

NumPy

Pytorch

MegEngine

Comment

argsort

argsort

argsort

sort

sort

sort

NN Funtional Operations

Linear functions

Pytorch

MegEngine

Comment

linear

linear

bilinear

Not Implemeted

Sparse functions

Pytorch

MegEngine

Comment

embedding

embedding

embedding_bag

Not Implemeted

one_hot

one_hot

NN Module

Pytorch

MegEngine

Comment

Parameter

Parameter

Containers

Pytorch

MegEngine

Comment

Module

Module

Sequential

Sequential

ModuleList

MegEngine 原生支持

ModuleDict

MegEngine 原生支持

ParameterList

MegEngine 原生支持

ParameterDict

MegEngine 原生支持

Padding Layers

Pytorch

MegEngine

Comment

ReflectionPad1d

Pad

mode = REFLECT

ReflectionPad2d

Pad

mode = REFLECT

ReflectionPad3d

Pad

mode = REFLECT

ReplicationPad1d

Pad

mode = EDGE

ReplicationPad2d

Pad

mode = EDGE

ReplicationPad3d

Pad

mode = EDGE

ZeroPad2d

Pad

mode = CONSTANT

ConstantPad1d

Pad

mode = CONSTANT

ConstantPad2d

Pad

mode = CONSTANT

ConstantPad3d

Pad

mode = CONSTANT

Linear Layers

Pytorch

MegEngine

Comment

Identity

Identity

Linear

Linear

Bilinear

Not Implemeted

Sparse Layers

Pytorch

MegEngine

Comment

Embedding

Embedding

EmbeddingBag

Not Implemeted

Distance Functions

Pytorch

MegEngine

Comment

CosineSimilarity

Not Implemeted

PairwiseDistance

Not Implemeted

Vision functions

Pytorch

MegEngine

Comment

pixel_shuffle

Not Implemeted

pad

Not Implemeted

interpolate

interpolate

upsample

interpolate

upsample_nearest

interpolate

upsample_bilinear

interpolate

grid_sample

remap

affine_grid

warp_affine

nms

nms

roi_align

roi_align

roi_pool

roi_pooling

OpenCV Python Package

Pytorch

MegEngine

Comment

cvtColor

cvt_color

resize

interpolate

remap

remap

warpAffine

warp_affine

warpPerspective

warp_perspective

NVIDIA

Pytorch

MegEngine

Comment

correlation

correlation

nvof

nvof

Not Implemeted

注解

一些 API 在 MegEngine 中可能还没有实现,但所有的 API 并不是一开始就被设计出来的。 我们可以像搭积木一样,利用已经存在的基础 API 来组合出 MegEngine 中尚未提供的接口。

比如 “如何实现 roll ” 这个问题,可以使用 splitconcat 拼接出来:

import megengine.functional as F

def roll(x, shifts, axis):
    shp = x.shape
    dim = len(shp)
    if isinstance(shifts, int):
        assert isinstance(axis, int)
        shifts = [shifts]
        axis = [axis]
    assert len(shifts) == len(axis)
    y = x
    for i in range(len(shifts)):
        axis_ = axis[i]
        shift_ = shifts[i]
        axis_t_ = axis_ + dim if axis_ < 0 else axis_
        assert (
            dim > axis_t_ >= 0
        ), "axis out of range (expected to be in range of [{}, {}], but got {})".format(
            -dim, dim - 1, axis_
        )
        if shift_ == 0:
            continue
            size = shp[axis_t_]
        if shift_ > 0:
            a, b = F.split(y, [size - shift_,], axis=axis_t_)
        else:
            a, b = F.split(y, [-shift_,], axis=axis_t_)
        y = F.concat((b, a), axis=axis_t_)
      return y

除此之外,你可以尝试在 GitHub Issues 或论坛中针对 API 问题发起求助。

我们也欢迎你将自己实现的 API 以 Pull Request 的形式提交到 MegEngine 代码库中来~

注解

对于缺失的 Loss Funtions 算子,大都可自行设计实现。