TracedModule¶
- class TracedModule(is_top, argdef_graph_map, argdef_outdef_map, is_qat=False)[源代码]¶
TracedModule
是通过追溯(tracing)普通模块创建的 Module.它拥有一个 argdef 到计算图 (InternalGraph) 的映射。
TracedModule
的 forward 方法将根据输入参数args/kwargs
的 argdef 从argdef_graph_map
中获取一个计算图,并对其进行解释执行。注解
TracedModule
只能由trace_module
创建。看trace_module
获取更多细节。- flatten()[源代码]¶
获取一个去除了
GetAttr
并且没有层次的 TracedModule.- 返回:
一个新的
TracedModule
.
- property graph¶
返回这个
TracedModule
的InternalGraph
.- 返回类型
- set_end_points(nodes)[源代码]¶
初始化
end_points
.当所有的
nodes
生成后,模块将停止执行并直接返回。- 参数
nodes (
Sequence
[Node
]) – 一个Node
列表。
例如,下面的代码
import megengine.module as M import megengine as mge import megengine.traced_module as tm class MyModule(M.Module): def forward(self, x): x = x + 1 + 2 return x net = MyModule() inp = mge.Tensor([0]) traced_module = tm.trace_module(net, inp) add_1_node = traced_module.graph.get_node_by_id(2).as_unique() traced_module.set_end_points(add_1_node) out = traced_module(inp)
将获得以下
out
:print(out)
[Tensor([1.], device=xpux:0)]
- set_watch_points(nodes)[源代码]¶
初始化
watch_points
.你可以在运行时调用这个函数来获得与
Node
对应的Tensor/Module
.- 参数
nodes – 一个
Node
列表。
例如,下面的代码
import megengine.module as M import megengine as mge import megengine.traced_module as tm class MyModule(M.Module): def forward(self, x): x = x + 1 + 2 return x net = MyModule() inp = mge.Tensor([0]) traced_module = tm.trace_module(net, inp) add_1_node = traced_module.graph.get_node_by_id(2).as_unique() traced_module.set_watch_points(add_1_node) out = traced_module(inp)
将获得以下
watch_node_value
:print(traced_module.watch_node_value)
{add_out: Tensor([1.], device=xpux:0)}