动手学PyG(三):PyG中的mini-batches
作者:互联网
PyG中的mini-batches
本文主要参考了 PyG英文文档
神经网络通常会采用分批的形式来训练。PyG通过创建稀疏块对角矩阵(由edge_index来定义)的形式来实现小批量图的并行化。而节点属性与训练目标则会在节点维度进行拼接。这种设计使得我们可以将不同规模的图放在同一个batch中。
A
=
[
A
1
⋱
A
n
]
,
X
=
[
X
1
⋮
X
n
]
Y
=
[
Y
1
⋮
Y
n
]
A=\begin{bmatrix} A_1 & & \\ & \ddots & \\ & & A_n \end{bmatrix},\quad X=\begin{bmatrix} X_1\\ \vdots \\ X_n\end{bmatrix} Y=\begin{bmatrix} Y_1\\ \vdots \\ Y_n\end{bmatrix}
A=⎣⎡A1⋱An⎦⎤,X=⎣⎢⎡X1⋮Xn⎦⎥⎤Y=⎣⎢⎡Y1⋮Yn⎦⎥⎤
PyG有其自己的torch_geometric.loader.DataLoader,这让我们可以很好的完成上述分批操作:
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES', use_node_attr=True)
loader = DataLoader(dataset, batch_size=32, shuffle=True)
for batch in loader:
batch
>>> DataBatch(batch=[1082], edge_index=[2, 4066], x=[1082, 21], y=[32])
batch.num_graphs
>>> 32
torch_geometric.data.Batch类继承自torch_geometric.data.Data,多出了一个batch属性。
batch属性为一个用于标记每一个来自于哪张图的列向量。
b
a
t
c
h
=
[
0
⋯
0
1
n
−
2
n
−
1
⋯
n
−
1
]
{\rm batch}=\begin{bmatrix}0&\cdots&0&1&n-2&n-1&\cdots&n-1\end{bmatrix}
batch=[0⋯01n−2n−1⋯n−1]
我们可以利用batch属性来对每张图独立的实现一些操作。
from torch_scatter import scatter_mean
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
dataset = TUDataset(root='/tmp/ENZYMES', name='ENZZYMES', use_node_attr=True)
loader = DataLoader(dataset batch_size=32, shuffle=True)
for data in loader:
data
>>> DataBatch(batch=[1082], edge_index=[2, 4066], x=[1082, 21], y=[32]
data.num_graphs
>>> 32
x = scatter_mean(data.x, data.batch, dim=0)
x.size()
>>> torch.Size([32,21])
更多关于PyG中的batch操作参考这里。torch-scatter的参考文档在这里。
标签:mini,batches,32,torch,batch,loader,bmatrix,geometric,PyG 来源: https://blog.csdn.net/ln_guangchen/article/details/122566984