TracedModule¶
- class TracedModule(is_top, argdef_graph_map, argdef_outdef_map, is_qat=False)[source]¶
TracedModuleis the Module created by tracing normal module.It owns an argdef to graph(InternalGraph) map. The forward method of
TracedModulewill get a graph fromargdef_graph_mapaccording to the argdef of inputargs/kwargsand interpret it.Note
TracedModulecan only be created bytrace_module. Seetrace_modulefor more details.- flatten()[source]¶
Get a new TracedModule, which eliminates
GetAttrand has no hierarchy.- Retruns:
A new
TracedModule.
- property graph¶
Return the
InternalGraphof thisTracedModule.- Return type
- set_end_points(nodes)[source]¶
Initialize the
end_points.When all the
nodesare generated, the Module will stop execution and return directly.- Parameters
nodes (
Sequence[Node]) – a list ofNode.
For example, the following code
import megengine.module as M import megengine as mge import megengine.traced_module as tm class MyModule(M.Module): def forward(self, x): x = x + 1 + 2 return x net = MyModule() inp = mge.Tensor([0]) traced_module = tm.trace_module(net, inp) add_1_node = traced_module.graph.get_node_by_id(2).as_unique() traced_module.set_end_points(add_1_node) out = traced_module(inp)
Will get the following
out:print(out)
[Tensor([1.], device=xpux:0)]
- set_watch_points(nodes)[source]¶
Initialize the
watch_points.You can call this function to get the
Tensor/Modulecorresponding to aNodeat runtime.- Parameters
nodes – a list of
Node.
For example, the following code
import megengine.module as M import megengine as mge import megengine.traced_module as tm class MyModule(M.Module): def forward(self, x): x = x + 1 + 2 return x net = MyModule() inp = mge.Tensor([0]) traced_module = tm.trace_module(net, inp) add_1_node = traced_module.graph.get_node_by_id(2).as_unique() traced_module.set_watch_points(add_1_node) out = traced_module(inp)
Will get the following
watch_node_value:print(traced_module.watch_node_value)
{add_out: Tensor([1.], device=xpux:0)}