import megengine.functional as F
from megengine.jit import trace

# import trace之后设置 enabled 属性切换动静态图
trace.enabled = True

# 使用 trace 类装饰网络 forward 的函数
@trace
def train_func(data, label, *, opt, net):
    pred = net(data)
    loss = F.cross_entropy_with_softmax(pred, label)
    opt.backward(loss)
    return pred, loss

# 调用函数训练网络,动静态图一套代码
train_func(data, label, opt=optimizer, net=le_net)
动静合一
瞄准痛点:静态图好部署,动态图易调试,但两者难以兼得
  • 同时适配科研实验和生产部署环境
  • 内置动静转换
  • 动静态混合编程
兼容并包
瞄准痛点:框架学习接口各异,模型复现困难,学习成本高
  • Pythonic风格API,简单直接,易于上手
  • 支持导入PyTorch Module
  • 特别为计算机视觉(Computer Vision)任务优化
import megengine as mge
import megengine.functional as F
import megengine.module as M
import numpy as np

# 经典的基于 Module 的网络搭建接口
class LeNet(M.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = M.Conv2d(1, 6, 5)
        self.relu1 = M.ReLU()
        self.pool1 = M.MaxPool2d(2, 2)
        # 省略部分代码...
        self.classifer = M.Linear(84, 10)

    # 符合 Pythonic 风格的计算流程代码
    def forward(self, x):
        x = self.pool1(self.relu1(self.conv1(x)))
        # 省略部分代码...
        x = self.classifer(x)
        return x
灵活高效
瞄准痛点:生产环境计算设备繁多,缺乏优秀性能
  • 高性能算子,充分利用算力
  • 高效内存优化策略,支持自动 Sublinear 内存优化
  • JIT代码生成机制,加速计算
  • 内置算法选择,智能适配设备
训练推理一体
瞄准痛点:从研究到生产,流程复杂,精度难以对齐
  • 从训练到推理,无需模型转化,精度损失最小化
  • 跨设备模型精度对齐
  • 自动模型优化简化流程
from megengine.jit import trace

# 使用 trace 类装饰网络 forward 的函数
@trace
def val_func(x, *, net):
    return net(x)

# 调用trace接口无需运行直接编译网络
val_func.trace(inp, net=net)

# 将编译后的网络进行导出,直接生成可用于部署的序列化文件
val_func.dump('./mnist.mge', arg_names=["data"])

京ICP备19000496号-11

用户协议| 隐私政策