您的位置:首页 > 房产 > 家装 > 最新任免名单最新_高清直播_西安网站搭建_网站搜什么关键词

最新任免名单最新_高清直播_西安网站搭建_网站搜什么关键词

2025/1/10 13:36:48 来源:https://blog.csdn.net/LOVEmy134611/article/details/144656817  浏览:    关键词:最新任免名单最新_高清直播_西安网站搭建_网站搜什么关键词
最新任免名单最新_高清直播_西安网站搭建_网站搜什么关键词

图神经网络实战(22)——基于Captum解释图神经网络

    • 0. 前言
    • 1. Captum 库
    • 2. 积分梯度法
    • 3. 实现集成梯度
    • 小结
    • 系列链接

0. 前言

我们已经学习了在图神经网络 (Graph Neural Networks, GNN) 上实现 GNNExplainer 模型,使用 MUTAG 数据集的对图分类预测结果进行解释。在本节中,我们将首先介绍 Captum,以及另一类应用于图数据的解释技术——基于梯度 (Gradient-based) 的方法,积分梯度 (integrated gradients)。然后,在 Twitch 社交网络数据集上使用 PyTorch Geometric 实现该技术。

1. Captum 库

Captum 是一个 Python 库,实现了许多适用于 PyTorch 模型的最先进的解释算法。该库并非专用于图神经网络 (Graph Neural Networks, GNN) ,它也可应用于文本、图像、表格数据等。它允许用户快速测试各种技术,并比较同一预测的不同解释。此外,Captum 还实现了 LIMEGradient SHAP 等流行算法,用于对模型的输入、层和神经元进行归因。
可以在 shell 中使用 pip 命令安装 Captum 库:

pip install captum

2. 积分梯度法

在本节中,我们将使用 Captum 在图数据上实现积分梯度 (integrated gradients)。这项技术旨在为每个输入特征分配一个归因分数 (attribution score)。为此,它使用相对于模型输入的梯度。具体来说,它使用一个输入 x x x 和一个基准输入 x ′ x' x (在本节中,所有边的权重都为零)。它计算 x x x x ′ x' x 之间路径上所有点的梯度,并将其累加。
形式上,对于输入 x x x,在第 i i i 维上的积分梯度定义如下:
I n t e g r a t e d G r a d s i ( x ) : : = ( x i − x i ′ ) × ∫ α = 0 1 ∂ F ( x ′ + α × ( x − x ′ ) ) ∂ x i d α IntegratedGrads_i(x)::=(x_i-x'_i)\times\int _{\alpha=0}^1 \frac {\partial F(x'+\alpha\times(x-x'))}{\partial x_i}d\alpha IntegratedGradsi(x)::=(xixi)×α=01xiF(x+α×(xx))dα
实际上,我们并不直接计算这个积分,而是用离散和来近似计算。积分梯度与模型无关,基于以下两个公理:

  • 敏感性 (Sensitivity): 每一个对预测有贡献的输入都必须得到一个非零的归因
  • 实现不变性 (Implementation invariance): 对于所有输入,输出都相等的两个神经网络(这些网络被称为功能等效网络)必须具有相同的归因分析

在本节中,我们将考虑的是节点和边(而非特征)。因此,可以看到输出结果与 GNNExplainer 有所不同,后者考虑的是节点特征和边。因此这两种方法可以互补。接下来,我们实现积分梯度,并将结果可视化。

3. 实现集成梯度

本节中,我们将在 Twitch 社交网络数据集数据集上实现积分梯度。该数据集表示一个用户-用户图,其中节点对应 Twitch 主播,连接对应相互之间的朋友关系。128 个节点特征代表了主播习惯、位置、喜欢的游戏等信息。我们的目标是确定流媒体用户是否使用明确的语言(二元分类)。
使用 PyTorch Geometric 实现一个简单的双层图卷积网络 (Graph Convolutional Network, GCN) 来完成此任务。然后,把模型转换为 Captum,以使用积分梯度 (integrated gradients) 算法,并解释结果。

(1) 导入所需的库:

import numpy as np
import tqdm
import matplotlib.pyplot as plt
import torch.nn.functional as F
from captum.attr import IntegratedGradientsimport torch_geometric.transforms as T
from torch_geometric.datasets import Twitch
from torch_geometric.nn import Explainer, GCNConv, to_captum

(2) 加载 Twitch 社交网络数据集:

dataset = Twitch('.', name="EN")
data = dataset[0]

(3) 使用带有 Dropout 的简单的双层 GCN

class GCN(torch.nn.Module):def __init__(self, dim_h):super().__init__()self.conv1 = GCNConv(dataset.num_features, dim_h)self.conv2 = GCNConv(dim_h, dataset.num_classes)def forward(self, x, edge_index):h = self.conv1(x, edge_index).relu()h = F.dropout(h, p=0.5, training=self.training)h = self.conv2(h, edge_index)return F.log_softmax(h, dim=1)

(4) 使用 Adam 优化器在 GPU 上训练模型:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN(64).to(device)
data = data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)

(5) 使用负对数似然损失函数对模型进行 200epoch 训练:

for epoch in range(200):model.train()optimizer.zero_grad()log_logits = model(data.x, data.edge_index)loss = F.nll_loss(log_logits, data.y)loss.backward()optimizer.step()

(6) 测试训练后的模型。由于我们没有指定任何测试集,因此我们将评估 GCN 在训练集上的准确性:

def accuracy(pred_y, y):return ((pred_y == y).sum() / len(y)).item()@torch.no_grad()
def test(model, data):model.eval()out = model(data.x, data.edge_index)acc = accuracy(out.argmax(dim=1), data.y)return accacc = test(model, data)
print(f'Accuracy: {acc*100:.2f}%')# Accuracy: 72.83%

该模型的准确率为 72.83%,考虑到模型是在训练集上进行评估,准确率相对较低。

(7) 接下来,实现解释方法——积分梯度。首先,必须指定要解释的节点(本节中为节点 0),并将 PyTorch Geometric 模型转换为 Captum。在这里,我们还指定要使用特征和边掩码 (mask_type=node_and_feature):

node_idx = 0
captum_model = to_captum(model, mask_type='node_and_edge', output_idx=node_idx)

(8) 使用 Captum 创建积分梯度对象,将上一步的结果作为输入:

ig = IntegratedGradients(captum_model)

(9) 有了需要传递给 Captum 的节点掩码 (data.x) 后,还需要为边掩码创建一个张量。在本节中,我们要考虑图中的每一条边,因此要初始化一个大小为 data.num_edges 的全为 1 的张量:

edge_mask = torch.ones(data.num_edges, requires_grad=True, device=device)

(10) attribute() 方法使用特定格式的节点和边掩码输入(因此使用 unsqueeze(0) 来重新格式化这些张量)。目标对应于目标节点的类别。最后,将邻接矩阵 (data.edge_index) 作为额外的前向参数传递:

attr_node, attr_edge = ig.attribute((data.x.unsqueeze(0), edge_mask.unsqueeze(0)),target=int(data.y[node_idx]),additional_forward_args=(data.edge_index),internal_batch_size=1)

(11)01 之间对归因分数进行缩放:

attr_node = attr_node.squeeze(0).abs().sum(dim=1)
attr_node /= attr_node.max()
attr_edge = attr_edge.squeeze(0).abs()
attr_edge /= attr_edge.max()

(12) 使用 PyTorch GeometricExplainer 类,将这些归因的图表示可视化:

fig = plt.figure(dpi=200)
explainer = Explainer(model)
ax, G = explainer.visualize_subgraph(node_idx, data.edge_index, attr_edge, node_alpha=attr_node, y=data.y)
ax.axis('off')
plt.show()

可视化表示

节点 0 的子图由蓝色节点组成,这些节点共享同一个类别。可以看到,节点 82 是最重要的节点(除节点 0 以外),而这两个节点之间的连接是最关键的边。这是一个简单明了的解释:我们有四个使用相同语言的主播组成的群体。节点 082 之间的相互友好关系很好地证明了这一预测。

接下来,查看节点 101 分类的解释:

node_idx = 101
captum_model = to_captum(model, mask_type='node_and_edge', output_idx=node_idx)
ig = IntegratedGradients(captum_model)
edge_mask = torch.ones(data.num_edges, requires_grad=True, device=device)attr_node, attr_edge = ig.attribute((data.x.unsqueeze(0), edge_mask.unsqueeze(0)),target=int(data.y[node_idx]),additional_forward_args=(data.edge_index),internal_batch_size=1)attr_node = attr_node.squeeze(0).abs().sum(dim=1)
attr_node /= attr_node.max()
attr_edge = attr_edge.squeeze(0).abs()
attr_edge /= attr_edge.max()fig = plt.figure(dpi=200)
explainer = Explainer(model)
ax, G = explainer.visualize_subgraph(node_idx, data.edge_index, attr_edge, node_alpha=attr_node, y=data.y)
ax.axis('off')
plt.show()

结果可视化

在这种情况下,我们的目标节点与不同类别的邻居(节点 53982849)相连。积分梯度更加重视与节点 101 共享相同类别的节点。我们还可以看到它们之间的连接对于这个分类做出了最大的贡献。这个子图更加丰富;可以看到即使是两跳邻居也有一点贡献。
但,这些解释并非无往不利,人工智能中的可解释性是一个内容丰富的课题,往往涉及不同背景。了解边、节点和特征的重要性至关重要,利用领域专家知识可以利用或完善这些解释,甚至能够推动模型架构的改进。

小结

可解释性是许多深度学习领域的关键要素,可以帮助我们建立更好的模型,在本节中,我们介绍了积分梯度(基于梯度的方法)技术。使用 PyTorch GeometricCaptumTwitch 数据集上实现了此方法,以获得节点分类的解释,最后对结果进行了可视化和讨论。

系列链接

图神经网络实战(1)——图神经网络(Graph Neural Networks, GNN)基础
图神经网络实战(2)——图论基础
图神经网络实战(3)——基于DeepWalk创建节点表示
图神经网络实战(4)——基于Node2Vec改进嵌入质量
图神经网络实战(5)——常用图数据集
图神经网络实战(6)——使用PyTorch构建图神经网络
图神经网络实战(7)——图卷积网络(Graph Convolutional Network, GCN)详解与实现
图神经网络实战(8)——图注意力网络(Graph Attention Networks, GAT)
图神经网络实战(9)——GraphSAGE详解与实现
图神经网络实战(10)——归纳学习
图神经网络实战(11)——Weisfeiler-Leman测试
图神经网络实战(12)——图同构网络(Graph Isomorphism Network, GIN)
图神经网络实战(13)——经典链接预测算法
图神经网络实战(14)——基于节点嵌入预测链接
图神经网络实战(15)——SEAL链接预测算法
图神经网络实战(16)——经典图生成算法
图神经网络实战(17)——深度图生成模型
图神经网络实战(18)——消息传播神经网络
图神经网络实战(19)——异构图神经网络
图神经网络实战(20)——时空图神经网络
图神经网络实战(21)——图神经网络的可解释性

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com