其他分享
首页 > 其他分享> > 自监督-Iterative Graph Self-distillation

自监督-Iterative Graph Self-distillation

作者:互联网

自监督-Iterative Graph Self-distillation

标签:自监督、图神经、知识蒸馏、图学习、对比学习

动机

贡献

思想

核心

在对比学习的框架下,结合自蒸馏技术使得教师网络同时对学生网络进行训练

框架

对于一个图数据集合首先进行分批,对于三个原图 \(G_1、G_2、G_3\),利用扩散技术对原图进行增强得到 \(G_1'、G_2'、G_3'\),

都经过一个编码器 Encoder \(f_{\theta}\):

\[h = f_{\theta}(G)\\ h_v^k = \sigma(W^k·CONCAT(h_v^{k - 1}, AGGREGATE(\{h_u^{k - 1} \quad \forall u \in N(v) \}))) \]

通过编码器得到一个图的表示 \(h\) 后经过一个投影头 \(g_{\theta}\)(两层的MLP):

\[z = g_{\theta}(h)\\ g_{\theta}(h) = W^{(2)}\sigma(w^{(1)}h) \]

投影后得到 \(z\) ,对于学生网络我们还有一个预测器 \(h_{\theta}\):

\[h_{\theta} (z) = W^{(2)}\sigma(W^{(1)}z) \]

得到 \(z、h_{\theta}(z)\) ,我们在潜在空间使用 \(L_2\) 范式得到一个近似输入空间中的语义距离,并且一致性损失可以被定义为会议话预测之间的均方误差

\[L^{con}(G_i,G_j) = ||h_{\theta}(z_i) - z_j'||_2^2 + ||h_{\theta}(z_i') - z_j||_2^2 \]

由于一致性损失,教师网络提供了一个回归目标来训练学生网络,并且使用梯度更新了学生网络的权重之后,使用 EMA(exponential moving average 指数移动平均) 更新教师网络的权重:

\[{\theta}_t' \leftarrow \tau{\theta}_{t - 1}' + (1 - \tau){\theta}_t \]

数据增强

损失函数

在自监督学习中,为了对比 锚点(anchor)\(G_i\) 和其他负样例 \(G_j\) ,采用一下目标函数:

\[L^{self-sup} = -E_{G_i G} [\frac{exp(-L_{i,i}^{con})}{exp(-L_{i,i}^{con}) + \sum_{j = 1}^{N-1}I_{i≠j}exp(-L_{i,j}^{con})}] \]

在最后的图表示中,我们利用混合函数获得最后图的表示:

\[Mix_{\lambda}(a, b) = {\lambda} a + (1 - {\lambda})b \\ \hat{h} = Mix_{\lambda}(h, h')\\ h = f_{\theta}(G) ~~~~~~~~~~~~ h' = f_{\theta'}(G) \]

在半监督学习中,可以使用少量的标记数据来进一步概括相似性损失,以处理任意数量的属于同一类的正样本:

\[L^{supcon} = \sum_{i = 1}^{Kl}\frac{1}{KN_{y'}}\sum_{j=1}^{Kl}I_{i≠j}I_{y_i'≠y_j'}L^{con}(G_i, G_j) \]

最后半监督的损失函数:

\[L^{semi}=L(G_L,{\theta}) + wL^{self-sup}(G_L ∪G_U, \theta) + w'L^{supcon}(G_L,\theta) \]

实验

自监督学习中图分类任务的准确率

半监督下进行图分类任务的准确率

结论

在本文中,提出了IGSD,一个新的图级表示学习框架,通过自我蒸馏。我们的框架通过对图实例的增广视图进行实例判别来迭代执行师生蒸馏。

标签:Graph,Iterative,网络,学习,监督,IGSD,theta,distillation,con
来源: https://www.cnblogs.com/owoRuch/p/15587907.html