您的位置:首页 > 游戏 > 手游 > 图神经网络简单理解 — — 附带案例

图神经网络简单理解 — — 附带案例

2024/12/23 7:52:26 来源:https://blog.csdn.net/qq_73910510/article/details/141213773  浏览:    关键词:图神经网络简单理解 — — 附带案例

图神经网络

图神经网络(Graph Neural Network, GNN)是一种深度学习模型,专门用于处理图结构数据。它能够捕捉节点的邻域结构信息,广泛应用于各种领域,如社交网络分析、生物信息学、推荐系统等。

GNN的核心思想是通过节点间的连接关系来传递和更新信息。基本的GNN模型包括递归图神经网络(Recursive GNN, RecGNN)和卷积图神经网络(Convolutional GNN, ConvGNN)。RecGNN通过递归地应用相同的参数集来提取节点的高级表示,而ConvGNN则通过图卷积层来学习节点特征的高级表示。

GNN可以针对不同类型的任务进行输出,包括节点级输出(如节点分类和回归任务)、边级输出(如边分类和链接预测任务)以及图级输出(如图分类任务)。在训练GNN时,可以根据任务和数据标签的可用性采用不同的训练策略,包括半监督学习、监督学习和无监督学习。

GNN

流程

  1. 聚合
  2. 更新
  3. 循环

在这里插入图片描述

假设每个节点的特征为: F A , F B , F C , F D , F E F_A, F_B, F_C, F_D, F_E FA,FB,FC,FD,FE

1. 聚合

节点A的邻居信息: N = W 1 F B + W 2 F C + W 3 F D N = W_1F_B + W_2F_C + W_3F_D N=W1FB+W2FC+W3FD,其中 W 1 , W 2 , W 3 W_1, W_2, W_3 W1,W2,W3表示可训练参数,这些参数也可以手动设置。

节点A的信息: σ ( W ⋅ F A + α ⋅ N ) \sigma(W \cdot F_A + \alpha \cdot N) σ(WFA+αN),其中 σ \sigma σ是激活函数, W W W是可训练参数

同理,节点B、C、D、E以同样的方式聚合信息。

  • 第一次聚合:节点A可以包含节点A、B、C、D的信息。
  • 第二次聚合:节点A可以包含节点A、B、C、D、E的信息,因为节点C会在第一次聚合后包含节点E的信息
2. 更新、循环

多次聚合、得到Loss,更新参数、循环训练。

GNN是一个特征提取的方法,GNN的层数越多,其感受野越大,每个节点考虑其他点的信息越多,考虑越全面。

GCN

GCN(Graph Convolutional Network)是GNN的一种特定形式,专注于使用卷积操作来处理图数据。GCN通过在图上定义滤波器来学习节点表示,这些滤波器可以捕捉局部图结构。GCN的关键优势是其参数共享机制和对图结构的局部感知能力,这使得GCN在处理大规模图数据时非常高效 。

图中常见的任务:

  1. 节点分类(Node Classification)
    • 任务目标:预测图中每个节点的类别标签。
    • 应用场景:社交网络中用户分类、生物信息学中基因功能预测等。
  2. 图分类(Graph Classification)
    • 任务目标:对整个图进行分类,例如区分不同的化学分子结构。
    • 应用场景:分子结构分析、社交网络中的社区检测等。
  3. 链接预测(Link Prediction)
    • 任务目标:预测图中节点间是否存在链接或边。
    • 应用场景:推荐系统中的好友推荐、社交网络中的潜在联系发现等。
  4. 图聚类(Graph Clustering)/社区检测(Community Detection)
    • 任务目标:将图中的节点分组,使得同一组内的节点相互之间的连接比其他组的节点更紧密。
    • 应用场景:社交网络分析、生物网络中的模块识别等。
  5. 图嵌入(Graph Embedding)
    • 任务目标:将图的节点或整个图映射到低维向量空间,以便于进行可视化或作为其他机器学习任务的输入。
    • 应用场景:节点相似性度量、图的可视化等。
  6. 图生成(Graph Generation)
    • 任务目标:生成新的图结构或扩展现有图。
    • 应用场景:分子结构生成、社交网络演化模拟等。
  7. 图编辑或图重构(Graph Editing/Reconstruction)
    • 任务目标:对现有图进行修改或重构,以改善其某些性质或适应特定的需求。
    • 应用场景:网络结构优化、社交网络信息更新等。
  8. 异常检测(Anomaly Detection)
    • 任务目标:在图中识别异常或不正常的模式。
    • 应用场景:网络安全中的入侵检测、社交网络中的欺诈行为识别等。
  9. 信息传播(Information Diffusion)
    • 任务目标:模拟和预测信息在图中的传播过程。
    • 应用场景:疾病传播模型、社交媒体中的消息扩散等。
  10. 知识图谱补全(Knowledge Graph Completion)
    • 任务目标:预测知识图谱中缺失的实体间的关系。
    • 应用场景:知识库的自动填充、推荐系统等。
  11. 图匹配(Graph Matching)
    • 任务目标:在两个或多个图之间找到节点和边的最佳对应关系。
    • 应用场景:模式识别、计算机视觉中的物体识别等。

GCN属于半监督学习(不需要每个节点都有标签都可以进行训练),计算Loss时,只需要考虑有标签的节点即可。

GCN的基本思想:

  • 邻接节点特征聚合:每个节点会将其邻居节点的特征聚合到自己身上,这个过程可以通过求和、均值等方式完成。
  • 图的卷积操作:聚合后的特征向量会通过一个线性变换和一个非线性激活函数(如ReLU)进行处理,以增加网络的非线性表达能力。
  • 特征更新:最后,每个节点都会用其新的特征向量更新自己,这个过程会反复进行,直到网络的特征表示达到稳定状态。

图卷积也可以做多层,但一般不建议做太深层,通常2 ~ 5层即可。

实验表明:GCN中,深层的网络结果往往不会带来更好的效果(直观解释:我认识我的朋友,那么我朋友的朋友的朋友就很可能不是我的朋友,也就是说我与他是没有联系的)

公式推导:

符号说明

  • G:表示图。

    在这里插入图片描述

  • A:邻接矩阵,表示各个节点的连接关系,维度:[节点个数, 节点个数]

    在这里插入图片描述

  • D:各个节点的度,表各个节点有几条边,是一个对角矩阵,维度:[节点个数,节点个数]

    在这里插入图片描述

  • X:每个节点的特征,维度[节点个数,特征维度]

    在这里插入图片描述

由于每个节点与自己也是有相关性的,因此对邻接矩阵进行如下变换:
A ~ = A + λ I N \widetilde{A} = A + \lambda I_N A =A+λIN
这相当于给每个节点加一条自连接的边,即自己与自己也相邻,此时 D ~ = D \widetilde{D} = D D =D,即度数在原来的基础上 + 1 + 1 +1

在这里插入图片描述

在这里插入图片描述

由于节点的度越大,做矩阵乘法后值越大,相当于一个人认识的人越多,其特征值越大,这样不好。令 D ~ = D ~ − 1 \widetilde{D} = \widetilde{D}^{-1} D =D 1,即对度数矩阵求倒数,此时每个值变为 1 度数 \frac{1}{度数} 度数1。此时公式为: D ~ − 1 ( A ~ X ) \widetilde{D}^{-1}(\widetilde{A}X) D 1(A X),左乘 D ~ − 1 \widetilde{D}^{-1} D 1相当于对行做了归一化操作(可以仔细思考这个过程)。

由于 D ~ − 1 \widetilde{D}^{-1} D 1只有对角线上有值,因此在做矩阵的乘法的时候,只会对第一行相乘,且值为: 1 度数 \frac{1}{度数} 度数1,因此相乘之后的值每一行和为1,相当于进行以此行归一化操作。

同理也需要对列进行归一化,公式如下:
D ~ − 1 ( A ~ X ) D ~ − 1 = D ~ − 1 A ~ D ~ − 1 X \widetilde{D}^{-1}(\widetilde{A}X)\widetilde{D}^{-1} = \widetilde{D}^{-1}\widetilde{A}\widetilde{D}^{-1}X D 1(A X)D 1=D 1A D 1X
由于同一位置进行两次归一化,为了抵消其中的一次,我们将 D ~ − 1 \widetilde{D}^{-1} D 1需转换为 D ~ − 1 2 \widetilde{D}^{-\frac{1}{2}} D 21。可得公式:

D ~ 1 2 A ~ D ~ − 1 2 X \widetilde{D}^{\frac{1}{2}}\widetilde{A}\widetilde{D}^{-\frac{1}{2}}X D 21A D 21X

归一化解释:

假如BA很重要,但对B来讲,A并不是特别重要,左乘 D ~ − 1 2 \widetilde{D}^{-\frac{1}{2}} D 21,考虑了BA的重要性,右乘 D ~ − 1 2 \widetilde{D}^{-\frac{1}{2}} D 21则又考虑了AB的重要性,因为B的度数较大,取倒数开根号后,值就很小,再进行矩阵的乘法的时候就平衡了这一重要性。

前向传播公式:

A ^ = D ~ − 1 2 A ~ D ~ − 1 2 , Z = f ( X , A ^ ) = s o f t m a x ( A ^ R e l u ( A ^ X W ( 0 ) ) W ( 1 ) ) ) \widehat{A} = \widetilde{D}^{-\frac{1}{2}}\widetilde{A}\widetilde{D}^{-\frac{1}{2}}, \\ Z = f(X, \widehat{A}) = softmax(\widehat{A} Relu(\widehat{A}XW^{(0)})W^{(1)})) A =D 21A D 21,Z=f(X,A )=softmax(A Relu(A XW(0))W(1)))
其中: W ( 0 ) W^{(0)} W(0)的维度为:[特征维度, 自定义中间层维度] W ( 1 ) W^{(1)} W(1)的维度为:[自定义中间层维度,分类维度]

代码实验

数据介绍

Cora 数据集是一个广泛使用的图结构数据集,特别适用于图神经网络(GNN)的研究和应用。它由2708篇科学出版物组成,这些出版物之间通过5429条引用边相互关联,并且被分为7个类别。每个出版物都由一个1433维的二进制词向量表示,这些词向量反映了论文中是否包含词典中的特定单词,词典共有1433个独特单词。

文件介绍:

  • ind.cora.x:包含140个节点的特征向量,用于训练。
  • ind.cora.tx:包含1000个节点的特征向量,用作测试集。
  • ind.cora.allx:包含除测试节点外的所有节点特征,总共1708个节点。
  • ind.cora.yind.cora.ty:分别表示训练和测试节点的one-hot编码标签。
  • ind.cora.ally:表示ind.cora.allx对应节点的one-hot编码标签。
  • ind.cora.graph:记录节点间连接信息。
  • ind.cora.test.index:记录测试集节点的索引

Cora 数据集的节点数在训练集和测试集中的分布可能不是总和为全部节点数,其中训练集通常只使用140个节点,而测试集使用1000个节点 11。剩余的节点被包含在ind.cora.allx文件中。在训练过程中,尽管只有训练集节点用于更新梯度,但所有节点都会参与到训练过程中,作为特征的一部分。

数据集下载地址:kimiyoung/planetoid: Semi-supervised learning with graph embeddings (github.com)
数据下载之后,需要再建一个 row 文件夹,然后将下载下来的文件放入到 row 文件夹中,并在同级文件中创建processed空文件夹,然后进行导入即可,具体如下:

在这里插入图片描述

dataset = Planetoid(root=r'data/', name='Cora')

Cora 数据集一共有2708个样本点,每个样本点表示一篇科学论文,所有样本点被分为8个类别,每个样本点被编码为一个1433维度的词向量。

所有节点一共被分为七类,分别为:

  1. 基于案例
  2. 遗传算法
  3. 神经网络
  4. 概率方法
  5. 强化学习
  6. 规则学习
  7. 理论

每一篇论文都至少引用了一篇其他论文,或被其他论文引用。

数据内容如下

在这里插入图片描述

  • x:表示一共有2708个节点,每个节点的特征编码为1433个向量。

  • y:表示节点标签向量。

  • edge_index:表示边索引,是一个二维的列表,两个维度的值分别表示边的起点和终点。

  • train_mask:表示哪些数据是用于训练的,是一个bool类型的列表。

  • val_mask:表示哪些数据是用于验证的,是一个bool类型的列表。

  • test_mask:表示哪些数据是用于测试的,是一个bool类型的列表。

代码

对节点进行分类的代码

import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv
from torch_geometric.transforms import NormalizeFeatures#载入数据,  transform=NormalizeFeatures() 对节点进行归一化处理
dataset = Planetoid(root=r'data/', name='Cora')
data = dataset[0]
print(data)#定义网络架构
class Net(torch.nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = GCNConv(dataset.num_features, 16)  # 输入的通道是每个节点特征维度,16是中间隐藏神经元个数self.conv2 = GCNConv(16, dataset.num_classes)   # 最后输出的通道是节点分类的个数def forward(self, x, edge_index):x = self.conv1(x, edge_index)x = F.relu(x)x = self.conv2(x, edge_index)return F.log_softmax(x, dim=1)device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
data = data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)#模型训练
model.train()
for epoch in range(200):optimizer.zero_grad()out = model(data.x, data.edge_index)    # 模型的输入有节点特征还有边特征, 使用的是全部数据loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])   # 损失仅仅计算的是训练集的损失,out[data.train_mask]这样可直接筛选出训练数据,前提是数据类型是tensor的类型loss.backward()optimizer.step()#测试:
model.eval()
test_predict = model(data.x, data.edge_index)[data.test_mask]
max_index = torch.argmax(test_predict, dim=1)
test_true = data.y[data.test_mask]
correct = 0
for i in range(len(max_index)):if max_index[i] == test_true[i]:correct += 1
print('测试集准确率为:{}%'.format(correct*100/len(test_true)))

在这里插入图片描述

参考文章

  • 【图神经网络实战】深入浅出地学习图神经网络GNN(上)-CSDN博客
  • 图神经网络(GNN)最简单全面原理与代码实现_gnn基本原理-CSDN博客
  • 【数据集介绍】Cora数据集介绍-CSDN博客

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com