其他分享
首页 > 其他分享> > CTPN代码研读(三)utils/dataset(data_provider)研读

CTPN代码研读(三)utils/dataset(data_provider)研读

作者:互联网

CTPN代码研读系列:

1. 数据集的使用以及模型
2. utils/prepare/label
3. utils/dataset/data_provider

(本内容为自己理解,如有错误欢迎指正)

知识点:

工作流程

  1. 接prepare处理后的数据,读取图片,读取TXT坐标文件
  2. 将拼接出来的检测矩形框显示到图片上(显示split_label效果)
  3. 进行图片的多进程批处理。(处理过程中将图片,检测框坐标,图片尺寸压人同步线程队列进行处理)
    (作为一个源源不断向其他代码提供处理完成数据的代码,故命名为provider)

代码解读:

dataset/data_provider:

# encoding:utf-8
'''
读取处理过的标签和图片
'''
import os
import time

import cv2
import matplotlib.pyplot as plt
import numpy as np

from utils.dataset.data_util import GeneratorEnqueuer

DATA_FOLDER = "../../data/dataset/mlt/"


def get_training_data():
    '''

    :return 找到的训练集img的地址
    '''
    img_files = []
    exts = ['jpg', 'png', 'jpeg', 'JPG']
    for parent, dirnames, filenames in os.walk(os.path.join(DATA_FOLDER, "image")):
        for filename in filenames:
            for ext in exts:
                if filename.endswith(ext):
                    img_files.append(os.path.join(parent, filename))
                    break
    print('Find {} images'.format(len(img_files)))
    return img_files


def load_annoataion(p):
    '''
    读取文件中的坐标点
    :param p:  文件指针
    :return: 以坐标点为列表的各个框
    '''
    bbox = []
    with open(p, "r") as f:
        lines = f.readlines()
    for line in lines:
        line = line.strip().split(",")
        x_min, y_min, x_max, y_max = map(int, line)
        bbox.append([x_min, y_min, x_max, y_max, 1])
    return bbox


def generator(vis=False):
    '''
    读取处理过的标签和图片
    :param vis: 是否选择展示效果
    :return:
    [im]:图片列表
    Bbox:检测框坐标
    im_info: 图片——长宽通道数——信息
    '''
    image_list = np.array(get_training_data())
    #print(image_list)
    print('{} training images in {}'.format(image_list.shape[0], DATA_FOLDER))
    index = np.arange(0, image_list.shape[0])
    while True:
        np.random.shuffle(index)
        #把顺序打乱
        for i in index:
            try:
                im_fn = image_list[i]
                im = cv2.imread(im_fn)
                h, w, c = im.shape
                im_info = np.array([h, w, c]).reshape([1, 3])
                #读取图片

                #解析出标签
                _, fn = os.path.split(im_fn)
                #路径 文件名
                fn, _ = os.path.splitext(fn)
                #文件名 扩展名
                txt_fn = os.path.join(DATA_FOLDER, "label", fn + '.txt')
                #组合出对应的txt标签文件
                if not os.path.exists(txt_fn):
                    print("Ground truth for image {} not exist!".format(im_fn))
                    continue
                bbox = load_annoataion(txt_fn)
                if len(bbox) == 0:
                    print("Ground truth for image {} empty!".format(im_fn))
                    continue

                #将拼接出来的检测矩形框显示到图片上
                if vis:
                    for p in bbox:
                        cv2.rectangle(im, (p[0], p[1]), (p[2], p[3]), color=(0, 0, 255), thickness=1)
                    fig, axs = plt.subplots(1, 1, figsize=(30, 30))
                    axs.imshow(im[:, :, ::-1])
                    #设置坐标刻度
                    axs.set_xticks([])
                    axs.set_yticks([])
                    plt.tight_layout()
                    #plt.tight_layout会自动调整子图参数,使之填充整个图像区域。
                    plt.show()
                    plt.close()
                    #plt.imshow()函数负责对图像进行处理,并显示其格式,但是不能显示。
                    #其后跟着plt.show()才能显示出来。
                yield [im], bbox, im_info

            except Exception as e:
                print(e)
                continue


def get_batch(num_workers, **kwargs):
    #**kwargs:形参中按照关键字传值把多余的传值以字典的方式呈现
    '''
    在实现按批次处理??
    将各张图片以三个对象压人队列信息进行处理(内部主要实现了多线程,并没有具体处理??)
    :param num_workers:
    :param kwargs:
    :return:
    '''
    try:
        enqueuer = GeneratorEnqueuer(generator(**kwargs), use_multiprocessing=True)
        #传入的第一个参数为迭代器
        enqueuer.start(max_queue_size=24, workers=num_workers)
        generator_output = None
        while True:
            while enqueuer.is_running():
                if not enqueuer.queue.empty():
                    generator_output = enqueuer.queue.get()
                    #put放进去 get拿出来
                    #拿到处理图片的某一进程
                    break
                else:
                    time.sleep(0.01)
            yield generator_output
            generator_output = None
    finally:
        #finally块的作用就是为了保证无论出现什么情况,finally块里的代码一定会被执行。
        #结束进程
        if enqueuer is not None:
            enqueuer.stop()


if __name__ == '__main__':
    gen = get_batch(num_workers=2, vis=True)
    while True:
        image, bbox, im_info = next(gen)
        #将三个信息按原格式出队列
        print('done')

dataset/data_util

# encoding = utf-8
import multiprocessing
#多进程包
import threading
#多线程包
import time

import numpy as np

try:
    import queue
except ImportError:
    import Queue as queue


class GeneratorEnqueuer():
    def __init__(self, generator,
                 use_multiprocessing=False,
                 wait_time=0.05,
                 random_seed=None):
        self.wait_time = wait_time
        self._generator = generator
        self._use_multiprocessing = use_multiprocessing
        self._threads = []
        self._stop_event = None
        self.queue = None
        self.random_seed = random_seed

    def start(self, workers=1, max_queue_size=10):
        '''
        选择是否调用多线程,并通过守护进程,开始同时处理多张图片

        :param workers:  需要多少个工作线程??
        :param max_queue_size: 最大的队列尺寸,同时处理10张图??
        :return:
        '''
        def data_generator_task():
            '''
            把三个迭代的信息打包成一个对象并压入多线程队列中,进行同时处理
            :return:
            '''
            while not self._stop_event.is_set():
                #event:使用set()方法后,isSet()方法返回True
                try:
                    if self._use_multiprocessing or self.queue.qsize() < max_queue_size:
                        generator_output = next(self._generator)
                        #generator有三个迭代对象:图片 检测框坐标 图片尺寸信息
                        #将三个作为一个对象入队
                        self.queue.put(generator_output)
                        #print('队列入了{}'.format(generator_output))
                    else:
                        time.sleep(self.wait_time)
                except Exception:
                    self._stop_event.set()
                    # event.set():将event的标志设置为True,调用wait方法的所有线程将被唤醒。
                    raise
                #当程序出现错误,python会自动引发异常,也可以通过raise显示地引发异常。

        #这才是主函数,天呐
        try:
            #选择是否使用多进程
            if self._use_multiprocessing:
                self.queue = multiprocessing.Queue(maxsize=max_queue_size)
                self._stop_event = multiprocessing.Event()
            else:
                self.queue = queue.Queue()
                self._stop_event = threading.Event()

            for _ in range(workers):
                if self._use_multiprocessing:
                    # Reset random seed else all children processes
                    # share the same seed
                    np.random.seed(self.random_seed)
                    thread = multiprocessing.Process(target=data_generator_task)
                    #process(target): 要执行的方法
                    thread.daemon = True
                    #当且仅当主线程运行时有效,当其他非Daemon线程结束时可自动杀死所有Daemon线程。
                    if self.random_seed is not None:
                        self.random_seed += 1
                else:
                    thread = threading.Thread(target=data_generator_task)
                self._threads.append(thread)
                thread.start()
        except:
            self.stop()
            raise

    def is_running(self):
        return self._stop_event is not None and not self._stop_event.is_set()
        #是否正在运行
        #信号传输没有被阻塞并且有事件?

    def stop(self, timeout=None):
        '''
        关闭对应的多线程或者多进程
        :param timeout:
        :return:
        '''
        if self.is_running():
            self._stop_event.set()

        #线程工作情况
        for thread in self._threads:
            if thread.is_alive():
            #thread.isAlive(): 返回线程是否活动的。
                if self._use_multiprocessing:
                    thread.terminate()
                    #TerminateThread在线程外终止一个线程,用于强制终止线程。
                else:
                    thread.join(timeout)
                #thread.join():(用户进程)所完成的工作就是线程同步,
                #即主线程任务结束之后,进入阻塞状态,一直等待其他的子线程执行结束之后,主线程再终止

        #进程工作状况
        if self._use_multiprocessing:
            if self.queue is not None:
                self.queue.close()

        self._threads = []
        self._stop_event = None
        self.queue = None

    def get(self):
        while self.is_running():
            if not self.queue.empty():
                inputs = self.queue.get()
                if inputs is not None:
                    yield inputs
            else:
                time.sleep(self.wait_time)

标签:研读,generator,self,._,dataset,queue,线程,provider,event
来源: https://blog.csdn.net/qq_35307005/article/details/89929403