极简CenterNet(二)核心代码
作者:互联网
本节给出网络结构、损失函数、训练和验证部分等主要代码,并使用几种简单数据集进行了训练验证。
1,resnet.py
原论文中centernet的主网络部分分别使用了hourglass,DLA,resnet三种网络,其中resnet是最简单的,我们的极简代码当然先从resnet18结构入手。
代码见https://github.com/zzzxxxttt/pytorch_simple_CenterNet_45/blob/master/nets/resnet.py,原封不动,我就不贴出来了。这是一个简单的网络结构:输入图像是B x 3 x 512 x 512的(1)然后在resnet基础上去掉最后的全连接层,经过layer1~4之后得到的特征图尺寸B x 512 x 16 x 16;(2)然后连接上三层反卷积层构成上采样层,使特征图上采样到B x 256 x 128 x 128的特征图;(3)然后连接三个分支,分别输出heatmap(热点图),regs(中心点偏移量),w_h_(宽高)。每个分支都是一个两层卷积结构,其中heatmap分支的输出是B x C x 128 x 128,C表示num_classes即检测的目标类别数,例如coco是80,它的每样本每通道的128 x 128图可以理解为对应类别在每个像素上的置信度;regs分支的输出是B x 2 x 128 x 128,表示预测的中心点的x方向和y方向偏移量,由于我们对原图片进行了4倍的缩小,所以再取整后会造成截断误差,这个regs就是为了补偿这个截断误差的,不是特别重要,不要它也不影响多少精度;w_h_分支的输出也是B x 2 x 128 x 128,表示检测框的宽高,这个当然是很重要的,再具体解释一下更好理解,其中每个128 x 128图中的点的两个通道的数值可以理解为“假设该点是目标中心时的检测框宽高”,至于这个点到底是不是真的目标中心,则由heatmap中该点的置信度来确定。
懒得画图了,网上介绍这个原理的图很多,下面借用https://www.jianshu.com/p/d5d7cd7ad200上的一张,看看就明白了:
2, train.py
我在https://github.com/zzzxxxttt/pytorch_simple_CenterNet_45基础上简化修改,使它更简洁:
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset, TensorDataset
from resnet import get_pose_net
from data_loader import CustomizeDataset
import cv2
import torch.nn.functional as F
from utils.utils import _tranpose_and_gather_feature
from utils.post_process import ctdet_decode, _nms
import matplotlib.pyplot as plt
import time
t0 = time.time()
train_dataset = CustomizeDataset(mode='train',num_classes=2)
val_dataset = CustomizeDataset(mode='val',num_classes=2)
kwargs = {"num_workers": 0, "pin_memory": True}
train_loader = DataLoader(dataset=train_dataset, shuffle=False, batch_size=20, **kwargs)
val_loader = DataLoader(dataset=val_dataset, shuffle=False, batch_size=1, **kwargs)
def _neg_loss(preds, targets):
pos_inds = targets.eq(1).float()
neg_inds = targets.lt(1).float()
neg_weights = torch.pow(1 - targets, 4)
preds = torch.clamp(preds, min=1e-4, max=1 - 1e-4)
pos_loss = torch.log(preds) * torch.pow(1 - preds, 2) * pos_inds
neg_loss = torch.log(1 - preds) * torch.pow(preds, 2) * neg_weights * neg_inds
num_pos = pos_inds.float().sum()
pos_loss = pos_loss.sum()
neg_loss = neg_loss.sum()
loss = - (pos_loss + neg_loss) / num_pos
return loss / len(preds)
def _reg_loss(regs, gt_regs, mask):
mask = mask[:, :, None].expand_as(gt_regs).float()
loss = sum(F.l1_loss(r * mask, gt_regs * mask, reduction='sum') / (mask.sum() + 1e-4) for r in regs)
return loss / len(regs)
net = get_pose_net(num_layers=18, head_conv=64, num_classes=2)
net = net.cuda()
net.train()
optimizer = torch.optim.Adam(net.parameters(), 1e-3)
losses_record = []
for epoch in range(10):
for idx,data in enumerate(train_loader):
img, heatmap, labels, gt_regs, gt_wh, inds, masks, bbox = data
img, heatmap, gt_regs, gt_wh, inds, masks = \
img.cuda(), heatmap.cuda(), gt_regs.cuda(), gt_wh.cuda(), inds.cuda(), masks.cuda()
hmap, regs, wh = net(img)[0]
hmap = torch.sigmoid(hmap)
hmap_loss = _neg_loss(hmap, heatmap)
regs = _tranpose_and_gather_feature(regs, inds)
wh = _tranpose_and_gather_feature(wh, inds)
reg_loss = _reg_loss(regs, gt_regs, masks)
w_h_loss = _reg_loss(wh, gt_wh, masks)
loss = 10*hmap_loss + 1 * reg_loss + 0.1 * w_h_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(idx,'hmap_loss:%.4f, reg_loss:%.4f, w_h_loss:%.4f, loss:%.4f'%(
hmap_loss.item(), reg_loss.item(),w_h_loss.item(),loss.item()),'time:%.1f'%(time.time()-t0))
losses_record.append([hmap_loss.item(), reg_loss.item(),w_h_loss.item(),loss.item()])
if idx%10==0:
plt.figure();plt.imshow(_nms(hmap)[0,0].data.cpu().numpy())
plt.figure();plt.imshow(heatmap[0,0].data.cpu().numpy())
plt.figure();plt.imshow(_nms(hmap)[0,1].data.cpu().numpy())
plt.figure();plt.imshow(heatmap[0,1].data.cpu().numpy())
losses_record = np.array(losses_record)
plt.figure();plt.semilogy(losses_record);plt.legend(['hmap loss','regs loss','wh loss','total loss'])
def IOU(box1,box2):
xA = max(box1[0], box2[0])
yA = max(box1[1], box2[1])
xB = min(box1[2], box2[2])
yB = min(box1[3], box2[3])
interArea = max(0,(xB - xA + 1)) * max(0,(yB - yA + 1))
box1Area = (box1[2] - box1[0] + 1) * (box1[3] - box1[1] + 1)
box2Area = (box2[2] - box2[0] + 1) * (box2[3] - box2[1] + 1)
iou = interArea / float(box1Area + box2Area - interArea)
return iou
net.eval()
ious = []
with torch.no_grad():
for idx,data in enumerate(val_loader):
img, heatmap, labels, gt_regs, gt_wh, inds, masks, bbox = data
img, heatmap, gt_regs, gt_wh, inds, masks = \
img.cuda(), heatmap.cuda(), gt_regs.cuda(), gt_wh.cuda(), inds.cuda(), masks.cuda()
hmap, regs, wh = net(img)[0]
dets = ctdet_decode(hmap, regs, wh )
dets = dets.detach().cpu().numpy().reshape(1, -1, dets.shape[2])[0]
image = img[0].permute(1,2,0).data.cpu().numpy()
image = (image/2 + 0.5)*255
image = image.astype('uint8')[:,:,::-1]
image = image.copy()
bbox = bbox[0].data.cpu().numpy()
labels = labels[0].data.cpu().numpy()
pred_box = []
for label in range(heatmap.shape[1]):
bbox_l = bbox[labels==label].copy()
det_l = dets[dets[:,5]==label].copy()
box_num = bbox_l.shape[0]
for n in range(box_num):
det = det_l[n]
det[4] = det[4]*100
det[:4] = det[:4]*4
det = det.round().astype('int')
pred_box.append(det[:4])
box = bbox_l[n]
image = cv2.rectangle(image, (box[0], box[1]), (box[2], box[3]), (0, 220, 0), 2)
image = cv2.rectangle(image, (det[0], det[1]), (det[2], det[3]), (0, 0, 220), 2)
iou = IOU(box,det[:4]) #TODO:此计算方法对每类多目标时不正确,待后续修改
ious.append(iou)
cv2.imwrite('fig2/%06d.jpg'%(999-idx),image)
mean_iou = np.mean(ious)
print('mean iou: %.4f'%mean_iou)
其他一些utils中的函数也请参见https://github.com/zzzxxxttt/pytorch_simple_CenterNet_45,不做修改。
这其中关于heatmap loss,也就是代码中的_neg_loss函数需要说明一下:这个损失是centernet的关键,开始不是很容易理解,实际上它就是对focal loss的再升级,在focal loss的基础上加上了对标注框热点图附近的衰减。但需要注意的是它的目的不是引导preds和targets一致(这和通常的损失函数不一样),它的目的是使preds趋向于单个中心点为1,其他点为0的输出图。不信可以试一下,如果preds和targets都是heatmap时,heatmap loss并不是0,只有当preds是中心单点为1其他为0,而targets是heatmap时,heatmap loss才为0。个人认为这是一个大坎,理解了这一点之后就很容易理解别的了。
3,效果
我们先用几种简单的数据集来检验,训练集都用800张图片,验证集用200张图片。由于数据非常简单,我们不用mAP指标(将会太高),我们用mIOU指标来验证效果。其中单个五星数据集和单个正方+单个四芒星数据集验证结果中的检测框(红色)和标注框(绿色)的情况如下图(也有个别IOU低一些的这里没有画)。
具体试验数据对比见下表:
数据集 | lr | 训练速度 | mIOU | 备注 |
---|---|---|---|---|
单个正方形,不旋转 | 1e-3 | 42fps | 0.8740 | |
单个正方形,随机旋转 | 1e-3 | 42fps | 0.0000 | 不收敛 |
单个正方形,随机旋转 | 1e-4 | 42fps | 0.9525 | |
单个五角星,随机旋转 | 1e-4 | 42fps | 0.9750 | |
单个五角星+单个五边形,随机旋转 | 1e-4 | 40fps | 0.9638 | |
两个五角星,随机旋转 | 1e-4 | 40fps | 0.9668 |
从损失函数变化图中我们可以看出:
- regs loss只在训练初期会下降,后期一直不再变化,说明中心点偏移损失regs loss确实起点作用,但作用不大;
- hmap loss和wh loss表现出轮动的现象:hmap loss初期就开始下降,但wh loss初期不动,当hmap loss下降到一定程度后,也就是说当中心点找的比较靠谱后wh loss才开始下降;而wh loss后期基本没法再优化后,wh loss还会持续下降,这是因为此时输出的heatmap图开始继续提高中心点的聚焦度,并提高中心点的置信度,最终目的是趋向于中心点1,而其他点0。
- 对于单正方形旋转数据集与不旋转数据集相比,wh损失会高很多,这是因为我们的标注框是根据外接矩形标注的,而外接矩形这个几何概念网络较难掌握,对于一个倾斜的正方形,网络计算它的外接矩形会比较困难。
- 学习率很重要,过大学习率会导致heatmap损失难以继续下降,可能是输出的hmap已经和输入的heatmap接近了,此时必须用更小的学习率才能学习。
通过以上分析,其实我们可以得出几种网络的改进思路:
- regs loss初期可以不用,最后的这个网络分支都可以锁死不用,到最后几轮再打开训练一下即可,这样可以节省训练时间。
- heatmap在训练初期可使用,后期可以去掉高斯圆,就只用中心单点,效果应该更好一些。
关于heatmap的更多讨论,以及其他一些讨论,请见下节。
这几天腰疼病犯了,没法久坐,所以匆匆写了写,文字没有润色修剪,可能说的不是很好懂,各位将就看看吧。
标签:极简,gt,loss,代码,CenterNet,hmap,regs,heatmap,wh 来源: https://blog.csdn.net/Brikie/article/details/116274548