其他分享
首页 > 其他分享> > pytorch各种损失函数

pytorch各种损失函数

作者:互联网

官方文档:https://pytorch.org/docs/stable/nn.html#loss-functions

1:torch.nn.L1Loss

mean absolute error (MAE) between each element in the input x and target y .

MAE是指平均绝对误差,也称L1损失:

[公式]

loss = nn.L1Loss()
input = torch.randn(1, 2, requires_grad=True)
target = torch.randn(1, 2)
output = loss(input, target)

 

2:torch.nn.MSELoss

measures the mean squared error (squared L2 norm) between each element in the input x and target y .

loss = nn.MSELoss()
input = torch.randn(1, 2, requires_grad=True)
target = torch.randn(1, 2)
output = loss(input, target)

 

3:torch.nn.NLLLoss && torch.nn.CrossEntropyLoss

torch.nn.NLLLoss是用于多分类的负对数似然损失函数(negative log likelihood loss)

 torch.nn.CrossEntropyLoss是交叉熵损失函数

二者的区别:

m = nn.LogSoftmax(dim=1)
loss = nn.NLLLoss()
# input is of size N x C = 3 x 5
input = torch.randn(3,5,requires_grad=True)
#each element in target has to have 0 <= value < C
target = target = torch.empty(3, dtype=torch.long).random_(5)
output = loss(m(input), target)
print(output)

loss = nn.CrossEntropyLoss()
output = loss(input, target)
print(output)

 

 

4:torch.nn.BCELoss && torch.nn.BCEWithLogitsLoss

二者的区别:

m = torch.nn.Sigmoid()
loss = torch.nn.BCELoss()
input = torch.randn(3,requires_grad=True)
target = torch.empty(3).random_(2)
output = loss(m(input), target)
print(output)

loss = torch.nn.BCEWithLogitsLoss()
output = loss(input, target)
print(output)

 

ref:https://www.cnblogs.com/wanghui-garcia/p/10862733.html

标签:loss,randn,函数,nn,torch,损失,pytorch,input,target
来源: https://www.cnblogs.com/xiaoxiaomajinjiebiji/p/13984479.html