megengine.utils.network.NodeFilter¶
- class NodeFilter(node_iter)[源代码]¶
作用于网络节点(包括计算节点和变量)的过滤器,其本质是一个封装特定过滤功能的
NetworkNode
迭代器。示例
# find all :class:`.ImmutableTensor` nodes for i in NodeFilter(node_iter).param_provider(): print(i) # find all :class:`.ImmutableTensor` nodes that end with ':W' for i in NodeFilter(node_iter).param_provider().name('*:W'): print(i) # number of inputs nr_input = NodeFilter(node_iter).data_provider().as_count()
方法
as_count
()返回迭代器的长度。
as_dict
()遍历过滤器并返回一个有序字典,其键为节点名称,值为节点对象。
as_list
()将过滤器中的内容以列表的形式表示并返回。
断言过滤器中只包含单个元素,并返回该元素。
check_type
(node_type)断言过滤器中的所有计算节点都属于给定类型。
返回所有类型为
DataProvider
的计算节点,该方法是.type(DataProvider)
的简写。has_input
(var)遍历过滤器中的所有算子节点,返回以给定变量节点为输入的对象集合。
make_all_deps
(*dest_vars)创建并返回一个过滤器
NodeFilter
,其中包含给定变量依赖的所有节点。name
(pattern[, ignorecase])通过节点名称过滤。
not_type
(node_type)移除过滤器中所有指定类型的计算节点。
get
ParamProvider
oprs; shorthand for.type(ParamProvider)
type
(node_type)通过节点类型过滤。