图神经网络GNN的可解释性问题与解释方法最新进展
作者:互联网
卷积神经网络(Convolutional Neural Network,CNN)和图神经网络(Graph Neural Network,GNN)的主要区别是什么?
简单来说,就是输入数据。
你可能还记得,CNN 所需的输入是一个固定大小的向量或矩阵。然而,某些类型的数据自然是用图表示的,如分子、引用网络或社交媒体连接网络都可以用图数据来表示。在过去 GNN 还没有普及的时候,图数据往往需要经过这样的转换,才可以直接提供给 CNN 作为输入。例如,分子结构仍被转换为固定大小的指纹,每个位表示是否存在某种分子子结构 [1] 。这是一种将分子数据插入 CNN 的好方法,但它会不会导致信息损失呢?
GNN 利用图数据,省去了数据预处理步骤,充分利用了数据中包含的信息。现在,有许多不同的 GNN 架构,其背后的理论也很快变得复杂起来。然而,GNN 可以分为两类:空间方法(Spatial approach)和谱方法(Spectral approach)。空间方法更为直观,因为它将池化和卷积操作(如 CNN)重新定义到图域。谱方法从一个稍有不同的角度来处理这个问题,因为它侧重于处理被定义为使用傅里叶变换(Fourier transform)的图网络的信号。
如果你想了解更多信息,可以阅读 Thomas Kipf 写的博文,这篇博文深入介绍了 GCN。另一篇值得阅读的有趣论文是《MoleculeNet:分子机器学习的基准》[2] ,它很好地介绍了 GNN,并描述了流行的 GNN 架构。
当前 GNN 可解释性问题
在撰写本文时,对 GNN 解释方法有贡献的论文屈指可数。尽管如此,这仍然是一个非常重要的主题,最近变得越来越流行。
GNN 的流行要比标准神经网络晚得多。虽然这一领域有很多有趣的研究,但还不是很成熟。GNN 的库和工具目前仍然处于“实验阶段”,我们现在真正需要的是,让更多的人使用它来发现 Bug 和错误,并转向生产就绪的模型。
创建生产就绪的模型的一种方法是更好地理解它们所做的预测。可以使用不同的解释方法来完成。我们已经看到了许多应用于 CNN 的有趣的可解释性方法,例如梯度归因(Gradient attribution)、显著性映射(Saliency maps)或类激活映射(Class activation mapping)等。那么,为什么不将它们重新用于 GNN 呢?
事实上,这就是目前正在发生的事情。最初用于 CNN 的解释性方法正在重新设计并应用于 GNN。虽然无需重新发明轮子,但我们仍然必须通过重新定义数学操作来调整这些方法,使它们适用于图数据。这里的主要问题是,关于这一主题的研究还相当新,第一篇论文发表于 2019 年。然而,随着时间的推移,它将会变得越来越流行,目前几乎还没有什么有趣的可解释性方法可用于 GNN 模型。
在本文中,我们将研究新的图解释方法,并了解如何将它们应用于 GNN 模型。
第一次尝试:可视化节点激活
2015 年,Duvenaud 等人 [3] 发表了关于 GNN 解释技术的开创性论文。该论文主要贡献是提出了一种新的 神经图指纹模型,同时也为这种结构创造了一种解释方法。该模型背后的主要思想是直接从图数据本身创建可区分的指纹。为实现这一点,作者不得不重新定义图的池化和平滑操作。这些操作然后被用来创建一个单层。
如图 4 所示,这些层叠加 n 次以产生向量输出。层深度也与相邻节点的半径相对应,从这些节点收集和汇集节点特征(在本例中为和函数)。这是因为,对于每一层,池化操作都从邻近节点收集信息,如图 1 所示。对于更深的层,池化操作的传播会延伸到更远的邻近节点。与正常的指纹相反,这种方法是可区分的,这使得反向传播可以以类似 CNN 的方式更新其权值。
除了 GNN 模型之外,他们还创建了一个简单的方法,可以显示节点激活及其邻近的节点。遗憾的是,这种方法在论文中没有很好地解释,因此有必要查看它们的 代码实现 以理解其底层机制。然而,他们在分子数据上运行它来预测溶解度,它突出了部分对溶解度预测有交稿预测能力的分子。
到底是怎么工作的呢?为计算节点激活量,我们需要进行以下计算。对于每个分子,让我们通过每一层向前传递数据,类似于在一个典型 CNN 网络中对图像的处理。然后,我们用 softmax()
函数提取每个指纹位各层的贡献。然后,我们就能够将一个节点(原子)与它周围对特定指纹位贡献最大的邻居(取决于层深)关联起来。
这种方法相对简单,但没有留下很好的文档记录。但该论文所做的初步工作是相当有前途的,研究人员随后进行了更为详细的尝试,将 CNN 的解释方法转化为图域。
如果你想进一步了解它,可以参阅他们的论文、代码仓库 或我在这个 Github 问题 底部对方法的详细解释。
卷积神经网络的重用方法
敏感度分析(Sensitive Analysis)、类激活映射 或 激发反向传播(Excitation Backpropagation)都是已经成功应用于 CNN 的解释技术的示例。目前针对可解释 GNN 的研究试图将这种方法转换成图域。这一领域的大部分工作都是在这些论文 [4][5] 中完成的。
本文不再着重于这些方法的数学解释,我将为你提供这些方法的直观解释,并简要讨论这些方法的结果。
归纳 CNN 的解释方法
为了重用 CNN 的解释方法,让我们将 CNN 输入的数据(即图像)视为一个格子形状的图。图 6 说明了这个想法。
如果我们牢记图像的图推广,我们可以说 CNN 解释方法不怎么关注边缘(像素之间的连接),而是关注节点(像素值)。这里的问题是,图数据中的边缘包含了很多有用的信息 [4][5] ,而不是以网格状的顺序。
哪些方法已经转换为图域?
到目前为止,文献已经讨论过以下解释方法:
敏感度分析(Sensitivity Analysis)[4]
引导反向传播 (Guided Backpropagation)[4]
层次相关传播 (Layer-wise Relevance Propagation)[4]
基于梯度的热图 (Gradient-based heatmaps)[5]
类激活映射(Class Activation Maps)[5]
梯度加权类激活映射(Gradient-weighted Class Activation Mapping (Grad-CAM))[5]
激发反向传播 (Excitation Backpropagation)[5]
请注意,其中两篇论文 [4][5] 的作者并未提供这些方法的开源实现,因此目前尚无法使用它们。
模型无关方法:GNNExplainer
本文的代码仓库可以在这里找到。
不用担心,实际上有一个 GNN 解释工具可以供你使用!
GNNExplainer 是一种模型无关的开源 GNN 解释方法。它也是一个相当通用的,因为它可以应用于节点分类、图分类和边预测。这是创建图特定解释方法的首次尝试,由斯坦福大学的研究人员提出 [6] 。
它是如何工作的?
作者声称,重用之前应用于 CNN 的解释方法是一种糟糕的方法,因为它们未能纳入关系信息,而关系信息是图数据的本质。此外,基于梯度的方法对于离散输入并不是特别有效,而 GNN 通常就是这种情况(例如,邻接矩阵 是二进制矩阵)。
为了克服这些问题,他们创建了一种与模型无关的方法,该方法可以找到输入数据的子图,这些子图以最重要的方式影响 GNN 的预测。说得更具体一些,子图的选择是为了最大化与模型预测的互信息(Mutual information)。下图显示了 GNNExplainer 如何处理由体育活动组成的图数据的一个示例。
作者提出的一个非常重要的假设是 GNN 模型的公式。模型的架构或多或少是任意的,但它需要实现 3 个关键计算:
相邻节点间的神经信息计算。
来自节点邻域的消息聚合。
聚合消息的非线性变换及节点表示。
这些所需的计算在某种程度上是有限制的,但大多数现代 GNN 架构无论如何,都是基于消息传递架构的。[6]
论文给出了 GNNExplainer 的数学定义和优化框架的描述。然而,我会省去一些细节,向你展示一些可以通过 GNNExplainer 获得的有趣结果。
GNNExplainer 在实践中的应用
为比较结果,他们使用了 3 种不同的解释方法:GNNExplainer,基于梯度的方法(Gradient-based method,GRAD)和 图注意力模型(Graph attention model,GAT)。在本文前面部分提到了 GRAD,但 GAT 模型需要一些解释。这是另一个 GNN 架构,它学习边的注意力权重,这将帮助我们确定图网络中哪些边对节点分类实际上很重要。这可以作为另一种解释技术使用,但它只适用于这个特定的模型,而且它不解释节点特性,这与 GNNExplainer 相反。
首先让我们看一下应用于合成数据集的解释方法的性能。图 8 显示了在两个不同的数据集上运行的实验。
BA-Shapes 数据集 基于 Barabasi-Albert(BA) 图,它是一种图网络,我们可以通过改变它的一些参数来自由调整大小。在这个基础图上,我们将附加一些小的房屋结构图(motifs),这些图在图 8 中显示为真相(Ground Truth)。此房屋结构的节点有 3 个不同的标签:顶部、中部和底部。这些节点标签只是表示节点在房子中的位置。因此,对于一个类房屋的节点,我们有 1 个顶部节点、2 个中部节点和 2 个底部节点。还有一个额外的标签,表示该节点不属于类房屋图结构。总体而言,我们有 4 个标签和一个 BA 图,其中包含 300 个节点和 80 个类房屋结构,这些结构被添加到 BA 图的随机节点中。我们还通过添加 0.1N 个随机边来增加一些随机性。
BA-Community 是两个 BA-Shapes 数据集的联合,因此总共有 8 个不同的标签(每个 BA-Shapes 数据集有 4 个标签)和两倍的节点。
让我们来看一看结果。
结果似乎很有希望。GNNExplainer 似乎以最准确的方式解释了结果,因为所选择的子图与真相相同。Grad 和 Att 方法未能提供类似的解释。
如图 9 所示,该实验也是在真实数据集上进行的。这一次的任务是对整个图网络进行分类,而不是对单个节点进行分类。
Mutag 是一个由分子组成的数据集,这些分子是根据对某类细菌的诱变效应进行分类的。数据集有很多不同的标签,它包含 4337 个分子图。
Reddit-Binary 是一个代表 Reddit 中在线讨论主题的数据集。在这个图网络中,用户用节点表示,边表示对另一个用户评论的响应。根据用户交互的类型,有两种可能的标签,它可以是问答互动,也可以是在线讨论互动。总的来说,它包含 2000 个图。
对于 Mutag 数据集,GNNExplainer 正确识别已知具有致突变性的化学基团(如 NO2、NH2)。GNNExplainer 还解释了归类为 Reddit-Binary 数据集在线讨论分类图。这种类型的交互通常可以用树状的模式表示(查看真相)。
标签:解释,GNN,解释性,GNNExplainer,神经网络,CNN,方法,节点 来源: https://blog.51cto.com/15060462/2675210