低时延 RNN-T 训练
Dan神又出新东西了!(前面的都学完了吗?)本文主要介绍k2
中的低时延RNN-T
训练,这是一篇短小的写给懒人的科普文,不会有详细的理论推导,感兴趣的大佬可以直接阅读论文:https://arxiv.org/pdf/2211.00490.pdf
流式模型与时延
这里说的时延主要是指由模型本身带来的输出延迟,比如一个字是在第 100 帧说的,但是直到送了 150 帧数据进去才输出来。时延问题可以说是端到端模型基因里带来的缺点,一个大家都比较认可的解释是,
RNN-T/CTC
这样基于序列的损失函数对于 Alignments
的优化是无差别的,即只管优化能输出 transcript 对应的路径,不管这个路径是先输出 symbol
还是先输出 blank
。所以,对于流式模型的训练,由于当前看到的 context
有限,模型总是倾向于看到更多的 context 后再决定是否输出 symbol
。从图一中可以获得一个感性的认识,图中从上到下三条线分别代表:没有使用时延正则的流式模型,使用了时延正则的流式模型和非流式模型在训练过程中的时延曲线。
可以看到,非流式模型在训练过程中的时延几乎是不变的,而且由于能看到全部
right context
,时延是很低的。而对于流式模型,可以看到随着模型优化得越来越好,时延反而越来越大,这也从侧面验证了模型倾向于看更多的数据来提高输出置信度。中间那条线是使用了我们提出的时延正则的流式训练,可以看到时延是随着模型的优化持续降低。
路径与时延
对于
RNN-T
的训练 lattice
是一个 U * T
的矩阵(如图二所示),理论上从左下角到右上角的所有路径都是合法的路径,由于向上是输出 symbol
,向右是输出 blank
,所以偏向左上角的路径的时延要小于偏向右下角的路径,即图中红色路径的时延比蓝色路径的时延低。时延正则
时延正则的目标是给低时延的路径一些鼓励(加分),给高时延的路径一些抑制(减分)。此处有非常详细的理论推理,十几个公式,这里就不展开了,感兴趣的大佬可以读原论文(链接见文章开头)。最终的实现就是给
lattice
中每条输出 symbol
的边加一个分数,这个分数根据边所在的帧而不同,以中轴线为基准,左侧加正值
(鼓励),右侧加负值
(惩罚),示意图如图三所示。这样位于左上角的路径的分数得到增强,位于右下角的路径分数会被抑制,从而达到降低时延的目的。实验及结果
目前 k2[1] 和 fast_rnnt[2] 两个仓库都已经合并了
delay-penalty
的实现(见 delay-penalty[3]),只需要在使用 pruned rnnt
损失函数时多传入一个 delay_penalty
参数就可以实现低延时的 RNN-T
训练(注意:rnnt_loss_smoothed
和 rnnt_loss_pruned
两个地方都要加)。我们在 Streaming Conformer 和 LSTM 上都做了一些实验,结果证明我们提出的时延正则方法很有效果,并且能简单的通过调整超参数来平衡准确率和时延。结果中的 MAD
表示 token 的平均时延,MED
表示最后一个 token 的平均时延,时延都是根据 Montreal-Forced-Aligner[4] 对齐结果来计算的。我们还对比了使用不同
chunk size
解码的结果,chunk
解码本身就会带来时延,chunk size
越大,带来的时延越大。下图是不同 chunk size
解码的准确率和时延的关系图(这里的时延为总时延,即 chunk_size / 2 + MAD
). 可以看出,使用大些的 chunk size
,在相同时延情况下,可以取得更好的准确率。另外,说起时延不得不提 Google 提出的 FastEmit[5], 我们也与
fast emit
做了对比,发现结果不相上下,有时略好。不过我们相信我们的方法有一个更清晰的理论解释(比如考虑了 symbol
输出的时间信息)。当然,如果执念要使用
FastEmit
,我们在 k2
中也提供了实现,见 k2 FastEmit[6],合并是不会合并的, 欢迎尝试。总结
本文介绍了
新一代 Kaldi
中提出的低时延 RNN-T
训练的方法,粗略介绍了时延产生的原因,阐明了我们做时延正则的方案,欢迎大家尝鲜!欲知更多细节和推导,请阅读原论文:https://arxiv.org/pdf/2211.00490.pdf。往期文章
新一代 Kaldi 基于 WebSocket 的语音识别服务实战sherpa + ncnn 进行语音识别
如何在icefall中玩转预训练模型
本文出品:新一代 Kaldi-NGK 编辑部 撰文:蛋哥的 pkufool
参考资料
k2: https://github.com/k2-fsa/k2
[2]fast_rnnt: https://github.com/danpovey/fast_rnnt
[3]delay-penalty: https://github.com/k2-fsa/k2/pull/976
[4]Montreal-Forced-Aligner: https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner
[5]FastEmit: https://arxiv.org/pdf/2010.11148.pdf
[6]k2 FastEmit: https://github.com/k2-fsa/k2/pull/1069
关键词
模型
结果
论文
流式模型
更多
最新评论
推荐文章
作者最新文章
你可能感兴趣的文章
Copyright Disclaimer: The copyright of contents (including texts, images, videos and audios) posted above belong to the User who shared or the third-party website which the User shared from. If you found your copyright have been infringed, please send a DMCA takedown notice to [email protected]. For more detail of the source, please click on the button "Read Original Post" below. For other communications, please send to [email protected].
版权声明:以上内容为用户推荐收藏至CareerEngine平台,其内容(含文字、图片、视频、音频等)及知识版权均属用户或用户转发自的第三方网站,如涉嫌侵权,请通知[email protected]进行信息删除。如需查看信息来源,请点击“查看原文”。如需洽谈其它事宜,请联系[email protected]。
版权声明:以上内容为用户推荐收藏至CareerEngine平台,其内容(含文字、图片、视频、音频等)及知识版权均属用户或用户转发自的第三方网站,如涉嫌侵权,请通知[email protected]进行信息删除。如需查看信息来源,请点击“查看原文”。如需洽谈其它事宜,请联系[email protected]。