其他分享
首页 > 其他分享> > PyTorch笔记--交叉熵损失函数实现

PyTorch笔记--交叉熵损失函数实现

作者:互联网

交叉熵(cross entropy):用于度量两个概率分布间的差异信息。交叉熵越小,代表这两个分布越接近。

函数表示(这是使用softmax作为激活函数的损失函数表示):

是真实值,是预测值。)

命名说明:

pred=F.softmax(logits),logits是softmax函数的输入,pred代表预测值,是softmax函数的输出。

pred_log=F.log_softmax(logits),pred_log代表对预测值再取对数后的结果。也就是将logits作为log_softmax()函数的输入。

方法一,使用log_softmax()+nll_loss()实现

torch.nn.functional.log_softmax(input)

  对输入使用softmax函数计算,再取对数。

torch.nn.functional.nll_loss(input, target)

  input是经log_softmax()函数处理后的结果,pred_log

  target代表的是真实值。

  有了这两个输入后,该函数对其实现交叉熵损失函数的计算,即上面公式中的L。

>>> import torch
>>> import torch.nn.functional as F
>>> x = torch.randn(1, 28)
>>> w = torch.randn(10,28)
>>> logits = x @ w.t()
>>> pred_log = F.log_softmax(logits, dim=1)
>>> pred_log
tensor([[ -0.8779,  -6.7271,  -9.1801,  -6.8515,  -9.6900,  -6.3061,  -3.7304,
          -8.1933, -11.5704,  -0.5873]])
>>> F.nll_loss(pred_log, torch.tensor([3]))
tensor(6.8515)

logits的维度是(1, 10)这里可以理解成是1个输入,最终可能得到10个分类的结果中的一个。pred_log就是

这里的参数target=torch.tensor([3]),我的理解是,他代表真正的分类的值是在第3类(从0编号)。

使用独热编码代表真实值是[0, 0, 0, 1, 0, 0, 0, 0, 0, 0],即这个输入它是属于第三类的。

根据上述公式进行计算,现在我们 都已经知道了。

对其进行点乘操作

 

 

 

 方法二,使用cross_entropy()实现

torch.nn.functional.cross_entropy(input, target)

  这里的input是没有经过处理的logits,这个函数会自动根据logits计算出pred_log

  target是真实值

>>> import torch
>>> import torch.nn.functional as F
>>> x = torch.randn(1, 28)
>>> w = torch.randn(10,28)
>>> logits = x @ w.t()
>>> F.cross_entropy(logits, torch.tensor([3]))
tensor(6.8515)

这里我删除了上面使用方法一的代码部分,x和w没有重新随机生成,所以计算结果是一样的。

还在学习过程,做此纪录,如有不对,请指正。

标签:log,--,pred,torch,笔记,PyTorch,softmax,logits,函数
来源: https://www.cnblogs.com/xxmrecord/p/15123626.html