Pytorch 自定义数据加载器
作者:互联网
在前面,我们使用Lenet训练的都是使用默认数据加载器加载特定的数据,本章节我们分析下怎么使用自定义的data.Dataset加载数据
吸烟数据集
数据分为两位,smoke和no_smoke
- smoke数据
主要是嘴上有烟的,不同大小,不同距离和不同模糊度的图片
- nosmoke 数据
主要是光脸,喝饮料,戴口罩和手在嘴部等其他脸部反例数据
smoke_clas
-train //每个分类各1000张数据
-0
-1
-valid //每个数据各300张数据
-0
-1
-test //每个分类各200张数据
-0
-1
生成训练数据集的mean和std
Pytorch# python tools/getPixelMeanStd.py
img_mean: [0.479514059223452, 0.3860349787384188, 0.3531293626700743]
img_std: [0.19383713054251062, 0.17074204578960742, 0.16474959905710554]
自定义数据集加载
关键函数如下
find_classes
def find_classes(dir:str) -> Tuple[List[str], Dict[str, int]]:
classes = [d.name for d in os.scandir(dir) if d.is_dir]
classes.sort()
class_to_idx = {cls_name:i for i, cls_name in enumerate(classes)}
return classes, class_to_idx
根据传入的目录名称,加载对应的分类名称和id,名称会按照升序排序,例如吸烟数据,最后得到的分类和分类id对应结果如下
classes[['0', '1']]
class_to_idx[{'0': 0, '1': 1}]
因为我们使用数字命名目录,所以,分类名称也是数字,如果改成字母命名的话,结果如下:
classes[['smoke', 'no_smoke']]
class_to_idx[{'smoke': 0, 'no_smoke': 1}]
make_dataset
def make_dataset(
directory: str,
class_to_idx: Dict[str, int],) -> List[Tuple[str, int]]:
instances = []#struct
if not os.path.isdir(directory):
raise ValueError("Image not dir!!!")
image_count = 0
for target_class in sorted(class_to_idx.keys()):
class_index = class_to_idx[target_class]
target_dir = os.path.join(directory, target_class)
if not os.path.isdir(target_dir):
continue
for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
for fname in sorted(fnames):
path = os.path.join(root, fname)
if is_valid_file(path):
item = path, class_index
#print(item)
instances.append(item)
return instances
解析每一个分类目录下图片,返回每个图片的路径和分类id结果,每个item的格式如下
('./datas/smoke_clas/test/1/08021120510000090.jpg', 1)
('./datas/smoke_clas/test/0/07172027020001567.jpg', 0)
定义数据加载器 SmokeData
然后,集成继承data.Dataset实现smokeData加载器
class SmokeData(data.Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.mean = (0.479, 0.385, 0.352)
self.std = (0.194, 0.171, 0.165)
#获取类别和类别id
classes, class_to_idx = find_classes(root_dir)
samples = make_dataset(root_dir, class_to_idx)
if len(samples) == 0:
msg = "Found 0 files in subfolders of: {}\n".format(root_dir)
raise RuntimeError(msg)
self.classes = classes
self.class_to_idx = class_to_idx
self.samples = samples
self.targets = [s[1] for s in samples]
print('------SmokeData[%s]',root_dir)
print('classes[%s]'% self.classes)
print('class_to_idx[%s]'%self.class_to_idx)
self.count = 0
#print('targets[%s]'%self.targets)
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
path, target = self.samples[idx]
sample = default_loader(path)
self.count = self.count+1
#sample.save("./"+str(self.count)+'.jpg')
if self.transform is not None:
sample = self.transform(sample)
#图片内容,和分类id
return sample, target
主要实现__len__和__getitem__方法,分别返回数据集的长度,和遍历获取每个数据的图片和分类id。
使用数据加载器
def getSmokeDataloader(train_dir, test_dir, dataresize=64):
train_transforms = transforms.Compose([
transforms.RandomRotation(20),
transforms.Resize((dataresize,dataresize)),
transforms.RandomHorizontalFlip(0.5),
#transforms.ColorJitter(brightness=[0.8,1.3], contrast=[0.8,1.3], saturation=[0.8,1.3], hue=0.2),
transforms.ToTensor(),
transforms.Normalize((0.479, 0.385, 0.352),
(0.194, 0.171, 0.165))])
test_transforms = transforms.Compose([
transforms.Resize((dataresize,dataresize)),
transforms.ToTensor(),
transforms.Normalize((0.479, 0.385, 0.352),
(0.194, 0.171, 0.165))])
#创建自定义的加载器
tain_smoke_data = SmokeData(train_dir, transform = train_transforms)
test_smoke_data = SmokeData(test_dir, transform = test_transforms)
# 使用预处理格式加载图像
#train_data = datasets.ImageFolder(train_dir,transform = train_transforms)
#valid_data = datasets.ImageFolder(test_dir,transform = test_transforms)
# 创建三个加载器,分别为训练,验证,测试,将训练集的batch大小设为64,即每次加载器向网络输送64张图片
#shuffle 随机打乱,网络更容易学习不同的特征,更容易收敛
print('load dataset......')
trainloader = torch.utils.data.DataLoader(tain_smoke_data,batch_size = 64,shuffle = True)
validloader = torch.utils.data.DataLoader(test_smoke_data,batch_size = 64)
return trainloader,validloader
通过以上步骤,我们就获取了训练集和验证集的数据加载器,然后训练的时候使用方法如下
#数据读取
for i,data in enumerate(train_loader):
inputs,labels = data
#有GPU则将数据置入GPU加速
inputs, labels = inputs.to(self.device), labels.to(self.device)
以上就是自定义pytroch数据加载器的具体实现
源码参考
Pytorch/datasets/classifier/smokeDataset.py
标签:定义数据,idx,self,Pytorch,transforms,smoke,class,dir,加载 来源: https://blog.csdn.net/jiadongfengyahoo/article/details/112389826