其他分享
首页 > 其他分享> > DGL官方教程--图卷积神经网络(GCN)

DGL官方教程--图卷积神经网络(GCN)

作者:互联网

Graph Convolutional Network

Author:
, Minjie Wang, Yu Gai, Quan Gan, Zheng Zhang
这是使用DGL实施图卷积网络的简要介绍(Kipf & Welling et al., Semi-Supervised Classification with Graph Convolutional Networks)。我们以DGLGraph上较早的教程为基础,并演示DGL如何将图与深度神经网络相结合并学习结构表示。

Model Overview

从消息传递的角度看GCN

我们从消息传递的角度描述了图卷积神经网络的一层; 数学可以在这里找到。 对于每个节点uuu,它归结为以下步骤:

1)汇总邻居的表示 hvh_{v}hv​ 产生中间表示 h^u\hat{h}_uh^u​。2)转换汇总表示h^u\hat{h}_{u}h^u​线性投影,然后非线性: hu=f(Wuh^u)h_{u} = f(W_{u} \hat{h}_u)hu​=f(Wu​h^u​)。
我们将通过DGL消息传递实现第1步,并通过apply_nodes方法实现第2步,该 方法的节点UDF将是PyTorch nn.Module

使用DGL的GCN实现

我们首先定义消息并像往常一样减少功能。由于聚合在一个节点上u 只涉及总结邻居的表象 hv,我们可以简单地使用内置函数:

import dgl
import dgl.function as fn
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph

gcn_msg = fn.copy_src(src='h', out='m')
gcn_reduce = fn.sum(msg='m', out='h')

然后,我们为定义节点UDF apply_nodes,它是一个完全连接的层:

class NodeApplyModule(nn.Module):
    def __init__(self, in_feats, out_feats, activation):
        super(NodeApplyModule, self).__init__()
        self.linear = nn.Linear(in_feats, out_feats)
        self.activation = activation

    def forward(self, node):
        h = self.linear(node.data['h'])
        if self.activation is not None:
            h = self.activation(h)
        return {'h' : h}

然后,我们继续定义GCN模块。GCN层本质上在所有节点上执行消息传递,然后应用NodeApplyModule。请注意,为简单起见,我们省略了本文中的缺失。

class GCN(nn.Module):
    def __init__(self, in_feats, out_feats, activation):
        super(GCN, self).__init__()
        self.apply_mod = NodeApplyModule(in_feats, out_feats, activation)

    def forward(self, g, feature):
        g.ndata['h'] = feature
        g.update_all(gcn_msg, gcn_reduce)
        g.apply_nodes(func=self.apply_mod)
        return g.ndata.pop('h')

前向功能与PyTorch中任何其他常见的NN模型相同。我们可以像一样初始化GCN nn.Module。例如,让我们定义一个由两个GCN层组成的简单神经网络。假设我们正在训练cora数据集的分类器(输入要素大小为1433,类别数为7)。最后一个GCN层计算节点嵌入,因此最后一个层通常不应用激活。

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.gcn1 = GCN(1433, 16, F.relu)
        self.gcn2 = GCN(16, 7, None)

    def forward(self, g, features):
        x = self.gcn1(g, features)
        x = self.gcn2(g, x)
        return x
net = Net()
print(net)

out:

Net(
  (gcn1): GCN(
    (apply_mod): NodeApplyModule(
      (linear): Linear(in_features=1433, out_features=16, bias=True)
    )
  )
  (gcn2): GCN(
    (apply_mod): NodeApplyModule(
      (linear): Linear(in_features=16, out_features=7, bias=True)
    )
  )
)

我们使用DGL的内置数据模块加载cora数据集。

from dgl.data import citation_graph as citegrh
import networkx as nx
def load_cora_data():
    data = citegrh.load_cora()
    features = th.FloatTensor(data.features)
    labels = th.LongTensor(data.labels)
    train_mask = th.BoolTensor(data.train_mask)
    test_mask = th.BoolTensor(data.test_mask)
    g = data.graph
    # add self loop
    g.remove_edges_from(nx.selfloop_edges(g))
    g = DGLGraph(g)
    g.add_edges(g.nodes(), g.nodes())
    return g, features, labels, train_mask, test_mask

训练模型后,我们可以使用以下方法评估模型在测试数据集上的性能:

def evaluate(model, g, features, labels, mask):
    model.eval()
    with th.no_grad():
        logits = model(g, features)
        logits = logits[mask]
        labels = labels[mask]
        _, indices = th.max(logits, dim=1)
        correct = th.sum(indices == labels)
        return correct.item() * 1.0 / len(labels)

然后,我们按照以下方式训练网络:

import time
import numpy as np
g, features, labels, train_mask, test_mask = load_cora_data()
optimizer = th.optim.Adam(net.parameters(), lr=1e-3)
dur = []
for epoch in range(50):
    if epoch >=3:
        t0 = time.time()

    net.train()
    logits = net(g, features)
    logp = F.log_softmax(logits, 1)
    loss = F.nll_loss(logp[train_mask], labels[train_mask])

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch >=3:
        dur.append(time.time() - t0)

    acc = evaluate(net, g, features, labels, test_mask)
    print("Epoch {:05d} | Loss {:.4f} | Test Acc {:.4f} | Time(s) {:.4f}".format(
            epoch, loss.item(), acc, np.mean(dur)))

out:

/home/ubuntu/.pyenv/versions/miniconda3-latest/lib/python3.7/site-packages/numpy/core/fromnumeric.py:3257: RuntimeWarning: Mean of empty slice.
  out=out, **kwargs)
/home/ubuntu/.pyenv/versions/miniconda3-latest/lib/python3.7/site-packages/numpy/core/_methods.py:161: RuntimeWarning: invalid value encountered in double_scalars
  ret = ret.dtype.type(ret / rcount)
Epoch 00000 | Loss 1.9444 | Test Acc 0.1470 | Time(s) nan
Epoch 00001 | Loss 1.9181 | Test Acc 0.1610 | Time(s) nan
Epoch 00002 | Loss 1.8911 | Test Acc 0.1900 | Time(s) nan
Epoch 00003 | Loss 1.8613 | Test Acc 0.2360 | Time(s) 0.0839
Epoch 00004 | Loss 1.8310 | Test Acc 0.2630 | Time(s) 0.0840
Epoch 00005 | Loss 1.8007 | Test Acc 0.2850 | Time(s) 0.0837
Epoch 00006 | Loss 1.7705 | Test Acc 0.2980 | Time(s) 0.0836
Epoch 00007 | Loss 1.7408 | Test Acc 0.3100 | Time(s) 0.0838
Epoch 00008 | Loss 1.7109 | Test Acc 0.3170 | Time(s) 0.0837
Epoch 00009 | Loss 1.6810 | Test Acc 0.3280 | Time(s) 0.0839
Epoch 00010 | Loss 1.6513 | Test Acc 0.3550 | Time(s) 0.0841
Epoch 00011 | Loss 1.6219 | Test Acc 0.3760 | Time(s) 0.0841
Epoch 00012 | Loss 1.5942 | Test Acc 0.3910 | Time(s) 0.0840
Epoch 00013 | Loss 1.5674 | Test Acc 0.4030 | Time(s) 0.0841
Epoch 00014 | Loss 1.5413 | Test Acc 0.4140 | Time(s) 0.0841
Epoch 00015 | Loss 1.5157 | Test Acc 0.4270 | Time(s) 0.0842
Epoch 00016 | Loss 1.4912 | Test Acc 0.4430 | Time(s) 0.0842
Epoch 00017 | Loss 1.4676 | Test Acc 0.4520 | Time(s) 0.0843
Epoch 00018 | Loss 1.4451 | Test Acc 0.4600 | Time(s) 0.0842
Epoch 00019 | Loss 1.4233 | Test Acc 0.4640 | Time(s) 0.0842
Epoch 00020 | Loss 1.4021 | Test Acc 0.4730 | Time(s) 0.0842
Epoch 00021 | Loss 1.3815 | Test Acc 0.4760 | Time(s) 0.0842
Epoch 00022 | Loss 1.3616 | Test Acc 0.4810 | Time(s) 0.0842
Epoch 00023 | Loss 1.3423 | Test Acc 0.4890 | Time(s) 0.0842
Epoch 00024 | Loss 1.3236 | Test Acc 0.5080 | Time(s) 0.0842
Epoch 00025 | Loss 1.3056 | Test Acc 0.5180 | Time(s) 0.0843
Epoch 00026 | Loss 1.2881 | Test Acc 0.5240 | Time(s) 0.0843
Epoch 00027 | Loss 1.2713 | Test Acc 0.5310 | Time(s) 0.0844
Epoch 00028 | Loss 1.2550 | Test Acc 0.5400 | Time(s) 0.0843
Epoch 00029 | Loss 1.2392 | Test Acc 0.5570 | Time(s) 0.0844
Epoch 00030 | Loss 1.2238 | Test Acc 0.5670 | Time(s) 0.0844
Epoch 00031 | Loss 1.2089 | Test Acc 0.5800 | Time(s) 0.0843
Epoch 00032 | Loss 1.1944 | Test Acc 0.5860 | Time(s) 0.0843
Epoch 00033 | Loss 1.1803 | Test Acc 0.5960 | Time(s) 0.0843
Epoch 00034 | Loss 1.1666 | Test Acc 0.6000 | Time(s) 0.0843
Epoch 00035 | Loss 1.1532 | Test Acc 0.6070 | Time(s) 0.0843
Epoch 00036 | Loss 1.1401 | Test Acc 0.6160 | Time(s) 0.0843
Epoch 00037 | Loss 1.1273 | Test Acc 0.6220 | Time(s) 0.0843
Epoch 00038 | Loss 1.1147 | Test Acc 0.6240 | Time(s) 0.0843
Epoch 00039 | Loss 1.1023 | Test Acc 0.6310 | Time(s) 0.0843
Epoch 00040 | Loss 1.0901 | Test Acc 0.6340 | Time(s) 0.0843
Epoch 00041 | Loss 1.0782 | Test Acc 0.6400 | Time(s) 0.0843
Epoch 00042 | Loss 1.0664 | Test Acc 0.6410 | Time(s) 0.0843
Epoch 00043 | Loss 1.0548 | Test Acc 0.6460 | Time(s) 0.0842
Epoch 00044 | Loss 1.0434 | Test Acc 0.6470 | Time(s) 0.0842
Epoch 00045 | Loss 1.0322 | Test Acc 0.6520 | Time(s) 0.0842
Epoch 00046 | Loss 1.0211 | Test Acc 0.6600 | Time(s) 0.0842
Epoch 00047 | Loss 1.0101 | Test Acc 0.6600 | Time(s) 0.0841
Epoch 00048 | Loss 0.9993 | Test Acc 0.6650 | Time(s) 0.0841
Epoch 00049 | Loss 0.9886 | Test Acc 0.6670 | Time(s) 0.0841

GCN in one formula

在数学上,GCN模型遵循以下公式:
H(l+1)=σ(D~12A~D~12H(l)W(l))H^{(l+1)} = \sigma(\tilde{D}^{-\frac{1}{2}}\tilde{A}\tilde{D}^{-\frac{1}{2}}H^{(l)}W^{(l)})H(l+1)=σ(D~−21​A~D~−21​H(l)W(l))

这里, H(l)H^{(l)}H(l) 表示lthl^{th}lth 网络中的层 σ\sigmaσ 是非线性的,并且 WWW 是该层的权重矩阵。 DDD 和 AAA如通常所见,分别代表度矩阵和邻接矩阵。〜是一种重新规范化的技巧,其中,我们向图的每个节点添加了自连接,并构建了相应的度数和邻接矩阵。输入的形状 H(0)H^{(0)}H(0)是 N×DN×DN×D,在哪里 NNN 是节点数,并且 DDD是输入要素的数量。我们可以将多层链接起来,以生成形状为:mathN`N‘N 乘以 FF`F‘的节点级表示输出,其中F 是输出节点特征向量的维。

可以使用稀疏矩阵乘法内核(例如Kipf的pygcn代码)有效地实现该方程 。实际上,由于使用内置函数,上述DGL实现实际上已经使用了该技巧。要了解其内幕,请阅读我们在PageRank上的教程。
脚本的总运行时间:(0分钟17.986秒)

下载源代码:1_gcn.py
下载源代码:1_gcn.ipynb

平湖片帆 发布了0 篇原创文章 · 获赞 1 · 访问量 36 私信 关注

标签:Acc,Loss,卷积,self,DGL,GCN,Epoch,Time,Test
来源: https://blog.csdn.net/weixin_45613751/article/details/104088428