resa推理
作者:互联网
参考https://blog.csdn.net/qq_42178122/article/details/122787261博主的博文
import os
import os.path as osp
import time
import shutil
import torch
import torchvision
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import torch.optim
import cv2
import numpy as np
import models
import argparse
from utils.config import Config
from runner.runner import Runner
from datasets import build_dataloader
color_list =[
(255, 0, 0),
(255, 225, 0),
(255, 0, 255),
(125, 125, 125),
(255, 125, 125),
(0, 125, 0)
]
def main():
args = parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(gpu) for gpu in args.gpus)
cfg = Config.fromfile(args.config)
cfg.gpus = len(args.gpus)
cfg.load_from = args.load_from
cfg.finetune_from = args.finetune_from
cfg.view = args.view
cfg.work_dirs = args.work_dirs + '/' + cfg.dataset.train.type
cudnn.benchmark = True
cudnn.fastest = True
runner = Runner(cfg)
runner.net.eval()
val_loader = build_dataloader(cfg.dataset.val, cfg, is_train=False)
def to_cuda(batch):
for k in batch:
if k == 'meta':
continue
batch[k] = batch[k].cuda()
return batch
def is_short(lane):
start = [i for i, x in enumerate(lane) if x > 0]
if not start:
return 1
else:
return 0
def probmap2lane( seg_pred, exist, b, resize_shape=(720, 1280), smooth=True, y_px_gap=10, pts=56, thresh=0.6):
"""
Arguments:
----------
seg_pred: np.array size (5, h, w)
resize_shape: reshape size target, (H, W)
exist: list of existence, e.g. [0, 1, 1, 0]
smooth: whether to smooth the probability or not
y_px_gap: y pixel gap for sampling
pts: how many points for one lane
thresh: probability threshold
Return:
----------
coordinates: [x, y] list of lanes, e.g.: [ [[9, 569], [50, 549]] ,[[630, 569], [647, 549]] ]
"""
if resize_shape is None:
resize_shape = seg_pred.shape[1:] # seg_pred (5, h, w)
_, h, w = seg_pred.shape
H, W = resize_shape
coordinates = []
a = 0
for i in range(cfg.num_classes - 1):
prob_map = seg_pred[i + 1] # seg_pred[0]:背景
if smooth:
prob_map = cv2.blur(prob_map, (9, 9), borderType=cv2.BORDER_REPLICATE)
coords = get_lane(prob_map, y_px_gap, pts, thresh, resize_shape)
# print(exist)
# if (int)(b[i]) == 0: # if (int)(exist[i])==0:
# continue
if is_short(coords):
continue
coordinates.append(
[[coords[j], H - 10 - j * y_px_gap] if coords[j] > 0 else [-1, H - 10 - j * y_px_gap] for j in
range(pts)])
# if (int)(exist[i])==1:
# a =a+1
# if a==2:
# break
if len(coordinates) == 0:
coords = np.zeros(pts)
coordinates.append(
[[coords[j], H - 10 - j * y_px_gap] if coords[j] > 0 else [-1, H - 10 - j * y_px_gap] for j in
range(pts)])
# print(coordinates)
return coordinates
def fix_gap(coordinate):
if any(x > 0 for x in coordinate):
start = [i for i, x in enumerate(coordinate) if x > 0][0]
end = [i for i, x in reversed(list(enumerate(coordinate))) if x > 0][0]
lane = coordinate[start:end+1]
if any(x < 0 for x in lane):
gap_start = [i for i, x in enumerate(
lane[:-1]) if x > 0 and lane[i+1] < 0]
gap_end = [i+1 for i,
x in enumerate(lane[:-1]) if x < 0 and lane[i+1] > 0]
gap_id = [i for i, x in enumerate(lane) if x < 0]
if len(gap_start) == 0 or len(gap_end) == 0:
return coordinate
for id in gap_id:
for i in range(len(gap_start)):
if i >= len(gap_end):
return coordinate
if id > gap_start[i] and id < gap_end[i]:
gap_width = float(gap_end[i] - gap_start[i])
lane[id] = int((id - gap_start[i]) / gap_width * lane[gap_end[i]] + (
gap_end[i] - id) / gap_width * lane[gap_start[i]])
if not all(x > 0 for x in lane):
print("Gaps still exist!")
coordinate[start:end+1] = lane
return coordinate
def get_lane(prob_map, y_px_gap, pts, thresh, resize_shape=None):
"""
Arguments:
----------
prob_map: prob map for single lane, np array size (h, w)
resize_shape: reshape size target, (H, W)
Return:
----------
coords: x coords bottom up every y_px_gap px, 0 for non-exist, in resized shape
"""
if resize_shape is None:
resize_shape = prob_map.shape
h, w = prob_map.shape
H, W = resize_shape
H -= cfg.cut_height
coords = np.zeros(pts)
coords[:] = -1.0
for i in range(pts):
y = int((H - 10 - i * y_px_gap) * h / H)
if y < 0:
break
line = prob_map[y, :]
id = np.argmax(line)
if line[id] > thresh:
coords[i] = int(id / w * W)
if (coords > 0).sum() < 2:
coords = np.zeros(pts)
fix_gap(coords)
# print(coords.shape)
return coords
def view(img, coords, file_path=None):
i=0
for coord in coords:
for x, y in coord:
if x <= 0 or y <= 0:
continue
x, y = int(x), int(y)
cv2.circle(img, (x, y), 4, color_list[i], 2)
i = i+1
# if file_path is not None:
# if not os.path.exists(osp.dirname(file_path)):
# os.makedirs(osp.dirname(file_path))
# cv2.imwrite(file_path, img)
import time
time_start = time.clock()
fps = 0.0
capture = cv2.VideoCapture("/media/gooddz/新加卷/检测视频/极弯场景.mp4")
import torchvision
import utils.transforms as tf
def transform_val():
val_transform = torchvision.transforms.Compose([
tf.SampleResize((640, 368)),
tf.GroupNormalize(mean=([103.939, 116.779, 123.68], (0, )), std=(
[1., 1., 1.], (1, ))),
])
return val_transform
while (True):
t1 = time.time()
ref,frame = capture.read()
# img_test1 = cv.resize(img, (int(y / 2), int(x / 2)))
frame = cv2.resize(frame,(1280,720))
frame_copy = frame.copy()
frame = frame[160:, :, :]
# print(type(frame))
# frame = frame[None,:]
# val_transform = transforms.Compose([
# tf.SampleResize((640, 368)),
# tf.GroupNormalize(mean=([103.939, 116.779, 123.68], (0,)), std=(
# [1., 1., 1.], (1,))),
# ])
# print(frame.shape)
transform = transform_val()
frame = transform((frame,))
# print(frame, "zzz")
# print(frame[0].shape)
frame = torch.from_numpy(frame[0]).permute(2, 0, 1).contiguous().float()
frame = torch.tensor(frame)
# print(frame.shape)
frame = frame.unsqueeze(0).float()
frame = frame.cuda()
with torch.no_grad():
# print(data['img'])
output = runner.net(frame)
# print(output)
seg_pred, exist_pred = output['seg'], output['exist']
# a = output['exist_lane']
# _, b_1 = torch.max(F.softmax(a, dim=2), 2)
# print(F.softmax(a, dim=1),b)
# a = F.softmax(a, dim=0)
# print(b,b.shape)
# s = torch.argmax(seg_pred[0],0)
# s = s.detach().cpu().numpy()
# dst_binary_image = np.zeros([s.shape[0], s.shape[1]], np.uint8)
# for y in range(s.shape[0]):
# for x in range(s.shape[1]):
# dst_binary_image[y,x] = (s[y,x]*40)
# cv2.imshow("zz",dst_binary_image)
# cv2.waitKey(5)
seg_pred = F.softmax(seg_pred, dim=1)
seg_pred = seg_pred.detach().cpu().numpy()
exist_pred = exist_pred.detach().cpu().numpy()
# print(b, b.shape, exist_pred, exist_pred.shape)
for b in range(len(seg_pred)):
seg = seg_pred[b]
# print(len(seg_pred))
exist_1 = [1 if exist_pred[b, i] >
0.5 else 0 for i in range(cfg.num_classes - 1)]
lane_coords = probmap2lane(seg, exist_1, thresh=0.6, b=exist_1[b])
# print(lane_coords)
for i in range(len(lane_coords)):
lane_coords[i] = sorted(
lane_coords[i], key=lambda pair: pair[1])
# frame = np.array(frame)
# print(lane_coords)
# print(frame_copy.shape, type(frame_copy))
view(frame_copy, lane_coords)
# frame = frame[0].permute([1, 2, 0])
# (720, 1280, 3)
# print(frame.shape)
fps = (fps + (1. / (time.time() - t1))) / 2
# print(frame[0].shape,frame)
# frame_copy = frame_copy.astype(np.uint8)
# cv2.namedWindow('imshow', cv2.WINDOW_NORMAL)
cv2.imshow('imshow', frame_copy)
cv2.waitKey(1)
print("fps:", fps)
cv2.destroyAllWindows()
time_end = time.clock()
print(time_end-time_start)
def parse_args():
parser = argparse.ArgumentParser(description='Train a detector')
parser.add_argument('config', help='train config file path')
parser.add_argument(
'--work_dirs', type=str, default='work_dirs',
help='work dirs')
parser.add_argument(
'--load_from', default='/home/llgj/桌面/ldz/resa-main_原/work_dirs/TuSimple/20220120_083126_lr_2e-02_b_4/ckpt/best.pth')
parser.add_argument(
'--finetune_from', default=None,
help='whether to finetune from the checkpoint')
parser.add_argument(
'--validate',
action='store_true',
help='whether to evaluate the checkpoint during training')
parser.add_argument(
'--view',
action='store_true',
help='whether to show visualization result')
parser.add_argument('--gpus', nargs='+', type=int, default='0')
parser.add_argument('--seed', type=int,
default=None, help='random seed')
args = parser.parse_args()
return args
if __name__ == '__main__':
main()
#configs/tusimple.py --gpus 0
#configs/tusimple.py --validate --load_from /media/gooddz/学习/culane_resnet50.pth --gpus 0 --view
标签:lane,cfg,gap,shape,coords,import,resa,推理 来源: https://blog.csdn.net/qq_45013882/article/details/122832847