megengine.traced_module.TracedModule.set_end_points

TracedModule.set_end_points(nodes)[源代码]

Initialize the end_points.

When all the nodes are generated, the Module will stop execution and return directly.

参数

nodes (Sequence[Node]) – a list of Node.

For example, the following code

import megengine.module as M
import megengine as mge
import megengine.traced_module as tm

class MyModule(M.Module):
    def forward(self, x):
        x = x + 1 + 2
        return x

net = MyModule()

inp = mge.Tensor([0])
traced_module = tm.trace_module(net, inp)
add_1_node = traced_module.graph.get_node_by_id(2).as_unique()
traced_module.set_end_points(add_1_node)

out = traced_module(inp)

Will get the following out:

print(out)
[Tensor([1.], device=xpux:0)]