megengine.traced_module.traced_module.InternalGraph¶
- class InternalGraph(name=None, prefix_name='', module_name='')[源代码]¶
InternalGraph
is the main data structure used in the TracedModule. It is used to represent the execution procedure of Module’s forward method.For example, the following code
import megengine.random as rand import megengine.functional as F import megengine.module as M import megengine.traced_module as tm class MyModule(M.Module): def __init__(self): super().__init__() self.param = rand.normal(size=(3, 4)) self.linear = M.Linear(4, 5) def forward(self, x): return F.relu(self.linear(x + self.param)) net = MyModule() inp = F.zeros(shape = (3, 4)) traced_module = tm.trace_module(net, inp)
Will produce the following
InternalGraph
:print(traced_module.graph)
MyModule.Graph (self, x) { %2: linear = getattr(self, "linear") -> (Linear) %3: param = getattr(self, "param") -> (Tensor) %4: add_out = x.__add__(param, ) %5: linear_out = linear(add_out, ) %6: relu_out = nn.relu(linear_out, ) return relu_out }
Attributes
Get the list of input Nodes of this graph.
Get the list of output Nodes of this graph.
Get the parent graph of this graph.
Methods
add_input_node
(shape[, dtype, name])Add an input node to the graph.
add_output_node
(node)Add an output node to the Graph.
compile
()Delete unused expr.
eval
(*inputs)Call this method to execute the graph.
exprs
([recursive])Get the Exprs that constitute this graph.
get_dep_exprs
(nodes)Get the dependent Exprs of the
nodes
.get_expr_by_id
([expr_id, recursive])Filter Exprs by their
id
.get_function_by_type
([func, recursive])Filter Exprs by the type of
CallFunction
.get_method_by_type
([method, recursive])Filter Exprs by the type of
CallMethod
.get_module_by_type
(module_cls[, recursive])Filter Nodes by the
module_type
ofModuleNode
.get_node_by_id
([node_id, recursive])Filter Nodes by their
id
.get_node_by_name
([name, ignorecase, recursive])Filter Nodes by their full name.
insert_exprs
([expr])Initialize the trace mode and insertion position.
interpret
(*inputs)nodes
([recursive])Get the Nodes that constitute this graph.
replace_node
(repl_dict)Replace the Nodes in the graph.
reset_inputs
(*args, **kwargs)reset_outputs
(outputs)Reset the output Nodes of the graph.