本文介绍新一代 Kaldi 中模型平均:
相关代码:https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless4

原来的模型平均策略

使用过 icefall 的用户可能会注意到,使用 decode.py 解码时,除了需要提供参数 --epoch 来指定要读取的模型外,还会搭配参数 --avg 来实现模型平均。
在 icefall 的训练过程中,我们会在每个 epoch 结束时保存当前的模型,为文件 epoch-*.pt 。这么做一方面是为了方便用户在训练中断时继续训练,另一方面则是为了在解码时实现模型平均。
例如,使用 decode.py 解码时,指定参数 --epoch 24 --avg 3, 我们会读取文件 epoch-24.ptepoch-23.ptepoch-22.pt 所保存的模型,利用它们求得一个平均模型来解码,如下图所示。

模型平均的目的

在使用随机梯度下降算法优化模型参数的过程中,得到的每个模型,可以看作是在参数空间中,朝着优化目标函数的方向探索的不同采样点。
在训练的后期,模型参数采样点会在最优点附近浮动。通过在参数空间,对这些在最优点附近浮动的采样点求均值,可以得到一个噪声(随机性)更低的模型,即更加接近最优点的模型。

存在的局限性

  • 每个 epoch 采样一个模型,中间经过了很多个 batch,这种采样方式可能过于稀疏。
  • 如果每个 epoch 对训练数据的遍历顺序不是随机的,那么每个 epoch 结束时所保存的模型采样点,可能会“记住”了数据遍历的顺序。使用这些采样点来进行模型平均,显然不是我们所期望的。

潜在的解决方案

为了解决上述问题,我们可以缩小采样间隔,实现更加密集的模型平均,例如每隔 100 个 batch 采样一个模型。这种做法的好处在于:
  • 通过使用更密集的采样点,可以得到一个噪声(随机性)更低的平均模型。
  • 用来作模型平均的采样点,覆盖了每个 epoch 中数据遍历的不同阶段,这样就不用担心前面提到的模型可能会“记住”了数据遍历顺序的问题。
然而,这种做法可能会占用更多的存储空间和 I/O :
  • 训练时,我们需要保存大量的模型,例如每 100 个 batch 保存一个;
  • 解码时,为了实现模型平均,我们需要读取大量的模型文件。

改进后的模型平均策略

为了解决上述的问题,icefall 中采用了一种巧妙的策略,来实现更加密集的模型平均。

训练中维护平均模型 

在训练的过程中,除了当前的模型 外,还维护着一个平均模型 。 表示从训练开始,每隔 个 batch 作一个模型采样,并对所得到的样本集合取均值的结果。
假设从训练开始,到当前为止已经采样 次,此时所维护的平均模型 为:
式中, 为采样的 batch 下标集合。
在进行第 次采样时,我们结合当前的模型 ,更新平均模型,得到 :
感兴趣的同学可以简单证明,更新后的平均模型 ,就是 batch 下标集合 所对应的模型采样点的均值,即 :
在 icefall 实现中, 对应的是 train.py 中的参数 --average-period,其默认值为 100,即每 100 个 batch 作一次采样。
可参考代码:https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless4/train.py
在每个 epoch 结束时,我们在文件 epoch-*.pt中,除了保存当前的模型之外,也会保存当前所维护的平均模型 。
  • 注意,我们并没有每隔 100 个 batch 保存一个模型,保存的文件 epoch-*.pt个数并没有增多。

解码时利用平均模型 

解码时,我们可以基于训练过程中所保存的平均模型 ,以一种节省 I/O 的方式,实现模型平均。
假设 epoch- 和 epoch- 分别表示我们保存的两个模型, :
  • epoch- 保存着平均模型 ,其经过了 次采样;
  • epoch- 保存着平均模型 ,其经过了 次采样。
为了获取在 epoch- 和 epoch- 中间的平均模型,即第 次采样到第 次采样之间样本点的均值 , 我们只需要根据 和 两个模型计算得到:
例如,使用 decode.py 解码时,指定参数 --epoch 24 --avg 3 --use-averaged-model 1, 我们会读取文件 epoch-24.ptepoch-21.pt 中分别保存的平均模型 ,利用这两个模型,求得在这个区间内的每隔 100 个 batch 作模型采样的平均模型,并依靠该平均模型来解码,如下图所示:
可参考代码:https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py
如此一来,解码时我们只需要加载两个模型,便能获得采样点更加密集,即噪声(随机性)更低的平均模型。

实验结果

我们在 full librispeech 数据集上使用 Reworked Conformer 训练 20 个 epoch,然后比较以下三种方式,在 test-cleantest-other 两个测试集上,对解码结果的影响:
  • epoch-20:使用 epoch-20 的模型解码,即指定参数 --epoch 20
  • epoch-20-avg-5:使用 epoch-16 ~ epoch 20 这个五个模型的平均模型解码,即指定参数 --epoch 20 --avg 5
  • epoch-20-avg-5-use-averaged-model:使用 epoch-15 ~ epoch 20 这个区间内每隔 100 个 batch 作采样的平均模型解码,即指定参数 --epoch 20 --avg 5 --use-averaged-model 1
Decoding modelWER on test-clean (%)WER on test-other (%)
epoch-203.348.18
epoch-20-avg-52.937.1
epoch-20-avg-5-use-averaged-model2.827.0
由表可以看出,相比较于使用原来的模型平均策略(每个epoch采样一次),使用改进后的模型平均策略(每隔 100 个 batch 采样一次),可以在识别性能上获得比只用单个模型更显著的性能提升。
注意:本文没有讨论文件 checkpoint-*.pt,与文件 epoch-*.pt 的区别在于,它是默认每 8000 个 batch 保存一个文件,方便用户通过指定参数 --iter 使用 epoch 中间的模型。
请注意区分 checkpoint-*.pt 和上述的每 100 个 batch 更新一次平均模型的区别。上述的模型平均策略同样应用于文件 checkpoint-*.pt

往期文章

本文出品:新一代 Kaldi-NGK编辑部
撰文:蛋哥的 Zengwei Yao
(如果你也想成为蛋哥的某某,我们欢迎大家给新一代Kaldi投稿

投稿邮箱地址:[email protected]
小编将送出蛋哥亲自手绘的新一代KaldiT恤~)
继续阅读
阅读原文