其他分享
首页 > 其他分享> > 动手学PyG(三):PyG中的mini-batches

动手学PyG(三):PyG中的mini-batches

作者:互联网

PyG中的mini-batches


本文主要参考了 PyG英文文档

神经网络通常会采用分批的形式来训练。PyG通过创建稀疏块对角矩阵(由edge_index来定义)的形式来实现小批量图的并行化。而节点属性与训练目标则会在节点维度进行拼接。这种设计使得我们可以将不同规模的图放在同一个batch中。
A = [ A 1 ⋱ A n ] , X = [ X 1 ⋮ X n ] Y = [ Y 1 ⋮ Y n ] A=\begin{bmatrix} A_1 & & \\ & \ddots & \\ & & A_n \end{bmatrix},\quad X=\begin{bmatrix} X_1\\ \vdots \\ X_n\end{bmatrix} Y=\begin{bmatrix} Y_1\\ \vdots \\ Y_n\end{bmatrix} A=⎣⎡​A1​​⋱​An​​⎦⎤​,X=⎣⎢⎡​X1​⋮Xn​​⎦⎥⎤​Y=⎣⎢⎡​Y1​⋮Yn​​⎦⎥⎤​

PyG有其自己的torch_geometric.loader.DataLoader,这让我们可以很好的完成上述分批操作:

from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader

dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES', use_node_attr=True)
loader = DataLoader(dataset, batch_size=32, shuffle=True)

for batch in loader:
	batch
	>>> DataBatch(batch=[1082], edge_index=[2, 4066], x=[1082, 21], y=[32])

	batch.num_graphs
	>>> 32

torch_geometric.data.Batch类继承自torch_geometric.data.Data,多出了一个batch属性。
batch属性为一个用于标记每一个来自于哪张图的列向量。
b a t c h = [ 0 ⋯ 0 1 n − 2 n − 1 ⋯ n − 1 ] {\rm batch}=\begin{bmatrix}0&\cdots&0&1&n-2&n-1&\cdots&n-1\end{bmatrix} batch=[0​⋯​0​1​n−2​n−1​⋯​n−1​]

我们可以利用batch属性来对每张图独立的实现一些操作。

from torch_scatter import scatter_mean
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader

dataset = TUDataset(root='/tmp/ENZYMES', name='ENZZYMES', use_node_attr=True)
loader = DataLoader(dataset batch_size=32, shuffle=True)

for data in loader:
	data
	>>> DataBatch(batch=[1082], edge_index=[2, 4066], x=[1082, 21], y=[32]

	data.num_graphs
	>>> 32

	x = scatter_mean(data.x, data.batch, dim=0)
	x.size()
	>>> torch.Size([32,21])

更多关于PyG中的batch操作参考这里。torch-scatter的参考文档在这里

标签:mini,batches,32,torch,batch,loader,bmatrix,geometric,PyG
来源: https://blog.csdn.net/ln_guangchen/article/details/122566984