跟李沐学AI–锚框代码解析–3
作者:互联网
跟李沐学AI–锚框代码解析–3
非极大值抑制预测边界框
- 当存在许多锚框时,可能会输出许多相似的具有明显重带你的预测边界框,围绕同一目标,为了简化输出,使用给非极大值抑制(non-maximum suppression, NMS)合并对应目标为同一类的类似的预测边界框
- 其工作原理如下:
- 基础概念:对于一个预测边界框B,目标检测模型会计算每个类的预测概率,最大预测概率 p p p 所对应的类别,就是边框 B B B 的类别,这里 p p p 为 B B B的置信度,对于同一张图像,所有非背景预测边框按照置信都降序排序,生成列表 L L L
- 操作过程:
- 从
L
L
L 中选取置信度最高的预测边界框
B
1
B_1
B1作为基准,然后将所有与
B
1
B_1
B1的
I
o
U
IoU
IoU 超过预定阈值
ϵ
\epsilon
ϵ 的非基准预测边接框从
L
L
L 中移除。此时在
L
L
L 中,对于
B
1
B_1
B1 只剩下一个可用边界框,其他的相似锚框均基于上述标准被删除
- 注意这里求IoU值,是以 B 1 B_1 B1 为基准,使用其他预测锚框与 B 1 B_1 B1 进行计算IoU,而不是真实边框
- 从 L L L 中选取置信度第二稿的预测边框 B 2 B_2 B2 作为有一个基准,然后将所有与 B 2 B_2 B2的IoU大于 ϵ \epsilon ϵ的非基准预测边界框从 L L L 中移除;
- 重复上述过程,历遍 L L L中所有的锚框,直到 L L L中所有的预测边界框都曾被用作基准;此时 L L L中任意一对预测边界框的IoU都小于阈值 ϵ \epsilon ϵ,没有一对锚框相似
- 代码如下:
-
def nms(boxes, scores, iou_threshold): """对预测边界框的置信度进行排序 args: boxes: 预测边框 [anchors_num, 4] scores: 置信度 [anchors_num] iou_threshold: iou阈值 """ B = torch.argsort(scores, dim=-1, descending=True) '''返回scores排序后的下标 B --> tensor([0, 3, 1, 2]) ''' keep = [] # 保留预测边界框的指标 '''B.numel() 返回的tensor中的元素个数''' while B.numel() > 0: i = B[0] keep.append(i) if B.numel() == 1: break iou = box_iou(boxes[i, :].reshape(-1, 4), boxes[B[1:], :].reshape(-1, 4)).reshape(-1) '''iou 计算的为 B1与 B2, B3,...的iou一维矩阵 iou --> tensor([0.00, 0.74, 0.55])''' inds = torch.nonzero(iou <= iou_threshold).reshape(-1) ''' inds 返回的为所有iou小于阈值的下标 ''' B = B[inds + 1] '''由于iou矩阵长度为 anchors_num-1, 剔除了最大的数值,因此在这里需要加1''' return torch.tensor(keep, device=boxes.device)
非极大抑制方法的应用:
- 该部分由一个函数实现,主要步骤简述如下:
- a. 根据锚框与类的置信度矩阵,求取每个锚框的最大置信度和其最大置信度所对应的类
- b. 利用转换函数,将带偏移量的锚框转为预测锚框,并基于预测锚框使用非极大值抑制方法进行筛选,并将keep和non_keep的下标进行排序(使用的为torch.cat直接拼接),其中keep对应的是物种种类,non_keep对应的为背景,利用合并排序后的下标对 锚框最大置信度 和 预测边框的顺序进行重排
- c. 对置信度小于置信度阈值的锚框进行处理,设置为背景锚框,在锚框类的预测概率中储存的为 1- p p p
- d. 最后对上述结果进行合并,最内层维度中的六个元素提供了同一预测边界框的输出信息。 第一个元素是预测的类索引,从 0 开始(0代表狗,1代表猫),值 -1 表示背景或在非极大值抑制中被移除了。 第二个元素是预测的边界框的置信度。 其余四个元素分别是预测边界框左上角和右下角的 (x,y) 轴坐标(范围介于 0 和 1 之间)。
- 代码:
def multibox_detection(cls_probs, offset_preds, anchors, nms_threshold=0.5,
pos_threshold=0.009999999):
"""使用非极大值抑制来预测边界框
args:
cls_probs: 锚框对于不同类别的概率
[batch_size, 1+class_num, anchors_num]
offset_preds: 不同锚框的偏移量
[anchors_num * 4]
anchors: 锚框矩阵
[anchors_num]
"""
device, batch_size = cls_probs.device, cls_probs.shape[0]
anchors = anchors.squeeze(0)
num_classes, num_anchors = cls_probs.shape[1], cls_probs.shape[2]
out = []
for i in range(batch_size):
cls_prob, offset_pred = cls_probs[i], offset_preds[i].reshape(-1, 4)
'''得到最大置信度,及所对应的种类
cls_prob 每一列代表的为单一锚框对应的不同类的置信度
conf: 每个锚框对于不同类的最大置信度 --> [anchors_num]
class_id: 每个锚框最大置信度对应的种类 --> [anchors_num]'''
conf, class_id = torch.max(cls_prob[1:], 0)
'''将带偏移量的边框转变为预测边框'''
predicted_bb = offset_inverse(anchors, offset_pred)
keep = nms(predicted_bb, conf, nms_threshold)
# 找到所有的 non_keep 索引,并将类设置为背景,就是设置为-1
all_idx = torch.arange(num_anchors, dtype=torch.long, device=device)
'''找到没有非极大值边框的编号,并排序,keep在前,non_keep在后'''
combined = torch.cat((keep, all_idx))
uniques, counts = combined.unique(return_counts=True)
non_keep = uniques[counts == 1]
''' all_id_sorted作为之后置信度和预测边框的索引 '''
all_id_sorted = torch.cat((keep, non_keep))
'''对于没有保留的锚框,认为是背景锚框,基于non_keep将class_id变成 -1
并按all_id_sorted 对class_id 进行重排'''
class_id[non_keep] = -1
class_id = class_id[all_id_sorted]
''' 将各个锚框的最大置信度和各个预测框,按照NMS值进行排序 '''
conf, predicted_bb = conf[all_id_sorted], predicted_bb[all_id_sorted]
# `pos_threshold` 是一个用于非背景预测的阈值
'''将置信度小于阈值的预测框的id设置为 -1, 抛弃'''
below_min_idx = (conf < pos_threshold)
class_id[below_min_idx] = -1
'''对小于阈值的预测锚框的概率值,进行计算处理'''
conf[below_min_idx] = 1 - conf[below_min_idx]
''' 将类别信息、置信度和预测边框按列合并
pred_info --> [anchors_num, 6] '''
pred_info = torch.cat((class_id.unsqueeze(1),
conf.unsqueeze(1),
predicted_bb), dim=1)
out.append(pred_info)
return torch.stack(out)
- 带偏移量的锚框转换函数的代码如下:
def offset_inverse(anchors, offset_preds):
"""根据带有预测偏移量的锚框来计算预测边界框。"""
anc = d2l.box_corner_to_center(anchors)
pred_bbox_xy = (offset_preds[:, :2] * anc[:, 2:] / 10) + anc[:, :2]
pred_bbox_wh = torch.exp(offset_preds[:, 2:] / 5) * anc[:, 2:]
pred_bbox = torch.cat((pred_bbox_xy, pred_bbox_wh), axis=1)
predicted_bbox = d2l.box_center_to_corner(pred_bbox)
return predicted_bbox
标签:置信度,预测,AI,锚框,李沐学,keep,id,anchors 来源: https://blog.csdn.net/qq_34992900/article/details/120728762