其他分享
首页 > 其他分享> > Graph Convolution Network 理解与实现

Graph Convolution Network 理解与实现

作者:互联网

Graph Convolution Network 理解与实现

https://zhuanlan.zhihu.com/p/51990489

Graph Convolution作为Graph Networks的一个分支,可以说几乎所有的图结构网络都是大同小异,详见综述,而Graph Convolution Network又是Graph Networks中最简单的一个分支。理解了它便可以理解很多近年来的图结构网络,比如Scene Graph Generation中的Message Passing机制等。后续打算持续更新一些原始GCN的变体。

【相关文章和网站】:

  1. Paper: Semi-Supervised Classification with Graph Convolutional Networks, 2016
  2. Paper: Gated Graph Sequence Neural Networks, 2016
  3. Website: How powerful are Graph Convolutional Networks?
  4. Github: 关于Gated Graph Convolution Network的Pytorch实现 KaihuaTang/GGNN-for-bAbI-dataset.pytorch.1.0
  5. 其实Graph Convolution Network (GCN)可以看作Graph Networks的一个分支(只有Node feature,无Edge feature和global attribute),而Graph Networks则有一篇2018年的综述:Relational inductive biases, deep learning, and graph networks, 2018

【Graph Convolution Network和传统CNN的关系】

img

我们不妨把传统的CNN的输入图片\(I\)也定义为一个Graph,他包含一堆Pixel集合\({p_i}\)看作是Node, 而graph的边则是通过pixel的连通性定义的,所以每个pixel有至多8个edge和他相连。而Convolution其实就是把他的8个neighbour pixel的feature和他自己的feature乘以一个可学习的参数化kernel,来update这个pixel的feature.

那么由此,就不难理解GCN了。GCN主要的区别在于,他的node间的边,不是通过连通性定义的,而是需要给定了一个edge set,或者说graph的adjacent matrix。而且由于每个node可以有任意数量的neighbour node,所以update feature时,所有node其实是乘以了同一套参数。

【公式化】

这里我们参考Semi-Supervised Classification with Graph Convolutional Networks, 2016给出Graph Convolution的最终公式,忽略了原文的推导过程。

GCN可以定义为如下公式:

\[Z=GCN(X,A) \]

详细展开如下:

\[Z=\hat{D}^{-\frac{1}{2}}\hat{A}\hat{D}^{-\frac{1}{2}}X\Theta \]

【伪代码实现】

\[input : X, A, output : Z \]

\[Y = f_c(X) \]

\[Z = (A+I) * Y / (A\cdot sum(1)+1) \]

【Gated Graph Convolution Network】

但是上述Node特征更新的方式比较原始,Gated Graph Sequene Neural Networks, ICLR, 2016将Graph Convolution的X to Z的更新改成了GRU(LSTM)的形式。同时设计了一个Graph-Level的特征。下面实现参考了上文的思想,但做了些简化,比如原文将Incoming Edges和Outgoing Edges区分了这里我就沿用朴素Graph Convolution的A,不做拓展。

【Gated Graph Convolution Network 公式&伪代码】

\[input:X^t,output:X^{t+1}(即Z) \]

\[Y=A\ast f_c(X^t) \]

\[U=\sigma(W_1Y+W_2X^t) \]

\[R=\sigma(W_3Y+W_4X^t) \]

\[X^{t+1}_{tem}=tanh(W_5Y+W_6(R\cdot X^t)) \]

\[X^{t+1}=(1-U)\cdot X^t+U\cdot X^{t+1}_{tem} \]

【Graph-Level特征获取】

很多应用需要将一整个graph整合成一个特征,而原始的Graph Convolution则只能生成每个node的特征。graph-level的定义如下:

\[h_G=tanh(\sum_{nodes}\sigma([X^T,X^0]))\cdot tanh(f_{c_2}([X^T,X^0]) \]

当然,还有很多文章,采取更为简单的graph-level feature提取方法:

\[h_G=\sum_{nodes}X^T_i,or h_G=\frac{1}{Num(nodes)}\sum_{nodes}X^T_i \]

【Code】

关于Gated Graph Convolution Network的代码,可以参考以下Github项目 KaihuaTang/GGNN-for-bAbI-dataset.pytorch.1.0

标签:node,Convolution,Graph,feature,Networks,Network
来源: https://www.cnblogs.com/fusheng-rextimmy/p/15387340.html