其他分享
首页 > 其他分享> > YOLO添加Focal loss

YOLO添加Focal loss

作者:互联网

将YOLOv3及以上的网络中的BCE loss更改为Focal loss
loss函数分为三部分,位置损失、置信度损失、类别损失,此处只需要将置信度损失更换为Focal loss,具体原理请仔细理解置信度损失的含义。
YOLOX链接:https://link.zhihu.com/?target=https%3A//github.com/Megvii-BaseDetection/YOLOX

1 找到置信度预测损失计算位置loss_obj,并进行替换(位置在386-405行左右)

loss_iou:定位损失;loss_obj:置信度预测损失;loss_cls:预测损失
        loss_iou = (
            self.iou_loss(bbox_preds.view(-1, 4)[fg_masks], reg_targets)
        ).sum() / num_fg
        #loss_obj = (  
        #    self.bcewithlog_loss(obj_preds.view(-1, 1), obj_targets)
        #).sum() / num_fg
        loss_obj = (
            self.focal_loss(obj_preds.sigmoid().view(-1, 1), obj_targets)
        ).sum() / num_fg
        loss_cls = (
            self.bcewithlog_loss(
                cls_preds.view(-1, self.num_classes)[fg_masks], cls_targets
            )
        ).sum() / num_fg

2 创建focal_loss方法,放到def get_l1_target(…)之前即可,代码如下:

def focal_loss(self, pred, gt):
        pos_inds = gt.eq(1).float()
        neg_inds = gt.eq(0).float()
        pos_loss = torch.log(pred+1e-5) * torch.pow(1 - pred, 2) * pos_inds * 0.75
        neg_loss = torch.log(1 - pred+1e-5) * torch.pow(pred, 2) * neg_inds * 0.25
        loss = -(pos_loss + neg_loss)
        return loss

标签:loss,obj,self,YOLO,num,fg,置信度,Focal
来源: https://blog.csdn.net/weixin_43850171/article/details/123140328