megengine.traced_module.TracedModule.set_watch_points

TracedModule.set_watch_points(nodes)[源代码]

Initialize the watch_points.

You can call this function to get the Tensor/Module corresponding to a Node at runtime.

参数

nodes – a list of Node.

For example, the following code

import megengine.module as M
import megengine as mge
import megengine.traced_module as tm

class MyModule(M.Module):
    def forward(self, x):
        x = x + 1 + 2
        return x

net = MyModule()

inp = mge.Tensor([0])
traced_module = tm.trace_module(net, inp)
add_1_node = traced_module.graph.get_node_by_id(2).as_unique()
traced_module.set_watch_points(add_1_node)

out = traced_module(inp)

Will get the following watch_node_value:

print(traced_module.watch_node_value)
{add_out: Tensor([1.], device=xpux:0)}