T5语言模型自从首次在Hugging Face Transformers中亮相以来就颇受欢迎。有很多人一直不断要求能够在float16精度下运行T5模型。

迄今为止,T5语言模型只能在支持bfloat16格式的硬件上运行,该格式是模型最初进行训练所采用的格式。这限制了T5的使用范围,仅适用于一些特定的CPU、TPU(v2及以上版本)和GPU(A100及以上型号)。
与使用float32精度相比,使用float16精度往往更佳,原因是在使用float32精度时往往会超出硬件存储限制,或者执行时间会变得过长。
伴随FLAN-T5的发布,我们非常高兴能够支持这些模型在IPU上运行——使用float16精度(见本次推送头条)。  
在本文中,我们非常高兴地向大家介绍我们的FLAN-T5 IPU解决方案。尽管该解决方案是专门为T5模型开发的,但这些方法可重复使用,在类似场景中为您提供帮助。
在IPU上将T5转换
为float16精度
识别计算图中的动态部分
在运行模型之前,我们需要对模型代码进行快速检查,以查找那些无法编译成静态图的部分。我们在T5Block中发现了图中的动态分支。同时,如果数据在float16精度时已经溢出,那么所创建的分支就会截断数据:
# clamp inf values to enable fp16 training 
if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): 
    clamp_value = torch.finfo(hidden_states.dtype).max - 1000 
    hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) 
我们选择将动态条件torch.isinf(hidden_states).any()从这个分支*中移除,因为:
  • 我们无法静态编译这个动态分支条件
  • 虽然对隐藏状态进行截断只是处理float16问题的一种临时措施,但在训练过程中仍然是必要的,因此无法完全移除。有关我们在推理过程中如何处理问题根源的详细信息,请参阅“前馈的向下投影”部分。
*这个改变也已经在最新版本的Transformers中进行了更新。
启用Poplar的浮点异常检测
我们的Poplar后端已内置浮点异常检测功能,这使得追踪数值问题的源头变得更加直观明了。具体过程包括以下步骤:
  1. 在应用程序中启用浮点异常。在PopTorch中,您可以使用:opts.Precision.enableFloatingPointExceptions(True)。(更多信息请参考PopTorch用户指南[1]
  2. 运行应用程序时启用图分析功能:POPLAR_ENGINE_OPTIONS='"{"autoReport.all":"true", "autoReport.outputExecutionProfile": "false", "autoReport.directory":"./report"}'。更多详情,请参阅图分析工具用户指南[2]中的捕获IPU报告[3]部分
  3. 如果触发了浮点异常,将生成一个poptorch_error.log文件。打开此文件并向下滚动到(或搜索)Backtrace。找到回溯顶部附近的ID,用(Id: 1234)表示,并在图分析的程序树中搜索该ID。从这里,您应该能够查看有问题的操作的调试信息,并确定它在模型中的位置。
*请注意,我们使用"autoReport.outputExecutionProfile": "false"来避免执行性能分析的开销[4],因为我们只关注程序树的信息。
利用这种方法,我们解决了其余的浮点异常。
解决浮点异常问题
注意力掩码
前两个异常情况被发现在注意力掩码中。在这两个地方注意力掩码被“倒置”并叠加使用。掩码值被设定为-torch.finfo(torch.float16).min(即-65504),而传值设定为0。这样做是为了使掩码的注意力值传递到softmax时,在输出结果中它们的相关性最小化。然而,如果掩码的是负数且绝对值大于float16的最小分辨率-65504,那么最终会得到一个负无穷大值:
>>> torch.tensor([-65504], dtype=torch.float16) - 10 
tensor([-65504.], dtype=torch.float16) 
>>> torch.tensor([-65504], dtype=torch.float16) - 100 
tensor([-inf], dtype=torch.float16) 
我们通过简单地将掩码缩减25%来解决这两个异常情况,这意味着您可以有低至-16376的注意力值,并且掩码不会导致溢出。
由tanh近似的GeLU
第三个异常是在FLAN-T5模型所使用的tanh GeLU近似的显式定义中发现的(原始T5模型使用ReLU激活)。该公式
0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0)))) 
对输入进行立方运算,如果输入的绝对值大于大约39,就会导致溢出。我们通过在输入大于39时恢复到ReLU来解决这个问题,这是一个安全的近似值,因为当输入的绝对值大于5时,ReLU==GeLU。
Pre-norm的残差连接
第四个异常是在编码器的FF层中发现的残差添加。我们看到,当FF网络的输出被添加到其输入时,该操作就溢出了。我们通过以下方式解决了这个问题:
  1. 将输入到第一个编码器块的嵌入数据转换为float32

  2. 对于每个编码器块中的自注意力和前馈:
    1. 在LayerNorm*之后转为float16,这样大部分的计算仍是float16
    2. 在添加到float32剩余连接之前、dropout之后,将其转为float32

  3. 在所有的编码器块之后,将final_layer_norm*的输出转为float16,准备用于解码器,而解码器都是float16
*T5的LayerNorm的实现方式使得其实际上是自动发生的
下面的图表用以下颜色编码来表示数据的精确度:
T5编码器由一连串的块组成,每个块包含一个自注意力层和一个前馈层:
其中的每一层都有相同的基本结构,唯一不同的是注意力/隐藏层:
在上面第2步中提到的类型转换之后,这些层看起来像这样:
这可以防止(能通过编码器一路传递的)pre-norm残差的溢出。
前馈的向下投影
最后一个浮点异常是在编码器前馈层隐藏部分的向下投影中发现的。在代码中,这是wo层,为了清楚起见,我们将其称为DownProject。目前,前馈层及其隐藏部件看起来是这样的:
通过对DownProject的输入进行缩小处理,然后在其安全地转换回float32后将输出进行放大,我们成功解决了溢出问题。
通过分析DownProject输出激活的标准差,并确定一个适合这些激活的值作为2的幂次,我们就选出了一个缩放因子。我们选择使用2的幂次是因为这样只需要改变float16激活的指数,而不必对尾数进行有损修改。
我们发现标准差为~4400,因此我们选择了8作为缩放因子,将标准差减小到~550。在应用了这个缩放后,前馈层及其隐藏部件的效果如下所示:
最新版本的Transformers对这个问题的解决方法是始终将这个层保持在float32。
验证
由于我们对模型进行了一些更改,您可能会想确认该模型是否仍然能够按预期运行。为此,我们在CPU上使用float32,在IPU上使用float16,在MMLU基准测试的一个子集上对其进行了验证。结果显示,CPU和IPU分别达到了整体平均值49.3%和49.4%,证明我们没有降低原始模型的性能。
*我们目前的FLAN-T5-XL实施最大输入长度为896个标记,所以我们此处使用的MMLU子集,其样本没有超过这个长度。
结论
现在,我们就拥有了可以在IPU上以float16进行推理的FLAN-T5-XL的实施。您还可以前往Paperspace,亲身体验更多精彩。
[1]https://docs.graphcore.ai/projects/poptorch-user-guide/en/latest/index.html
[2]https://docs.graphcore.ai/projects/graph-analyser-userguide/en/latest/index.html
[3]https://docs.graphcore.ai/projects/graph-analyser-userguide/en/latest/capturing-ipu-reports.html
[4]https://docs.graphcore.ai/projects/graph-analyser-userguide/en/latest/capturing-ipu-reports.html#profiling-overhead
获取更多Graphcore资讯,阅读深度技术文章,并与其他创新者们一起交流,请至中国官网graphcore.cn,以及关注Graphcore微信、微博和知乎创新社区。
Graphcore中国官网
Graphcore官方微信
Graphcore微博创新社区
Graphcore知乎创新社区
点击阅读原文,查看英文blog。
继续阅读
阅读原文