其他分享
首页 > 其他分享> > DGL GAT

DGL GAT

作者:互联网

GAT
DGL document
DGL GAT
DGL官方教程GAT

# Case 1: Homogeneous graph
pip install dgl

import dgl
import numpy as np
import torch as th
from dgl.nn import GATConv

g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
# print(g)
# print(g.nodes())
# print(g.edges())
g = dgl.add_self_loop(g) 
# 添加自环tensor([0, 1, 2, 3, 2, 5, 0, 1, 2, 3, 4, 5]), tensor([1, 2, 3, 4, 0, 3, 0, 1, 2, 3, 4, 5]),避免删掉孤立节点
feat = th.ones(6, 10) # 6×10矩阵
'''
tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])
'''
gatconv = GATConv(10, 2, num_heads=3)
'''
GATConv(
  (fc): Linear(in_features=10, out_features=6, bias=False)
  (feat_drop): Dropout(p=0.0, inplace=False)
  (attn_drop): Dropout(p=0.0, inplace=False)
  (leaky_relu): LeakyReLU(negative_slope=0.2)
)
'''
res = gatconv(g, feat)
'''
g -graph
feat(torch.Tensor)
'''
res
'''
tensor([[[ 0.0171,  2.3988],
         [-0.1598,  0.1317],
         [ 0.3279, -1.9599]],

        [[ 0.0171,  2.3988],
         [-0.1598,  0.1317],
         [ 0.3279, -1.9599]],

        [[ 0.0171,  2.3988],
         [-0.1598,  0.1317],
         [ 0.3279, -1.9599]],

        [[ 0.0171,  2.3988],
         [-0.1598,  0.1317],
         [ 0.3279, -1.9599]],

        [[ 0.0171,  2.3988],
         [-0.1598,  0.1317],
         [ 0.3279, -1.9599]],

        [[ 0.0171,  2.3988],
         [-0.1598,  0.1317],
         [ 0.3279, -1.9599]]], grad_fn=<AddBackward0>)
'''
# Case 2: Unidirectional bipartite graph
pip install dgl

import dgl
import numpy as np
import torch as th
from dgl.nn import GATConv

u = [0, 1, 0, 0, 1]
v = [0, 1, 2, 3, 2]
g = dgl.bipartite((u, v))
u_feat = th.tensor(np.random.rand(2, 5).astype(np.float32))
v_feat = th.tensor(np.random.rand(4, 10).astype(np.float32))
gatconv = GATConv((5,10), 2, 3)
'''
GATConv(
  (fc_src): Linear(in_features=5, out_features=6, bias=False)
  (fc_dst): Linear(in_features=10, out_features=6, bias=False)
  (feat_drop): Dropout(p=0.0, inplace=False)
  (attn_drop): Dropout(p=0.0, inplace=False)
  (leaky_relu): LeakyReLU(negative_slope=0.2)
)
'''
res = gatconv(g, (u_feat, v_feat))
res

标签:10,False,features,GAT,dgl,DGL,import,feat
来源: https://blog.csdn.net/LoveKKarlie_/article/details/117922366