k2 核心数据结构 Ragged 解析
本文介绍新一代 Kaldi 项目的 Ragged 数据结构相关代码:文中作图代码对应的colab: https://colab.research.google.com/drive/1kQc3co8gYbJwRNjdIe2NSRPD8ijSjwV8?usp=sharingk2 ragged 代码:https://github.com/k2-fsa/k2/blob/master/k2/csrc/ragged.h
本文主要探讨 k2 中不规则数据 Ragged 类型
1 缘起 Ragged
每当我打开 k2 项目首页[1],都会不经意地回想起2021年春天的某个早上,我请教范利春博士 k2 项目是什么,他回答我说 “去看 Ragged 吧”。在 k2 首页简短的项目介绍里,Ragged 这一概念频繁出现,却让当时初次接触这个项目的我完全摸不着头脑。
去翻 Ragged 源码[2], 硬着头皮读了几遍,不免感觉宝强附体:
现在回头看去,或许可以从这个角度理解:Ragged 是在高效率地表示 List of List。(注意: 可以不断嵌套,比如 List[List[List]], List[List[List[List[...]]]]。
2 矩阵与 Ragged 数据索引计算方式对比分析
下文中的坐标以及存储位置 offset 等索引均以 0 为起始值,即 0-based。
2.1 矩阵中坐标与存储位置互相推导
比如需要存储维度为 [4, 80] 的 fbank 特征, 可以开辟 320 个 float 空间, 其中 [0, 80) 放第一帧, [80, 160) 放第二帧,依次类推(注意空间为左开右闭)。
fbank 每一帧都包含 80 个浮点数,很规整。所以只需要记住 4, 80 这两个维度信息就可以很快实现坐标 和存贮位置 的互换。即 。或者反过来 。
上述换算过程中,只有 80 维这个信息用到了,4 帧这个信息貌似没用到。其实借助 "4 帧"这一信息,我们可以推导出数据逻辑上是 2 维。比如 [4, 80] 和 [2,2,80] 需要的存储空间是一样的,但是逻辑上一个是 2 维,一个是 3 维。
总体而言,对于这种结构规整的数据,只需要很少的辅助信息就可以推导出来数据逻辑上的组织形式。以及实现其坐标与实际存储位置的互换。
但是一旦数据结构不规整,要完成这种互换就需要更多的辅助信息。以下图为例:
如果要存储这句话的发音序列,可以开辟一串存储空间,各 phone 依次存入:
aux_labels = [h e sh an t on g yi]
(注:该发音方式仅仅是为了本文讲解方便,不代表用于实际生产。)
如果没有辅助信息,仅有这种线性的存储方式完全不可能完成逻辑坐标和实际存储位置的互换。比如第 2 个字的第 2 个 phone 是什么?或者第 5 个 phone 属于哪个字?
2.2 Ragged 数据坐标与存储位置的互相推导
在 k2 中,可以借助 row_ids / row_splits 的概念[3]解决上述两个问题。
2.2.1 使用 row_splits 由坐标计算存储位置
从某种角度来说,row_splits 记录的是每个字起始 phone 在存储空间中的offset。在本例中:
# 注意最后有两个 7, 解释见3.1
row_splits = [0, 2, 4, 7, 7]
如果需要计算第 2 个字的第 2 个 phone(即坐标为 [2, 2])对应的存储位置 offset:
offset = row_splists[2] + 2 = 4 + 2 = 6
# 对应的 phone 为:
# aux_labels[6], 即为 g
所以坐标 [2, 2] 对应的 offset=6, 对应的 phone 为 g。
接下来目测验证一下上述结果, 用List[List] 结构来记录各个字对应的 phone 序列,如下所示:
[
[h e]
[sh an]
[t on g]
[yi]
]
其中第 2 个字“统”的第 2 个 phone 是 g。
说明上面通过 row_splits 计算所得的 offset=6 是正确的。
2.2.2 使用 row_ids 由存储位置计算坐标
现在我们反过来由 offset=6 计算其坐标,显然通过 row_splits 也可以算,比如:
offset = 6
word = 0
for i in range(len(row_splits) - 1):
if row_splits[i] <= offset and row_splits[i + 1] > offset:
word = i
这样也能算,但是每次都要跑一遍循环。所以对于这样可能会经常用到,但是却不好算的变量该怎么办?怎么办?怎么办?
是的,正如你所想,算一次,然后存下来。
如果把每个 phone 对应的 word 的坐标用 row_ids 存储如下:
row_ids = [0, 0, 1, 1, 2, 2, 2, 3]
计算第 6 个 phone 所属 word 的坐标:
word = row_ids[6] = 2
获取到对应的 word 的坐标后,还可以借助 row_splits 进一步计算该 phone 在这个字内的坐标:
# 找到该 word 的起始 start_phone 对应的 offset:
word_start_phone_offset = row_splits[2] = 4
# phone 的 offset 与 start_phone 的 offset
# 差值就是 phone 在当前 word 中的坐标
offset_inside_word = 6 - 4 = 2
所以, offset=6 可得坐标 [word, phone] = [2, 2], 即 2 号字的 2 号 phone。至此,我们借助 row_splits/row_ids 实现了坐标与存储位置 offset 之间的互相转换。而且计算的效率很高,只是过程略微烧脑亿点点。
3 更深层次的 Ragged 数据解析
如前文所述,理论上,Ragged 可以建模任意层次深度的 List[List[***]]。接下来探究更深层次 Ragged 数据描述方式。
假设我们要研究的对象是按照城市和省份组织起来的一些国内高校,如下所示:
# 最外层 国家 (对应 china_univ_fsav)
[
# 省份层
# 陕西省 (对应 china_univ_fsav 中的 fsa level)
[
# 城市层
# 西安市 (对应 china_univ_fsav 中的 state level)
[
# 高校层
西北工业大学 (对应 china_univ_fsav 中的 arc level)
西安交通大学
西安电子科技大学
长安大学
]
# 延安市
[
延安大学
]
# 汉中市
[
陕西理工大学
]
]
# 台湾省 (对应 china_univ_fsav 中的各个 fsa level)
[
# 台南市 (对应 china_univ_fsav 中的各个 state level)
[
成功大学 (对应 china_univ_fsav 中的各个 arc level)
]
# 台中市
[
逢甲大学
]
]
]
该问题等价于用 Ragged 中 row_splits/row_ids 的概念描述下面三个 List 之间的关系:
provinces = [陕西省, 台湾省]
cites = [西安市, 延安市,汉中市,台南市, 台中市]
univs = [西北工业大学,西安交通大学,西安电子科技大学,长安大学,延安大学,陕西理工大学,成功大学,逢甲大学]
3.1 k2 中 state 与 arc 的“从属”关系
在 k2 中,每个 state “包含”的 arc 是从它出发的 arc(“进入”它的 arc 不归它管)。因此我们可以把下图中的每个 state 看作一个“城市”,从该城市出发的一条 arc 代表一所位于该城市的大学。比如 state 0 可以代表“西安市”, 它包含西工大,西交,西电,长安大学等四所大学。
既然一个 state 代表一个城市,那自然而然包含多个 state 的 fsa 可以代表一个省份,如上图可以代表陕西省,其中 state 0/1/2 分别对应西安市,延安市,汉中市。
state 3/4 没有对应的城市名字。因为在 k2 中,指向终止节点的 arc 的 input label 一般是 -1, state 3/4 仅仅是为了满足这一要求,不对应任何城市。同理 state 3 包含的那条 arc, 也不对应任何大学。各 state 包含的 arc 如下:
state 0(西安) --> [西工大,西交,西电,长安大学]
state 1(延安) --> [延安大学]
state 2(汉中) --> [陕西理工大学]
state 3(辅助 state) --> [-1 辅助 arc]
state 4(辅助终止 state) --> []
值得注意的是,最终的 state 4 不包含任何 arc, 它包含的 arc 记为 []。
该 fsa 对应的 row_splits 为:
# 注意末尾有两个 7
# 7(最后一个) - 7(倒数第二个) = 0;
# 表示最后一个 state 对应的包含的 arc 为 []
row_splits = [0, 4, 5, 6, 7, 7]
3.2 多个 Ragged 可以构建更高一维的 Ragged
把 3.1 中 Ragged 理解为 List[List[List]] 即
Provice[Cites[Universities]]
多个省份放在一块需要在最外层再套一个List, 即
Country[Provices[Cites[Universities]]]
假设我们已经创建了台湾省的部分高校如下图所示(完整代码见文首 colab 链接):
在 k2 中,可以很方便地把两个 fsa 套起来组成更高维度的 Ragged:
china_univ_fsav = k2.create_fsa_vec([shaanxi_univ_fsa, taiwan_univ_fsa])
使用如下代码可以查看 china_univ_fsav 组织形式以及 province/city 各层级的 row_splits/row_ids
shape = china_univ_fsav.arcs.shape()
print(shape)
# 对应输出
# [ [ [ x x x x ] [ x ] [ x ] [ x ] [ ] ] [ [ x ] [ x ] [ x ] [ ] ] ]
注意上面输出结果中两个空list "[]", 对应的是我们 3.1 节中提到的连续两个 “7”(即辅助终止 state)。
添加上这些辅助 state/arc, 我们要研究的对象更新为:
provinces = [陕西省, 台湾省]
cites = [西安市, 延安市, 汉中市,辅助state, 辅助终止state, 台南市, 台中市, 辅助state, 辅助终止state]
univs = [西北工业大学,西安交通大学,西安电子科技大学,长安大学,延安大学,延安大学,陕西理工大学,辅助arc, 成功大学,逢甲大学, 辅助arc]
“省 -- 城市” 级别的 row_ids/row_splits:
print(f"shape.row_ids(1): {shape.row_ids(1)}")
print(f"shape.row_splits(1): {shape.row_splits(1)}")
# 对应输出
# shape.row_ids(1): tensor([0, 0, 0, 0, 0, 1, 1, 1, 1], dtype=torch.int32)
# shape.row_splits(1): tensor([0, 5, 9], dtype=torch.int32)
“城市 -- 高校” 级别的 row_ids/row_splits:
print(f"shape.row_ids(2): {shape.row_ids(2)}")
print(f"shape.row_splits(2): {shape.row_splits(2)}")
# 对应输出
# shape.row_ids(2): tensor([0, 0, 0, 0, 1, 2, 3, 5, 6, 7], dtype=torch.int32)
# shape.row_splits(2): tensor([ 0, 4, 5, 6, 7, 7, 8, 9, 10, 10], dtype=torch.int32)
3.2.1 由坐标计算存储位置 offset
比如需要计算坐标 [省份, 城市, 高校] = [1, 0, 0] 对应的高校在 univs 中的存储位置 offset。
通过分析第 3 节一开始的 List, 我们目测一下,该坐标对应的是“台湾省,台南市,成功大学”。
接下来我们用 row_splits 计算该结果。
首选由坐标[省份, 城市] = [1, 0] 和 “省--城市” 对应的 row_splits 计算是城市 offset:
province, city = 1, 0
city_offset = shape.row_splits(1)[province] + city = 5 + 0 = 5
# 到此环节可得出目标城市为
# 台南市 = cites[city_offset] = cites[5]
确实是“台南市”,与目测结果一致。
再由 [城市,高校] = [0, 0] 和 "城市 -- 大学" 对应的 row_splits 计算是大学对应的 offset:
univ = 0
univ_offset = shape.row_splits(2)[city_offset] + univ = 7 + 0 = 7
# 所以可得出目标高校为
# 成功大学 = univs[univ_offset] = univs[7]
所以 坐标 [省份, 城市, 高校] = [1, 0, 0] 对应的高校存储位置 offset = 7。
确实是“成功大学”,与目测结果一致。
3.2.2 由存储位置 offset 计算坐标
由 row_ids 可以完成存储位置到坐标的推导。假设需要计算 univ_offset = 8(即逢甲大学) 所属的城市和省份坐标。
由 univ_offset = 8 和“城市 -- 大学“对应的 row_ids 计算所属城市:
city_offset = shape.row_ids(2)[univ_offset] = 6
# 所以目标城市为:
# 台中市 = cites[city_offset] = cities[6]
由 city_offset = 6 和 “省--城市” 对应的 row_ids 计算所属省份:
province_offset = shape.row_ids(1)[city_offset] = 1
# 所以目标省份为:
# 台湾省 = provinces[province_offset] = provinces[1]
city_offset = 6 是在所有城市列表中的 offset 位置,可以借助 row_splits(1) 计算在省份内部的坐标:
province_start_city_offset = shape.row_splits(1)[province_offset] = 5
offset_inside_province = city_offset - province_start_city_offset = 6 - 5 = 1
即所求坐标中的 [province, city] = [1, 1]。
同理,借助 row_splits(2) 可以计算高校在城市内部的坐标:
city_start_univ_offset = shape.row_splits(2)[city_offset] = 8
offset_inside_city = univ_offset - city_start_univ_offset = 8 - 8 = 0
所以可得 univ_offset = 8 对应的坐标 [province, city, univ] = [1, 1, 0],
即 [台湾省, 台中市, 逢甲大学]。
总结而言:row_splits 描述的是“从外到内”, row_ids 描述的是 “从内到外”。二者都是在描述“内外”对应关系,只是使用在不同的场合。
总体展示如下:
row_splits 表示对下一层元素的“划分“,对应“从外到内:
row_ids 表示对上一层的“从属”,对应“从内到外”:
4 对应代码实现
从上面的例子中可以看出,对于 Ragged 数据的描述,每一层都有对应的 row_split/row_ids。这对应k2 中 RaggedShape[4] 的 layers_ 成员变量:
class RaggedShape{
...
private:
...
std::vector<RaggedShapeLayer> layers_;
};
每一层 RaggedShapeLayer[5] 都有自己的 row_splits 和 row_ids:
struct RaggedShapeLayer {
...
Array1<int32_t> row_splits;
...
Array1<int32_t> row_ids;
...
};
一个 Ragged 数据,就是对一串线性存储的数据配合各个层级维度的描述,简而言之:Ragged = Array1 + RaggedShape。对应 Ragged[6] 定义如下:
template <typename T>
struct Ragged {
RaggedShape shape;
...
Array1<T> values;
};
注意 Ragged 是一个模版类,当类型参数 typename T 为 Arc[7] 时,就得到了状态机 Fsa 的底层实现[8]:
struct Arc {
int32_t src_state;
int32_t dest_state;
int32_t label;
float score;
...
};
...
using Fsa = Ragged<Arc>; // 2 axes: state,arc
using FsaVec = Ragged<Arc>; // 3 axes: fsa,state,arc. Note, the src_state
// and dest_state in the arc are *within the
// FSA*, i.e. they are idx1 not idx01.
using FsaOrVec = Ragged<Arc>; // for when we don't know if it will have 2 or
// 3 axes. (i.e. Fsa or FsaVec)
...
注意 Fsa 和 FsaVec 都是 Ragged, 因为它们底层都是 Array1+ RaggedShape。二者的差异只是 RaggedShape 的层次深度(即代码中的 axes)不同。
5 总结
本文通过对比矩阵与不规则 Ragged 数据的索引方式的计算,介绍了 k2 中 Ragged 数据结构相关设计。
本文没有探讨 k2 中关于 fsa 的操作,比如get_tot_scores. 不过在 k2 中,边上的 score 一般要求为 log(p), 所以聪明的读者能帮小编算一下概率 p("河山统一") = ?/% 吗?
(将上面问题结果和本文转发至朋友圈,周日(八月七号)晚上十点之前把截图发给小编,点赞数前三名赠送 k2 T恤一件。)
6 展望
本文构思于建军节前夜,后续也会带来更多地关于 k2 底层算法的介绍。
往期文章
参考资料
k2项目首页: https://github.com/k2-fsa/k2
[2]Ragged 头文件: https://github.com/k2-fsa/k2/blob/master/k2/csrc/ragged.h
[3]row_ids / row_splits概念: https://github.com/k2-fsa/k2/blob/7dcabf85e8bf06984c4abab0400ef1322b5ff3df/k2/csrc/utils.h#L50
[4]RaggedShape 定义: https://github.com/k2-fsa/k2/blob/7dcabf85e8bf06984c4abab0400ef1322b5ff3df/k2/csrc/ragged.h#L77
[5]RaggedShapeLayer 定义: https://github.com/k2-fsa/k2/blob/7dcabf85e8bf06984c4abab0400ef1322b5ff3df/k2/csrc/ragged.h#L40
[6]Ragged 定义: https://github.com/k2-fsa/k2/blob/master/k2/csrc/ragged.h#L329
[7]Arc 定义: https://github.com/k2-fsa/k2/blob/7dcabf85e8bf06984c4abab0400ef1322b5ff3df/k2/csrc/fsa.h#L31
[8]Fsa 定义: https://github.com/k2-fsa/k2/blob/7dcabf85e8bf06984c4abab0400ef1322b5ff3df/k2/csrc/fsa.h#L113
最新评论
推荐文章
作者最新文章
你可能感兴趣的文章
Copyright Disclaimer: The copyright of contents (including texts, images, videos and audios) posted above belong to the User who shared or the third-party website which the User shared from. If you found your copyright have been infringed, please send a DMCA takedown notice to [email protected]. For more detail of the source, please click on the button "Read Original Post" below. For other communications, please send to [email protected].
版权声明:以上内容为用户推荐收藏至CareerEngine平台,其内容(含文字、图片、视频、音频等)及知识版权均属用户或用户转发自的第三方网站,如涉嫌侵权,请通知[email protected]进行信息删除。如需查看信息来源,请点击“查看原文”。如需洽谈其它事宜,请联系[email protected]。
版权声明:以上内容为用户推荐收藏至CareerEngine平台,其内容(含文字、图片、视频、音频等)及知识版权均属用户或用户转发自的第三方网站,如涉嫌侵权,请通知[email protected]进行信息删除。如需查看信息来源,请点击“查看原文”。如需洽谈其它事宜,请联系[email protected]。