跟着官方文档学DGL框架第六天——异构图卷积模块(HeteroGraphConv)

发布时间:2023-02-12 11:30

参考链接

  1. https://docs.dgl.ai/guide/nn-heterograph.html#guide-nn-heterograph
  2. https://docs.dgl.ai/api/python/nn.pytorch.html#dgl.nn.pytorch.HeteroGraphConv

在异构图中,我们分别对每种关系进行处理(不同的DGL NN模块),让源节点的消息沿着不同的关系传递到目标节点,然后对于同一目标节点,聚合不同关系传来的信息来更新特征。公式如下

h d s t ( l + 1 ) = A G G r ∈ R , r d s t = d s t ( f r ( g r , h r s r c ( l ) , h r d s t ( l ) ) ) h_{dst}^{\left ( l+1\right )}=AGG_{r\in R,r_{dst}=dst}\left ( f_{r}\left ( g_{r},h_{r_{src}}^{\left ( l\right )},h_{r_{dst}}^{\left ( l\right )}\right )\right ) hdst(l+1)=AGGrR,rdst=dst(fr(gr,hrsrc(l),hrdst(l)))

其中, f r f_{r} fr是每个关系 r r r对应的的NN模块, A G G AGG AGG是聚合函数。

DGL提供了HeteroGraphConv模块

HeteroGraphConv初始化

HeteroGraphConv的初始化有两个参数:

1) mods

包含了对处理各种关系的NN模块,是一个字典类型参数(dict[str, nn.Module]),键为关系名字符串,值为作用在对应关系上的NN模块。注:NN模块中的forword()函数第一个参数必须为DGLHeteroGraph对象,第二个参数可以为代表节点特征的张量或代表源节点特征和目标节点特征的张量对。使用“跟着官方文档学DGL框架第五天”中定义NN模块的方法和DGL已实现的NN模块都是可以的。

2) aggregate

是一个字符串类型参数,表示聚合目标节点上来自不同关系的信息的方式。支持 ‘sum’, ‘max’, ‘min’, ‘mean’, ‘stack’。其中,‘stack’是固定地在第二维执行。当然,也可以根据下面这个格式自定义聚合函数。

def my_agg_func(tensors, dsttype):
    # tensors: is a list of tensors to aggregate
    # dsttype: string name of the destination node type for which the
    #          aggregation is performed
    stacked = torch.stack(tensors, dim=0)
    return torch.sum(stacked, dim=0)

初始化部分的代码如下:

import torch.nn as nn

class HeteroGraphConv(nn.Module):
    def __init__(self, mods, aggregate='sum'):
        super(HeteroGraphConv, self).__init__()
        self.mods = nn.ModuleDict(mods)
        if isinstance(aggregate, str):
            # 获取聚合函数的内部函数
            self.agg_fn = get_aggregate_fn(aggregate)
        else:
            self.agg_fn = aggregate

forward()函数

forward()有四个参数:

  1. g (DGLHeteroGraph):异构图,可以根据“跟着官方文档学DGL框架第三天”来构造。

  2. inputs (dict[str, Tensor] or pair of dict[str, Tensor]):输入的节点特征,字典型参数,键为节点类型字符串,值为节点特征;也可以是两个字典组成的元组,分别表示源节点特征和目标节点特征。

  3. mod_args (dict[str, tuple[any]], optional):字典类型参数,键为关系类型,值为对应NN模块的额外位置参数。

  4. mod_kwargs (dict[str, dict[str, any]], optional):字典类型参数,键为关系类型,值为对应NN模块的key-word参数。

forward()函数首先为每种目标节点类型声明一个空列表,用于保存来自不同NN模块的输出张量。代码如下:

def forward(self, g, inputs, mod_args=None, mod_kwargs=None):
    if mod_args is None:
        mod_args = {}
    if mod_kwargs is None:
        mod_kwargs = {}
    outputs = {nty : [] for nty in g.dsttypes}

然后根据图的类型,对输入的节点特征分为源节点特征和目标结点特征。但这样做的话,如果输入是两个字典构成的元组时,似乎有矛盾。

        if g.is_block:
            src_inputs = inputs
            dst_inputs = {k: v[:g.number_of_dst_nodes(k)] for k, v in inputs.items()}
        else:
            src_inputs = dst_inputs = inputs

利用“g.canonical_etypes”便利所有关系类型,得到相应的关系子图“rel_graph”。

       for stype, etype, dtype in g.canonical_etypes:
            rel_graph = g[stype, etype, dtype]
            if rel_graph.num_edges() == 0:
                continue
            if stype not in src_inputs or dtype not in dst_inputs:
                continue

得到的关系子图“rel_graph”是一个二部图,输入特征为该关系下的源节点特征与目标结点特征构成的元组。(这里也与“跟着官方文档学DGL框架第五天”附言中提到的“expand_as_pair()”对二部图的处理方式对应上了。)接着使用相应关系的NN模块得到目标在该关系下的节点信息“dstdata”。

# outputs: {dtype: [dstdata1, dstdata2, ...]}
        for stype, etype, dtype in g.canonical_etypes:
            rel_graph = g[stype, etype, dtype]
            if rel_graph.num_edges() == 0:
                continue
            if stype not in src_inputs or dtype not in dst_inputs:
                continue
            dstdata = self.mods[etype](
                rel_graph,
                (src_inputs[stype], dst_inputs[dtype]),
                *mod_args.get(etype, ()),
                **mod_kwargs.get(etype, {}))
            outputs[dtype].append(dstdata)

最后调用聚合函数,聚合目标节点来自各种关系的消息。

        rsts = {}
        for nty, alist in outputs.items():
            if len(alist) != 0:
                rsts[nty] = self.agg_fn(alist, nty)
        
        return rsts

例子

构建异构图

这里构建了一个三种关系的异构图

import dgl
g = dgl.heterograph({
    ('user', 'follows', 'user') : edges1,
    ('user', 'plays', 'game') : edges2,
    ('store', 'sells', 'game')  : edges3})

建立一个HeteroGraphConv对象

这里对“follows”、“plays”和“sells”三种关系分别定义了NN模块,聚合函数为“sum”。

import dgl.nn.pytorch as dglnn
conv = dglnn.HeteroGraphConv({
    'follows' : dglnn.GraphConv(...),
    'plays' : dglnn.GraphConv(...),
    'sells' : dglnn.SAGEConv(...)},
    aggregate='sum')

输入“user”特征

“user”通过两种关系可以分别到达“user”和“game”。所以最后“user”和“game”都会收到“user”的消息。

import torch as th
h1 = {'user' : th.randn((g.number_of_nodes('user'), 5))}
h2 = conv(g, h1)
print(h2.keys())
# dict_keys(['user', 'game'])
### 输入“user”和“store”特征
通过关系,“user”只会收到来自“user”的消息;而“game”会同时收到来自“user”和“store”的消息,所以需要聚合。
f1 = {'user' : ..., 'store' : ...}
f2 = conv(g, f1)
print(f2.keys())
# dict_keys(['user', 'game']

输入一对输入

这一对输入,是由源节点特征字典和目标节点特征字典构成的元组。但是从前面的代码看,对输入不支持这种形式,所以前面的应该不是完整代码。

x_src = {'user' : ..., 'store' : ...}
x_dst = {'user' : ..., 'game' : ...}
y_dst = conv(g, (x_src, x_dst))
print(y_dst.keys())
# dict_keys(['user', 'game'])

HeteroGraphConv代码汇总

import torch.nn as nn

class HeteroGraphConv(nn.Module):
    def __init__(self, mods, aggregate='sum'):
        super(HeteroGraphConv, self).__init__()
        self.mods = nn.ModuleDict(mods)
        if isinstance(aggregate, str):
            # 获取聚合函数的内部函数
            self.agg_fn = get_aggregate_fn(aggregate)
        else:
            self.agg_fn = aggregate

    def forward(self, g, inputs, mod_args=None, mod_kwargs=None):
        if mod_args is None:
            mod_args = {}
        if mod_kwargs is None:
            mod_kwargs = {}
        outputs = {nty : [] for nty in g.dsttypes}

        if g.is_block:
            src_inputs = inputs
            dst_inputs = {k: v[:g.number_of_dst_nodes(k)] for k, v in inputs.items()}
        else:
            src_inputs = dst_inputs = inputs

# outputs: {dtype: [dstdata1, dstdata2, ...]}
        for stype, etype, dtype in g.canonical_etypes:
            rel_graph = g[stype, etype, dtype]
            if rel_graph.num_edges() == 0:
                continue
            if stype not in src_inputs or dtype not in dst_inputs:
                continue
            dstdata = self.mods[etype](
                rel_graph,
                (src_inputs[stype], dst_inputs[dtype]),
                *mod_args.get(etype, ()),
                **mod_kwargs.get(etype, {}))
            outputs[dtype].append(dstdata)

        rsts = {}
        for nty, alist in outputs.items():
            if len(alist) != 0:
                rsts[nty] = self.agg_fn(alist, nty)
        
        return rsts

ItVuer - 免责声明 - 关于我们 - 联系我们

本网站信息来源于互联网,如有侵权请联系:561261067@qq.com

桂ICP备16001015号