其他分享
首页 > 其他分享> > pytorch API

pytorch API

作者:互联网

  1. pytorch--多标签分类损失函数
import torch
import numpy as np

pred = np.array([[-0.4089, -1.2471, 0.5907],
                [-0.4897, -0.8267, -0.7349],
                [0.5241, -0.1246, -0.4751]])
label = np.array([[0, 1, 1],
                  [0, 0, 1],
                  [1, 0, 1]])

pred = torch.from_numpy(pred).float()
label = torch.from_numpy(label).float()

## 通过BCEWithLogitsLoss直接计算输入值(pick)
crition1 = torch.nn.BCEWithLogitsLoss()
loss1 = crition1(pred, label)
print(loss1)

crition2 = torch.nn.MultiLabelSoftMarginLoss()
loss2 = crition2(pred, label)
print(loss2)

##  通过BCELoss计算sigmoid处理后的值
crition3 = torch.nn.BCELoss()
loss3 = crition3(torch.sigmoid(pred), label)
print(loss3)

标签:nn,pred,torch,label,np,pytorch,API,print
来源: https://www.cnblogs.com/mercurysun/p/16334222.html