megengine.utils.network.NodeFilter

class NodeFilter(node_iter)[源代码]

作用于网络节点(包括计算节点和变量)的过滤器,其本质是一个封装特定过滤功能的 NetworkNode 迭代器。

示例

# find all :class:`.ImmutableTensor` nodes
for i in NodeFilter(node_iter).param_provider():
    print(i)

# find all :class:`.ImmutableTensor` nodes that end with ':W'
for i in NodeFilter(node_iter).param_provider().name('*:W'):
    print(i)

# number of inputs
nr_input = NodeFilter(node_iter).data_provider().as_count()

方法

as_count()

返回迭代器的长度。

as_dict()

遍历过滤器并返回一个有序字典,其键为节点名称,值为节点对象。

as_list()

将过滤器中的内容以列表的形式表示并返回。

as_unique()

断言过滤器中只包含单个元素,并返回该元素。

check_type(node_type)

断言过滤器中的所有计算节点都属于给定类型。

data_provider()

返回所有类型为 DataProvider 的计算节点,该方法是 .type(DataProvider) 的简写。

has_input(var)

遍历过滤器中的所有算子节点,返回以给定变量节点为输入的对象集合。

make_all_deps(*dest_vars)

创建并返回一个过滤器 NodeFilter ,其中包含给定变量依赖的所有节点。

name(pattern[, ignorecase])

通过节点名称过滤。

not_type(node_type)

移除过滤器中所有指定类型的计算节点。

param_provider()

get ParamProvider oprs; shorthand for .type(ParamProvider)

type(node_type)

通过节点类型过滤。