作者 | Dustin Zelle、Arno Eigenwillig
译者 | 核子可乐
策划 | 冬梅
对象以及各对象间的关系在我们生活的世界中无处不在。在理解对象时,各对象间的关系往往与对象本身的属性同等重要,最典型的例子包括物流网络、生产网络、知识图谱或者社交网络。长久以来,离散数学与计算机科学一直把此类网络转化为图,即通过边以各种不规则方式将任意节点连接起来。然而,大多数机器学习(ML)算法仅允许各输入对象之间存在规则且统一的关系,例如像素网络、单词序列,或者完全没有关系。
图神经网络(GNN)是一种强大的技术,能够利用图连接性(例如早期 DeepWalk 与 Node2Vec 算法)以及各节点与边上的输入特征。GNN 可以对整个图(特定分子是否存在某种形式的反应?)、单个节点(根据引文,本文档的主题是什么?)或者潜在边(消费者是否可能同时购买某产品与另一产品?)进行预测。除了图预测之外,GNN 也是一种强大的工具,能够弥合与各类典型神经网络用例之间的鸿沟。它们会以连续方式对图中的离散关系信息进行编码,将其自然包含在另一深度学习系统当中。
近日,谷歌在博客中官宣正式发布 TensorFlow GNN 1.0——用于大规模构建 GNN 的经过生产测试的库。
据谷歌称,TensorFlow GNN 1.0 支持在 TensorFlow 中建模和训练,以及从大型数据存储中提取输入图。TF-GNN 属于从零开始建立的异构图,使用不同的节点和边集合表示对象类型及其关系。现实世界中的对象及其关系往往归属于多种类别,可通过 TF-GNN 的异构焦点以非常自然的方式进行表现。
在 TensorFlow 之内,此类图以 tfgnn.GraphTensor 类型的对象来表示。这是一种复合张量类型(即一个 Python 类中的张量集合),也是 tf.data.Dataset、tf.function 等中的首要对象类型。它既能存储图结构,也可存储节点、边和整个图的特征。GraphTensors 的可训练变换能被定义为网络服务 Keras API 中的 Layers 对象,或直接使用 tfgnn.GraphTensor 原语。
GNN:对上下文中的对象进行预测
为了便于理解,谷歌团队给出了 TF-GNN 的一个典型用例:预测一个大型数据库中,上交叉引用表定义的图中某类节点的属性。例如,计算机科学(CS)arXiv 论文的引文数据库涉及一对多引用和多对一引用的关系,人们希望借此预测每篇论文的主题领域。
与大多数神经网络一样,GNN 在包含大量标记示例(约数百万个)的数据集上进行训练,但每个训练步骤仅涉及一小批训练示例(例如几百个)。为了扩展至百万量级,GNN 会在底层图中较小的子图流上进行训练。每个子图都包含足够的原始数据来计算其中心标记节点的 GNN 结果并训练模型。这个过程(常被称为子图采样)对于 GNN 训练非常重要。大多数现有工具以批量方式完成采样,并生成用于训练的静态子图。TF-GNN 则通过动态加交互采样改进了这一过程。
图中所示为子图采样过程,即从较大的图中抽取更小、更易于处理的子图,以创建用于 GNN 训练的输入示例。
TF-GNN 1.0 首次带来更灵活的 Python API,能够在所有相应比例上配置动态或批量子图采样,包括:在 Colab notebook 以交互方式对存储在单一训练主机主内存中的小型数据集进行高效采样,或者通过 Apache Beam 对存储在网络文件系统中的大型数据集(可能包含数亿个节点和数十亿条边)进行分布式采样。
在这些采样子图上,GNN 的作用是计算根节点处的隐藏(或潜在)状态;隐藏状态中聚集并编码有根节点邻域的相关信息。一种早期方法是使用消息传递神经网络。在每轮消息传递中,节点会沿着传入边接收来自相邻节点的消息,并据此更新自身隐藏状态。经过 n 轮传递之后,根节点的隐藏状态已经反映出 n 个边内所有节点的聚合信息(在下图中,n=2)。消息和新的隐藏状态由神经网络的隐藏层计算。在异构图中,一般应当为不同的顶点和边类型使用单独训练的隐藏层。
如图所示,在这个简单的消息传递神经网络中,每一步都将节点状态从外部节点传播至内部节点,并在内部节点中进行汇总以计算新的节点状态。一旦到达根节点,即可做出最终预测。
训练设置的具体实现,则是将输出层放置在标记节点的 GNN 隐藏状态之上、计算损失(以测量预测误差),再通过反向传播来更新模型权重。整个过程与常规神经网络训练一致。
除了监督训练(即最小化由标签定义的损失)之外,GNN 还能以无监督方式进行训练(即不依赖标签)。我们可以借此计算节点及其特征的离散图结构的连续表示(也称嵌入)。这些表示常可用于其他机器学习系统。通过这种方式,即可将编码为图的离散关系信息包含在更典型的神经网络用例当中。TF-GNN 支持对异构图的无监督目标进行细粒度特化。
GNN 架构是如何构建的
TF-GNN 库支持以不同抽象级别构建并训练 GNN。
在最高层,用户能够使用与库捆绑的任何以 Keras 层表示的预定义模型。除了研究文献中涉及的少部分模型之外,TF-GNN 还提供一套高度可配置的模型模板,其中提供精选的建模选项。我们发现这些模型选项能够为谷歌的诸多内部问题提供强有力的基线支持。有这些模板实现 GNN 层,用户只需从 Keras 层开始初始化。
import tensorflow_gnn as tfgnn

from tensorflow_gnn.models import mt_albis


def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):

"""Builds a GNN as a Keras model."""

graph = inputs = tf.keras.Input(type_spec=graph_tensor_spec)


# Encode input features (callback omitted for brevity).

graph = tfgnn.keras.layers.MapFeatures(

node_sets_fn=set_initial_node_states)(graph)


# For each round of message passing...

for _ in range(2):

# ... create and apply a Keras layer.

graph = mt_albis.MtAlbisGraphUpdate(

units=128, message_dim=64,

attention_type="none", simple_conv_reduce_type="mean",

normalization_type="layer", next_state_type="residual",

state_dropout_rate=0.2, l2_regularization=1e-5,

)(graph)


return tf.keras.Model(inputs, graph)

在最低层,用户可以根据用于在图中传递数据的原语,从头开始编写 GNN 模型。比如将数据从节点广播至所有传出边,或者将数据通过所有传入边汇聚至节点当中。在涉及特征或隐藏状态时,TF-GNN 将平等处理各节点、边和整个输入图,这样不仅能够直接表达以节点为中心的模型(例如前文提到的 MPNN),还可表示更通用的 GraphNet 形式。这一过程可以(但不一定)通过将 Kera 作为核心 TensorFlow 上的建模框架来实现。
训练编排
高级用户当然可以自由进行自定义模型训练,此外 TF-GNN Tunner 还提供一种简洁的方法,用以协调常见情况下 Keras 模型的训练。以下为一条简单的调用示例:
Runner 为 ML pains 提供现成的解决方案,例如分布式训练和云 TPU 上固定形状的 tfgnn.GraphTensor 填充。除了前面提到的单一训练任务之外,Runner 还支持多个(两个或以上)任务的联合训练。例如,可以将无监督任务与监督任务混合,以形成具有特定于应用归纳偏差的最终连续表示(即嵌入)。调用方只需将任务参数替换为任务映射即可:
from tensorflow_gnn import runner

from tensorflow_gnn.models import contrastive_losses


runner.run(

task={

"classification": runner.RootNodeBinaryClassification("papers", ...),

"dgi": contrastive_losses.DeepGraphInfomaxTask("papers"),

},

...

)

此外,TF-GNN Runner 还包含用于模型归因的集成梯度实现。集成梯度输出是一个 GraphTensor,其连接性与观察到的 GraphTensor 相同,只是用梯度值替代其特征。在 GNN 预测中,较大的梯度值往往比较小梯度值贡献更多。用户可以检查梯度值,借此了解哪些特征在 GNN 中作用更大。
原文链接:
https://blog.tensorflow.org/2024/02/graph-neural-networks-in-tensorflow.html?m=1
 活动推荐
AICon 全球人工智能与大模型开发与应用大会暨通用人工智能开发与应用生态展将于 5 月 17 日正式开幕,本次大会主题为「智能未来,探索 AI 无限可能」。如您感兴趣,可点击「阅读原文」查看更多详情。
目前会议 8 折优惠购票,火热进行中,购票或咨询其他问题请联系票务同学:13269078023,或扫描上方二维码添加大会福利官,可领取福利资料包。
今日荐文

再也不用羡慕修仙永生了,我已经“做到”了

Taylor Swift 身陷不雅照风波:AI 越强、Deepfakes 越猖狂,微软和推特们无法推责

清华系2B模型杀出,性能吊打LLaMA-13B,170万tokens仅需1块钱!

性能逼近GPT-4,开源Mistral-Medium意外泄露?CEO最新回应来了

碾压前辈!Meta发布“最大、性能最好”的开源Code Llama 70B,但开发者纷纷喊穷:玩不起!

OpenAI出手后,GPT-4真的不懒了?网友不买账:只靠打补丁恐怕无济于事!


你也「在看」吗? 👇
继续阅读
阅读原文