第七周.01.Message更新讲解+GCN实例

发布时间:2023-06-28 08:30

文章目录

  • update_all
  • send_and_recv
  • Built-in Function
  • 实例代码

本文内容整理自深度之眼《GNN核心能力培养计划》
公式输入请参考: 在线Latex公式
update_all和send_and_recv是图网中非常重要两个函数,用人话来描述就是我们要如何汇聚邻居的信息:
1、汇聚什么?节点还是边?还是节点加边?还是节点减边?
2、如何汇聚?求和?最大?平均?
3、汇聚后更新节点表征需要什么操作?这个不是必须的,可以是做个特征变化啥的。

update_all

官网说明:https://docs.dgl.ai/generated/dgl.DGLGraph.update_all.html

DGLGraph.update_all(message_func, reduce_func, apply_node_func=None, etype=None)

来看下里面的几个参数。
message_func,消息函数(从源节点到目标节点进行操作),可以使用DGL自带的消息函数或者自定义消息函数
reduce_func,产生消息后,对消息进行汇聚aggregate操作,也是可以使用DGL自带的汇聚函数或者自定义汇聚函数
apply_node_func,节点更新函数,经过上面两步后如何更新节点embedding,这个函数只有用户自定义
以上三个函数是update_all中最核心的部分。

send_and_recv

官网说明:https://docs.dgl.ai/generated/dgl.DGLGraph.send_and_recv.html

DGLGraph.send_and_recv(edges, message_func, reduce_func, apply_node_func=None, etype=None, inplace=False)

这个函数和上面的update_all里面一样有三个核心消息函数,这里就不写了,不一样的是send_and_recv函数可以指定边(第一个参数)进行消息操作。
边这个参数可以有以下几种方式:

方式 含义
整型int 代表单个边的编号
整型 Tensor Tensor 中的每个元素代表一个边的编号,tensor的device类型及数据类型要和Graph的ID类型要一致
可迭代的整型 每个元素代表一个边的编号
(Tensor ,Tensor ) 用节点的方式来表示边,两个Tensor 分别表示起始和结束节点
(可迭代的整型,可迭代的整型) 同上

Built-in Function

上面讲三个核心函数的时候有提到DGL有自带的消息处理函数,我们来看看:
官网地址:https://docs.dgl.ai/api/python/dgl.function.html#dgl-built-in-function
第七周.01.Message更新讲解+GCN实例_第1张图片
从表中可以看到有三大类:
第一大类是单对象操作,直接copy消息,下划线后面分别代表拷贝的对象:节点、边,后面两个和前面两个是功能一样的
第二大类是双对象操作,下划线前后分别代表两个对象,中间代表操作类型
第三大类是reduce函数,四个。

实例代码

https://docs.dgl.ai/tutorials/models/1_gnn/1_gcn.html
针对官网的GCN代码重新进行讲解,这次重点看上面提到的函数。
原文的模型描述公式没显示出来,这里重新贴下:
For each node u u u:

  1. Aggregate neighbors’ representations h v h_{v} hv to produce an intermediate representation h ^ u \hat{h}_u h^u.
  2. Transform the aggregated representation h ^ u \hat{h}_{u} h^u with a linear projection followed by a non-linearity: h u = f ( W u h ^ u ) h_{u} = f(W_{u} \hat{h}_u) hu=f(Wuh^u).

We will implement step 1 with DGL message passing, and step 2 by PyTorch nn.Module.

GCN implementation with DGL

We first define the message and reduce function as usual. Since the
aggregation on a node u u u only involves summing over the neighbors’
representations h v h_v hv, we can simply use builtin functions:
具体代码:

import dgl
import dgl.function as fn
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph

gcn_msg = fn.copy_u(u='h', out='m')#update_all的第一个参数,单对象操作,直接拷贝原节点信息作为消息输出
gcn_reduce = fn.sum(msg='m', out='h')#update_all的第二个参数,采用sum作为aggregate方式,吃的上面的输出
class GCNLayer(nn.Module):
    def __init__(self, in_feats, out_feats):
        super(GCNLayer, self).__init__()
        self.linear = nn.Linear(in_feats, out_feats)#out_feats是输出的分类数量

    def forward(self, g, feature):
        # Creating a local scope so that all the stored ndata and edata
        # (such as the `'h'` ndata below) are automatically popped out
        # when the scope exits.
        with g.local_scope():
            g.ndata['h'] = feature# 初始化的特征丢给节点
            g.update_all(gcn_msg, gcn_reduce)# 更新节点表征,里面两个函数在上面,考虑一下博文中提出的三个问题
            h = g.ndata['h']#将最后g.ndata读取出来作为结果
            return self.linear(h)#update_all的第二个参数,采用sum作为aggregate方式,吃的上面的输出

下面定义一个GCN模型,在cora上进行一个分类,模型包含两层GCN layer,输入特征维度是1433,分类数是7

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.layer1 = GCNLayer(1433, 16)#输入1433,输出中间层为16
        self.layer2 = GCNLayer(16, 7)#输入16,最后输出7分类
    
    def forward(self, g, features):
        x = F.relu(self.layer1(g, features))
        x = self.layer2(g, x)#最后输出层不用做ReLU
        return x
net = Net()
print(net)

导入cora数据,划分训练测试数据集

from dgl.data import CoraGraphDataset
def load_cora_data():
    dataset = CoraGraphDataset()
    g = dataset[0]
    features = g.ndata['feat']
    labels = g.ndata['label']
    train_mask = g.ndata['train_mask']
    test_mask = g.ndata['test_mask']
    return g, features, labels, train_mask, test_mask

测试模型效果

def evaluate(model, g, features, labels, mask):
    model.eval()
    with th.no_grad():
        logits = model(g, features)
        logits = logits[mask]
        labels = labels[mask]
        _, indices = th.max(logits, dim=1)
        correct = th.sum(indices == labels)
        return correct.item() * 1.0 / len(labels)

训练模型

import time
import numpy as np
g, features, labels, train_mask, test_mask = load_cora_data()
# Add edges between each node and itself to preserve old node representations
g.add_edges(g.nodes(), g.nodes())#加selfloop:A'=A+I
optimizer = th.optim.Adam(net.parameters(), lr=1e-2)
dur = []
for epoch in range(50):
    if epoch >=3:
        t0 = time.time()

    net.train()
    logits = net(g, features)
    logp = F.log_softmax(logits, 1)
    loss = F.nll_loss(logp[train_mask], labels[train_mask])#两步计算交叉熵损失
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if epoch >=3:
        dur.append(time.time() - t0)
    
    acc = evaluate(net, g, features, labels, test_mask)
    print("Epoch {:05d} | Loss {:.4f} | Test Acc {:.4f} | Time(s) {:.4f}".format(
            epoch, loss.item(), acc, np.mean(dur)))

结果:
第七周.01.Message更新讲解+GCN实例_第2张图片

ItVuer - 免责声明 - 关于我们 - 联系我们

本网站信息来源于互联网,如有侵权请联系:561261067@qq.com

桂ICP备16001015号