TracedModule 常见图手术#
Note
阅读本文所展示的图手术例子,需要先了解 TracedMdoule 图手术的相关接口和用法, 请参考 Node 和 Expr 的查找方法 和 图手术常用方法。
修改 Node 的名字#
修改 graph 中的 Node 的名字,可以直接对 Node.name 赋值即可,但在赋值时要确保新的名字未被 graph 中其它的 Node 所使用。
例如修改某 graph 第一个输出的 Node 的名字,可以通过直接获取 graph 的 outputs,来获得输出 Node, 再直接重新设置 Node 的 name 就可对其重命名。
>>> out_node = traced_net.graph.outputs[0]
>>> out_node.name = "I_am_output"
为模型添加前后处理#
由于 TracedModule 可以被重新 trace,因此在加前后处理时,可以新写一个 Module ,并将前处理,主体模型和后处理作为新 Module 的 3 个子 Module, 并在新 module 的 forward 函数中分别调用 3 个 module 就完成了前后处理的添加。例子如下:
添加前后处理
import numpy as np
import pickle
import megengine.functional as F
import megengine.module as M
import megengine.traced_module as tm
class Main(M.Module):
def forward(self, x):
return x
class PreProcess(M.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
x = x*y
return x
class PostProcess(M.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
x = x/y
return x
class Net(M.Module):
def __init__(self, traced_module):
super().__init__()
self.pre_process = PreProcess()
self.traced_module = traced_module
self.post_process = PostProcess()
def forward(self, x, y):
x = self.pre_process(x, y)
x = self.traced_module(x)
x = self.post_process(x, y)
return x
if __name__ == "__main__":
module = Main()
x = F.zeros((1, 14, 8, 8))
traced_module = tm.trace_module(module, x)
obj = pickle.dumps(traced_module)
traced_module = pickle.loads(obj)
# 新写一个 module,将 之前 dump 的 TracedModule 作为该 module 的一个子 module
new_module = Net(traced_module)
x = F.zeros((1, 14, 8, 8))
y = F.ones((1, 14, 8, 8))
traced_module = tm.trace_module(new_module, x, y)
predict = traced_module(x, y)
np.testing.assert_equal(x.numpy(), predict.numpy())
将一些常量吸收到卷积里#
对于一些基于 anchor 的检测算法,经常会在卷积的输出后,对卷积结果乘 stride
或除 anchor_size
,
在推理部署时,可以将这些常量吸收到卷积里,基于 TracedModule 可以较容易的实现这些转换,如下面的例子:
吸常量到卷积中
import numpy as np
import pickle
import megengine.functional as F
import megengine.module as M
import megengine.traced_module as tm
from megengine.traced_module.node import TensorNode
import megengine as mge
class Net(M.Module):
def __init__(self,):
super().__init__()
self.conv = M.Conv2d(in_channels=3, out_channels=16, kernel_size=1, bias=True)
def forward(self, x):
x = self.conv(x)
stride, anchor_size= 8, 128
x = x * stride
x = x / anchor_size
return x
def fuse_const():
net = Net()
inp = mge.Tensor(np.random.random(size = (1,3,16,16)), dtype=np.float32)
traced_net = tm.trace_module(net, inp)
obj = pickle.dumps(traced_net)
traced_net = pickle.loads(obj)
graph = traced_net.graph
for div_expr in graph.get_method_by_type("__truediv__").as_list():
div_self, div_inp = div_expr.args
if isinstance(div_inp, TensorNode):
# 除数不是 TensorNode,就满足了我们的条件
continue
mul_expr = div_self.expr
mul_self, mul_inp = mul_expr.args
call_conv_expr = mul_self.expr
conv_node = call_conv_expr.inputs[0]
# 直接通过 owner 访问 self.conv ,并修改其 weight 和 bias
conv_module = conv_node.owner
conv_module.weight = conv_module.weight * mul_inp / div_inp
conv_module.bias = conv_module.bias * mul_inp / div_inp
# 修改之后,要用 conv 的输出替换 div 的输出
call_conv_expr.top_graph.replace_node({div_expr.outputs[0] : call_conv_expr.outputs[0]})
# 把与 graph 输出无关的 expr 删掉
call_conv_expr.top_graph.compile()
gt = net(inp)
actual = traced_net(inp)
np.testing.assert_equal(gt.numpy(), actual.numpy())
if __name__ == "__main__":
fuse_const()
将一些 OP 转换为 fp16#
对于一些计算量特别大的全连接层,会占用较多的存储资源,可以通过将其转换为 fp16 计算减少其占用的资源, 如下面的例子:
将 Linear 转为 fp16 计算
import numpy as np
import pickle
import megengine.functional as F
import megengine.module as M
import megengine.traced_module as tm
import megengine as mge
class Net(M.Module):
def __init__(self,):
super().__init__()
self.linear_0 = M.Linear(3, 1024, bias = True)
self.linear_1 = M.Linear(1024, 4096, bias=True)
def forward(self, x):
x = self.linear_0(x)
x = self.linear_1(x)
return x
def to_fp16():
net = Net()
inp = mge.Tensor(np.random.random(size = (1,3)), dtype=np.float32)
traced_net = tm.trace_module(net, inp)
obj = pickle.dumps(traced_net)
traced_net = pickle.loads(obj)
graph = traced_net.graph
for linear_node in graph.get_module_by_type(M.Linear).as_list():
linear_module = linear_node.owner
if linear_module.in_features * linear_module.out_features < 100*1024:
# 不满足条件的 Linear 跳过
continue
# 将 weight 和 bias 转换为 fp16
linear_module.weight = linear_module.weight.astype(np.float16)
linear_module.bias = linear_module.bias.astype(np.float16)
linear_call_expr = linear_node.users[0]
# 把输入转换为 fp16
inp_node = linear_call_expr.inputs[1]
with linear_call_expr.top_graph.insert_exprs():
new_inp_node = inp_node.astype(np.float16)
# 将 linear 的输入替换为fp16的输入
linear_call_expr.replace_inputs({inp_node: new_inp_node})
# 把输出转换为 fp16
out_node = linear_call_expr.outputs[0]
with linear_call_expr.top_graph.insert_exprs():
new_out_node = out_node.astype(np.float32)
# 将 out_node 作为输入的 expr 的输入替换为 new_out_node
linear_call_expr.top_graph.replace_node({out_node: new_out_node})
linear_call_expr.top_graph.compile()
gt = net(inp)
actual = traced_net(inp)
np.testing.assert_allclose(gt.numpy(), actual.numpy(), atol=5e-2)
if __name__ == "__main__":
to_fp16()
通过 Graph 确定数据流向#
在量化训练时,常常会对 concat 的输入做某些约束,通过 TracedModule 可以轻易的找到这些 concat 的输入是来自于哪个内置的 function 或 Module 的输出,如下面的例子。
查找 concat 的输入
import numpy as np
import megengine.functional as F
import megengine.module as M
import megengine.traced_module as tm
import megengine as mge
class Net(M.Module):
def __init__(self,):
super().__init__()
self.conv = M.Conv2d(3, 16, 1, bias=False)
self.bn = M.BatchNorm2d(16)
self.conv_bn = M.Sequential(
M.Conv2d(16, 16, 1,bias=False),
M.BatchNorm2d(16)
)
def forward(self, x):
x = self.conv(x)
x0 = self.bn(x)
x1 = self.conv_bn(x0)
x = F.concat((x0, x1), 1)
return x
def find_cat_inputs():
net = Net()
inp = mge.Tensor(np.random.random(size = (1,3, 16, 16)), dtype=np.float32)
traced_net = tm.trace_module(net, inp)
flattened_net = traced_net.flatten()
cat_expr = flattened_net.graph.get_function_by_type(F.concat).as_unique()
print(cat_expr)
# _orig_name 包含了其是由哪个 builtin 的 module 输出的信息
print([n._orig_name for n in cat_expr.inputs])
"""
%8: concat_out = tensor.concat((bn_out, conv_bn_out), 1, )
['bn_out', 'conv_bn.1_out']
"""
if __name__ == "__main__":
find_cat_inputs()
Conv 和 BN 融合#
在 推理 或 量化训练 时,常常需要将 Conv 和 Bn 融合到一起,基于 TracedModule 的 Graph 可以找到满足融合条件的 Conv 和 Bn,并以图手术的方式将其融合,如下面的例子。
将 BN 融合到 Conv 中
import numpy as np
import pickle
import megengine.functional as F
import megengine.module as M
import megengine.module.qat as Q
import megengine.traced_module as tm
from megengine.traced_module.expr import CallMethod
from megengine.traced_module.node import ModuleNode
import megengine as mge
class Net(M.Module):
def __init__(self,):
super().__init__()
self.conv = M.Conv2d(3,16,1, bias=False)
self.bn = M.BatchNorm2d(16)
self.conv_bn = M.Sequential(
M.Conv2d(16,16,1,bias=False),
M.BatchNorm2d(16)
)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = F.relu(x)
x = self.conv_bn(x)
return x
def _fuse_conv_bn(conv : M.Conv2d, bn : M.BatchNorm2d = None):
weight, bias = conv.weight, conv.bias
target_cls = M.ConvBn2d
if not conv.training:
class FakeCls:
def __init__(self, conv, bn):
self.conv = conv
self.bn = bn
def apply_quant_weight(self, inp):
return inp
weight, bias = Q.ConvBn2d.fold_weight_bias(
FakeCls(conv, bn),
bn.running_mean,
bn.running_var
)
target_cls = M.Conv2d
this_module = target_cls(
conv.in_channels,
conv.out_channels,
conv.kernel_size,
conv.stride,
conv.padding,
conv.dilation,
conv.groups,
conv.bias is not None,
conv.conv_mode,
conv.compute_mode,
name=conv.name,
)
if conv.training:
this_module.conv.weight = weight
this_module.conv.bias = bias
this_module.bn = bn
else:
this_module.weight = weight
this_module.bias = bias
return this_module
def fuse_bn_transform():
net = Net()
inp = mge.Tensor(np.random.random(size = (1,3, 16, 16)), dtype=np.float32)
traced_net = tm.trace_module(net, inp)
obj = pickle.dumps(traced_net)
traced_net = pickle.loads(obj)
graph = traced_net.graph
for conv_node in graph.get_module_by_type(M.Conv2d).as_list():
if len(conv_node.users) > 1:
continue
conv_expr = conv_node.users[0]
conv_out_node = conv_expr.outputs[0]
if len(conv_out_node.users) > 1:
# conv -> bn,conv 的输出只能被 bn 使用
continue
# 判断 conv 之后的 expr 是否是 bn
bn_expr = conv_out_node.users[0]
if not isinstance(bn_expr, CallMethod):
continue
bn_node = bn_expr.inputs[0]
if not isinstance(bn_node, ModuleNode) or bn_node.module_type != M.BatchNorm2d:
continue
conv_module = conv_node.owner
bn_module = bn_node.owner
new_module = _fuse_conv_bn(conv_module, bn_module)
cur_graph = conv_node.top_graph
self_node = cur_graph.inputs[0]
self_module = self_node.owner
name = conv_module._name
# 将 fuse 后的 module 设置到 调用 conv 的 module 上
setattr(self_module, conv_module._name, new_module)
inp_node = conv_expr.inputs[1]
bn_out_node = bn_expr.outputs[0]
# 将 fuse 后的 module 以图手术的方式 insert 到 graph 中
with cur_graph.insert_exprs():
fused_conv_out = getattr(self_node, name)(inp_node)
cur_graph.replace_node({bn_out_node: fused_conv_out})
cur_graph.compile()
gt = net(inp)
actual = traced_net(inp)
np.testing.assert_allclose(gt.numpy(), actual.numpy(), atol=5e-2)
if __name__ == "__main__":
fuse_bn_transform()