发布时间:2023-05-29 14:00
matlab2019a
图神经网络(Graph Neural Network,GNN)是一种基于信息扩散机制的神经网络模型。在GNN模型中,其包括一组处理单元,每个处理单元的表示图中的一个节点,各个节点之间具有一定的连通性。当这些节点相互交换信息的时候,系统将逐渐收敛并达到平衡状态。
GNN图神经网络的输出由每个节点本地计算的单元状态获得。为保证GNN各个节点中平衡点的唯一性,需要对其扩散机制进行约束。GNNs不同于传统的细胞神经网络,其可以用于处理更一般的图类,例如各类循环图、有向图和无向图,GNNs还可以处理以节点为中心的系统模型,而无需对系统做任何预处理。
在图神经网络中,图的各个节点将被认为是目标对象,每一个目标对象通过各自的特征信息来关联其他目标的特征。然后通过顶点包含的信息以及其邻域的信息,如图1所示。
GNN图神经网络的处理过程包括状态更新,学习过程两个环节,下面对其主要处理过程进行原理介绍。
在GNN中,全局状态变量X(t+1)的迭代更新方式可以表示为:
公式3中,对于任意个节点v其局部状态变量Xv(t+1)的迭代更新方式可以表示为:
在公式中,GNN其包含了一个编码网络,其结构如下图所示
图2中,局部变换函数fw和局部输出函数gw的基本结构如图3所示。
由图3可知,编码网络在结构上属于递归神经网络,其每一层网络结构对应一个时刻,并且包含编码网络所有单元的副本。不同时刻的两个层之间的网连接取决于编码网络的连接。神经元之间的连接可以分为内部连接和外部连接,内部连通性由用于实现该单元的神经网络结构决定,外部连接取决于处理图形的边缘。
GNN的训练过程主要为权值参数w的估计过程,根据公式3可知,GNN的学习任务可以通过如下LOSS函数来表示:
% Mutagenesis example
clc;
clear;
close all;
warning off;
rng(1);
addpath 'GNN_1.1.c-master\comparisonNet\'
addpath 'GNN_1.1.c-master\datasets\'
addpath 'GNN_1.1.c-master\experiments\'
addpath 'GNN_1.1.c-master\initialization\'
addpath 'GNN_1.1.c-master\isomorphism\'
addpath 'GNN_1.1.c-master\MLP\'
addpath 'GNN_1.1.c-master\neuralNetworks\'
addpath 'GNN_1.1.c-master\private\'
addpath 'GNN_1.1.c-master\systemModels\'
addpath 'GNN_1.1.c-master\utils\'
addpath 'GNN_1.1.c-master\database\'
addpath 'GNN_1.1.c-master\'
startSession
% Create a 10-fold cross validation data set
makeMutagenicDataset
global multidata
% Train the GNN by only 1 data set
dataSet = multidata(1);
dataSet.trainSet
LEN=0.1;
Configure('GNN3.config')
learn
LEN=0.2;
Configure('GNN3.config')
learn
LEN=0.3;
Configure('GNN3.config')
learn
LEN=0.4;
Configure('GNN3.config')
learn
LEN=0.5;
Configure('GNN3.config')
learn
LEN=0.6;
Configure('GNN3.config')
learn
LEN=0.7;
Configure('GNN3.config')
learn
LEN=0.8;
Configure('GNN3.config')
learn
LEN=0.9;
Configure('GNN3.config')
learn
LEN=10;
Configure('GNN3.config')
learn
plotTrainingResults;
% Test
% test
close all;
t1=learning.history.forwardItHistory;
t2=learning.history.backwardItHistory;
KK=32;
for i = 1:length(t1);
if i<=KK
t1b(i)=mean(t1(1:i));
t2b(i)=mean(t2(1:i));
else
t1b(i)=mean(t1(i-KK:i));
t2b(i)=mean(t2(i-KK:i));
end
end
figure;
plot(t1b,'b--');
hold on
plot(t2b,'r');
hold off
legend('Forward iterations', 'Backward iterations');
xlabel('训练次数');
ylabel('训练误差%');
figure;
plot([1:size(learning.history.trainErrorHistory,2)],learning.history.trainErrorHistory/10,'b');
hold on
t=[learning.config.stepsForValidation:learning.config.stepsForValidation:...
learning.config.stepsForValidation*(size(learning.history.validationErrorHistory,2))];
t(end)=learning.current.nSteps-1;
% xlim([0,250]);
xlabel('训练次数');
ylabel('训练误差%');
load data.mat%调用10次训练的结果
figure;
plot(X,ERR,'b-o');
xlabel('训练样本数量');
ylabel('训练误差%');
axis([1000,10000,0,40]);
grid on
其中GNN工具箱内部代码列表如下:
[1]Gnana J B , Rani S M M . Graph Neural Network for Minimum Dominating Set[J]. International Journal of Com- puter Applications, 2012, 56(1):12-16.A05-87