发布时间:2023-06-13 08:30
DGL库很友好出了汉语教程地址就在这个地方,这里基本从那边粘贴过来,算作个人笔记。
DGL的核心数据结构DGLGraph
提供了一个以图为中心的编程抽象。 DGLGraph
提供了接口以处理图的结构、节点/边 的特征,以及使用这些组件可以执行的计算。
了解基本概念,图、图的表示、加权图与未加权图、同构与异构图、多重图
DGL用唯一整数表示节点,即点ID;对应的两个端点ID表示一条边。根据添加顺序每条边有边ID。DGL中边是有方向的,即边 ( u , v ) (u,v) (u,v)表示节点 u u u指向节点 v v v。
对于多节点,DGL使用一个一维整形张量(如,PyTorch的Tensor类)保持图的点ID,DGL称之为”节点张量”。对于多多条边,DGL使用一个包含2个节点张量的元组 ( U , V ) (U,V) (U,V) ,其中,用 ( U [ i ] , V [ i ] ) (U[i],V[i]) (U[i],V[i]) 指代一条 U [ i ] U[i] U[i] 到 V [ i ] V[i] V[i] 的边。
创建一个 DGLGraph
对象的一种方法是使用 dgl.graph()
函数。它接受一个边的集合作为输入。DGL也支持从其他的数据源来创建图对象。
下面的代码段使用了 dgl.graph()
函数来构建一个 DGLGraph
对象,对应着下图所示的包含4个节点的图。 其中一些代码演示了查询图结构的部分API的使用方法。
import dgl
import torch as th
# 边 0->1, 0->2, 0->3, 1->3
u,v = th.tensor([0,0,0,1]), th.tensor([1,2,3,3])
g = dgl.graph((u,v))
print(g)
Using backend: pytorch
Graph(num_nodes=4, num_edges=4,
ndata_schemes={}
edata_schemes={})
# 获取节点
print(g.nodes())
tensor([0, 1, 2, 3])
# 获取边对应的点
print(g.edges())
(tensor([0, 0, 0, 1]), tensor([1, 2, 3, 3]))
# 获取边的对应端点和边ID
print(g.edges(form='all'))
(tensor([0, 0, 0, 1]), tensor([1, 2, 3, 3]), tensor([0, 1, 2, 3]))
# 如果具有最大ID的节点没有边,在创建图的时候,用户需要明确地指明节点的数量。
g = dgl.graph((u, v), num_nodes=8)
对于无向的图,用户需要为每条边都创建两个方向的边。可以使用 dgl.to_bidirected()
函数来实现这个目的。 如下面的代码段所示,这个函数可以把原图转换成一个包含反向边的图。
bg = dgl.to_bidirected(g)
bg.edges()
(tensor([0, 0, 0, 1, 1, 2, 3, 3]), tensor([1, 2, 3, 0, 3, 0, 0, 1]))
DGL可以用32或64位整数作为ID但类型要一致。下面是两种转换方法
edges = th.tensor([2, 5, 3]), th.tensor([3, 5, 0]) # 边:2->3, 5->5, 3->0
g64 = dgl.graph(edges) # DGL默认使用int64
print(g64.idtype)
torch.int64
g32 = dgl.graph(edges, idtype=th.int32) # 使用int32构建图
g32.idtype
torch.int32
g64_2 = g32.long() # 转换成int64
g64_2.idtype
torch.int64
g32_2 = g64.int() # 转换成int32
g32_2.idtype
torch.int32
DGLGraph
对象的节点和边可具有多个用户定义的、可命名的特征,以储存图的节点和边的属性。
通过 ndata
和 edata
接口可访问这些特征。
例如,以下代码创建了2个节点特征(分别在第8、15行命名为 'x'
、 'y'
)和1个边特征(在第9行命名为 'x'
)。
import dgl
import torch as th
g = dgl.graph((th.tensor([0,0,1,5]), th.tensor([1,2,2,0]))) # 6个节点,四条边
# g = dgl.graph(([0, 0, 1, 5], [1, 2, 2, 0]))
g
Graph(num_nodes=6, num_edges=4,
ndata_schemes={}
edata_schemes={})
g.ndata['x'] = th.ones(g.num_nodes(), 3) # 长度为3的节点特征
g.edata['x'] = th.ones(g.num_edges(), dtype=th.int32) # # 标量整型特征
g
Graph(num_nodes=6, num_edges=4,
ndata_schemes={'x': Scheme(shape=(3,), dtype=torch.float32)}
edata_schemes={'x': Scheme(shape=(), dtype=torch.int32)})
# 不同名称的特征可以具有不同形状
g.ndata['y'] = th.randn(g.num_nodes(), 5) # x, y两种特征
g.ndata['x'][1] # 获取节点1的特征
tensor([1., 1., 1.])
g.edata['x'][th.tensor([0, 3])] # 获取边0和3的特征
tensor([1, 1], dtype=torch.int32)
关于 ndata
和 edata
接口的重要说明:
'x'
)