其他分享
首页 > 其他分享> > 跟李沐学AI–锚框代码解析–3

跟李沐学AI–锚框代码解析–3

作者:互联网

跟李沐学AI–锚框代码解析–3

非极大值抑制预测边界框

非极大抑制方法的应用:

def multibox_detection(cls_probs, offset_preds, anchors, nms_threshold=0.5,
                       pos_threshold=0.009999999):
    """使用非极大值抑制来预测边界框
    args:
        cls_probs: 锚框对于不同类别的概率
            [batch_size, 1+class_num, anchors_num]
        offset_preds: 不同锚框的偏移量
            [anchors_num * 4]
        anchors: 锚框矩阵
            [anchors_num]
    """
    device, batch_size = cls_probs.device, cls_probs.shape[0]
    anchors = anchors.squeeze(0)
    num_classes, num_anchors = cls_probs.shape[1], cls_probs.shape[2]
    out = []
    for i in range(batch_size):
        cls_prob, offset_pred = cls_probs[i], offset_preds[i].reshape(-1, 4)
        '''得到最大置信度,及所对应的种类
            cls_prob 每一列代表的为单一锚框对应的不同类的置信度
            conf: 每个锚框对于不同类的最大置信度 --> [anchors_num]
            class_id: 每个锚框最大置信度对应的种类 --> [anchors_num]'''
        conf, class_id = torch.max(cls_prob[1:], 0)
        '''将带偏移量的边框转变为预测边框'''
        predicted_bb = offset_inverse(anchors, offset_pred)
        keep = nms(predicted_bb, conf, nms_threshold)
        # 找到所有的 non_keep 索引,并将类设置为背景,就是设置为-1
        all_idx = torch.arange(num_anchors, dtype=torch.long, device=device)
        '''找到没有非极大值边框的编号,并排序,keep在前,non_keep在后'''
        combined = torch.cat((keep, all_idx))
        uniques, counts = combined.unique(return_counts=True)
        non_keep = uniques[counts == 1]
        ''' all_id_sorted作为之后置信度和预测边框的索引 '''
        all_id_sorted = torch.cat((keep, non_keep))
        '''对于没有保留的锚框,认为是背景锚框,基于non_keep将class_id变成 -1
           并按all_id_sorted 对class_id 进行重排'''
        class_id[non_keep] = -1 
        class_id = class_id[all_id_sorted]
        ''' 将各个锚框的最大置信度和各个预测框,按照NMS值进行排序 '''
        conf, predicted_bb = conf[all_id_sorted], predicted_bb[all_id_sorted]
        # `pos_threshold` 是一个用于非背景预测的阈值
        '''将置信度小于阈值的预测框的id设置为 -1, 抛弃'''
        below_min_idx = (conf < pos_threshold)
        class_id[below_min_idx] = -1
        '''对小于阈值的预测锚框的概率值,进行计算处理'''
        conf[below_min_idx] = 1 - conf[below_min_idx]
        ''' 将类别信息、置信度和预测边框按列合并
           pred_info --> [anchors_num, 6] '''
        pred_info = torch.cat((class_id.unsqueeze(1),
                               conf.unsqueeze(1),
                               predicted_bb), dim=1)
        out.append(pred_info)
    return torch.stack(out)
def offset_inverse(anchors, offset_preds):
    """根据带有预测偏移量的锚框来计算预测边界框。"""
    anc = d2l.box_corner_to_center(anchors)
    pred_bbox_xy = (offset_preds[:, :2] * anc[:, 2:] / 10) + anc[:, :2]
    pred_bbox_wh = torch.exp(offset_preds[:, 2:] / 5) * anc[:, 2:]
    pred_bbox = torch.cat((pred_bbox_xy, pred_bbox_wh), axis=1)
    predicted_bbox = d2l.box_center_to_corner(pred_bbox)
    return predicted_bbox

标签:置信度,预测,AI,锚框,李沐学,keep,id,anchors
来源: https://blog.csdn.net/qq_34992900/article/details/120728762