CTPN代码研读(三)utils/dataset(data_provider)研读
作者:互联网
CTPN代码研读系列:
1. 数据集的使用以及模型
2. utils/prepare/label
3. utils/dataset/data_provider
(本内容为自己理解,如有错误欢迎指正)
知识点:
-
python–multiprocessing包
简单介绍:http://www.cnblogs.com/tkqasn/p/5701230.html -
Python – queue包 同步的队列类
Queue类实现了一个基本的先进先出(FIFO)容器,使用put()将元素添加到序列尾端,get()从队列尾部移除元素
具体介绍:https://www.cnblogs.com/skiler/p/6977727.html -
plt.tight_layout会自动调整子图参数,使之填充整个图像区域。
-
形参中的:args和**kwargs:
多个实参,放到一个元组里面,以开头,可以传多个参数;**是形参中按照关键字传值把多余的传值以字典的方式呈现
*args:(表示的就是将实参中按照位置传值,多出来的值都给args,且以元祖的方式呈现)
**kwargs:(表示的就是形参中按照关键字传值把多余的传值以字典的方式呈现)
具体参考:https://www.cnblogs.com/xuyuanyuan123/p/6674645.html -
Queue.qsize() :返回queue的近似值。注意:qsize>0 不保证(get)取元素不阻塞。qsize< maxsize不保证(put)存元素不会阻塞
-
Python多线程的threading Event:
event它是沟通中最简单的一个过程之中,一个线程产生一个信号。Python 通过threading.Event()产生一个event对象。event对象维护一个内部标志(标志初始值为False),通过set()将其置为True。wait(timeout)则用于堵塞线程直至Flag被set(或者超时,可选的),isSet()用于查询标志位是否为True,Clear()则用于清除标志位(使之为False)。
设置\清除信号
Event的set()方法可设置Event对象内部的信号标志为真,Event对象提供了isSet()方法来推断其内部信号标志的状态,使用set()方法后,isSet()方法返回True。clear()方法可清除Event对象内部的信号标志(设为False)。使用clear方法后。isSet()方法返回False
等待
当Event对象的内部信号标志为False时。wait方法一直堵塞线程等待到其为真或者超时(若提供,浮点数,单位为秒)才返回,若Event对象内部标志为True则wait()方法马上返回。 -
Daemon线程(守护线程):
当且仅当主线程运行时有效,当其他非Daemon线程结束时可自动杀死所有Daemon线程。- 如果某个子线程的daemon属性为True(守护线程),主线程运行结束时不对这个子线程进行检查而直接退出,同时所有daemon值为True的子线程将随主线程一起结束,而不论是否运行完成。
- 如果某个子线程的daemon属性为False(用户线程),主线程结束时会检测该子线程是否结束,如果该子线程还在运行,则主线程会等待它完成后再退出;
-
#thread.isAlive(): 返回线程是否活动的。
-
thread.join():(用户进程)所完成的工作就是线程同步,即主线程任务结束之后,进入阻塞状态,一直等待其他的子线程执行结束之后,主线程在终止
-
join有一个timeout参数:
当设置守护线程时,含义是主线程对于子线程等待timeout的时间将会杀死该子线程,最后退出程序。所以说,如果有10个子线程,全部的等待时间就是每个timeout的累加和。简单的来说,就是给每个子线程一个timeout的时间,让他去执行,时间一到,不管任务有没有完成,直接杀死。
没有设置守护线程时,主线程将会等待timeout的累加和这样的一段时间,时间一到,主线程结束,但是并没有杀死子线程,子线程依然可以继续执行,直到子线程全部结束,程序退出。
工作流程
- 接prepare处理后的数据,读取图片,读取TXT坐标文件
- 将拼接出来的检测矩形框显示到图片上(显示split_label效果)
- 进行图片的多进程批处理。(处理过程中将图片,检测框坐标,图片尺寸压人同步线程队列进行处理)
(作为一个源源不断向其他代码提供处理完成数据的代码,故命名为provider)
代码解读:
dataset/data_provider:
# encoding:utf-8
'''
读取处理过的标签和图片
'''
import os
import time
import cv2
import matplotlib.pyplot as plt
import numpy as np
from utils.dataset.data_util import GeneratorEnqueuer
DATA_FOLDER = "../../data/dataset/mlt/"
def get_training_data():
'''
:return 找到的训练集img的地址
'''
img_files = []
exts = ['jpg', 'png', 'jpeg', 'JPG']
for parent, dirnames, filenames in os.walk(os.path.join(DATA_FOLDER, "image")):
for filename in filenames:
for ext in exts:
if filename.endswith(ext):
img_files.append(os.path.join(parent, filename))
break
print('Find {} images'.format(len(img_files)))
return img_files
def load_annoataion(p):
'''
读取文件中的坐标点
:param p: 文件指针
:return: 以坐标点为列表的各个框
'''
bbox = []
with open(p, "r") as f:
lines = f.readlines()
for line in lines:
line = line.strip().split(",")
x_min, y_min, x_max, y_max = map(int, line)
bbox.append([x_min, y_min, x_max, y_max, 1])
return bbox
def generator(vis=False):
'''
读取处理过的标签和图片
:param vis: 是否选择展示效果
:return:
[im]:图片列表
Bbox:检测框坐标
im_info: 图片——长宽通道数——信息
'''
image_list = np.array(get_training_data())
#print(image_list)
print('{} training images in {}'.format(image_list.shape[0], DATA_FOLDER))
index = np.arange(0, image_list.shape[0])
while True:
np.random.shuffle(index)
#把顺序打乱
for i in index:
try:
im_fn = image_list[i]
im = cv2.imread(im_fn)
h, w, c = im.shape
im_info = np.array([h, w, c]).reshape([1, 3])
#读取图片
#解析出标签
_, fn = os.path.split(im_fn)
#路径 文件名
fn, _ = os.path.splitext(fn)
#文件名 扩展名
txt_fn = os.path.join(DATA_FOLDER, "label", fn + '.txt')
#组合出对应的txt标签文件
if not os.path.exists(txt_fn):
print("Ground truth for image {} not exist!".format(im_fn))
continue
bbox = load_annoataion(txt_fn)
if len(bbox) == 0:
print("Ground truth for image {} empty!".format(im_fn))
continue
#将拼接出来的检测矩形框显示到图片上
if vis:
for p in bbox:
cv2.rectangle(im, (p[0], p[1]), (p[2], p[3]), color=(0, 0, 255), thickness=1)
fig, axs = plt.subplots(1, 1, figsize=(30, 30))
axs.imshow(im[:, :, ::-1])
#设置坐标刻度
axs.set_xticks([])
axs.set_yticks([])
plt.tight_layout()
#plt.tight_layout会自动调整子图参数,使之填充整个图像区域。
plt.show()
plt.close()
#plt.imshow()函数负责对图像进行处理,并显示其格式,但是不能显示。
#其后跟着plt.show()才能显示出来。
yield [im], bbox, im_info
except Exception as e:
print(e)
continue
def get_batch(num_workers, **kwargs):
#**kwargs:形参中按照关键字传值把多余的传值以字典的方式呈现
'''
在实现按批次处理??
将各张图片以三个对象压人队列信息进行处理(内部主要实现了多线程,并没有具体处理??)
:param num_workers:
:param kwargs:
:return:
'''
try:
enqueuer = GeneratorEnqueuer(generator(**kwargs), use_multiprocessing=True)
#传入的第一个参数为迭代器
enqueuer.start(max_queue_size=24, workers=num_workers)
generator_output = None
while True:
while enqueuer.is_running():
if not enqueuer.queue.empty():
generator_output = enqueuer.queue.get()
#put放进去 get拿出来
#拿到处理图片的某一进程
break
else:
time.sleep(0.01)
yield generator_output
generator_output = None
finally:
#finally块的作用就是为了保证无论出现什么情况,finally块里的代码一定会被执行。
#结束进程
if enqueuer is not None:
enqueuer.stop()
if __name__ == '__main__':
gen = get_batch(num_workers=2, vis=True)
while True:
image, bbox, im_info = next(gen)
#将三个信息按原格式出队列
print('done')
dataset/data_util
# encoding = utf-8
import multiprocessing
#多进程包
import threading
#多线程包
import time
import numpy as np
try:
import queue
except ImportError:
import Queue as queue
class GeneratorEnqueuer():
def __init__(self, generator,
use_multiprocessing=False,
wait_time=0.05,
random_seed=None):
self.wait_time = wait_time
self._generator = generator
self._use_multiprocessing = use_multiprocessing
self._threads = []
self._stop_event = None
self.queue = None
self.random_seed = random_seed
def start(self, workers=1, max_queue_size=10):
'''
选择是否调用多线程,并通过守护进程,开始同时处理多张图片
:param workers: 需要多少个工作线程??
:param max_queue_size: 最大的队列尺寸,同时处理10张图??
:return:
'''
def data_generator_task():
'''
把三个迭代的信息打包成一个对象并压入多线程队列中,进行同时处理
:return:
'''
while not self._stop_event.is_set():
#event:使用set()方法后,isSet()方法返回True
try:
if self._use_multiprocessing or self.queue.qsize() < max_queue_size:
generator_output = next(self._generator)
#generator有三个迭代对象:图片 检测框坐标 图片尺寸信息
#将三个作为一个对象入队
self.queue.put(generator_output)
#print('队列入了{}'.format(generator_output))
else:
time.sleep(self.wait_time)
except Exception:
self._stop_event.set()
# event.set():将event的标志设置为True,调用wait方法的所有线程将被唤醒。
raise
#当程序出现错误,python会自动引发异常,也可以通过raise显示地引发异常。
#这才是主函数,天呐
try:
#选择是否使用多进程
if self._use_multiprocessing:
self.queue = multiprocessing.Queue(maxsize=max_queue_size)
self._stop_event = multiprocessing.Event()
else:
self.queue = queue.Queue()
self._stop_event = threading.Event()
for _ in range(workers):
if self._use_multiprocessing:
# Reset random seed else all children processes
# share the same seed
np.random.seed(self.random_seed)
thread = multiprocessing.Process(target=data_generator_task)
#process(target): 要执行的方法
thread.daemon = True
#当且仅当主线程运行时有效,当其他非Daemon线程结束时可自动杀死所有Daemon线程。
if self.random_seed is not None:
self.random_seed += 1
else:
thread = threading.Thread(target=data_generator_task)
self._threads.append(thread)
thread.start()
except:
self.stop()
raise
def is_running(self):
return self._stop_event is not None and not self._stop_event.is_set()
#是否正在运行
#信号传输没有被阻塞并且有事件?
def stop(self, timeout=None):
'''
关闭对应的多线程或者多进程
:param timeout:
:return:
'''
if self.is_running():
self._stop_event.set()
#线程工作情况
for thread in self._threads:
if thread.is_alive():
#thread.isAlive(): 返回线程是否活动的。
if self._use_multiprocessing:
thread.terminate()
#TerminateThread在线程外终止一个线程,用于强制终止线程。
else:
thread.join(timeout)
#thread.join():(用户进程)所完成的工作就是线程同步,
#即主线程任务结束之后,进入阻塞状态,一直等待其他的子线程执行结束之后,主线程再终止
#进程工作状况
if self._use_multiprocessing:
if self.queue is not None:
self.queue.close()
self._threads = []
self._stop_event = None
self.queue = None
def get(self):
while self.is_running():
if not self.queue.empty():
inputs = self.queue.get()
if inputs is not None:
yield inputs
else:
time.sleep(self.wait_time)
标签:研读,generator,self,._,dataset,queue,线程,provider,event 来源: https://blog.csdn.net/qq_35307005/article/details/89929403