TracedModule

class TracedModule(is_top, argdef_graph_map, argdef_outdef_map, is_qat=False)[源代码]

TracedModule 是通过追溯(tracing)普通模块创建的 Module.

它拥有一个 argdef 到计算图 (InternalGraph) 的映射。TracedModule 的 forward 方法将根据输入参数 args/kwargs 的 argdef 从 argdef_graph_map 中获取一个计算图,并对其进行解释执行。

注解

TracedModule 只能由 trace_module 创建。看 trace_module 获取更多细节。

clear_end_points()[源代码]

清除 end_points.

clear_watch_points()[源代码]

清除 watch_pointswatch_node_value.

flatten()[源代码]

获取一个去除了 GetAttr 并且没有层次的 TracedModule.

返回:

一个新的 TracedModule.

property graph

返回这个 TracedModuleInternalGraph.

返回类型

InternalGraph

set_end_points(nodes)[源代码]

初始化 end_points.

当所有的 nodes 生成后,模块将停止执行并直接返回。

参数

nodes (Sequence[Node]) – 一个 Node 列表。

例如,下面的代码

import megengine.module as M
import megengine as mge
import megengine.traced_module as tm

class MyModule(M.Module):
    def forward(self, x):
        x = x + 1 + 2
        return x

net = MyModule()

inp = mge.Tensor([0])
traced_module = tm.trace_module(net, inp)
add_1_node = traced_module.graph.get_node_by_id(2).as_unique()
traced_module.set_end_points(add_1_node)

out = traced_module(inp)

将获得以下 out:

print(out)
[Tensor([1.], device=xpux:0)]
set_watch_points(nodes)[源代码]

初始化 watch_points.

你可以在运行时调用这个函数来获得与 Node 对应的 Tensor/Module.

参数

nodes – 一个 Node 列表。

例如,下面的代码

import megengine.module as M
import megengine as mge
import megengine.traced_module as tm

class MyModule(M.Module):
    def forward(self, x):
        x = x + 1 + 2
        return x

net = MyModule()

inp = mge.Tensor([0])
traced_module = tm.trace_module(net, inp)
add_1_node = traced_module.graph.get_node_by_id(2).as_unique()
traced_module.set_watch_points(add_1_node)

out = traced_module(inp)

将获得以下 watch_node_value:

print(traced_module.watch_node_value)
{add_out: Tensor([1.], device=xpux:0)}