API 与 使用方式

注解

注意:TracedModule API 在未来一段时间会根据使用反馈进行调整,请关注 github release note 获取变更。欢迎在文档或 Github 提交使用反馈,一起让模型到应用更快更便捷!

以 resnet18 为例介绍 TracedModule 的使用方式,model.py 可从 这里 下载。 通过 trace_module 方法将一个普通的 Module 转变成 TracedModule。接口形式如下:

def trace_module(module: Module, *inputs, **kwargs) -> TracedModule:
    """
    module: 要被 trace 的原 Module
    inputs/kwargs: Module.forward 所需的参数
    """
    ...
    return traced_module

将自定义的 resnet18(Module)转换为 TracedModule:

import megengine.functional as F
import megengine.module as M
import megengine as mge
import model

# resnet : Module
resnet = model.__dict__["resnet18"]()

import megengine.traced_module as tm
inp = F.zeros(shape=(1,3,224,224))

# traced_resnet : TracedModule
traced_resnet =  tm.trace_module(resnet, inp)

Node 、Expr 、InternalGraph 的常用属性和方法

TracedModule.graph

查看 TracedModule 对应的 InternalGraph,以及子 TracedModule 对应的 InternalGraph。通过 "{:ip}".format(InternalGraph) 查看 Expr 的 id,Node 的 id 和 name。在一个 InternalGraph 中每个 Expr 和 Node 都有一个唯一的 id 与其对应。通过这个 id 可以区分和定位不同的 Expr 与 Node。

InternalGraph.exprs

遍历 Graph 中的 Expr。通过访问 InternalGraph.exprs 可得到该 graph 按执行顺序的 Expr 序列。

InternalGraph.exprs (recursive : bool = True)

按 Expr 执行顺序获取 Expr 执行序列

  • recursive: 是否获取子 Graph 中的 Expr,默认为 True

InternalGraph.nodes

遍历 Graph 中的 Node。通过访问 InternalGraph.nodes 可得到该 graph 中的 Node 序列。

InternalGraph.nodes (recursive : bool = True)

按 id 从小到大返回 Graph 中的 Node

  • recursive: 是否获取子 Graph 中的 Node,默认为 True

Expr.inputs & Expr.outputs

通过访问 Expr 的 inputs 和 outputs 属性,可获得该 Expr 的输入和输出 Node。

Expr.inputs : List[Node] Expr.outputs : List[Node]

Node.expr

通过访问 Node 的 expr 属性,可获得该 Node 是由哪个 Expr 生成的。

Node.expr : Expr

Node.users

通过访问 Node 的 users 属性,可获得该 Node 是将会被哪些 Expr 作为输入。

Node.users : Lsit[Expr]

ModuleNode.owner

通过访问 ModuleNode 的 owner 属性,可直接访问该 ModuleNode 所对应的 Module。

ModuleNode.owner : Module

Node.top_graph & Expr.top_graph

通过访问 Node 或 Expr 的 top_graph 属性,可直获得该 Node 或 Expr 所属的 InternalGraph。

Node.top_graph : InternalGraph

Expr.top_graph : InternalGraph

InternalGraph.eval

通过访问 InternalGraph 的 eval 方法,可以直接运行该 Graph。

InternalGraph.eval (*inputs)

将 Tensor 直接输入 Graph 并返回按 Expr 执行序列执行后的结果

  • inputs 模型的输入

Node 和 Expr 的查找方法

BaseFilter

BaseFilter 是一个可迭代的类,其提供了一些方法将迭代器转换为 list, dict 等。

NodeFilterExprFilter 继承于 BaseFilter,NodeFilter 负责处理 Node,ExprFilter 负责处理 Expr。

  • BaseFilter.as_list : 返回 Node 或 Expr 列表

  • BaseFilter.as_dict : 返回 Node 或 Expr 的 id 和 Node 或 Expr 组成的字典

  • BaseFilter.as_unique : 如果查找到的 Node 或 Expr 只有一个,直接返回该 Node 或 Expr, 否则报错

  • BaseFilter.as_count : 返回查找到 Node 或 Expr 的数量

get_node_by_id

通过 Node 的 id 从 Graph 中获取对应 id 的 Node。

InternalGraph.get_node_by_id (node_id: List[int] = None, recursive=True)

获取 InternalGraph 中 id 在 node_id 里的 Node,支持一次查找多个 Node

  • node_id 待查找 Node 的 id 列表

  • recursive 是否查找子 Graph 中的 Node,默认为 True

get_expr_by_id

与 get_node_by_id 类似,该方法通过 Expr 的 id 从 Graph 中获取对应 id 的 Expr

InternalGraph.get_expr_by_id (expr_id: List[int] = None, recursive=True)

获取 InternalGraph 中 id 在 expr_id 里的 Expr,支持一次查找多个 Expr

  • expr_id 待查找 Expr 的 id 列表

  • recursive 是否查找子 Graph 中的 Expr,默认为 True

get_function_by_type

通过该方法查找 Graph 中调用了某个 function 的 CallFunction Expr

InternalGraph.get_function_by_type (func: Callable = None, recursive=True)

获取 InternalGraph 中 self.func == func 的 CallFunction

  • func 可调用的函数

  • recursive 是否查找子 Graph 中的 Expr,默认为 True

get_method_by_type

通过该方法查找 Graph 中调用了某个 method 的 CallMethod Expr

InternalGraph.get_method_by_type (method: str = None, recursive=True)

获取 InternalGraph 中 self.method == method 的 CallMethod

  • method 待查找某对象的方法的名字(该方法是一个可调用的函数)

  • recursive 是否查找子 Graph 中的 Expr,默认为 True

get_module_by_type

通过该方法查找 Graph 中对应某种 Module 的 ModuleNode

InternalGraph.get_module_by_type (module_cls: Module, recursive=True)

获取 InternalGraph 中对应于 module_cls 的 ModuleNode

  • module_cls Module 某个子类

  • recursive 是否查找子 Graph 中的 Expr,默认为 True

图手术常用方法

add_input_node

为最顶层的 InternalGraph 增加一个输入,此输入会作为一个 free_varargs 参数(即无形参名称)。当调用该方法的 Graph 是一个子 Graph 时,将会报错。

InternalGraph.add_input_node (shape, dtype="float32", name="args")

为顶层 Graph 新增一个输入

  • shape 新增输入的 shape

  • dtype 新增输入的 dtype,默认为 “float32”

  • name 新增输入的名字,默认为 “args”,若该名字在 Graph 种已存在,则会在 name 后添加后缀,以保证 name 在 Graph 在的唯一性。

add_output_node

为最顶层的 InternalGraph 增加一个输出,此输入会作为输出元组种的最后一个元素。当调用该方法的 Graph 是一个子 Graph 时,将会报错。

InternalGraph.add_output_node (node: TensorNode)

将 Graph 种的某个 Node 作为 Graph 的一个输出

  • node Graph 中的某 Node

reset_outputs

重新设置最顶层 InternalGraph 的输出。当调用该方法的 Graph 是一个子 Graph 时,将会报错。

当要改变的输出较多时,一个一个调用 add_output_node 较为麻烦,通过 reset_outputs 方法一次性重置输出内容于结构。

InternalGraph.reset_outputs (node: outputs)

重置 Graph 的输出

  • node 由 Graph 中的 TensorNode 构成的某种结构,支持 list, dict, tuple 等(最底层的元素必须是 TensorNode)。

replace_node

替换 InternalGraph 中的指定 Node。可用于新增 Expr 后替换一些 Node,或结合 InternalGraph.compile 删某些 Expr。

InternalGraph.replace_node (repl_dict : Dict[Node, Node])

替换 Graph 中的 key 替换为 value

  • repl_dict 为一个 keyvalue 都为 Node 的字典,且 keyvalue 必须在同一个 Graph 中。生成 value 的 Expr 之后的所有将 key 作为输入的 Expr 的输入将被替换为 value

insert_exprs

向 InternalGraph 中插入 Expr。可用于插入 functionModule ,并在插入的过程中将这些 functionModule 解析为 Expr 或 TracedModule。

一般与 replace_nodecompile 一起使用完成图手术。

InternalGraph.insert_exprs (expr: Optional[Expr] = None)

向 Graph 中插入 Expr

  • exprInternalGraph._exprsexpr 之后插入新的 functionModule

insert_exprs 的作用域里,TensorNode 可以当作 Tensor 使用, ModuleNode 可以当作 Module

注解

由于 __setitem__ 比较特殊,因此在图手术模式下 TensorNode 的赋值结果作为输出时需要特别注意要图手术结果是否符合预期。

# x_node 是一个 TensorNode , x_node 的 name 为 x_node
x = x_node
with graph.insert_exprs():
    x[0] = 1  # 此操作会解析为 setitem_out = x_node.__setietm__(0, 1, ), 此时变量 x 依然对应的是 x_node
    x[0] = 2  # 此操作会解析为 setitem_out_1 = setitem_out.__setietm__(0, 2, ), 此时变量 x 依然对应的是 x_node
graph.replace_node({* : x}) #此处实际替换的 x 依然为 x_node

with graph.insert_exprs():
    x[0] = 1  # 此操作会解析为 setitem_out = x_node.__setietm__(0, 1, ), 此时变量 x 依然对应的是 x_node
    x = x * 1 # 此操作会解析为 mul_out = setitem_out.__mul__(1, ), 此时变量 x 对应的是 mul_out
graph.replace_node({* : x}) #此处实际替换的 x 为 mul_out

compile

该方法会将 InternalGraph 与输出无关的 Expr 删除。

InternalGraph.compile ()

常与 insert_exprsreplace_node 一起使用。

wrap

有时不希望插入的函数被解析为 megengine 内置的 function, 此时可以选择用 wrap 函数将自定义的函数当作 megengine 内置函数处理, 即不再 trace 到函数内部。

wrap (func: Callable)

将自定义函数注册为内置函数

  • func 为一个可调用的对象。

TracedModule 常用方法

flatten

该方法可去除 InternalGraph 的层次结构,即将子 graph 展开, 并返回一个新的 TracedModule。在新的 TracedModule 中,所有的 Getattr Expr 将被转换为 Constant Expr。

TracedModule.flatten ()

返回一个新的 TracedModule,其 Graph 无层次结构

拍平后的 InternalGraph 仅包含内置 Module 的 Expr,此时可以直观的得到数据之间的连接关系。

set_watch_points & clear_watch_points

查看 TracedModule 执行时 graph 中某个 Node 对应的真正的 Tensor/Module。

TracedModule.set_watch_points (nodes : Sequence[Node])

设置观察点

  • nodes 待观察的 Node

TracedModule.clear_watch_points ()

清除观察点

set_end_points & clear_end_points

设置模型停止运行的位置,接受一个 List[Node] 作为输入,当网络生成所有设置的 Node 后会立即返回,不再继续往下执行。

TracedModule.set_end_points (nodes : Sequence[Node])

设置结束运行点

  • nodes 停止运行处的的 Node

TracedModule.clear_end_points ()

清除结束运行点

TracedModule 的局限

  • 不支持动态控制流,动态控制流是指 if 语句中的 condition 随输入的变化而变化,或者是 for, while 每次运行的语句不一样。当 trace 到控制流时,仅会记录并解释满足条件的那个分支。

  • 不支持全局变量(Tensor),即跨 Module 使用 Tensor 将会得到不可预知的结果,如下面的例子。

    import megengine.module as M
    import megengine as mge
    
    g_tensor = mge.Tensor([0])
    
    class Mod(M.Module):
        def forward(self, x):
            x = g_tensor + 1
            return x
    
  • trace 的 Module 或 function 参数中的非 Tensor 类型,将会被看作是常量存储在 Expr 的 const_val 属性中,并且该值将不会再变化。

  • 当被 trace 的自定义 Module 被调用了多次,并且每次传入参数中的非 Tensor 数据不一致时,将会被 trace 出多个 Graph。此时将无法通过 TracedModule.graph 属性访问 Graph,只能通过对应 Moldule 的 CallMethod Expr 访问,如下面的例子。

    import megengine.functional as F
    import megengine.module as M
    import megengine.traced_module as tm
    
    class Mod(M.Module):
        def forward(self, x, b):
            x  = x + b
            return x
    
    class Net(M.Module):
        def __init__(self, ):
            super().__init__()
            self.mod = Mod()
    
        def forward(self, x):
            x = self.mod(x, 1)
            x = self.mod(x, 2)
            return x
    
    net = Net()
    inp = F.zeros(shape=(1, ))
    
    traced_net = tm.trace_module(net, inp)
    
    print(traced_net.graph)
    '''
    Net.Graph (self, x) {
            %5:     mod = getattr(self, "mod") -> (Module)
            %6:     mod_out = mod(x, 1, )
            %10:    mod_1 = getattr(self, "mod") -> (Module)
            %11:    mod_1_out = mod_1(mod_out, 2, )
            return mod_1_out
    }
    '''
    # 此时 traced_net.mod 将会被 trace 出 2 个 graph,因此无法直接访问 graph 属性
    try:
        print(traced_net.mod.graph)
    except:
        print("error")
    
    # 可通过 mod 的 CallMethod Expr 访问对应的 Graph
    print(traced_net.graph.get_expr_by_id(6).as_unique().graph)
    '''
    mod.Graph (self, x) {
            %9:     add_out = x.__add__(1, )
            return add_out
    }
    '''
    print(traced_net.graph.get_expr_by_id(11).as_unique().graph)
    '''
    mod_1.Graph (self, x) {
            %14:    add_out = x.__add__(2, )
            return add_out
    }
    '''