DGL笔记3——自己写一个GNN模型

发布时间:2024-07-16 14:01

原文:Write your own GNN module

DGL笔记1——用DGL表示图
DGL笔记2——用DGL识别节点
DGL笔记3——自己写一个GNN模型

之前我们学习了 DGL 怎么表示一个图,然后怎么写一个简单的 GCN 模型进行节点识别。但是有时候我们的模型不仅仅是简单地堆叠现有的 GNN 模块。 比如我们现在想发明一种考虑节点重要性或边权重来聚合邻域信息的新方法,该怎么办?

所以我们现在将要学习:

  • DGL 的消息传递 API。

  • 自己实现 GraphSAGE 卷积模块。

记得在看这篇之前先看上一篇 GNN 分类哈~

首先导入相关包:

import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F

消息传递与GNNs

DGL 遵循 由 Gilmer 等人提出的 Message Passing Neural Network 中启发产生的消息传递范式(message passing paradigm )。本质上,很多GNN模型都符合以下框架:

m u → v = M ( l ) ( h v ( l − 1 ) , h u ( l − 1 ) , e u → v ( l − 1 ) ) \large m_{u\rightarrow v}=M^{(l)}\left( h_v^{(l-1)},h_u^{(l-1)},e_{u\rightarrow v}^{(l-1)} \right) muv=M(l)(hv(l1),hu(l1),euv(l1))
m u = Σ u ∈ N ( v ) m u → v ( l ) \large m_{u}=\Sigma_{u\in\mathcal N(v)}m_{u\rightarrow v}^{(l)} mu=ΣuN(v)muv(l)
h v ( l ) = U ( l ) ( h v ( l − 1 ) , m v ( l ) ) \large h_{v}^(l)=U^{(l)}\left( h_v^{(l-1)},m_{v}^{(l)} \right) hv(l)=U(l)(hv(l1),mv(l))

其中 DGL 将 M ( l ) M^{(l)} M(l) 称为消息函数(message function),将 Σ \Sigma Σ 称为聚合函数(reduce function),而将 U ( l ) U^{(l)} U(l) 称为更新函数(update function)。注意 Σ \Sigma Σ 在这儿可以代表任何函数,而不单单是一个求和函数。

举个 ,在 GraphSAGE convolution (Hamilton et al., 2017) 采用了下列属性公式:

h N ( v ) k ← A v e r a g e { h u k − 1 , ∀ u ∈ N ( v ) } \large h^k_{\mathcal N(v)}\leftarrow {\rm Average}\{ h_u^{k-1},\forall_u\in\mathcal N(v) \} hN(v)kAverage{huk1,uN(v)}
h v k ← R e L U ( W k ⋅ C O N C A T ( h v k − 1 , h N ( v ) k ) ) \large h_v^k \leftarrow {\rm ReLU}\left( W^k \cdot {\rm CONCAT}(h_v^{k-1},h^k_\mathcal {N(v)}) \right) hvkReLU(WkCONCAT(hvk1,hN(v)k))

可以看出消息传递有有向的:消息从 u u u v v v 并不一定需要和从 v v v u u u 一致。

虽然 DGL 已经有内置的 GraphSAGE ,也就是 dgl.nn.SAGEConv ,但是我们还是得试试怎么利用 DGL 手写一个 GraphSAGE 卷积。

import dgl.function as fn

class SAGEConv(nn.Module):
    """Graph convolution module used by the GraphSAGE model.

    参 数
    ----------
    in_feat : int
        输入特征维度.
    out_feat : int
        输出特征维度.
    """
    def __init__(self, in_feat, out_feat):
        super(SAGEConv, self).__init__()
        # A linear submodule for projecting the input and neighbor feature to the output.
        self.linear = nn.Linear(in_feat * 2, out_feat)

    def forward(self, g, h):
        """Forward computation

        参 数
        ----------
        g : Graph
            输入的图.
        h : Tensor
            输入的节点特征.
        """
        with g.local_scope():
            g.ndata['h'] = h
            # update_all is a message passing API.
            g.update_all(message_func=fn.copy_u('h', 'm'), reduce_func=fn.mean('m', 'h_N'))
            h_N = g.ndata['h_N']
            h_total = torch.cat([h, h_N], dim=1)
            return self.linear(h_total)

这段代码的核心部分是 g.update_all 函数,它会聚合领域特征然后做平均。 这里有三个概念:

  • 消息函数 fn.copy_u('h', 'm') 将名为 h 的节点特征复制,然后作为 消息 传递给邻居节点。

  • 聚合函数 fn.mean(‘m’, ‘h_N’) 将所有收到的名为 m 的消息做平均,然后将结果保存为一个新的节点特征 h_N

  • update_all 会告诉 DGL 向所有的节点和边发送消息,然后启动聚合函数。

之后,我们就可以堆叠自己的 GraphSAGE 卷积层,从而形成多层 GraphSAGE 网络。

class Model(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(Model, self).__init__()
        self.conv1 = SAGEConv(in_feats, h_feats)
        self.conv2 = SAGEConv(h_feats, num_classes)

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        return h

训练

下面的代码包括了数据读取和训练。

import dgl.data

dataset = dgl.data.CoraGraphDataset()
g = dataset[0]

def train(g, model):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    all_logits = []
    best_val_acc = 0
    best_test_acc = 0

    features = g.ndata['feat']
    labels = g.ndata['label']
    train_mask = g.ndata['train_mask']
    val_mask = g.ndata['val_mask']
    test_mask = g.ndata['test_mask']
    for e in range(200):
        # Forward
        logits = model(g, features)

        # Compute prediction
        pred = logits.argmax(1)

        # Compute loss
        # Note that we should only compute the losses of the nodes in the training set,
        # i.e. with train_mask 1.
        loss = F.cross_entropy(logits[train_mask], labels[train_mask])

        # Compute accuracy on training/validation/test
        train_acc = (pred[train_mask] == labels[train_mask]).float().mean()
        val_acc = (pred[val_mask] == labels[val_mask]).float().mean()
        test_acc = (pred[test_mask] == labels[test_mask]).float().mean()

        # Save the best validation accuracy and the corresponding test accuracy.
        if best_val_acc < val_acc:
            best_val_acc = val_acc
            best_test_acc = test_acc

        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        all_logits.append(logits.detach())

        if e % 20 == 0:
            print('In epoch {}, loss: {:.3f}, val acc: {:.3f} (best {:.3f}), test acc: {:.3f} (best {:.3f})'.format(
                e, loss, val_acc, best_val_acc, test_acc, best_test_acc))

model = Model(g.ndata['feat'].shape[1], 16, dataset.num_classes)
train(g, model)

Out:

  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done loading data from cached files.
In epoch 0, loss: 1.952, val acc: 0.072 (best 0.072), test acc: 0.091 (best 0.091)
In epoch 20, loss: 1.281, val acc: 0.668 (best 0.668), test acc: 0.678 (best 0.678)
In epoch 40, loss: 0.265, val acc: 0.720 (best 0.720), test acc: 0.735 (best 0.735)
In epoch 60, loss: 0.042, val acc: 0.742 (best 0.742), test acc: 0.759 (best 0.757)
In epoch 80, loss: 0.016, val acc: 0.748 (best 0.748), test acc: 0.756 (best 0.756)
In epoch 100, loss: 0.009, val acc: 0.746 (best 0.748), test acc: 0.755 (best 0.756)
In epoch 120, loss: 0.007, val acc: 0.748 (best 0.748), test acc: 0.755 (best 0.756)
In epoch 140, loss: 0.005, val acc: 0.750 (best 0.752), test acc: 0.759 (best 0.759)
In epoch 160, loss: 0.004, val acc: 0.752 (best 0.752), test acc: 0.757 (best 0.759)
In epoch 180, loss: 0.003, val acc: 0.750 (best 0.752), test acc: 0.760 (best 0.759)

首先这份代码和上一节是几乎一模一样的,注意 Model 中的几个参数:

  • g.ndata['feat'].shape[1] 指的是节点的特征维度,这里我们用的是 Cora 数据集,是1433。
  • 16 指的是 h_feats,也就是手动指定的隐藏层维度。
  • 最后的 num_classes 是类别数量,也就是总共 7 种类别。

更进一步的定制化

在 DGL 中还提供了很多内置的消息和聚合函数,它们都在 dgl.function 中。具体可以参阅 API 文档。

这些 API 可以帮助我们快速实现新的图卷积模型。这里再举个 ,下面的代码实现了一个新的 SAGEConv ,它使用加权平均的聚合领域表示。注意,edata 成员可以保存边特征,这些特征也会参与消息传递。

class WeightedSAGEConv(nn.Module):
    """Graph convolution module used by the GraphSAGE model with edge weights.

    参 数
    ----------
    in_feat : int
        输入特征维度.
    out_feat : int
        输出特征维度.
    """
    def __init__(self, in_feat, out_feat):
        super(WeightedSAGEConv, self).__init__()
        # 一个线性子模块,用于将输入和领域特征投影到输出
        self.linear = nn.Linear(in_feat * 2, out_feat)

    def forward(self, g, h, w):
        """Forward computation

        参 数
        ----------
        g : Graph
            输入的图.
        h : Tensor
            输入的节点特征.
        w : Tensor
            边的权重.
        """
        with g.local_scope():
            g.ndata['h'] = h  # 将 h 存入节点的 h 特征
            g.edata['w'] = w  # 将 w 存入边的 w 特征
            g.update_all(message_func=fn.u_mul_e('h', 'w', 'm'), reduce_func=fn.mean('m', 'h_N'))
            h_N = g.ndata['h_N']
            h_total = torch.cat([h, h_N], dim=1)
            return self.linear(h_total)

当前我们的数据集中的图并没有一个边权重,所以我们用 torch.ones(g.num_edges()).to(g.device) 人工地将所有边的权重指定为 1 ,放入模型的 forward() 函数中。当然你也可以用你想用的边权重去代替。fn.u_mul_e 稍后会讲到,我们先跳过。

class Model(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(Model, self).__init__()
        self.conv1 = WeightedSAGEConv(in_feats, h_feats)
        self.conv2 = WeightedSAGEConv(h_feats, num_classes)

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat, torch.ones(g.num_edges()).to(g.device))
        h = F.relu(h)
        h = self.conv2(g, h, torch.ones(g.num_edges()).to(g.device))
        return h

model = Model(g.ndata['feat'].shape[1], 16, dataset.num_classes)
train(g, model)

Out:

In epoch 0, loss: 1.952, val acc: 0.156 (best 0.156), test acc: 0.144 (best 0.144)
In epoch 20, loss: 1.218, val acc: 0.550 (best 0.550), test acc: 0.560 (best 0.560)
In epoch 40, loss: 0.209, val acc: 0.744 (best 0.744), test acc: 0.753 (best 0.751)
In epoch 60, loss: 0.033, val acc: 0.746 (best 0.752), test acc: 0.756 (best 0.754)
In epoch 80, loss: 0.013, val acc: 0.746 (best 0.752), test acc: 0.757 (best 0.754)
In epoch 100, loss: 0.008, val acc: 0.744 (best 0.752), test acc: 0.760 (best 0.754)
In epoch 120, loss: 0.006, val acc: 0.742 (best 0.752), test acc: 0.758 (best 0.754)
In epoch 140, loss: 0.005, val acc: 0.742 (best 0.752), test acc: 0.757 (best 0.754)
In epoch 160, loss: 0.004, val acc: 0.744 (best 0.752), test acc: 0.758 (best 0.754)
In epoch 180, loss: 0.003, val acc: 0.746 (best 0.752), test acc: 0.756 (best 0.754)

用 user-defined funtion 进一步定制化

为了更大的自由度,DGL 允许用户自定义消息和聚合函数。这里有个例子,我们编写一个用户自定义消息函数,它等价于 fn.u_mul_e('h', 'w', 'm')

def u_mul_e_udf(edges):
    return {'m' : edges.src['h'] * edges.data['w']}

这里我们先来简单讲解一下相关内容。首先,edges 有三个成员:srcdatadst ,分别代表源节点特征,边特征,和目标节点特征。

在 DGL 中,经常会把 src 记做 u ,把 edge 记做 e
比如 copy_src(src, out) 等价于 copy_u(u, out),而 copy_edge(edge, out) 等价于 copy_e(e, out)

在 DGL 的 nn.function.u_mul_e 已经实现了此函数。这是一个消息函数,如果特征具有相同的shape,则通过在 u 和 e 特征之间逐元素地执行乘法来计算边上的消息;否则,它首先将特征广播(Broadcasting)到一个新的形状并执行逐元素操作。

广播的过程和 NumPy 一样,可以参见 NumPy 的 Broadcasting 文档。

这里说白了,就是一个将两个不同维度的矩阵做一次广播过程。关键在于要搞清楚输入和输出的含义,这里的 u 就是源节点的特征域,比如在上面的代码中是 h ,而 e 就是边的特征域,也就是之前说的边权重 wm 是它的输出,也就是消息(message)域。

在写完消息函数以后,我们还能再写一下聚合函数。比如,下面这个函数的实现就等价于内置函数 fn.sum('m', 'h'),它的作用是将消息函数给加起来。

def sum_udf(nodes):
    return {'h': nodes.mailbox['m'].sum(1)}

注意这里的 nodes.mailbox 中保存的是节点收到的消息。简单来说,DGL 将根据节点的入度对节点进行分组,对于每个组,DGL 会沿着第二个维度来堆叠传入的消息。 然后,我们可以沿第二个维度执行 reduction 以聚合消息。

更多的自定义消息函数和聚合函数可以参考 User-defined Functions 文档。

自定义 GNN 模型的最佳实践

官方还很贴心地给出了自定义GNN模型的小建议。

DGL recommends the following practice ranked by preference:

  • Use dgl.nn modules.

  • Use dgl.nn.functional functions which contain lower-level complex operations such as computing a softmax for each node over incoming edges.

  • Use update_all with builtin message and reduce functions.

  • Use user-defined message or reduce functions.

本篇到此结束~

ItVuer - 免责声明 - 关于我们 - 联系我们

本网站信息来源于互联网,如有侵权请联系:561261067@qq.com

桂ICP备16001015号