其他分享
首页 > 其他分享> > datasets数据读取器

datasets数据读取器

作者:互联网

#切分数据集
img_dir = train_parameters['img_dir']
file_name = train_parameters['file_name']
df = pd.read_csv(file_name)
df = df.sample(frac=1)
train_list = []
val_list = []
for i in range(len(df)):
    if (i <= len(df)*0.8):
        dirlist = img_dir + '/' + df.iloc[i][0] + '.jpg'
        label = df.iloc[i][1]
        datainfo = [dirlist, label]
        train_list.append(datainfo)
    else:
        dirlist = img_dir + '/' + df.iloc[i][0] + '.jpg'
        label = df.iloc[i][1]
        datainfo = [dirlist, label]
        val_list.append(datainfo)

# print(len(train_list))
# print(train_list[1][1])

定义数据集

'''
继承paddle.io.Dataset类
'''

IMAGE_SIZE = [3,224,224]
class Datasets(Dataset):
def init(self, data, mode='train'):
'''
步骤二:实现构造函数,定义数据读取,划分训练和测试、验证数据集
'''

    super(Datasets, self).__init__()

    self.data = data
    self.mode = mode
    if self.mode == 'train':
        self.transforms = T.Compose([
            # T.RandomResizedCrop(IMAGE_SIZE),
            # T.RandomHorizontalFlip(0.5),
            # T.ToTensor(),
            # T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            #对输入图像进行裁剪,保持图片中心点不变。transform = CenterCrop(224)。
            T.CenterCrop(224),
            #随机调整图像的亮度,对比度,饱和度和色调。 transform = ColorJitter(0.4, 0.4, 0.4, 0.4)
            T.ColorJitter(0.4, 0.4, 0.4, 0.4), 
            #依据degrees参数指定的角度范围,按照均匀分布随机产生一个角度对图像进行旋转。
            T.RandomRotation(60),  
            #将形状为 (H x W x C)的输入数据 PIL.Image 或 numpy.ndarray 转换为 (C x H x W)。
            T.ToTensor(),
            #图像归一化处理,支持两种方式: 1. 用统一的均值和标准差值对图像的每个通道进行归一化处理; 2. 对每个通道指定不同的均值和标准差值进行归一化处理。
            T.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225],)   
        ])

    elif self.mode == 'valid':
        self.transforms = T.Compose([
            # T.Resize(IMAGE_SIZE[0]),
            # T.RandomCrop(IMAGE_SIZE),
            # T.ToTensor(),
            # T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            #对输入图像进行裁剪,保持图片中心点不变。transform = CenterCrop(224)。
            T.CenterCrop(224),  
            #随机调整图像的亮度,对比度,饱和度和色调。 transform = ColorJitter(0.4, 0.4, 0.4, 0.4)
            T.ColorJitter(0.4, 0.5, 0.6, 0.7), 
            #依据degrees参数指定的角度范围,按照均匀分布随机产生一个角度对图像进行旋转。
            T.RandomRotation(60),  
            #将形状为 (H x W x C)的输入数据 PIL.Image 或 numpy.ndarray 转换为 (C x H x W)。
            T.ToTensor(),
            #图像归一化处理,支持两种方式: 1. 用统一的均值和标准差值对图像的每个通道进行归一化处理; 2. 对每个通道指定不同的均值和标准差值进行归一化处理。
            T.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225],)   
        ])


def __getitem__(self, index):
    '''
    实现getitem方法,定义指定index时如何获取数据,并返回单条数(训练数据,对应的标签)
    '''
    image = Image.open(self.data[index][0])
    if image.mode != 'RGB':
        image = image.convert('RGB')
        
    data = self.transforms(image)
    label = np.array([self.data[index][1]-1]).astype('int64')
    return data, label

def __len__(self):
    return len(self.data)

`

标签:datasets,数据,self,图像,0.4,mode,归一化,data,读取器
来源: https://www.cnblogs.com/mumuzifeng/p/15109792.html