其他分享
首页 > 其他分享> > Pytorch nn.BCEWithLogitsLoss() 的简单理解与用法

Pytorch nn.BCEWithLogitsLoss() 的简单理解与用法

作者:互联网

这个东西,本质上和nn.BCELoss()没有区别,只是在BCELoss上加了个logits函数(也就是sigmoid函数),例子如下:

import torch
import torch.nn as nn

label = torch.Tensor([1, 1, 0])
pred = torch.Tensor([3, 2, 1])
pred_sig = torch.sigmoid(pred)
loss = nn.BCELoss()
print(loss(pred_sig, label))

loss = nn.BCEWithLogitsLoss()
print(loss(pred, label))

loss = nn.BCEWithLogitsLoss()
print(loss(pred_sig, label))

输出结果分别为:

tensor(0.4963)
tensor(0.4963)
tensor(0.5990)

可以看到,nn.BCEWithLogitsLoss()相当于是在nn.BCELoss()中预测结果pred的基础上先做了个sigmoid,然后继续正常算loss。所以这就涉及到一个比较奇葩的bug,如果网络本身在输出结果的时候已经用sigmoid去处理了,算loss的时候用nn.BCEWithLogitsLoss()…那么就会相当于预测结果算了两次sigmoid,可能会出现各种奇奇怪怪的问题——

 

标签:loss,nn,sigmoid,pred,torch,BCEWithLogitsLoss,Pytorch
来源: https://www.cnblogs.com/BlairGrowing/p/15970528.html