图神经网络实战(22)——基于Captum解释图神经网络
- 0. 前言
- 1. Captum 库
- 2. 积分梯度法
- 3. 实现集成梯度
- 小结
- 系列链接
0. 前言
我们已经学习了在图神经网络 (Graph Neural Networks, GNN) 上实现 GNNExplainer 模型,使用 MUTAG
数据集的对图分类预测结果进行解释。在本节中,我们将首先介绍 Captum
,以及另一类应用于图数据的解释技术——基于梯度 (Gradient-based
) 的方法,积分梯度 (integrated gradients
)。然后,在 Twitch
社交网络数据集上使用 PyTorch Geometric
实现该技术。
1. Captum 库
Captum
是一个 Python
库,实现了许多适用于 PyTorch
模型的最先进的解释算法。该库并非专用于图神经网络 (Graph Neural Networks, GNN) ,它也可应用于文本、图像、表格数据等。它允许用户快速测试各种技术,并比较同一预测的不同解释。此外,Captum
还实现了 LIME
和 Gradient SHAP
等流行算法,用于对模型的输入、层和神经元进行归因。
可以在 shell
中使用 pip
命令安装 Captum
库:
pip install captum
2. 积分梯度法
在本节中,我们将使用 Captum
在图数据上实现积分梯度 (integrated gradients
)。这项技术旨在为每个输入特征分配一个归因分数 (attribution score
)。为此,它使用相对于模型输入的梯度。具体来说,它使用一个输入 x x x 和一个基准输入 x ′ x' x′ (在本节中,所有边的权重都为零)。它计算 x x x 和 x ′ x' x′ 之间路径上所有点的梯度,并将其累加。
形式上,对于输入 x x x,在第 i i i 维上的积分梯度定义如下:
I n t e g r a t e d G r a d s i ( x ) : : = ( x i − x i ′ ) × ∫ α = 0 1 ∂ F ( x ′ + α × ( x − x ′ ) ) ∂ x i d α IntegratedGrads_i(x)::=(x_i-x'_i)\times\int _{\alpha=0}^1 \frac {\partial F(x'+\alpha\times(x-x'))}{\partial x_i}d\alpha IntegratedGradsi(x)::=(xi−xi′)×∫α=01∂xi∂F(x′+α×(x−x′))dα
实际上,我们并不直接计算这个积分,而是用离散和来近似计算。积分梯度与模型无关,基于以下两个公理:
- 敏感性 (
Sensitivity
): 每一个对预测有贡献的输入都必须得到一个非零的归因 - 实现不变性 (
Implementation invariance
): 对于所有输入,输出都相等的两个神经网络(这些网络被称为功能等效网络)必须具有相同的归因分析
在本节中,我们将考虑的是节点和边(而非特征)。因此,可以看到输出结果与 GNNExplainer 有所不同,后者考虑的是节点特征和边。因此这两种方法可以互补。接下来,我们实现积分梯度,并将结果可视化。
3. 实现集成梯度
本节中,我们将在 Twitch
社交网络数据集数据集上实现积分梯度。该数据集表示一个用户-用户图,其中节点对应 Twitch
主播,连接对应相互之间的朋友关系。128
个节点特征代表了主播习惯、位置、喜欢的游戏等信息。我们的目标是确定流媒体用户是否使用明确的语言(二元分类)。
使用 PyTorch Geometric
实现一个简单的双层图卷积网络 (Graph Convolutional Network, GCN) 来完成此任务。然后,把模型转换为 Captum
,以使用积分梯度 (integrated gradients
) 算法,并解释结果。
(1) 导入所需的库:
import numpy as np
import tqdm
import matplotlib.pyplot as plt
import torch.nn.functional as F
from captum.attr import IntegratedGradientsimport torch_geometric.transforms as T
from torch_geometric.datasets import Twitch
from torch_geometric.nn import Explainer, GCNConv, to_captum
(2) 加载 Twitch
社交网络数据集:
dataset = Twitch('.', name="EN")
data = dataset[0]
(3) 使用带有 Dropout
的简单的双层 GCN
:
class GCN(torch.nn.Module):def __init__(self, dim_h):super().__init__()self.conv1 = GCNConv(dataset.num_features, dim_h)self.conv2 = GCNConv(dim_h, dataset.num_classes)def forward(self, x, edge_index):h = self.conv1(x, edge_index).relu()h = F.dropout(h, p=0.5, training=self.training)h = self.conv2(h, edge_index)return F.log_softmax(h, dim=1)
(4) 使用 Adam
优化器在 GPU
上训练模型:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN(64).to(device)
data = data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
(5) 使用负对数似然损失函数对模型进行 200
个 epoch
训练:
for epoch in range(200):model.train()optimizer.zero_grad()log_logits = model(data.x, data.edge_index)loss = F.nll_loss(log_logits, data.y)loss.backward()optimizer.step()
(6) 测试训练后的模型。由于我们没有指定任何测试集,因此我们将评估 GCN
在训练集上的准确性:
def accuracy(pred_y, y):return ((pred_y == y).sum() / len(y)).item()@torch.no_grad()
def test(model, data):model.eval()out = model(data.x, data.edge_index)acc = accuracy(out.argmax(dim=1), data.y)return accacc = test(model, data)
print(f'Accuracy: {acc*100:.2f}%')# Accuracy: 72.83%
该模型的准确率为 72.83%
,考虑到模型是在训练集上进行评估,准确率相对较低。
(7) 接下来,实现解释方法——积分梯度。首先,必须指定要解释的节点(本节中为节点 0
),并将 PyTorch Geometric
模型转换为 Captum
。在这里,我们还指定要使用特征和边掩码 (mask_type=node_and_feature
):
node_idx = 0
captum_model = to_captum(model, mask_type='node_and_edge', output_idx=node_idx)
(8) 使用 Captum
创建积分梯度对象,将上一步的结果作为输入:
ig = IntegratedGradients(captum_model)
(9) 有了需要传递给 Captum
的节点掩码 (data.x
) 后,还需要为边掩码创建一个张量。在本节中,我们要考虑图中的每一条边,因此要初始化一个大小为 data.num_edges
的全为 1
的张量:
edge_mask = torch.ones(data.num_edges, requires_grad=True, device=device)
(10) attribute()
方法使用特定格式的节点和边掩码输入(因此使用 unsqueeze(0)
来重新格式化这些张量)。目标对应于目标节点的类别。最后,将邻接矩阵 (data.edge_index
) 作为额外的前向参数传递:
attr_node, attr_edge = ig.attribute((data.x.unsqueeze(0), edge_mask.unsqueeze(0)),target=int(data.y[node_idx]),additional_forward_args=(data.edge_index),internal_batch_size=1)
(11) 在 0
和 1
之间对归因分数进行缩放:
attr_node = attr_node.squeeze(0).abs().sum(dim=1)
attr_node /= attr_node.max()
attr_edge = attr_edge.squeeze(0).abs()
attr_edge /= attr_edge.max()
(12) 使用 PyTorch Geometric
的 Explainer
类,将这些归因的图表示可视化:
fig = plt.figure(dpi=200)
explainer = Explainer(model)
ax, G = explainer.visualize_subgraph(node_idx, data.edge_index, attr_edge, node_alpha=attr_node, y=data.y)
ax.axis('off')
plt.show()
节点 0
的子图由蓝色节点组成,这些节点共享同一个类别。可以看到,节点 82
是最重要的节点(除节点 0
以外),而这两个节点之间的连接是最关键的边。这是一个简单明了的解释:我们有四个使用相同语言的主播组成的群体。节点 0
和 82
之间的相互友好关系很好地证明了这一预测。
接下来,查看节点 101
分类的解释:
node_idx = 101
captum_model = to_captum(model, mask_type='node_and_edge', output_idx=node_idx)
ig = IntegratedGradients(captum_model)
edge_mask = torch.ones(data.num_edges, requires_grad=True, device=device)attr_node, attr_edge = ig.attribute((data.x.unsqueeze(0), edge_mask.unsqueeze(0)),target=int(data.y[node_idx]),additional_forward_args=(data.edge_index),internal_batch_size=1)attr_node = attr_node.squeeze(0).abs().sum(dim=1)
attr_node /= attr_node.max()
attr_edge = attr_edge.squeeze(0).abs()
attr_edge /= attr_edge.max()fig = plt.figure(dpi=200)
explainer = Explainer(model)
ax, G = explainer.visualize_subgraph(node_idx, data.edge_index, attr_edge, node_alpha=attr_node, y=data.y)
ax.axis('off')
plt.show()
在这种情况下,我们的目标节点与不同类别的邻居(节点 5398
和 2849
)相连。积分梯度更加重视与节点 101
共享相同类别的节点。我们还可以看到它们之间的连接对于这个分类做出了最大的贡献。这个子图更加丰富;可以看到即使是两跳邻居也有一点贡献。
但,这些解释并非无往不利,人工智能中的可解释性是一个内容丰富的课题,往往涉及不同背景。了解边、节点和特征的重要性至关重要,利用领域专家知识可以利用或完善这些解释,甚至能够推动模型架构的改进。
小结
可解释性是许多深度学习领域的关键要素,可以帮助我们建立更好的模型,在本节中,我们介绍了积分梯度(基于梯度的方法)技术。使用 PyTorch Geometric
和 Captum
在 Twitch
数据集上实现了此方法,以获得节点分类的解释,最后对结果进行了可视化和讨论。
系列链接
图神经网络实战(1)——图神经网络(Graph Neural Networks, GNN)基础
图神经网络实战(2)——图论基础
图神经网络实战(3)——基于DeepWalk创建节点表示
图神经网络实战(4)——基于Node2Vec改进嵌入质量
图神经网络实战(5)——常用图数据集
图神经网络实战(6)——使用PyTorch构建图神经网络
图神经网络实战(7)——图卷积网络(Graph Convolutional Network, GCN)详解与实现
图神经网络实战(8)——图注意力网络(Graph Attention Networks, GAT)
图神经网络实战(9)——GraphSAGE详解与实现
图神经网络实战(10)——归纳学习
图神经网络实战(11)——Weisfeiler-Leman测试
图神经网络实战(12)——图同构网络(Graph Isomorphism Network, GIN)
图神经网络实战(13)——经典链接预测算法
图神经网络实战(14)——基于节点嵌入预测链接
图神经网络实战(15)——SEAL链接预测算法
图神经网络实战(16)——经典图生成算法
图神经网络实战(17)——深度图生成模型
图神经网络实战(18)——消息传播神经网络
图神经网络实战(19)——异构图神经网络
图神经网络实战(20)——时空图神经网络
图神经网络实战(21)——图神经网络的可解释性