2021-03-21
作者:互联网
pytorch四种loss函数
cross_entropy vs nll_loss
适用于k分类问题
>>> labels = torch.tensor([1, 0, 2], dtype=torch.long)
>>> logits = torch.tensor([[2.5, -0.5, 0.1],
... [-1.1, 2.5, 0.0],
... [1.2, 2.2, 3.1]], dtype=torch.float)
>>> torch.nn.functional.cross_entropy(logits, labels)
tensor(2.4258)
>>> torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(logits, dim=1), labels)
tensor(2.4258)
3分类,每个分类都有自己的概率.
3 samples
nll_loss需要手动softmax,把三个分类的概率归一化(相加为1),再取log
cross_entropy内置了
[1]https://sebastianraschka.com/faq/docs/pytorch-crossentropy.html
标签:03,21,loss,torch,cross,entropy,2021,nll,tensor 来源: https://blog.csdn.net/weixin_44127535/article/details/115047634