megengine.traced_module.TracedModule

class TracedModule(is_top, argdef_graph_map, argdef_outdef_map, is_qat=False)[源代码]

TracedModule is the Module created by tracing normal module.

It owns an argdef to graph(InternalGraph) map. The forward method of TracedModule will get a graph from argdef_graph_map according to the argdef of input args/kwargs and interpret it.

注解

TracedModule can only be created by trace_module. See trace_module for more details.

Attributes

argdef_graph_map

argdef_outdef_map

graph

Return the InternalGraph of this TracedModule

Methods

apply(fn)

Applies function fn to all the modules within this module, including itself.

buffers([recursive])

Returns an iterable for the buffers of the module.

children(**kwargs)

Returns an iterable for all the submodules that are direct attributes of this module.

clear_end_points()

Clear the end_points.

clear_watch_points()

Clear the watch_points and watch_node_value.

disable_quantize([value])

Sets module's quantize_disabled attribute and return module.

eval()

Sets training mode of all the modules within this module (including itself) to False.

flatten()

Get a new TracedModule, which eliminates GetAttr and has no hierarchy.

forward(*args, **kwargs)

load_state_dict(state_dict[, strict])

Loads a given dictionary created by state_dict into this module.

modules(**kwargs)

Returns an iterable for all the modules within this module, including itself.

named_buffers([prefix, recursive])

Returns an iterable for key buffer pairs of the module, where key is the dotted path from this module to the buffer.

named_children(**kwargs)

Returns an iterable of key-submodule pairs for all the submodules that are direct attributes of this module, where 'key' is the attribute name of submodules.

named_modules([prefix])

Returns an iterable of key-module pairs for all the modules within this module, including itself, where 'key' is the dotted path from this module to the submodules.

named_parameters([prefix, recursive])

Returns an iterable for key Parameter pairs of the module, where key is the dotted path from this module to the Parameter.

named_tensors([prefix, recursive])

Returns an iterable for key tensor pairs of the module, where key is the dotted path from this module to the tensor.

parameters([recursive])

Returns an iterable for the Parameter of the module.

register_forward_hook(hook)

Registers a hook to handle forward results.

register_forward_pre_hook(hook)

Registers a hook to handle forward inputs.

replace_param(params, start_pos[, seen])

Replaces module's parameters with params, used by ParamPack to

set_end_points(nodes)

Initialize the end_points.

set_watch_points(nodes)

Initialize the watch_points.

state_dict([rst, prefix, keep_var])

tensors([recursive])

Returns an iterable for the Tensor of the module.

train([mode, recursive])

Sets training mode of all the modules within this module (including itself) to mode.

zero_grad()

Sets all parameters' grads to zero