pytorch各种损失函数
作者:互联网
官方文档:https://pytorch.org/docs/stable/nn.html#loss-functions
1:torch.nn.L1Loss
mean absolute error (MAE) between each element in the input x and target y .
MAE是指平均绝对误差,也称L1损失:
loss = nn.L1Loss() input = torch.randn(1, 2, requires_grad=True) target = torch.randn(1, 2) output = loss(input, target)
2:torch.nn.MSELoss
measures the mean squared error (squared L2 norm) between each element in the input x and target y .
loss = nn.MSELoss() input = torch.randn(1, 2, requires_grad=True) target = torch.randn(1, 2) output = loss(input, target)
3:torch.nn.NLLLoss && torch.nn.CrossEntropyLoss
torch.nn.NLLLoss是用于多分类的负对数似然损失函数(negative log likelihood loss)
torch.nn.CrossEntropyLoss是交叉熵损失函数
二者的区别:
m = nn.LogSoftmax(dim=1) loss = nn.NLLLoss() # input is of size N x C = 3 x 5 input = torch.randn(3,5,requires_grad=True) #each element in target has to have 0 <= value < C target = target = torch.empty(3, dtype=torch.long).random_(5) output = loss(m(input), target) print(output) loss = nn.CrossEntropyLoss() output = loss(input, target) print(output)
4:torch.nn.BCELoss && torch.nn.BCEWithLogitsLoss
二者的区别:
m = torch.nn.Sigmoid() loss = torch.nn.BCELoss() input = torch.randn(3,requires_grad=True) target = torch.empty(3).random_(2) output = loss(m(input), target) print(output) loss = torch.nn.BCEWithLogitsLoss() output = loss(input, target) print(output)
ref:https://www.cnblogs.com/wanghui-garcia/p/10862733.html
标签:loss,randn,函数,nn,torch,损失,pytorch,input,target 来源: https://www.cnblogs.com/xiaoxiaomajinjiebiji/p/13984479.html