PyTorch中Dataset和DataLoader的基本使用
作者:互联网
import torch import torch.utils.data.dataset as Dataset import numpy as np import torch.utils.data.dataloader as DataLoader Data = np.asarray([[1, 2], [3, 4], [5, 6], [7, 8]]) Label = np.asarray([[0], [1], [0], [2]]) class SubDataSet(Dataset.Dataset): # 定义数据类型和标签 def __init__(self, Data, Label): self.Data = Data self.Label = Label # 返回数据集的大小 def __len__(self): return len(self.Data) # 得到数据内容和标签,一个一个返回的 def __getitem__(self, index): data = torch.Tensor(self.Data[index]) label = torch.Tensor(self.Label[index]) return data, label dataset = SubDataSet(Data, Label) print(dataset) print(f"dataset size: {dataset.__len__()}") print(dataset.__getitem__(0)) # data, label print(dataset[0]) # __getitem__(0) == dataset[0] # batch_size表示一次性从dataset取多少个作为一个批次大小、 # data和label是一一对应 # shuffle表示每个epoch是否乱序
# num_workers表示并行的线程数 dataloader = DataLoader.DataLoader(dataset, batch_size = 2,shuffle = False, num_workers = 2) print(enumerate(dataloader)) for i, item in enumerate(dataloader): data, label = item print(f"data: {data} \n, label: {label} \n")
标签:__,Data,self,DataLoader,label,dataset,PyTorch,Dataset,data 来源: https://www.cnblogs.com/xjtu-yzk/p/16369302.html