其他分享
首页 > 其他分享> > torch_geometric使用指南 (作个人纪录)

torch_geometric使用指南 (作个人纪录)

作者:互联网

建议用最新版本的torch_geometric,不同版本的API变动会比较大。

这个包最关键的一个类是MessagePassing
在这里插入图片描述

其他不做解释。

MessagePassing的forward参数是:
在这里插入图片描述
最重要的是这个edge_index,结合前面的flow参数,edge_index包含了你输入这个图的所有边的信息(start node、end node)。如图(黄色highlight的部分),输入的edge_index一般情况下是LongTensor,此时形状必须为[2, num_messages],第一维存放start node idx, 第二维存放end node idx。比方说:

## 假设图里面有三个节点,node index为 0,1,2
### 有向图:0->1, 1->2, 2->0 
edge_index = [[0,1,2]
			  [1,2,0]]
### 无向图,全连接
edge_index = [[0,1,2,1,2,0]
			  [1,2,0,0,1,2]]

torch_geometric目前实现了很多近几年的GCN变体 (GAT,RGCN,etc.), 都是继承自MessagePassing, 只要理解了这个MessagePassing和他的edge_index,这些变体都可以直接调包用就可以了。

参考:
torch_geometric: MessagePassing

标签:node,index,torch,MessagePassing,edge,geometric,使用指南
来源: https://blog.csdn.net/weixin_43301333/article/details/122016380