©作者 | 董春玉
单位 | OpenMMLab
研究方向 | 深度视觉
问题描述
MMOCR在 MMDeploy中部署时,PANet模型在以 TensorRT-fp16 为后端的情况下会有精度损失。hmean-iou 由原本的 0.8- 掉点到 0.2-。此时需要相应的 debug 查找问题原因。

排除法查找节点

首先请教了有相关经验的同事,被告知一般只能二分查找,没有更方便的工具。此外,如果 ONNX 中有 reduce_sum 和 avg_pool 节点,可以重点尝试排查。经过 netron 可视化 ONNX 文件,发现 PANet 中无以上两种节点。没办法,只能笨办法慢慢查。
这时候,我首先想到的是,骨干网络应该不会出现数值溢出问题。而 PANet 中与其他常用的无此相关问题的模型的差异在于其 neck 部分是,FPEM-FFM 结构。会不会是在这个结构中产生问题的呢?所以,我没有直接使用二分查找。

2.1 卷积层映射

具体深入查看 FPEM-FFM 结构,发现有其他常用网络不用的卷积,深度可分离卷积。这时,很直觉的判断就是这个,所以,这时候我对该结构内的层采用的应对 FP16 数值溢出常用的解决方法。计算前减小数值,计算后恢复原来大小。
那么对于卷积而言,一次卷积运算可以认为是:输入 x,输出 conv(x) = wx+b,其中 W 是权重 weight,b 是偏移量 bias。如果我想要保证卷积过程无数值溢出情况,只需要将输入减小,计算完成后再将结果映射成真实值。
实际卷积的计算过程变成:conv(x/p) = w(x/p)+b。卷积完成后,先乘以 p 再减去 (p-1)b 即可。亦即:p*conv(x/p)-(p-1)b = wx+b。很自然地,对于一个卷积层 conv 和输入张量 x,进行如下变换:
defwarp_conv(x, conv, factor: int=32):
"""(W*(x/p)+b)*p-b*(p-1) == Wx+b

    conv(x) == warp_conv(x, conv)

    """

    x_tmp = conv(x / factor)

return
 factor * x_tmp - (factor - 
1
) * conv.bias.reshape(

1
-1
1
1
).repeat(
1
1
, x_tmp.size(
2
), x_tmp.size(
3
))

原始计算过程 conv(x) 被等效成 warp_conv(x, conv),并且,等效过程不会出现数值溢出。

2.2 归一化层映射

除了可能的卷积层中数值溢出外,归一化层出现问题的几率更高。归一化层的计算过程如下:
# input:x, output:out
w
 = weight / torch.sqrt(running_var + eps)

out
 = x * w + (bias - running_mean * w)

可以发现,归一化层 bn 的计算过程实际也是个线性运算:bn(x) = wx+b,其中 w = weight / torch.sqrt(running_var + eps),而 b = bias-running_mean * w。那么类似卷积层的一个映射函数可以是:
defwarp_bn(x, bn, factor: int=32):
import
 torch

    scale = bn.weight / torch.sqrt(bn.running_var + bn.eps)

    bias = bn.bias - bn.running_mean * scale

    bias_t = bias.reshape(
1
-1
1
1
).repeat(
1
1
, x.size(
2
), x.size(
3
))

return
 bn(x / factor) * factor - (factor - 
1
) * bias_t

2.3 验证

有了以上的分析,直接将 FPEM-FFM 结构中的所有的卷积和归一化层都替换一下。这样导出的模型再转换成 TensorRT 模型,进行精度测试后发现结果有很大提升,由原来的 0.2- 提升到 0.6+。看起来验证是成功的,必然是其中的某个层出现了数值溢出,替换后可以进行 fp16 推理。可是,为什么精度没有对齐到 fp32 呢?难道上述的替换并不完全等效?为此我进行了如下实验:
print
(conv(x) - warp_conv(x, conv).sum())

print
(bn(x) - warp_bn(x, bn).sum())

发现对于卷积层,并无数值误差,而对归一化层,出现了一定的误差。而且,bn 层的误差是随着映射函数中的变量 factor 增大而增大。“一定是多个归一化层累计效应,同时 factor 不应该设置太大”,我想。
就这样,我减小了 factor 的大小,同时逐一验证是具体哪个归一化层有问题。最后却没有定位到一个具体的归一化层产生数值溢出。一时间,整个 debug 似乎卡住了,没法继续推进了,似乎只能二分法?

整理思路

3.1 数值溢出的可能形式

不考虑 TensorRT 做层融合等情况,只考虑 PyTorch 做推理,数值溢出的可能形式有哪些呢?
  • 某个算子内部计算过程数值溢出,输入输出均可以用 fp16 表示
  • 跨内部连续多个算子出现数值溢出
  • 整个网络计算过程都有数值溢出
首先排除最后一种情况,一般的,输入是图像归一化后的结果,而输出一般需要对应到标签或者标签相关的值。所以,输入和输出都不可能出现数值溢出。而第一种情况种,PANet 只有 FPEM-FFM 结构是迥异于其他模型的,替换其中的卷积层和归一化层都不能解决问题的话,数值溢出只剩下第二种情况了。

3.2 多算子数值溢出

考虑到连续多个算子计算过程均出现数值溢出的可能,似乎直接进行单个算子映射已经无法解决问题了。同时,具体从哪个算子出现问题,我们仍然不知道。难道还是要寻求二分法的帮助?

寻找fp16失效的算子

没辙,只好二分法吧。可是如何二分法其实也很有考究。可行的 debug 方式:
  1. 每次提前返回结果,二分地导出 ONNX 再导出 TensorRT 模型,未被导出的部分继续以 PyTorch 代码衔接到 TensoRT 的计算结果后。
  2. 直接运行 PyTorch 模型,设置断点,查看哪些计算过程有数值异常地大。
第一种方法最为精准,肯定是可以找到具体的节点的。但是过程非常繁琐,同时需要大量的测试代码。第二种方法最为直接,但是也同样繁琐,因为一个图的节点太多了。要断点查看的话,可能需要很久。而且不能保证结果一定找到,可能存在疏漏。
第二种方法需要指导 fp16 数值表示的变化范围。IEEE 标准中,fp32 的取值范围 是1.4e-45 至 3.4e38,fp16的取值范围是 5.96e-8 ~ 65504。也就是说,fp16 的最大值不超过 65504。

4.1 断点调试

权衡了下,选择断点查看是否有数值异常。最后找到了异常的起始位置为骨干网络输出的特征层的最后一层,里面有大量的点的数值大于 65504。随后再逐步往前找到开始出现异常的节点,同时再逐步往后,找到异常结束的节点。结果发现起始位置在最后一层的最后一次归一化层,结束位置在 FPEM-FFM 结构的某个归一化完成后。

4.2 优化

上面的方法仍然略显繁琐,有没有不用手动调试,直接用程序查找的方式呢?
之前的断点调试本质是寻找图计算过程中,内部嵌套的 nn.Module 的推理过程中的张量是否超出 65504 的点罢了。我们完全可以用代码查找。直接对模型的每一层进行遍历,查看是否有异常点。
一个可能的实现方式是利用 PyTorch 的钩子(hook)技术。所谓的钩子技术其实并非 PyTorch 所独有,事实上,很多的软件架构都有提供。钩子技术是在某个事件执行完成后,自动执行的函数。也就是说,我们只要在网络的每个层设置一个钩子,在该层推理完成后再对输入和输出进行查找是否有异常点即可。
from pyclbr import Function

from typing import Sequence

import torch



def fp16_check(module: torch.nn.Module, 
input
: torch.Tensor, 
output
: torch.Tensor) -> None:

if
 isinstance(
input
, dict):

for
 _, value 
ininput
.items():

            fp16_check(module, value, 
output
)

return
if
 isinstance(
input
, Sequence):

for
 value 
ininput
:

            fp16_check(module, value, 
output
)

return
if
 isinstance(
output
, dict):

for
 _, value 
inoutput
.items():

            fp16_check(module, 
input
, value)

return
if
 isinstance(
output
, Sequence):

for
 value 
inoutput
:

            fp16_check(module, 
input
, value)

return
if
 torch.
abs
(
input
).
max
()<
65504and
 torch.
abs
(
output
).
max
()>
65504
:

print
(
'from: '
, module.finspect_name)

if
 torch.
abs
(
input
).
max
()>
65504and
 torch.
abs
(
output
).
max
()<
65504
:

print
(
'to: '
, module.finspect_name)

return



from contextlib import contextmanager

class FInspect:

    module_names = [
'model'
]

    handlers = []


    def hook_all_impl(cls, module: torch.nn.Module, hook_func: Function)-> None:

for
 name, child 
in
 module.named_children():

            cls.module_names.append(name)

            cls.hook_all_impl(cls, module=child, hook_func=hook_func)

        linked_name=
'->'
.join(cls.module_names)

        setattr(module, 
'finspect_name'
, linked_name)

        cls.module_names.pop()

        handler = module.register_forward_hook(hook=hook_func)

        cls.handlers.append(handler)


    @classmethod

    @contextmanager

    def hook_all(cls, module: torch.nn.Module, hook_func: Function)-> None:

        cls.hook_all_impl(cls, module, hook_func)

yield
        [i.
remove
() 
for
 i 
in
 cls.handlers]


with FInspect.hook_all(patched_model, fp16_check):

    patched_model(inputs)

尝试映射

整个异常过程可以表达成:bn(conv(relu(bn(x) + residual)))。对于这样一个过程,能否通过类似上述的对卷积层和归一化层的处理方式解决数值溢出问题呢?

5.1 relu一生之敌

我们知道,一个初始值,经过多个线性运算后,结果可以用一次线性运算还原。比如:w2(w1x + b1) + b2 = w1w2x + w2b1 + b2,结果还是个线性运算。这也是为什么神经网络需要激活函数————不然多个线性层的结果等效于一个线性层。
那么,上面的异常过程可以简化成 w2(relu(w1x + b1)) + b2。其中 w1, w2, b1, b2 都可以计算出来。可是,relu 激活函数可以在输入缩放减小后,对输出进行还原得到吗?我们都知道,relu 函数有个很好的性质,那就是 relu(px) = p*relu(x)。除此之外,再难有其他性质可以被利用,以对抗数值溢出。可是,如果计算过程是 relu(wx + b),缩放输入 x 以后得到的结果 relu(wx/p + b) 不能再简单地恢复到 relu(wx + b)。
不过,我们可以通过整体缩放 relu 的输入,将计算过程变成 relu((wx + b)/p),这样再乘以 p 就可以恢复成 relu(wx + b) 了。如此,relu 函数也可以绕过,对抗数值溢出。

5.2 公式

那么对于原公式 w2(relu(w1x + b1)) + b2,就可以映射成 p*w2(relu((w1x + b1)/p)) + b2 - (p-1)*b2。

5.3 实施

在 MMDeploy 框架中,只要利用函数重写劫持掉原始的 PANet 中的两段函数
mmocr.models.textdet.necks.FPEM_FFM.forward
mmdet.models.backbones.resnet.BasicBlock.forward
即可。在劫持的函数里,替换原有的计算过程,将上述的映射实现即可。
import torch

import torch.nn.functional as F


from mmdeploy.core import FUNCTION_REWRITER

from mmdeploy.utils.constants import Backend


FACTOR = 
32
ENABLE = False

CHANNEL_THRESH = 
400


@FUNCTION_REWRITER.register_rewriter(

    func_name=
'mmocr.models.textdet.necks.FPEM_FFM.forward'
,

    backend=Backend.TENSORRT.value)

deffpem_ffm__forward__trt(ctx, self, x, *args, **kwargs)
:

    c2, c3, c4, c5 = x

# reduce channel
    c2 = 
self
.reduce_conv_c2(c2)

    c3 = 
self
.reduce_conv_c3(c3)

    c4 = 
self
.reduce_conv_c4(c4)


ifENABLE:
        bn_w = 
self
.reduce_conv_c5[
1
].weight / torch.sqrt(

self
.reduce_conv_c5[
1
].running_var + 
self
.reduce_conv_c5[
1
].eps)

        bn_b = 
self
.reduce_conv_c5[

1
].bias - 
self
.reduce_conv_c5[
1
].running_mean * bn_w

        bn_w = bn_w.reshape(
1
, -
1
1
1
).repeat(
1
1
, c5.size(
2
), c5.size(
3
))

        bn_b = bn_b.reshape(
1
, -
1
1
1
).repeat(
1
1
, c5.size(
2
), c5.size(
3
))

        conv_b = 
self
.reduce_conv_c5[
0
].bias.reshape(
1
, -
1
1
1
).repeat(

1
1
, c5.size(
2
), c5.size(
3
))

        c5 = FACTOR * (
self
.reduce_conv_c5[
:-1
](c5)) - (FACTOR - 
1
) * (

            bn_w * conv_b + bn_b)

        c5 = 
self
.reduce_conv_c5[-
1
](c5)

else:
        c5 = 
self
.reduce_conv_c5(c5)


# FPEM
for
 i, fpem 
in
 enumerate(
self
.fpems):

        c2, c3, c4, c5 = fpem(c2, c3, c4, c5)

if
 i == 
0
:

            c2_ffm = c2

            c3_ffm = c3

            c4_ffm = c4

            c5_ffm = c5

else:
            c2_ffm += c2

            c3_ffm += c3

            c4_ffm += c4

            c5_ffm += c5


# FFM
    c5 = F.interpolate(

        c5_ffm,

        c2_ffm.size()[-
2:
],

        mode=
'bilinear'
,

        align_corners=
self
.align_corners)

    c4 = F.interpolate(

        c4_ffm,

        c2_ffm.size()[-
2:
],

        mode=
'bilinear'
,

        align_corners=
self
.align_corners)

    c3 = F.interpolate(

        c3_ffm,

        c2_ffm.size()[-
2:
],

        mode=
'bilinear'
,

        align_corners=
self
.align_corners)

    outs = [c2_ffm, c3, c4, c5]

return
 tuple(outs)



@FUNCTION_REWRITER.register_rewriter(

    func_name=
'mmdet.models.backbones.resnet.BasicBlock.forward'
,

    backend=Backend.TENSORRT.value)

defbasic_block__forward__trt(ctx, self, x)
:

ifself
.conv1.in_channels < 
CHANNEL_THRESH:
return
 ctx.origin_func(
self
, x)


    identity = x


    out = 
self
.conv1(x)

    out = 
self
.norm1(out)

    out = 
self
.relu(out)


    out = 
self
.conv2(out)


if
 torch.abs(
self
.norm2(out)).max() < 
65504
:

        out = 
self
.norm2(out)

        out += identity

        out = 
self
.relu(out)

return
 out

else:
        global ENABLE

        ENABLE = True

# the output of the last bn layer exceeds the range of fp16
        w1 = 
self
.norm2.weight / torch.sqrt(
self
.norm2.running_var +

self
.norm2.eps)

        bias = 
self
.norm2.bias - 
self
.norm2.running_mean * w1

        w1 = w1.reshape(
1
, -
1
1
1
).repeat(
1
1
, out.size(
2
), out.size(
3
))

        bias = bias.reshape(
1
, -
1
1
1
).repeat(
1
1
, out.size(
2
),

                                                out.size(
3
)) + identity

        out = 
self
.relu(w1 * (out / FACTOR) + bias / FACTOR)


return
 out

通过上述的重写函数,最后导出 PANet 模型可以媲美原始 PyTorch 模型,甚至略有超过(数值误差)。

5.4 总结

总结一下这篇博客,分享了一般的 FP16 数值溢出情况下的处理方式。
  • 一个快速查找数值溢出算子的方法。
  • 一个替换多个算子,从原始模型解决 FP16 数值溢出的方法。
更多阅读
#投 稿 通 道#
 让你的文字被更多人看到 
如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。
总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。 
PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析科研心得竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。
📝 稿件基本要求:
• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注 
• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题
• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算
📬 投稿通道:
• 投稿邮箱:[email protected] 
• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者
• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿
△长按添加PaperWeekly小编
🔍
现在,在「知乎」也能找到我们了
进入知乎首页搜索「PaperWeekly」
点击「关注」订阅我们的专栏吧
·
继续阅读
阅读原文