PyTorch笔记--交叉熵损失函数实现
作者:互联网
交叉熵(cross entropy):用于度量两个概率分布间的差异信息。交叉熵越小,代表这两个分布越接近。
函数表示(这是使用softmax作为激活函数的损失函数表示):
(是真实值,是预测值。)
命名说明:
pred=F.softmax(logits),logits是softmax函数的输入,pred代表预测值,是softmax函数的输出。
pred_log=F.log_softmax(logits),pred_log代表对预测值再取对数后的结果。也就是将logits作为log_softmax()函数的输入。
方法一,使用log_softmax()+nll_loss()实现
torch.nn.functional.log_softmax(input)
对输入使用softmax函数计算,再取对数。
torch.nn.functional.nll_loss(input, target)
input是经log_softmax()函数处理后的结果,pred_log
target代表的是真实值。
有了这两个输入后,该函数对其实现交叉熵损失函数的计算,即上面公式中的L。
>>> import torch >>> import torch.nn.functional as F >>> x = torch.randn(1, 28) >>> w = torch.randn(10,28) >>> logits = x @ w.t() >>> pred_log = F.log_softmax(logits, dim=1) >>> pred_log tensor([[ -0.8779, -6.7271, -9.1801, -6.8515, -9.6900, -6.3061, -3.7304, -8.1933, -11.5704, -0.5873]]) >>> F.nll_loss(pred_log, torch.tensor([3])) tensor(6.8515)
logits的维度是(1, 10)这里可以理解成是1个输入,最终可能得到10个分类的结果中的一个。pred_log就是。
这里的参数target=torch.tensor([3]),我的理解是,他代表真正的分类的值是在第3类(从0编号)。
使用独热编码代表真实值是[0, 0, 0, 1, 0, 0, 0, 0, 0, 0],即这个输入它是属于第三类的。
根据上述公式进行计算,现在我们 和都已经知道了。
对其进行点乘操作
方法二,使用cross_entropy()实现
torch.nn.functional.cross_entropy(input, target)
这里的input是没有经过处理的logits,这个函数会自动根据logits计算出pred_log
target是真实值
>>> import torch >>> import torch.nn.functional as F >>> x = torch.randn(1, 28) >>> w = torch.randn(10,28) >>> logits = x @ w.t() >>> F.cross_entropy(logits, torch.tensor([3])) tensor(6.8515)
这里我删除了上面使用方法一的代码部分,x和w没有重新随机生成,所以计算结果是一样的。
还在学习过程,做此纪录,如有不对,请指正。
标签:log,--,pred,torch,笔记,PyTorch,softmax,logits,函数 来源: https://www.cnblogs.com/xxmrecord/p/15123626.html