与其它框架 API 进行对比

注解

  • 你可以利用浏览器的查找功能在当前页面查询对应的 API.

  • 当前页面并非自动生成,如发现有缺失/过时内容,欢迎编辑当前页面。

Standard tensor operations

NumPy

Pytorch

MegEngine

Comment

numpy.ndarray

torch.Tensor

megengine.Tensor

Basic N-dimensional array type

警告

Pytorch 的 Tensor 类中提供了许多操作/计算方法,而在 MegEngine 中这些方法被统一实现在 functional 模块中, 意味着类似 functional.add() 等操作并不一定存在着对应的 Tensor.add() 实现,这是设计上的历史决定。

警告

Pytorch 中默认所有 Tensor 都需要被求导,因此提供了 torch.no_grad 来禁用梯度计算。 而在 MegEngine 中 Tensor 默认不需要被求导,需要通过 megengine.autodiff.GradManager.attach 来进行绑定, 被绑定后的 Tensor 可以通过 megengine.Tensor.detach 来解除绑定。

Bit operations

NumPy

Pytorch

MegEngine

Comment

numpy.left_shift

Not Found

megengine.functional.left_shift

numpy.right_shift

Not Found

megengine.functional.right_shift

NN Funtional Operations

NN Module

Pytorch

MegEngine

torch.nn.parameter.Parameter

megengine.Parameter

Containers

Pytorch

MegEngine

torch.nn.Module

megengine.module.Module

torch.nn.Sequential

megengine.module.Sequential

torch.nn.ModuleList

MegEngine 原生支持

torch.nn.ModuleDict

MegEngine 原生支持

torch.nn.ParameterList

MegEngine 原生支持

torch.nn.ParameterDict

MegEngine 原生支持

Not Implemeted

注解

一些 API 在 MegEngine 中可能还没有实现,但所有的 API 并不是一开始就被设计出来的。 我们可以像搭积木一样,利用已经存在的基础 API 来组合出 MegEngine 中尚未提供的接口。

比如 “如何实现 torch.roll ” 这个问题,可以使用 splitconcat 拼接出来:

import megengine.functional as F

def roll(x, shifts, axis):
    shp = x.shape
    dim = len(shp)
    if isinstance(shifts, int):
        assert isinstance(axis, int)
        shifts = [shifts]
        axis = [axis]
    assert len(shifts) == len(axis)
    y = x
    for i in range(len(shifts)):
        axis_ = axis[i]
        shift_ = shifts[i]
        axis_t_ = axis_ + dim if axis_ < 0 else axis_
        assert (
            dim > axis_t_ >= 0
        ), "axis out of range (expected to be in range of [{}, {}], but got {})".format(
            -dim, dim - 1, axis_
        )
        if shift_ == 0:
            continue
            size = shp[axis_t_]
        if shift_ > 0:
            a, b = F.split(y, [size - shift_,], axis=axis_t_)
        else:
            a, b = F.split(y, [-shift_,], axis=axis_t_)
        y = F.concat((b, a), axis=axis_t_)
      return y

除此之外,你可以尝试在 GitHub Issues 或论坛中针对 API 问题发起求助。

我们也欢迎你将自己实现的 API 以 Pull Request 的形式提交到 MegEngine 代码库中来~

注解

对于缺失的 Loss Funtions 算子,大都可自行设计实现。