TracedModule

class TracedModule(is_top, argdef_graph_map, argdef_outdef_map, is_qat=False)[source]

TracedModule is the Module created by tracing normal module.

It owns an argdef to graph(InternalGraph) map. The forward method of TracedModule will get a graph from argdef_graph_map according to the argdef of input args/kwargs and interpret it.

Note

TracedModule can only be created by trace_module. See trace_module for more details.

clear_end_points()[source]

Clear the end_points.

clear_watch_points()[source]

Clear the watch_points and watch_node_value.

flatten()[source]

Get a new TracedModule, which eliminates GetAttr and has no hierarchy.

Retruns:

A new TracedModule.

property graph

Return the InternalGraph of this TracedModule.

Return type

InternalGraph

set_end_points(nodes)[source]

Initialize the end_points.

When all the nodes are generated, the Module will stop execution and return directly.

Parameters

nodes (Sequence[Node]) – a list of Node.

For example, the following code

import megengine.module as M
import megengine as mge
import megengine.traced_module as tm

class MyModule(M.Module):
    def forward(self, x):
        x = x + 1 + 2
        return x

net = MyModule()

inp = mge.Tensor([0])
traced_module = tm.trace_module(net, inp)
add_1_node = traced_module.graph.get_node_by_id(2).as_unique()
traced_module.set_end_points(add_1_node)

out = traced_module(inp)

Will get the following out:

print(out)
[Tensor([1.], device=xpux:0)]
set_watch_points(nodes)[source]

Initialize the watch_points.

You can call this function to get the Tensor/Module corresponding to a Node at runtime.

Parameters

nodes – a list of Node.

For example, the following code

import megengine.module as M
import megengine as mge
import megengine.traced_module as tm

class MyModule(M.Module):
    def forward(self, x):
        x = x + 1 + 2
        return x

net = MyModule()

inp = mge.Tensor([0])
traced_module = tm.trace_module(net, inp)
add_1_node = traced_module.graph.get_node_by_id(2).as_unique()
traced_module.set_watch_points(add_1_node)

out = traced_module(inp)

Will get the following watch_node_value:

print(traced_module.watch_node_value)
{add_out: Tensor([1.], device=xpux:0)}