本文介绍新一代 Kaldi 项目的中函数 get_tot_scores(log_semiring=False) 梯度计算背后的科学原理。
文中作图代码对应的colab: https://colab.research.google.com/drive/1RKwfPVecn2e-sSSvAvZdYrXex21xMj9j#scrollTo=6o3lkwyjU_cx
本文灵感来源于一位热心群友反馈,对 k2 中 get_tot_scores(log_semiring=False)时的梯度计算方式进行简要分析。建议阅读本文时运行一下上述 colab 中的代码。
注意计算 ctc_loss 需要设置 log_semiring=True. 文中用到的 ctc_graph 仅仅是用来解释梯度计算/传递方式,计算出的结果并不是符合数学定义的 ctc_loss。

1. 群友总结

1. 因为我在看 k2 文档时发现,在autograd 阶段,

提取计算感觉有点粗暴,比如说在tropical semiring时,

梯度是不是有点粗暴了。

https://k2.readthedocs.io/en/latest/core_concepts/index.html
#autograd

2. 因为从公式部分理解也是如此。只有对应 arc 才会存在 grad

那么偏导应该是 1 ,其他部分是 0

该总结从
某种角度
上来说是正确的,但是不够全面,本文是对这一现象的解释和补充。

2. 梯度反向传播的基本原理


注意图中画红框的部分,它告诉我们梯度反向传递的起始值是 1. (这句话很重要,后面会用到),上图来自 Automatic Differentiation in Machine Learning: a Survey[1] 第13页。

3. 群友总结的“对应的 arc” 梯度为 1, 其它部分是 0" 现象背后的原因

函数 get_tot_scores(log_semiring=False),即 tropical semiring,计算梯度的底层核心代码为GetTotScoresTropicalBackward[2]
ans_grad_data[arc_idx012] =

    tot_scores_grad_data[fsas_idx0 * tot_scores_grad_stride];

可以看出,对于群友所说的“对应的 arc”的梯度其实是直接赋值为其所在的 fsa 的 tot_scores_grad。或者说在 tropical semiring 的条件下, get_tot_scores 会原封不动的把从损失函数方向传递过来的梯度直接赋值给 “对应的 arc”。
正如前文图中红框部分指出,在反向传递过程中,损失函数对最后一个节点的梯度是 1. 所以如果 get_tot_scores 刚好就是最后一个节点,其梯度则为 1。然后该梯度又被
原封不动地
传递给 “对应的 arc”。

(备注:所谓对应的 arc 是指参与计算 tot_scores 的 arc)

4. 代码中 tot_scores_grad 变量可以不是 1 吗?

当然可以。
如前文所述,梯度反向传播算法给最后一个节点的梯度自动赋值为 1。那只需要让 tot_scores 不是最后一个节点,就有可能使其梯度不为 1。
比如设想一个场景,我们在进行 multi task 的训练,同时使用了 ctc_loss 和 RNN_T loss, 我们想给 ctc_loss 加一个 0.15 的权重,则大概代码为:
# 声学模型输出
dense_fsa_vec = k2.DenseFsaVec(log_prob, supervision_segments)

# 对应的 transcript 构成的状态图
decoding_grpah = k2.ctc_graph([[1]])

# 生成 transcript 对应的 lattice
lat = k2.intersect_dense(decoding_grpah, dense_fsa_vec,output_beam=10000)

# 注意此处如果设置 log_semiring=True 才是 ctc_loss 的正确计算方式,此处设置为 False 仅为本文中介绍梯度传递使用。
ctc_scores=lat.get_tot_scores(log_semiring=False, use_double_scores=True)


# 乘上 0.15 的weight
weighted_ctc_scores = ctc_scores * 0.15


# 合并多个损失函数
scores = weighted_ctc_scores + rnnt_scores


# 其实要对 score 加个负号,此处我们仅关心梯度的传递,暂时忽略负号

# 反向传递梯度
scores.backward()
显然这种情况下 scores 对 ctc_scores 的梯度是 0.15, 即 get_tot_scores 函数接收到的梯度不再是 1,而是0.15, 自然各“对应 arc” 的梯度也被赋值为 0.15. 上面是大概伪码,文章开头的 colab 中有真实代码,其结果如下图所示,可以看出“对应 arc”的梯度是 0.15:

所以群友的总结可以拓展为:
当log_semiring=False时, “对应 arc” 的偏导是

原封不动赋值为 get_tot_scores 接收的梯度,

其它 arc 的梯度是 0。


只不过反向求导机制会自动把最后一个节点的梯度置为 1.

如果刚好 get_tot_scores 的结果是最后一个节点的话,

其接收到的就是这个 1,

随后这个 1 会传递给所有“对应的 arc”。

5 函数get_tot_scores(log_semiring=False)原封不动地往前传递梯度粗暴吗?

不仅不粗暴,而且很科学
概率域
中,状态机一条
线性路径
的概率等于各 arc 的概率的乘积, 比如:

在实际计算过程中,大量概率相乘,结果越来越小,容易溢出。所以往往把数据从概率域相乘转化到log 域相加, 即: 
k2 要求每一条 arc 上的 score 均为 log(p),所以对一条包含 a/b/c 三条边的
线性路径
的 score 计算方式为:

tot_score 对参与计算的各 score_a/score_b/score_c 的梯度都是 1,根据链式求导法则,应该原封不动地把梯度传递给各 “对应的 arc”。

6. k2 中其它函数的梯度也是原封不动地往前传吗?

当然不是。
如上一节所述,get_tot_scores 原封不动地往前传递梯度,是科学的需要,不是为了传而传。
如果我们把梯度再往前传一步,由 lat 传递给 decoding_graph, 可以看出 decoding_graph 中有 3 个 arc 的梯度是 0.15(lat 中原封不动传过来的),而有一个 arc 的梯度是 0.6 !!!

其实 0.6 是来自 lat 中 4 条 arc 梯度的相加,如下图四条蓝色箭头所示:

总之,k2 中梯度反向传递背后是精密的科学计算

7. 总结

本文以 get_tot_scores(log_semiring=False) 函数为例,浅析了 k2 中梯度计算和传递的
科学性

总之梯度是 0/1 也好,不是 0/1 也罢,都是
科学的需要
,而不是粗暴地乱传。
8. 展望
本文没有详细解释第 6 小节中提到的把 lat 中四条 arc 的梯度 0.15 加一块得到 0.6 赋值给 decoding_graph 中的一条 arc。
另外在利用 k2.get_tot_scores 计算损失函数时,总是先取负,再 backward,本文主要关注于梯度传播,暂时忽略了这个负号。

比如 
ctc 损失函数计算取负号[3]
    tot_scores = lattice.get_tot_scores(

        log_semiring=True, use_double_scores=self.use_double_scores)

    loss = -1 * tot_scores

比如 mmi 损失函数计算取负号[4]
    num_tot_scores = num_den_tot_scores[::2]

    den_tot_scores = num_den_tot_scores[1::2]


    tot_scores = num_tot_scores - den_scale * den_tot_scores

    loss = -1 * tot_scores.sum()

有奖征稿
:感兴趣的读者,可以将
梯度值 0.6
 以及
取负号
背后的科学原理投递至小编,一经采用,酬劳新一代 Kaldi 周边文创产品一件(最终解释权归
比珍珠还真的 Liliana
 所有)。

提示:关于梯度值 0.6,
intersect_dense 梯度计算[5]
 中的 index_add 和 arc_map 是重点。
_k2.index_add(arc_map_a, out_fsa_grad, grad_a)

_k2.index_add(arc_map_b, out_fsa_grad, grad_b.view(-1))

参考资料

[1] 
梯度反向传递原理: https://arxiv.org/pdf/1502.05767v4.pdf
[2] 
GetTotScoresTropicalBackward: https://github.com/k2-fsa/k2/blob/125d34703f898b5ca54f6f4a925f2bc2d7a5ba98/k2/python/csrc/torch/fsa.cu#L530
[3] 
ctc loss 负号: https://github.com/k2-fsa/k2/blob/5791dbb622138795cf8f31ccbfcd6b512d97d727/k2/python/k2/ctc_loss.py#L85
[4] 
mmi loss 负号: https://github.com/k2-fsa/icefall/blob/2332ba312d7ce72f08c7bac1e3312f7e3dd722dc/icefall/mmi.py#L95
[5] 
intersect_dense 梯度计算: https://github.com/k2-fsa/k2/blob/82d22f50c739a783c0bf7d23294aa953ab3dd416/k2/python/k2/autograd.py#L485

往期文章

本文出品:新一代 Kaldi-NGK 编辑部
撰文:蛋哥的 Liyong Guo
继续阅读
阅读原文