其他分享
首页 > 其他分享> > data_loader读取器

data_loader读取器

作者:互联网

import random
import numpy as np
import pandas as pd
import cv2
def date_loader(image_dir, file_name, batch_size=1, mode='train'):
    train_dir_list = []
    train_label = []
    test_dir_list = []
    test_label = []
    val_dir_list = []
    val_label = []
    df = pd.read_csv(file_name)
    
    # 生成训练和测试数据集 0.8 /0.2
    df = df.sample(frac=1)
    for i in range(len(df)):
        if i <= (len(df)*0.8-1):
            dir =  image_dir+ '/' + df.iloc[i][0] + '.jpg'
            train_dir_list.append(dir)
            train_label.append(int(df.iloc[i][1]-1))
        else: 
            dir =  image_dir+ '/' + df.iloc[i][0] + '.jpg'
            test_dir_list.append(dir)
            test_label.append(int(df.iloc[i][1]-1))
    
    # 生成随机验证集,比列0.2
    df1 = df.sample(frac=0.2)
    for i in range(len(df1)):
        dir =  image_dir+ '/' + df1.iloc[i][0] + '.jpg'
        val_dir_list.append(dir)
        val_label.append(int(df.iloc[i][1]-1))
    
    def reader():
        batch_img = []
        batch_label = []
        if mode == 'train':
            count = 0
            for i in range(len(train_dir_list)):
                img = cv2.imread(train_dir_list[i])
                img = cv2.resize(img, (224,224), interpolation=cv2.INTER_CUBIC)/255
                img = np.transpose(img, (2,0,1))
                batch_img.append(img)
                batch_label.append(train_label[i])
                count +=1
                if (count %batch_size==0):
                    # print(len(train_label))
                    yield np.array(batch_img).astype('float32'), np.asarray(batch_label).astype('int64').reshape(batch_size,1)
                    batch_img = []
                    batch_label = []
        elif mode == 'test':
            count = 0
            for i in range(len(test_dir_list)):
                img = cv2.imread(test_dir_list[i])
                img = cv2.resize(img, (224,224), interpolation=cv2.INTER_CUBIC)/255
                img = np.transpose(img, (2,0,1))
                batch_img.append(img)
                batch_label.append(test_label[i])
                count +=1
                if (count %batch_size==0):
                    # print(len(test_label))
                    yield np.array(batch_img).astype('float32'), np.asarray(batch_label).astype('int64').reshape(batch_size,1)
                    batch_img = []
                    batch_label = []
        elif mode == 'val':
            count = 0
            for i in range(len(val_dir_list)):
                img = cv2.imread(val_dir_list[i])
                img = cv2.resize(img, (224,224), interpolation=cv2.INTER_CUBIC)/255
                img = np.transpose(img, (2,0,1))
                batch_img.append(img)
                batch_label.append(val_label[i])
                count +=1
                if (count %batch_size==0):
                    # print(len(val_dir_list))
                    yield np.array(batch_img).astype('float32'), np.asarray(batch_label).astype('int64').reshape(batch_size,1)
                    batch_img = []
                    batch_label = []
    return reader

a = date_loader('image2_100','a_100_drop_p.csv',mode='test')
for n , data in enumerate(a()):
    images, label = data
    # print(label)
    break

train_reader = paddle.batch(date_loader('image2_100','a_100_drop_p.csv',mode='train'), batch_size=10)
test_reader = paddle.batch(date_loader('image2_100','a_100_drop_p.csv',mode='test'), batch_size=10)

标签:df,batch,loader,train,test,import,data,读取器
来源: https://www.cnblogs.com/mumuzifeng/p/15109802.html