yolov4-deepsort目标跟踪、ROI计数、绘制轨迹
作者:互联网
项目环境部署参看我的另一篇博客《windows下yolov4-deepsort项目tensorflowGPU版本配置+项目实战》
linux下部署直接参考源码github
效果展示:
目录
1 介绍
今天整理了下之前做的一些工作,在原有yolov4-deepsort功能上增加了一些别的小功能。
增加功能如下:
- 绘制ROI区域,并记录roi内目标数量
- 目标进入roi区域显示‘enter’
- 打印各类目标跟踪数
- 绘制目标运动轨迹
其余类似任务功能做法原理类似。
2 object_tracker.py解读与重写
上述效果的object_tracker.py文件如下(其余文件是作者源码),这里需要做的是提供自己的ROI区域坐标点,取点以后要拟合roi边界,可参考我的另一篇《ROI区域提取(图上直接利用鼠标事件提取坐标点,可视化显示)》,将对应位置的判别条件换为自己的数据即可。
import os
# comment out below line to enable tensorflow logging outputs
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import time
import tensorflow as tf
physical_devices = tf.config.experimental.list_physical_devices('GPU')
if len(physical_devices) > 0:
tf.config.experimental.set_memory_growth(physical_devices[0], True)
from absl import app, flags, logging
from absl.flags import FLAGS
import core.utils as utils
from core.yolov4 import filter_boxes
from tensorflow.python.saved_model import tag_constants
from core.config import cfg
from PIL import Image
import cv2
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession
# deep sort imports
from deep_sort import preprocessing, nn_matching
from deep_sort.detection import Detection
from deep_sort.tracker import Tracker
from tools import generate_detections as gdet
flags.DEFINE_string('framework', 'tf', '(tf, tflite, trt')
flags.DEFINE_string('weights', './checkpoints/yolov4-416',
'path to weights file')
flags.DEFINE_integer('size', 416, 'resize images to')
flags.DEFINE_boolean('tiny', False, 'yolo or yolo-tiny')
flags.DEFINE_string('model', 'yolov4', 'yolov3 or yolov4')
flags.DEFINE_string('video', './data/video/test.mp4', 'path to input video or set to 0 for webcam')
flags.DEFINE_string('output','./outputs/test2.avi', 'path to output video')
flags.DEFINE_string('output_format', 'XVID', 'codec used in VideoWriter when saving video to file')
flags.DEFINE_float('iou', 0.45, 'iou threshold')
flags.DEFINE_float('score', 0.50, 'score threshold')
flags.DEFINE_boolean('dont_show', False, 'dont show video output')
flags.DEFINE_boolean('info', False, 'show detailed info of tracked objects')
flags.DEFINE_boolean('count', True, 'count objects being tracked on screen')
flags.DEFINE_boolean('roi',True,'Draw ROI and count')
def main(_argv):
# Definition of the parameters (参数的定义)
max_cosine_distance = 0.4
nn_budget = None
nms_max_overlap = 1.0
# initialize deep sort (初始化深度排序)
model_filename = 'model_data/mars-small128.pb'
encoder = gdet.create_box_encoder(model_filename, batch_size=1)
# calculate cosine distance metric (计算余弦距离度规)
metric = nn_matching.NearestNeighborDistanceMetric("cosine", max_cosine_distance, nn_budget)
# initialize tracker (初始化跟踪器)
tracker = Tracker(metric)
# load configuration for object detector (为对象检测器加载配置)
config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)
STRIDES, ANCHORS, NUM_CLASS, XYSCALE = utils.load_config(FLAGS)
input_size = FLAGS.size
video_path = FLAGS.video
# load tflite model if flag is set (如果设置了标记,则加载tflite模型)
if FLAGS.framework == 'tflite':
interpreter = tf.lite.Interpreter(model_path=FLAGS.weights)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
print(input_details)
print(output_details)
# otherwise load standard tensorflow saved model (否则加载标准tensorflow保存模型)
else:
saved_model_loaded = tf.saved_model.load(FLAGS.weights, tags=[tag_constants.SERVING])
infer = saved_model_loaded.signatures['serving_default']
# begin video capture (开始视频捕捉)
try:
vid = cv2.VideoCapture(int(video_path))
except:
vid = cv2.VideoCapture(video_path)
out = None
# get video ready to save locally if flag is set (如果设置了标记,准备在本地保存视频)
if FLAGS.output:
# by default VideoCapture returns float instead of int (默认情况下,VideoCapture返回float而不是int)
width = int(vid.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(vid.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(vid.get(cv2.CAP_PROP_FPS))
codec = cv2.VideoWriter_fourcc(*FLAGS.output_format)
out = cv2.VideoWriter(FLAGS.output, codec, fps, (width, height))
frame_num = 0
# 设置一个用来存放对象的字典
object_dic = {}
# while video is running (视频运行时)
while True:
return_value, frame = vid.read()
if return_value:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
image = Image.fromarray(frame)
else:
print('Video has ended or failed, try a different video format!')
break
frame_num +=1
print('Frame #: ', frame_num)
frame_size = frame.shape[:2]
image_data = cv2.resize(frame, (input_size, input_size))
image_data = image_data / 255.
image_data = image_data[np.newaxis, ...].astype(np.float32)
start_time = time.time()
# run detections on tflite if flag is set (如果设置了标记,运行tflite检测)
if FLAGS.framework == 'tflite':
interpreter.set_tensor(input_details[0]['index'], image_data)
interpreter.invoke()
pred = [interpreter.get_tensor(output_details[i]['index']) for i in range(len(output_details))]
# run detections using yolov3 if flag is set (如果设置了标记,则使用yolov3运行检测)
if FLAGS.model == 'yolov3' and FLAGS.tiny == True:
boxes, pred_conf = filter_boxes(pred[1], pred[0], score_threshold=0.25,
input_shape=tf.constant([input_size, input_size]))
else:
boxes, pred_conf = filter_boxes(pred[0], pred[1], score_threshold=0.25,
input_shape=tf.constant([input_size, input_size]))
else:
batch_data = tf.constant(image_data)
pred_bbox = infer(batch_data)
for key, value in pred_bbox.items():
boxes = value[:, :, 0:4]
pred_conf = value[:, :, 4:]
boxes, scores, classes, valid_detections = tf.image.combined_non_max_suppression(
boxes=tf.reshape(boxes, (tf.shape(boxes)[0], -1, 1, 4)),
scores=tf.reshape(
pred_conf, (tf.shape(pred_conf)[0], -1, tf.shape(pred_conf)[-1])),
max_output_size_per_class=50,
max_total_size=50,
iou_threshold=FLAGS.iou,
score_threshold=FLAGS.score
)
# convert data to numpy arrays and slice out unused elements (将数据转换为numpy数组并分割出未使用的元素)
num_objects = valid_detections.numpy()[0] #检测出的所有目标个数 int型
bboxes = boxes.numpy()[0]
bboxes = bboxes[0:int(num_objects)]
scores = scores.numpy()[0]
scores = scores[0:int(num_objects)]
classes = classes.numpy()[0]
classes = classes[0:int(num_objects)] #检测出的类的index,是个np
# format bounding boxes from normalized ymin, xmin, ymax, xmax ---> xmin, ymin, width, height
# (格式化边界框从标准化ymin, xmin, ymax, xmax—> xmin, ymin,宽度,高度)
original_h, original_w, _ = frame.shape
bboxes = utils.format_boxes(bboxes, original_h, original_w)
# store all predictions in one parameter for simplicity when calling functions
# (在调用函数时,为了简单起见,将所有预测存储在一个参数中)
pred_bbox = [bboxes, scores, classes, num_objects]
# read in all class names from config (从配置中读取所有类名)
class_names = utils.read_class_names(cfg.YOLO.CLASSES)
# by default allow all classes in .names file (认情况下允许.names文件中的所有类)
allowed_classes = list(class_names.values()) # 默认的跟踪目标项
# custom allowed classes (uncomment line below to customize tracker for only people)
# (自定义允许的类(取消下面的注释行,只为人定制跟踪器))
# allowed_classes = ['person','car']
# loop through objects and use class index to get class name, allow only classes in allowed_classes list
# (循环遍历对象并使用类索引来获取类名,只允许allowed_classes列表中的类)
# 这里我用来存放当前帧下追踪到的各类的数量dict={'name':num} 然后后面将其打印显示出来
dict={}
for i in allowed_classes:
class_indx=(list(class_names.keys()))[list(class_names.values()).index(i)]
class_num=np.count_nonzero(classes == class_indx)
dict[i]=class_num
# 这里的做法和上面类似,主要是从上面检测出的所有目标中,筛选指定类
names = []
deleted_indx = []
for i in range(num_objects):
#类名的index
class_indx = int(classes[i])
#拿到类名
class_name = class_names[class_indx]
if class_name not in allowed_classes:
deleted_indx.append(i)
else:
names.append(class_name)
names = np.array(names)
# 统计目前追踪的目标数
count = len(names)
# 如果要计数的话
if FLAGS.count:
#打印总个数
cv2.putText(frame, "Objects being tracked: {}".format(count), (5, 35), cv2.FONT_HERSHEY_COMPLEX_SMALL, 2, (0, 255, 0), 2)
y=70
#打印各类个数
for key, value in dict.items():
if value !=0:
cv2.putText(frame, "{} being tracked: {}".format(key,value), (5, y), cv2.FONT_HERSHEY_COMPLEX_SMALL, 2,(0, 255, 0), 2)
y+=35
print("Objects being tracked: {}".format(count))
# delete detections that are not in allowed_classes (删除不属于allowed_classes的检测)
bboxes = np.delete(bboxes, deleted_indx, axis=0)
scores = np.delete(scores, deleted_indx, axis=0)
# encode yolo detections and feed to tracker (编码yolo检测并提供给跟踪器)
features = encoder(frame, bboxes)
detections = [Detection(bbox, score, class_name, feature) for bbox, score, class_name, feature in zip(bboxes, scores, names, features)]
# initialize color map (初始化彩色地图)
cmap = plt.get_cmap('tab20b')
colors = [cmap(i)[:3] for i in np.linspace(0, 1, 20)]
# run non-maxima supression (运行non-maxima压制)
boxs = np.array([d.tlwh for d in detections])
scores = np.array([d.confidence for d in detections])
classes = np.array([d.class_name for d in detections])
indices = preprocessing.non_max_suppression(boxs, classes, nms_max_overlap, scores)
detections = [detections[i] for i in indices]
# Call the tracker (调用跟踪器)
tracker.predict()
tracker.update(detections)
# 创建一个目标数列,用来存放ROI区域内的目标
target = []
track_colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0),
(0, 255, 255), (255, 0, 255), (255, 127, 255),
(127, 0, 255), (127, 0, 127)]
# update tracks (更新跟踪)
for track in tracker.tracks:
if not track.is_confirmed() or track.time_since_update > 1:
continue
bbox = track.to_tlbr()
class_name = track.get_class()
# draw bbox on screen (在屏幕上绘制bbox)
# 框索引:0-左上角点x,1-左上角点y,2-右下角点x,3-右下角点y
color = colors[int(track.track_id) % len(colors)]
color = [i * 255 for i in color]
cv2.rectangle(frame, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), color, 2)
cv2.rectangle(frame, (int(bbox[0]), int(bbox[1])), (int(bbox[0])+(len(class_name)+len(str(track.track_id)))*17, int(bbox[1]+20)), color, -1)
cv2.putText(frame, class_name + "-" + str(track.track_id),(int(bbox[0]), int(bbox[1]+10)),0, 0.75, (255,255,255),2)
# if enable info flag then print details about each track (如果启用信息标志,然后打印每个轨道的详细信息)
if FLAGS.info:
print("Tracker ID: {}, Class: {}, BBox Coords (xmin, ymin, xmax, ymax): {}".format(str(track.track_id), class_name, (int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3]))))
# 检测框中心位置
center = [int((bbox[0] + bbox[2]) / 2), int((bbox[1] + bbox[3]) / 2), int(bbox[2] - bbox[0]),
int(bbox[3] - bbox[1])]
#像字典中添加目标
# 如果没有当前车的ID,则创建{'id':{'trace':[[],[],[],……[]],'trace_frames':num},'id':{'trace':[[],[],[],……[]]}}
if not "%d" % track.track_id in object_dic:
# 创建当前id的字典:key(ID):val{轨迹,丢帧计数器} 当丢帧数超过10帧就删除该对象
object_dic["%d" % track.track_id] = {"trace": [],'traced_frames': 10}
object_dic["%d" % track.track_id]["trace"].append(center)
object_dic["%d" % track.track_id]["traced_frames"] += 1
# 如果有,直接写入
else:
object_dic["%d" % track.track_id]["trace"].append(center)
object_dic["%d" % track.track_id]["traced_frames"] += 1
# 加坐标判断和roi区域设置及画轨迹
if FLAGS.roi:
# 这里提供roi的坐标 [173, 456], [966, 91], [1240, 122], [574, 515]
pts1 = np.array([[173, 456], [966, 91], [1240, 122], [574, 515]], np.int32)
pts1 = pts1.reshape((-1, 1, 2))
cv2.polylines(frame, [pts1], True, (0, 255, 255), thickness=2)
# 判断目标是否在roi区域内
# 拿检测框的下边线中心进行判断
x = int((bbox[0] + bbox[2]) / 2)
y = int(bbox[3]) - 10 # 给个偏移
# 这里是4条线,分别是Lad\Lbc\Lab\Lcd,a左下角、b左上角、c右上角、d右下角。
yab = round(-0.46*x+535.63, 2)
ybc = round(0.11*x+-18.29, 2)
ycd = round(-0.59*x+853.710, 2)
yda = round(0.15*x+430.55, 2)
# 判断中心点是否落入roi内
if (y > yab and y > ybc and y < ycd and y < yda):
target.append(x)
cv2.putText(frame, str('enter'), (int(bbox[2] - 65), int(bbox[3] - 5)),
cv2.FONT_HERSHEY_COMPLEX_SMALL, 1, (255, 255, 0), 1)
cv2.putText(frame, "ROI count: {}".format(str(len(target))), (1500, 35), cv2.FONT_HERSHEY_COMPLEX_SMALL,2, (255, 0, 0), 2)
# 绘制轨迹
for s in object_dic:
i = int(s)
# 这里可以将目标的坐标存起来后面可以继续做目标速度,行驶方向的判断
# xlist, ylist, wlist, hlist = [], [], [], []
# 限制轨迹最大长度
if len(object_dic["%d" % i]["trace"]) > 20:
for k in range(len(object_dic["%d" % i]["trace"]) - 20):
del object_dic["%d" % i]["trace"][k]
# # # 绘制轨迹
if len(object_dic["%d" % i]["trace"]) > 2:
for j in range(1, len(object_dic["%d" % i]["trace"]) - 1):
pot1_x = object_dic["%d" % i]["trace"][j][0]
pot1_y = object_dic["%d" % i]["trace"][j][1]
pot2_x = object_dic["%d" % i]["trace"][j + 1][0]
pot2_y = object_dic["%d" % i]["trace"][j + 1][1]
# if pot2_x == pot1_x and pot1_y == pot2_y:
# del object_dic["%d" % i]
clr = i % 9 # 轨迹颜色随机
cv2.line(frame, (pot1_x, pot1_y), (pot2_x, pot2_y), track_colors[clr], 2)
# 对已经消失的目标予以排除
for s in object_dic:
if object_dic["%d" % int(s)]["traced_frames"] > 0:
object_dic["%d" % int(s)]["traced_frames"] -= 1
for n in list(object_dic):
if object_dic["%d" % int(n)]["traced_frames"] == 0:
del object_dic["%d" % int(n)]
fps = 1.0 / (time.time() - start_time)
print("FPS: %.2f" % fps)
result = np.asarray(frame)
result = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
if not FLAGS.dont_show:
cv2.imshow("Output Video", result)
# if output flag is set, save video file (如果设置了输出标志,保存视频文件)
if FLAGS.output:
out.write(result)
if cv2.waitKey(1) & 0xFF == ord('q'): break
cv2.destroyAllWindows()
if __name__ == '__main__':
try:
app.run(main)
except SystemExit:
pass
标签:ROI,deepsort,yolov4,int,object,cv2,track,bbox,class 来源: https://blog.csdn.net/weixin_48994268/article/details/115724165