其他分享
首页 > 其他分享> > Pytorch 自定义数据加载器

Pytorch 自定义数据加载器

作者:互联网

在前面,我们使用Lenet训练的都是使用默认数据加载器加载特定的数据,本章节我们分析下怎么使用自定义的data.Dataset加载数据

吸烟数据集

数据分为两位,smoke和no_smoke

在这里插入图片描述


smoke_clas
	-train //每个分类各1000张数据
		-0 
		-1
	-valid //每个数据各300张数据
		-0
		-1
	-test //每个分类各200张数据
		-0
		-1

生成训练数据集的mean和std

Pytorch# python tools/getPixelMeanStd.py
img_mean: [0.479514059223452, 0.3860349787384188, 0.3531293626700743]
img_std: [0.19383713054251062, 0.17074204578960742, 0.16474959905710554]

自定义数据集加载

关键函数如下

find_classes

def find_classes(dir:str) -> Tuple[List[str], Dict[str, int]]:
    classes = [d.name for d in os.scandir(dir) if d.is_dir]
    classes.sort()
    class_to_idx = {cls_name:i for i, cls_name in enumerate(classes)}
    return classes, class_to_idx

根据传入的目录名称,加载对应的分类名称和id,名称会按照升序排序,例如吸烟数据,最后得到的分类和分类id对应结果如下

classes[['0', '1']]
class_to_idx[{'0': 0, '1': 1}]

因为我们使用数字命名目录,所以,分类名称也是数字,如果改成字母命名的话,结果如下:

classes[['smoke', 'no_smoke']]
class_to_idx[{'smoke': 0, 'no_smoke': 1}]

make_dataset

def make_dataset(
    directory: str,
    class_to_idx: Dict[str, int],) -> List[Tuple[str, int]]:

    instances = []#struct

    if not os.path.isdir(directory):
        raise ValueError("Image not dir!!!")

    image_count = 0
    for target_class in sorted(class_to_idx.keys()):
            class_index = class_to_idx[target_class]
            target_dir = os.path.join(directory, target_class)
            if not os.path.isdir(target_dir):
                continue
            for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
                for fname in sorted(fnames):
                    path = os.path.join(root, fname)
                    if is_valid_file(path):
                        item = path, class_index
                        #print(item)
                        instances.append(item)

    return instances

解析每一个分类目录下图片,返回每个图片的路径和分类id结果,每个item的格式如下

('./datas/smoke_clas/test/1/08021120510000090.jpg', 1)
('./datas/smoke_clas/test/0/07172027020001567.jpg', 0)

定义数据加载器 SmokeData

然后,集成继承data.Dataset实现smokeData加载器

class SmokeData(data.Dataset):


    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.mean = (0.479, 0.385, 0.352)
        self.std = (0.194, 0.171, 0.165)
        #获取类别和类别id
        classes, class_to_idx = find_classes(root_dir)
        samples = make_dataset(root_dir, class_to_idx)

        if len(samples) == 0:
            msg = "Found 0 files in subfolders of: {}\n".format(root_dir)
            raise RuntimeError(msg)
        self.classes = classes
        self.class_to_idx = class_to_idx
        self.samples = samples
        self.targets = [s[1] for s in samples]
        print('------SmokeData[%s]',root_dir)
        print('classes[%s]'% self.classes)
        print('class_to_idx[%s]'%self.class_to_idx)
        self.count = 0
        #print('targets[%s]'%self.targets)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, target = self.samples[idx]
        sample = default_loader(path)
        self.count = self.count+1
        #sample.save("./"+str(self.count)+'.jpg')
        if self.transform is not None:
            sample = self.transform(sample)
        #图片内容,和分类id
        return sample, target

主要实现__len__和__getitem__方法,分别返回数据集的长度,和遍历获取每个数据的图片和分类id。

使用数据加载器

def getSmokeDataloader(train_dir, test_dir, dataresize=64):
        train_transforms = transforms.Compose([
                                               transforms.RandomRotation(20),
                                               transforms.Resize((dataresize,dataresize)),
                                               transforms.RandomHorizontalFlip(0.5), 
                                               #transforms.ColorJitter(brightness=[0.8,1.3], contrast=[0.8,1.3], saturation=[0.8,1.3], hue=0.2),
                                               transforms.ToTensor(), 
                                               transforms.Normalize((0.479, 0.385, 0.352),
                                                                    (0.194, 0.171, 0.165))])

        test_transforms = transforms.Compose([
                                            transforms.Resize((dataresize,dataresize)),
                                            transforms.ToTensor(),
                                            transforms.Normalize((0.479, 0.385, 0.352),
                                                            (0.194, 0.171, 0.165))])
        #创建自定义的加载器
        tain_smoke_data = SmokeData(train_dir, transform = train_transforms)
        test_smoke_data = SmokeData(test_dir, transform = test_transforms)
        # 使用预处理格式加载图像
        #train_data = datasets.ImageFolder(train_dir,transform = train_transforms)
        #valid_data = datasets.ImageFolder(test_dir,transform = test_transforms)

        # 创建三个加载器,分别为训练,验证,测试,将训练集的batch大小设为64,即每次加载器向网络输送64张图片
        #shuffle 随机打乱,网络更容易学习不同的特征,更容易收敛
        print('load dataset......')
        trainloader = torch.utils.data.DataLoader(tain_smoke_data,batch_size = 64,shuffle = True)
        validloader = torch.utils.data.DataLoader(test_smoke_data,batch_size = 64)

        return trainloader,validloader

通过以上步骤,我们就获取了训练集和验证集的数据加载器,然后训练的时候使用方法如下

#数据读取
for i,data in enumerate(train_loader):
    inputs,labels = data
    #有GPU则将数据置入GPU加速
    inputs, labels = inputs.to(self.device), labels.to(self.device)   

以上就是自定义pytroch数据加载器的具体实现

源码参考

Pytorch/datasets/classifier/smokeDataset.py

标签:定义数据,idx,self,Pytorch,transforms,smoke,class,dir,加载
来源: https://blog.csdn.net/jiadongfengyahoo/article/details/112389826