发布时间:2022-08-19 11:39
https://www.toutiao.com/a6690329469023945220/
本文重点介绍了 DGL v0.3的重要特性之一 — 消息融合。
我们在去年12月发布了Deep Graph Library (DGL)的首个公开版本。在过去的几个版本的更新中,DGL主要注重框架的易用性,比如怎样设计一系列灵活易用的接口,如何便于大家实现各式各样的图神经网络(GNN)模型,以及怎样和主流深度学习框架(如PyTorch,MXNet等)集成。因为这些设计,让DGL快速地获得了社区的认可和接受。然而天下没有免费的午餐,不同的框架对于相同的运算支持程度不同,并且普遍缺乏图层面上的计算原语,导致了计算速度上的不足。随着DGL接口的逐渐稳定,我们终于可以腾出手来解决性能问题。即将发布的DGL v0.3版本中,性能问题将得到全面而系统地改善。
相比当前的DGL稳定版本v0.2,DGL v0.3在性能上取得了显著提升。相比v0.2, DGL v0.3训练速度提高了19倍,并且大幅度降低了内存使用量,使得单GPU上能训练的图的大小提高到原来的8倍。比起PyG等其他框架,DGL不但训练更快,而且能够在巨大的图上(5亿节点,250亿边)训练图神经网络。
接下来,我们将介绍DGL v0.3的重要特性之一 — 消息融合(Fused Message Passing)。我们会逐一解释,为什么普通的消息传递无法拓展到大图上以及消息融合是怎么解决这一问题的。更多细节可以参考我们被 ICLR’19 的 RLGM workshop 所收录的论文[1]。
大图训练的性能瓶颈
绝大多数图神经网络模型遵循消息传递的计算范式,用户需要提供两个函数:
下图中,用户自定义的消息函数用
表示。消息函数将点 i 和 j 上的特征
,
以及边i->j上的特征
作为输入,生成边上的消息(黄色方框)。在每个节点上,用户定义的累和函数将消息累和,然后调用另一个用户定义的更新函数
更新节点的特征。
普通的消息传递很容易在DGL中实现:首先,我们通过 send 接口调用消息函数,然后通过recv 接口调用累和函数。下面的例子实现了目前流行的图卷积网络 Graph Convolution Network(GCN)。
# 使用自定义消息函数和累和函数计算图卷积
G.update_all(lambda edges: {'m' : edges.src['h']},
lambda nodes: {'h' : sum(nodes.mailbox['m'], axis=1)})
以上的代码非常简洁易懂,但性能却不佳。原因在于消息传递的过程中实际生成了消息张量(message tensor)。消息张量的大小正比于图中边的数量,因而当图增大时,消息张量消耗的内存空间也会显著上升。以 GraphSage 论文中的 Reddit 数据集(23.2万节点,1.14亿边)为例,如果我们用上述代码训练
GCN
,点上的特征会被拷贝成边上的信息,这会导致内存使用量骤增500倍。除了浪费内存,该做法还使得访存变得更为频繁,进而导致 GPU 的利用率降低。
消息融合解决大图训练难题
为了避免生成消息张量带来的额外开销,DGL实现了消息融合技术。DGL将 send 和 recv 接口合并成 send_and_recv(见下图)。DGL的后端通过自己的CUDA代码,在每个GPU线程中将源节点特征载入其本地内存并计算消息函数,然后将计算结果直接累和到目标节点,从而避免生成消息张量。
为实现消息融合,DGL提供了一系列预先定义好的内建函数。尽管这限制了用户对消息函数和累和函数的选择,但DGL提供了非常丰富的内建函数以实现绝大多数GNN模型。当然,用户也可以选择自己定义消息函数和累和函数,这种情况下,DGL不会进行消息融合优化。
另外在
反向传播
中,由于消息张量没有保存,因此需要被重新计算。实际操作中,许多消息函数的求导都不需要使用到消息张量(比如拷贝源节点特征到边上),而我们的实现也利用了这一特性。
在DGL中使用消息融合
使用消息融合非常简单。比如,我们可以用copy_src内建消息函数和sum内建累和函数改写先前的GCN实现:
import dgl.function as fn
G = ... # 任意图结构
# 将源节点的特征h拷贝为消息,并在目标节点累和生成新的特征h。
G.update_all(fn.copy_src('h', 'm'), fn.sum('m', 'h'))
图注意力模型 Graph Attention Network (GAT) 则可以用 src_mul_edge 内建消息函数和 sum内建累和函数组合实现:
# 这里假设注意力分数为边上特征e
G.update_all(fn.src_mul_edge('h', 'e', 'm'), fn.sum('m', 'h'))
DGL v0.3 将支持以下内建函数: