其他分享
首页 > 其他分享> > pytorch datasets与dataloader阐释说明

pytorch datasets与dataloader阐释说明

作者:互联网

 

 

 

 

一.torch.utils.data包含Dataset,Sampler,Dataloader

torch.utils.data主要包括以下三个类:
1. class torch.utils.data.Dataset: 作用: (1) 创建数据集,有__getitem__(self, index)函数来根据索引序号获取图片和标签, 有__len__(self)函数来获取数据集的长度.

其他的数据集类必须是torch.utils.data.Dataset的子类,比如说torchvision.ImageFolder.

2. class torch.utils.data.sampler.Sampler(data_source)
参数: data_source (Dataset) – dataset to sample from

作用: 创建一个采样器, class torch.utils.data.sampler.Sampler是所有的Sampler的基类, 其中,iter(self)函数来获取一个迭代器,对数据集中元素的索引进行迭代,len(self)方法返回迭代器中包含元素的长度.

3. class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)

 

二. datasets.ImageFolder  ,可用于提取分类网络图片使用

参数:

root:图片存储的根目录,即各类别文件夹所在目录的上一级目录。
transform:对图片进行预处理的操作(函数),原始图片作为输入,返回一个转换后的图片。
target_transform:对图片类别进行预处理的操作,输入为 target,输出对其的转换。如果不传该参数,即对 target 不做任何转换,返回的顺序索引 0,1, 2…
loader:表示数据集加载方式,通常默认加载方式即可。
is_valid_file:获取图像文件的路径并检查该文件是否为有效文件的函数(用于检查损坏文件)

 

属性值:

  

def verity_datasets():
root = './datasets/train' # 根路径
data = datasets.ImageFolder(root) # 可以理解载入dataset
print('data.classes:',data.classes) # 类别信息
print('data.class_to_idx:',data.class_to_idx) # 类别与索引
print('data.imgs:',data.imgs) # 图片地址与标签
img = cv2.imread(data.imgs[0][0])
plt.imshow(img)
plt.show()
for img,label in data:
image=cv2.cvtColor(np.asarray(img),cv2.COLOR_RGB2BGR)
print( image.shape,label)

代码运行结果如下:

 

 

 

若需要添加transform 可使用如下代码:

from torchvision.datasets import ImageFolder
from torchvision import transforms

#加上transforms
normalize=transforms.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5])
transform=transforms.Compose([
transforms.RandomCrop(180),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(), #将图片转换为Tensor,归一化至[0,1]
normalize
])

dataset=ImageFolder('./data/train',transform=transform)

 

三.dataloader加载方式,需要添加自己信息如何更改源码如下:

 

import numpy as np
from PIL import Image
from torch.utils.data.dataset import TensorDataset,Dataset
from typing import TypeVar, Generic, Iterable, Iterator, Sequence, List, Optional, Tuple
from torch.tensor import Tensor
T_co = TypeVar('T_co', covariant=True)
T = TypeVar('T')


class TensorDataset(Dataset[Tuple[Tensor, ...]]):
r"""Dataset wrapping tensors.

Each sample will be retrieved by indexing tensors along the first dimension.

Arguments:
*tensors (Tensor): tensors that have the same size of the first dimension.
"""
tensors: Tuple[Tensor, ...]

def __init__(self,my_info, *tensors: Tensor) -> None:
assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
self.tensors = tensors
self.my_info=my_info

def __getitem__(self, index):
return tuple([tensor[index],self.my_info[index]] for tensor in self.tensors)

def __len__(self):
return self.tensors[0].size(0)



def verity_dataloader():


x = torch.linspace(1, 10, 10)
y = torch.linspace(10, 1, 10)
k = [{'img_meta':20} for _ in range(10)]
print(x,y)
# 数据集包装数据和标签,实际是一个迭代器,类似dataset方法,一般为输入图片x与对应标签y,
# 但如果想更改传入更多参数,需要自己更改源码,主要是__getiterm__方法。
# torch_dataset = torch.utils.data.TensorDataset(x, y) # 未更改源码
torch_dataset = TensorDataset(k,x,y) # 已经更改了源码

loader = torch.utils.data.DataLoader(
# 从数据库中每次抽出batch size个样本
dataset=torch_dataset,
batch_size=3,
shuffle=True,
num_workers=2,
drop_last=True # True丢弃最后bath不足数据,false不丢弃
)

for step, (batch_x, batch_y) in enumerate(loader):
# training
print("steop:{}, batch_x:{}, batch_y:{}".format(step, batch_x, batch_y))

 

结果如下:

 

 

 

 

 

 

 

 

 

参考博客:

https://blog.csdn.net/qq_39507748/article/details/105394808

https://blog.csdn.net/tsq292978891/article/details/79414512

 

标签:__,datasets,self,torch,dataset,pytorch,data,tensors,dataloader
来源: https://www.cnblogs.com/tangjunjun/p/14856214.html