其他分享
首页 > 其他分享> > PyTorch中Dataset和DataLoader的基本使用

PyTorch中Dataset和DataLoader的基本使用

作者:互联网

import torch
import torch.utils.data.dataset as Dataset
import numpy as np
import torch.utils.data.dataloader as DataLoader

Data = np.asarray([[1, 2], [3, 4], [5, 6], [7, 8]])
Label = np.asarray([[0], [1], [0], [2]])

class SubDataSet(Dataset.Dataset):
    # 定义数据类型和标签
    def __init__(self, Data, Label):
        self.Data = Data
        self.Label = Label
        
    # 返回数据集的大小
    def __len__(self):
        return len(self.Data)
    
    # 得到数据内容和标签,一个一个返回的
    def __getitem__(self, index):
        data = torch.Tensor(self.Data[index])
        label = torch.Tensor(self.Label[index])
        return data, label


dataset = SubDataSet(Data, Label)
print(dataset)
print(f"dataset size: {dataset.__len__()}")
print(dataset.__getitem__(0))   # data, label
print(dataset[0])
# __getitem__(0) == dataset[0]
# batch_size表示一次性从dataset取多少个作为一个批次大小、
# data和label是一一对应
# shuffle表示每个epoch是否乱序
# num_workers表示并行的线程数 dataloader = DataLoader.DataLoader(dataset, batch_size = 2,shuffle = False, num_workers = 2) print(enumerate(dataloader)) for i, item in enumerate(dataloader): data, label = item print(f"data: {data} \n, label: {label} \n")

  

标签:__,Data,self,DataLoader,label,dataset,PyTorch,Dataset,data
来源: https://www.cnblogs.com/xjtu-yzk/p/16369302.html