其他分享
首页 > 其他分享> > DataWhale图网络学习(七)超大规模数据集类的创建

DataWhale图网络学习(七)超大规模数据集类的创建

作者:互联网

超大规模数据集类的创建


前面介绍了可以将数据集全部加载到内存的数据集,然而在实际中,一般数据集都很大,无法全部加载到内存,因此我们需要了解 按需加载样本到内存的数据集类

1 Dataset基类

我们通过继承torch_geometric.data.Dataset基类来自定义一个按需加载样本到内存的数据集类,此基类与Torchvision的Dataset类的概念密切相关。

这里我们继承torch_geometric.data.InMemoryDataset基类,我们需要实现以下方法:

import os.path as osp
import torch
from torch_geometric.data import Dataset, download_url

class MyOwnDataset(Dataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super(MyOwnDataset, self).__init__(root, transform, pre_transform)

    @property
    def raw_file_names(self):
        return ['some_file_1', 'some_file_2', ...]

    @property
    def processed_file_names(self):
        return ['data_1.pt', 'data_2.pt', ...]

    def download(self):
        # Download to `self.raw_dir`.
        path = download_url(url, self.raw_dir)
        ...

    def process(self):
        i = 0
        for raw_path in self.raw_paths:
            # Read data from `raw_path`.
            data = Data(...)

            if self.pre_filter is not None and not self.pre_filter(data):
                continue

            if self.pre_transform is not None:
                data = self.pre_transform(data)

            torch.save(data, osp.join(self.processed_dir, 'data_{}.pt'.format(i)))
            i += 1

    def len(self):
        return len(self.processed_file_names)

    def get(self, idx):
        data = torch.load(osp.join(self.processed_dir, 'data_{}.pt'.format(idx)))
        return data

其中,每个Data对象在process()方法中单独被保存,并在get()中通过指定索引进行加载。当我们不需要download数据集,不需要对数据集进行process时,我们就可以跳过这些函数。

1.1 无需定义Dataset类情况

我们可以在不重新定义Dataset类的情况下直接Dataloader对象并用于训练:

from torch_geometric.data import Data, DataLoader

data_list = [Data(...), ..., Data(...)]
loader = DataLoader(data_list, batch_size=32)

这里是对每个Data对象进行batch操作,我我们也可以把整个data_list作为一个对象进行batch操作:

from torch_geometric.data import Data, Batch

data_list = [Data(...), ..., Data(...)]
loader = Batch.from_data_list(data_list, batch_size=32)

2 图样本封装成批(BATCHING)与DataLoader

2.1 合并小图组成大图

前边《基于图神经网络的图表示学习》里边介绍了将大图拆分成小图然后再组合的方法,这里我们也可以参考。
由于图是不规整的数据结构,它可以有任意数量的节点和边,因此对图数据封装成批的操作与对图像和序列等数据封装成批的操作不同。PyTorch Geometric中采用的将多个图封装成批的方式是,将小图作为连通组件(connected component)的形式合并,构建一个大图。于是小图的邻接矩阵存储在大图邻接矩阵的对角线上。大图的邻接矩阵、属性矩阵、预测目标矩阵分别为:
A = [ A 1 ⋱ A n ] , X = [ X 1 ⋮ X n ] , Y = [ Y 1 ⋮ Y n ] . \mathbf{A} = \begin{bmatrix} \mathbf{A}_1 & & \\ & \ddots & \\ & & \mathbf{A}_n \end{bmatrix}, \qquad \mathbf{X} = \begin{bmatrix} \mathbf{X}_1 \\ \vdots \\ \mathbf{X}_n \end{bmatrix}, \qquad \mathbf{Y} = \begin{bmatrix} \mathbf{Y}_1 \\ \vdots \\ \mathbf{Y}_n \end{bmatrix}. A=⎣⎡​A1​​⋱​An​​⎦⎤​,X=⎣⎢⎡​X1​⋮Xn​​⎦⎥⎤​,Y=⎣⎢⎡​Y1​⋮Yn​​⎦⎥⎤​.

此方法有以下关键的优势

通过torch_geometric.data.DataLoader类,多个小图被封装成一个大图。其为PyTorch的DataLoader的子类,它覆盖了collate()函数,该函数定义了一列表的样本是如何封装成批的。

2.2 小图的属性增值与拼接

将小图存储到大图中时需要对小图的属性做一些修改,一个最显著的例子就是要对节点序号增值。在最一般的形式中,PyTorch Geometric的DataLoader类会自动对edge_index张量增值,增加的值为当前被处理图的前面的图的累积节点数量。然而,有一些特殊的场景中(如下所述),基于需求我们希望能修改这一行为。PyTorch Geometric允许我们通过覆盖torch_geometric.data.__inc__()torch_geometric.data.__cat_dim__()函数来实现我们希望的行为。在未做修改的情况下,它们在Data类中的定义如下。

def __inc__(self, key, value):
    if 'index' in key or 'face' in key:
        return self.num_nodes
    else:
        return 0

def __cat_dim__(self, key, value):
    if 'index' in key or 'face' in key:
        return 1
    else:
        return 0

我们可以看到,__inc__()定义了两个连续的图的属性之间的增量大小,而__cat_dim__()定义了同一属性的图形张量应该在哪个维度上被连接起来。PyTorch Geometric为存储在Data类中的每个属性调用此二函数,并以它们各自的key和值value作为参数。

接下来,我们将学习一些对__inc__()__cat_dim__()的修改可能是绝对必要的案例。

2.3 图的匹配

当我们在一个Data对象中存储多个图,例如用于图匹配等应用,我们需要确保所有这些图的正确封装成批行为。

下面我们以两个图为例,,一个源图 G s G_s Gs​和一个目标图 G t G_t Gt​,存储在一个Data类中,用以进行匹配:

class PairData(Data):
    def __init__(self, edge_index_s, x_s, edge_index_t, x_t):
        super(PairData, self).__init__()
        self.edge_index_s = edge_index_s
        self.x_s = x_s
        self.edge_index_t = edge_index_t
        self.x_t = x_t

在这种情况中,edge_index_s应该根据源图 G s G_s Gs​的节点数做增值,即x_s.size(0),而edge_index_t应该根据目标图 G t G_t Gt​的节点数做增值,即x_t.size(0)。因此我们定义返回增量大小的函数__inc__()如下:

class PairData(Data):
    def __init__(self, edge_index_s, x_s, edge_index_t, x_t):
        super(PairData, self).__init__()
        self.edge_index_s = edge_index_s
        self.x_s = x_s
        self.edge_index_t = edge_index_t
        self.x_t = x_t

    def __inc__(self, key, value):
        if key == 'edge_index_s':
            return self.x_s.size(0)
        if key == 'edge_index_t':
            return self.x_t.size(0)
        else:
            return super().__inc__(key, value)

这里我们构建两个图 G s G_s Gs​和 G t G_t Gt​, G s G_s Gs​有5个节点, G t G_t Gt​有4个节点,每个节点有16个属性值:
在这里插入图片描述

edge_index_s = torch.tensor([
    [0, 0, 0, 0],
    [1, 2, 3, 4],
])
x_s = torch.randn(5, 16)  # 5 nodes.
edge_index_t = torch.tensor([
    [0, 0, 0],
    [1, 2, 3],
])
x_t = torch.randn(4, 16)  # 4 nodes.

然后通过定义的数据集类load数据:

data = PairData(edge_index_s, x_s, edge_index_t, x_t)
data_list = [data, data]
loader = DataLoader(data_list, batch_size=2)
batch = next(iter(loader))

print(batch)

print(batch.edge_index_s)

print(batch.edge_index_t)

Batch(edge_index_s=[2, 8], edge_index_t=[2, 6], x_s=[10, 16], x_t=[8, 16])

tensor([[0, 0, 0, 0, 5, 5, 5, 5],
[1, 2, 3, 4, 6, 7, 8, 9]])

tensor([[0, 0, 0, 4, 4, 4],
[1, 2, 3, 5, 6, 7]])

虽然 G s G_s Gs​和 G t G_t Gt​的节点数不同,但是也被正确的封装成批了。然而,由于PyTorch Geometric无法识别PairData对象中实际的图,所以batch属性(将大图每个节点映射到其各自对应的小图)没有正确工作。此时就需要DataLoaderfollow_batch参数发挥作用。在这里,我们可以指定我们要为哪些属性维护批信息。

loader = DataLoader(data_list, batch_size=2, follow_batch=['x_s', 'x_t'])
batch = next(iter(loader))

print(batch)

print(batch.x_s_batch)

print(batch.x_t_batch)

Batch(edge_index_s=[2, 8], edge_index_t=[2, 6], x_s=[10, 16], x_s_batch=[10], x_t=[8, 16], x_t_batch=[8])

tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])

tensor([0, 0, 0, 0, 1, 1, 1, 1])

可以看到,follow_batch=['x_s', 'x_t']现在成功地为节点特征x_sx_t分别创建了名为x_s_batchx_t_batch的赋值向量。这些信息现在可以用来在一个单一的Batch对象中对多个图进行聚合操作,例如,全局池化。

2.4 二部图

二部图也叫二分图,其邻接矩阵定义两种类型的节点之间的连接关系。一般来说,不同类型的节点数量不需要一致,于是二部图的邻接矩阵 A ∈ { 0 , 1 } N × M A \in \{0,1\}^{N \times M} A∈{0,1}N×M可能为平方矩阵,即可能有 N ≠ M N \neq M N​=M。对二部图的封装成批过程中,edge_index 中边的源节点与目标节点做的增值操作应是不同的。我们将二部图中两类节点的特征特征张量分别存储为x_sx_t。为了对二部图实现正确的封装成批,应该独立地为边的源节点和目标节点做增值操作。

class BipartiteData(Data):
    def __init__(self, edge_index, x_s, x_t):
        super(BipartiteData, self).__init__()
        self.edge_index = edge_index
        self.x_s = x_s
        self.x_t = x_t
	def __inc__(self, key, value):
	    if key == 'edge_index':
	        return torch.tensor([[self.x_s.size(0)], [self.x_t.size(0)]])
	    else:
	        return super().__inc__(key, value)

其中,edge_index[0](边的源节点)根据x_s.size(0)做增值运算,而edge_index[1](边的目标节点)根据x_t.size(0)做增值运算。下面创建一个二部图,x_s有2个节点,16维;x_t有3个节点16维。

edge_index = torch.tensor([
    [0, 0, 1, 1],
    [0, 1, 1, 2],
])
x_s = torch.randn(2, 16)  # 2 nodes.
x_t = torch.randn(3, 16)  # 3 nodes.

load数据并查看batch:

data = BipartiteData(edge_index, x_s, x_t)
data_list = [data, data]
loader = DataLoader(data_list, batch_size=2)
batch = next(iter(loader))

print(batch)

print(batch.edge_index)

Batch(batch=[6], edge_index=[2, 8], ptr=[3], x_s=[4, 16], x_t=[6, 16])
tensor([[0, 0, 1, 1, 2, 2, 3, 3],
[0, 1, 1, 2, 3, 4, 4, 5]])

得到了我们期望的结果

2.5 在新的维度上做拼接

有时,Data对象的属性需要在一个新的维度上做拼接(如经典的封装成批),例如,图级别属性或预测目标。具体来说,形状为[num_features]的属性列表应该被返回为[num_examples, num_features],而不是[num_examples * num_features]。我们通过自定义__cat_dim__()返回一个None来实现自定义连接。

 class MyData(Data):
     def __cat_dim__(self, key, item):
         if key == 'foo':
             return None
         else:
             return super().__cat_dim__(key, item)

edge_index = torch.tensor([
   [0, 1, 1, 2],
   [1, 0, 2, 1],
])
foo = torch.randn(16)

data = MyData(edge_index=edge_index, foo=foo)
data_list = [data, data]
loader = DataLoader(data_list, batch_size=2)
batch = next(iter(loader))

print(batch)

Batch(batch=[6], edge_index=[2, 8], foo=[2, 16], ptr=[3])

正如我们期望的,batch.foo现在由两个维度来表示,一个批维度,一个特征维度。

3 创建超大规模数据集类实践

import os
import os.path as osp

import pandas as pd
import torch
from ogb.utils.mol import smiles2graph
from ogb.utils.torch_util import replace_numpy_with_torchtensor
from ogb.utils.url import download_url, extract_zip
from rdkit import RDLogger
from torch_geometric.data import Data, Dataset
import shutil

RDLogger.DisableLog('rdApp.*')

class MyPCQM4MDataset(Dataset):

    def __init__(self, root):
        self.url = 'https://dgl-data.s3-accelerate.amazonaws.com/dataset/OGB-LSC/pcqm4m_kddcup2021.zip'
        super(MyPCQM4MDataset, self).__init__(root)

        filepath = osp.join(root, 'raw/data.csv.gz')
        data_df = pd.read_csv(filepath)
        self.smiles_list = data_df['smiles']
        self.homolumogap_list = data_df['homolumogap']

    @property
    def raw_file_names(self):
        return 'data.csv.gz'

    def download(self):
        path = download_url(self.url, self.root)
        extract_zip(path, self.root)
        os.unlink(path)
        shutil.move(osp.join(self.root, 'pcqm4m_kddcup2021/raw/data.csv.gz'), osp.join(self.root, 'raw/data.csv.gz'))

    def len(self):
        return len(self.smiles_list)

    def get(self, idx):
        smiles, homolumogap = self.smiles_list[idx], self.homolumogap_list[idx]
        graph = smiles2graph(smiles)
        assert(len(graph['edge_feat']) == graph['edge_index'].shape[1])
        assert(len(graph['node_feat']) == graph['num_nodes'])

        x = torch.from_numpy(graph['node_feat']).to(torch.int64)
        edge_index = torch.from_numpy(graph['edge_index']).to(torch.int64)
        edge_attr = torch.from_numpy(graph['edge_feat']).to(torch.int64)
        y = torch.Tensor([homolumogap])
        num_nodes = int(graph['num_nodes'])
        data = Data(x, edge_index, edge_attr, y, num_nodes=num_nodes)
        return data

    # 获取数据集划分
    def get_idx_split(self):
        split_dict = replace_numpy_with_torchtensor(torch.load(osp.join(self.root, 'pcqm4m_kddcup2021/split_dict.pt')))
        return split_dict

if __name__ == "__main__":
    dataset = MyPCQM4MDataset('dataset2')
    from torch_geometric.data import DataLoader
    from tqdm import tqdm
    dataloader = DataLoader(dataset, batch_size=256, shuffle=True, num_workers=4)
    for batch in tqdm(dataloader):
        pass

这里需要pip install ogb以及conda install -c rdkit rdkit,代码运行时由于要下载数据集,需要翻越一些障碍。

参考:

标签:__,index,self,集类,batch,DataWhale,超大规模,edge,data
来源: https://blog.csdn.net/Kay_Xiaohe_He/article/details/118459552