其他分享
首页 > 其他分享> > 从零写CRNN文字识别 —— (3)数据加载器

从零写CRNN文字识别 —— (3)数据加载器

作者:互联网

简介

上一节实现了加载配置,加载配置文件可以方便的进行参数的修改,这一节实现加载数据。

DataLoader

我使用的数据是MLT2017的数据集,在其中把法语的分割出来了,数据集下载地址:法语OCR识别数据集

其中解压后包含训练集图片文件夹、测试集图片文件夹、训练集标签文件和测试集标签文件以及字典文件。

数据可以放置在工程的data文件夹下或者你喜欢的位置,加载数据的代码自然就放在data文件夹下,命名dataset.py:

import torch.utils.data as data # 加载torch的数据加载器
import numpy as np
import time
import cv2
import sys
import os
sys.path.append(os.getcwd())

# 实现模板类
class OCRDataset(data.Dataset):
    def __init__(self,config,is_train=True):
        self.root = config.DATASET.ROOT
        self.is_train = is_train
        self.inp_h = config.MODEL.IMAGE_SIZE.H
        self.inp_w = config.MODEL.IMAGE_SIZE.W
        self.dataset_name = config.DATASET.DATASET
        self.mean = np.array(config.DATASET.MEAN, dtype=np.float32)
        self.std = np.array(config.DATASET.STD, dtype=np.float32)
        char_file = config.DATASET.CHAR_FILE
        txt_file = config.DATASET.JSON_FILE['train'] if is_train else config.DATASET.JSON_FILE['val']
        txt_file = os.path.join(self.root,txt_file)
        # convert name:indices to name:string
        self.labels = []
        with open(txt_file, 'r', encoding='utf-8') as file:
            contents = file.readlines()
            for c in contents:
                imgname = c.split('\t')[0]
                string = c.split('\t')[1].replace("\n","")
                self.labels.append({imgname: string})

        print("load {} images!".format(self.__len__()))

    def __len__(self):
        # 实现模板方法
        return len(self.labels)

    def __getitem__(self,idx):
        img_name = list(self.labels[idx].keys())[0]
        img = cv2.imread(os.path.join(self.root, img_name))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        img_h, img_w = img.shape
        img = cv2.resize(img, (0,0), fx=self.inp_w / img_w, fy=self.inp_h / img_h, interpolation=cv2.INTER_CUBIC)
        img = np.reshape(img, (self.inp_h, self.inp_w, 1))
        img = img.astype(np.float32)
        img = (img/255. - self.mean) / self.std
        img = img.transpose([2, 0, 1])

        return img, idx

这段代码看着很复杂其实很简单:

在__init__函数中最后拿到了self.labels = []
他的数据形式就是:
self.labels = [{“img.png”:“abcd”},{“img2.png”:“abcdrffff”}…]
就是把路径和标签存在了字典里,字典用列表包着。

测试

在train.py中加入测试代码:

import os
sys.path.append(os.getcwd())
import argparse
import model.model as crnn
import torch
import torch.optim as optim

from utils.utils import load_yml
from data.dataset import OCRDataset

def parse_arg():
    parser = argparse.ArgumentParser(description="train crnn")
    parser.add_argument('--cfg', help='experiment configuration filename', required=True, type=str)
    args = parser.parse_args()
    config = load_yml(args.cfg)
    return config

if __name__ == "__main__":
    config = parse_arg()
    print(config)
    train_dataset = OCRDataset(config)
    train_loader = data.DataLoader(
        dataset=train_dataset,
        batch_size=config.TRAIN.BATCH_SIZE_PER_GPU,
        shuffle=config.TRAIN.SHUFFLE,
        num_workers=config.WORKERS,
        pin_memory=config.PIN_MEMORY,
    )
    # get device
    if torch.cuda.is_available():
        device = torch.device("cuda:{}".format(config.GPUID))
    else:
        device = torch.device("cpu:0")
    for i, (inp, idx) in enumerate(train_loader):
        inp = inp.to(device)
        print("inp",inp[0].cpu().detach().numpy(),inp[0].cpu().detach().numpy().shape)
        exit(-1)# 这里就测试打印一个batch然后退出程序

输出结果:

在这里插入图片描述

数据加载完成接着就是搭建模型了~

标签:__,img,import,self,inp,CRNN,零写,config,加载
来源: https://blog.csdn.net/qq_37668436/article/details/113647812