图神经网络实战(21)——图神经网络的可解释性
- 0. 前言
- 1. 解释技术介绍
- 2. GNNExplainer 简介
- 3. 实现 GNNExplainer
- 小结
- 系列链接
0. 前言
一直以来,神经网络最受人诟病的缺陷之一就是其输出结果难以解释,图神经网络也难逃这一局限:除了要解释哪些特征是重要的,还需要考虑相邻节点和连接。针对这一问题,可解释性领域已经发展出许多技术,以更好地理解预测背后的原因或模型的一般行为。其中一些技术已经被应用到了图神经网络中,或利用图结构来提供更精确的解释。在本节中,我们将探讨一些解释技术,以了解模型做出特定预测的原因。我们将介绍不同类别的技术,并重点介绍 GNNExplainer
技术,并使用 MUTAG
数据集的图分类任务中应用 GNNExplainer
模型。
1. 解释技术介绍
图神经网络 (Graph Neural Networks, GNN) 解释技术在很大程度上受到其他可解释性人工智能 (Explainable Artificial Intelligence
, XAI
) 技术的启发,这些技术可以分为基于每个预测的局部解释和针对整个模型的全局解释。虽然理解整个 GNN
模型的技术是可行的,但我们将重点关注局部解释,因为局部解释对于深入了解预测至关重要。
在本节中,我们将区分 “可理解” (Interpretable
) 模型和 “可解释” (Explainable
) 模型。如果一个模型(如决策树)在设计上是人类可以理解的,那么它就被称为"可理解"模型;另一方面,如果模型是一个黑盒子,其预测结果只能通过解释技术来追溯理解,那么它就是"可解释"模型。可解释模型通常发生在神经网络上:它们的权重和偏置不像决策树那样提供明确的规则,但它们的结果可以间接解释。
局部解释技术主要分为四类:
- 基于梯度 (
Gradient-based
) 的方法:通过分析输出的梯度来估计归因分数,例如,积分梯度 (integrated gradients
) - 基于扰动的方法:通过屏蔽或修改输入特征来衡量输出的变化(例如,
GNNExplainer
) - 分解法:将模型的预测分解为若干项,以衡量其重要性(例如,
GNN-LRP
) - 代理法:使用一个简单、可解释的模型来近似原始模型对某一区域的预测(例如,
GraphLIME
)
这些技术是互补的:它们有时会对边和特征的贡献产生分歧,这可以用来进一步完善预测的解释。解释技术通常使用以下指标进行评估:
- 保真度 (
Fidelity
):比较原始图 G i G_i Gi 和修改后的图 G ^ i \hat G_i G^i 之间的预测概率 y i y_i yi。根据对 y ^ i \hat y_i y^i 的解释,修改后的图只保留 G ^ i \hat G_i G^i 中最重要的特征(节点、边、节点特征)。换句话说,保真度衡量的是被确定为重要的特征是否足以获得正确预测。其形式化定义如下:
F i d e l i t y = 1 N ∑ i = 1 N ( f ( G i ) y i − f ( G ^ i ) y i ) Fidelity=\frac 1N\sum_{i=1}^N(f(G_i)_{y_i}-f(\hat G_i)_{y_i}) Fidelity=N1i=1∑N(f(Gi)yi−f(G^i)yi) - 稀疏性 (
Sparsity
):衡量被认为重要的特征(节点、边、节点特征)的比例。过于冗长的解释更难理解,这也是鼓励稀疏性的原因。其计算方法如下:
S p a r s i t y = 1 N ∑ i = 1 N ∑ i = 1 N ( 1 − ∣ m i ∣ ∣ M i ∣ ) Sparsity=\frac 1N\sum_{i=1}^N\sum_{i=1}^N(1-\frac {|m_i|}{|M_i|}) Sparsity=N1i=1∑Ni=1∑N(1−∣Mi∣∣mi∣)
其中, ∣ m i ∣ |m_i| ∣mi∣ 是重要输入特征的数量, ∣ M i ∣ |M_i| ∣Mi∣ 是特征的总数。除了传统图数据外,解释技术还经常在合成数据集上进行评估,如BA-Shapes
、BA-Community
、Tree-Cycles
和Tree-Grid
等。这些数据集是用图生成算法生成的,用于创建特定的模式。接下来,我们将介绍基于扰动的技术 (GNNExplainer
)。
2. GNNExplainer 简介
在本节中,我们将介绍基于梯度的可解释性人工智能 (Explainable Artificial Intelligence
, XAI
) 技术 GNNExplainer
,并用它来解释图同构网络 (Graph Isomorphism Network, GIN) 模型在 MUTAG
数据集上输出的预测结果。
GNNExplainer
于 2019
年由 Ying
等人提出,是一种图神经网络 (Graph Neural Networks, GNN) 架构,旨在解释来自另一个 GNN
模型的预测。对于表格数据,我们想知道哪些特征对预测最重要。然而,对于图数据来说,这还不够,我们还需要知道哪些节点最有影响力。GNNExplainer
通过提供子图 G S G_S GS 和节点特征子集 X S X_S XS 来生成包含这两个部分的解释。下图说明了 GNNExplainer
为给定节点提供的解释:
为了预测 G S G_S GS 和 X S X_S XS,GNNExplainer
实现了一个边掩码(用于隐藏连接)和一个特征掩码(用于隐藏节点特征)。如果某条连接或某个特征很重要,那么删除它就会显著改变预测结果。反之,如果预测结果没有变化,则说明这些信息是多余的或根本不相关。这一原则是基于扰动的技术(如 GNNExplainer
)的核心。
在实践中,我们必须设计一个损失函数找到最佳掩码。GNNExplainer
衡量了预测标签分布 Y Y Y 与 ( G S , X S ) (G_S, X_S) (GS,XS) 之间的相互依赖性,也称为互信息 (mutual information
, MI
)。我们的目标是最大化 MI
,这等价于最小化条件交叉熵。训练 GNNExplainer
的目的是找到能使预测 Y Y Y 的概率最大化的变量 G S G_S GS 和 X S X_S XS。
除此优化框架外,GNNExplainer
还会学习一个二元特征掩码,并实现了几种正则化技术。其中,最重要的是一个用于最小化稀疏性的技术。它的计算方法是将掩码参数的所有元素之和添加到损失函数中。GNNExplainer
可以创建更友好、更简洁、更易于理解的解释。
GNNExplainer
可应用于大多数 GNN
架构和不同的任务,如单节点分类、链接预测或图分类,它还能生成类标签或整个图的解释。在对图进行分类时,该模型会考虑图中所有节点的邻接矩阵,而不是单个节点的邻接矩阵。下一节,我们将应用 GNNExplainer
来解释图分类。
3. 实现 GNNExplainer
在本节中,我们将使用 MUTAG
数据集。该数据集中的 188
个图各代表一种化合物,其中节点是原子(有 7
种可能的原子),边是化学键(有 4
种可能的化学键)。节点和边特征分别代表原子和边类型的独热编码。我们的目标是根据每种化合物对沙门氏菌的诱变作用将每个化合物分为两种类别。
我们将重用图同构网络 (Graph Isomorphism Network, GIN) 模型进行蛋白质分类,在 GIN
一节中,我们可视化了模型进行的正确和不正确的分类。但是,我们无法解释其所做的预测。在本节,我们将使用 GNNExplainer
来理解最重要的子图和节点特征,以解释分类。为了方便起见,我们将忽略边特征。
(1) 从 PyTorch
和 PyTorch Geometric
中导入所需的类:
import matplotlib.pyplot as pltimport torch
import torch.nn.functional as F
from torch.nn import Linear, Sequential, BatchNorm1d, ReLU, Dropoutfrom torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GINConv, global_add_pool, GNNExplainer
(2) 加载 MUTAG
数据集并将其打散:
dataset = TUDataset(root='data/TUDataset', name='MUTAG').shuffle()
(3) 创建训练集、验证集和测试集:
train_dataset = dataset[:int(len(dataset)*0.8)]
val_dataset = dataset[int(len(dataset)*0.8):int(len(dataset)*0.9)]
test_dataset = dataset[int(len(dataset)*0.9):]
(4) 创建数据加载器,实现小批量处理:
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=True)
(5) 创建一个有 32
个隐藏维度的 GIN
模型:
class GIN(torch.nn.Module):"""GIN"""def __init__(self, dim_h):super(GIN, self).__init__()self.conv1 = GINConv(Sequential(Linear(dataset.num_node_features, dim_h),BatchNorm1d(dim_h), ReLU(),Linear(dim_h, dim_h), ReLU()))self.conv2 = GINConv(Sequential(Linear(dim_h, dim_h), BatchNorm1d(dim_h), ReLU(),Linear(dim_h, dim_h), ReLU()))self.conv3 = GINConv(Sequential(Linear(dim_h, dim_h), BatchNorm1d(dim_h), ReLU(),Linear(dim_h, dim_h), ReLU()))self.lin1 = Linear(dim_h*3, dim_h*3)self.lin2 = Linear(dim_h*3, dataset.num_classes)def forward(self, x, edge_index, batch):h1 = self.conv1(x, edge_index)h2 = self.conv2(h1, edge_index)h3 = self.conv3(h2, edge_index)h1 = global_add_pool(h1, batch)h2 = global_add_pool(h2, batch)h3 = global_add_pool(h3, batch)h = torch.cat((h1, h2, h3), dim=1)h = self.lin1(h)h = h.relu()h = F.dropout(h, p=0.5, training=self.training)h = self.lin2(h)return F.log_softmax(h, dim=1)model = GIN(dim_h=32)
(6) 对模型进行 100
个 epoch
训练,并进行测试:
@torch.no_grad()
def test(model, loader):criterion = torch.nn.CrossEntropyLoss()model.eval()loss = 0acc = 0for data in loader:out = model(data.x, data.edge_index, data.batch)loss += criterion(out, data.y) / len(loader)acc += accuracy(out.argmax(dim=1), data.y) / len(loader)return loss, accdef accuracy(pred_y, y):return ((pred_y == y).sum() / len(y)).item()criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
epochs = 200model.train()
for epoch in range(epochs+1):total_loss = 0acc = 0val_loss = 0val_acc = 0# Train on batchesfor data in train_loader:optimizer.zero_grad()out = model(data.x, data.edge_index, data.batch)loss = criterion(out, data.y)total_loss += loss / len(train_loader)acc += accuracy(out.argmax(dim=1), data.y) / len(train_loader)loss.backward()optimizer.step()# Validationval_loss, val_acc = test(model, val_loader)# Print metrics every 20 epochsif(epoch % 20 == 0):print(f'Epoch {epoch:>3} | Train Loss: {total_loss:.2f} | Train Acc: {acc*100:>5.2f}% | Val Loss: {val_loss:.2f} | Val Acc: {val_acc*100:.2f}%')test_loss, test_acc = test(model, test_loader)
print(f'Test Loss: {test_loss:.2f} | Test Acc: {test_acc*100:.2f}%')
(7) GIN
模型经过训练后获得了较高的准确率 (84.21%
)。接下来,使用 PyTorch Geometric
中的 GNNExplainer
类创建一个 GNNExplainer
模型,对其进行 100
个 epoch
训练:
explainer = GNNExplainer(model, epochs=100, num_hops=1)
(8) GNNExplainer
可用于解释对某个节点 (.explain_node()
) 或整个图 (.explain_graph()
) 所做的预测。在本节中,我们将在测试集的最后一个图上使用 GNNExplainer
模型:
data = dataset[-1]
feature_mask, edge_mask = explainer.explain_graph(data.x, data.edge_index)
(9) 上一步返回了特征和边掩码。打印特征掩码以查看最重要的值:
print(feature_mask)Explain graph: 100%|██████████| 100/100 [00:00<00:00, 121.91it/s]
tensor([0.7777, 0.6492, 0.6702, 0.2613, 0.2655, 0.2748, 0.2574])
这些值在 0
(不太重要)和 1
(更重要)之间进行了归一化处理。这七个值对应于我们在数据集中发现的七个原子,依次为碳 (C
)、氮 (N
)、氧 (O
)、氟 (F
)、碘 (I
)、氯 (Cl
) 和溴 (Br
)。特征的重要性相似:最有用的是代表碳 (C
) 的第一个特征,而最不重要的是代表碘 (I
) 的第七个特征。
(10) 可以使用 visualize_graph()
方法在图上绘制边掩码,箭头的不透明度代表每个连接的重要性:
fig = plt.figure(dpi=200)
ax, G = explainer.visualize_subgraph(-1, data.edge_index, edge_mask, y=data.y)
ax.axis('off')
plt.show()
图中显示了对预测贡献最大的连接。在本节中,GIN
模型正确地对图进行了分类。我们可以看到,节点 6
、7
和 8
之间的连接是最相关的,高亮显示的连接对于该化合物的分类至关重要。我们可以通过打印 data.edge_attr
获得与化学键(芳香键、单键、双键或三键)相关的标签,从而进一步了解它们。在本节中,它对应于边 16
至 19
,它们都是单键或双键。
通过打印 data.x
,还可以查看节点 6
、7
和 8
以获取更多信息。节点 6
代表一个氮原子,而节点 7
和 8
则代表两个氧原子。
GNNExplainer
并没有提供有关决策过程的精确规则,但提供了有关 GNN 模型预测重点的见解,需要人类专业知识来确保这些见解是连贯一致的,并符合传统领域知识。
小结
在本节中,我们探讨了应用于图神经网络 (Graph Neural Networks
, GNN
) 的可解释性人工智能 (Explainable Artificial Intelligence
, XAI
) 技术。可解释性是许多领域的关键要素,可以帮助我们建立更好的模型。我们介绍了不同的局部解释技术,并重点讨论了 GNNExplainer
(基于扰动的方法),在图分类任务中应用 GNNExplainer
。
系列链接
图神经网络实战(1)——图神经网络(Graph Neural Networks, GNN)基础
图神经网络实战(2)——图论基础
图神经网络实战(3)——基于DeepWalk创建节点表示
图神经网络实战(4)——基于Node2Vec改进嵌入质量
图神经网络实战(5)——常用图数据集
图神经网络实战(6)——使用PyTorch构建图神经网络
图神经网络实战(7)——图卷积网络(Graph Convolutional Network, GCN)详解与实现
图神经网络实战(8)——图注意力网络(Graph Attention Networks, GAT)
图神经网络实战(9)——GraphSAGE详解与实现
图神经网络实战(10)——归纳学习
图神经网络实战(11)——Weisfeiler-Leman测试
图神经网络实战(12)——图同构网络(Graph Isomorphism Network, GIN)
图神经网络实战(13)——经典链接预测算法
图神经网络实战(14)——基于节点嵌入预测链接
图神经网络实战(15)——SEAL链接预测算法
图神经网络实战(16)——经典图生成算法
图神经网络实战(17)——深度图生成模型
图神经网络实战(18)——消息传播神经网络
图神经网络实战(19)——异构图神经网络
图神经网络实战(20)——时空图神经网络