megengine.traced_module.traced_module.InternalGraph

class InternalGraph(name=None, prefix_name='', module_name='')[源代码]

InternalGraph is the main data structure used in the TracedModule. It is used to represent the execution procedure of Module’s forward method.

For example, the following code

import megengine.random as rand
import megengine.functional as F
import megengine.module as M

import megengine.traced_module as tm

class MyModule(M.Module):
    def __init__(self):
        super().__init__()
        self.param = rand.normal(size=(3, 4))
        self.linear = M.Linear(4, 5)

    def forward(self, x):
        return F.relu(self.linear(x + self.param))

net = MyModule()

inp = F.zeros(shape = (3, 4))
traced_module = tm.trace_module(net, inp)

Will produce the following InternalGraph:

print(traced_module.graph)
MyModule.Graph (self, x) {
        %2:     linear = getattr(self, "linear") -> (Linear)
        %3:     param = getattr(self, "param") -> (Tensor)
        %4:     add_out = x.__add__(param, )
        %5:     linear_out = linear(add_out, )
        %6:     relu_out = nn.relu(linear_out, )
        return relu_out
}

Attributes

inputs

Get the list of input Nodes of this graph.

outputs

Get the list of output Nodes of this graph.

top_graph

Get the parent graph of this graph.

Methods

add_input_node(shape[, dtype, name])

Add an input node to the graph.

add_output_node(node)

Add an output node to the Graph.

compile()

Delete unused expr.

eval(*inputs)

Call this method to execute the graph.

exprs([recursive])

Get the Exprs that constitute this graph.

get_dep_exprs(nodes)

Get the dependent Exprs of the nodes.

get_expr_by_id([expr_id, recursive])

Filter Exprs by their id.

get_function_by_type([func, recursive])

Filter Exprs by the type of CallFunction.

get_method_by_type([method, recursive])

Filter Exprs by the type of CallMethod.

get_module_by_type(module_cls[, recursive])

Filter Nodes by the module_type of ModuleNode.

get_node_by_id([node_id, recursive])

Filter Nodes by their id.

get_node_by_name([name, ignorecase, recursive])

Filter Nodes by their full name.

insert_exprs([expr])

Initialize the trace mode and insertion position.

interpret(*inputs)

nodes([recursive])

Get the Nodes that constitute this graph.

replace_node(repl_dict)

Replace the Nodes in the graph.

reset_inputs(*args, **kwargs)

reset_outputs(outputs)

Reset the output Nodes of the graph.