其他分享
首页 > 其他分享> > 2021-03-21

2021-03-21

作者:互联网

pytorch四种loss函数

cross_entropy vs nll_loss

适用于k分类问题

>>> labels = torch.tensor([1, 0, 2], dtype=torch.long)
>>> logits = torch.tensor([[2.5, -0.5, 0.1],
...                        [-1.1, 2.5, 0.0],
...                        [1.2, 2.2, 3.1]], dtype=torch.float)
>>> torch.nn.functional.cross_entropy(logits, labels)
tensor(2.4258)
>>> torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(logits, dim=1), labels)
tensor(2.4258)

3分类,每个分类都有自己的概率.
3 samples
nll_loss需要手动softmax,把三个分类的概率归一化(相加为1),再取log
cross_entropy内置了

[1]https://sebastianraschka.com/faq/docs/pytorch-crossentropy.html

标签:03,21,loss,torch,cross,entropy,2021,nll,tensor
来源: https://blog.csdn.net/weixin_44127535/article/details/115047634