其他分享
首页 > 其他分享> > nn.CrossEntropyLoss()使用是label参数的注意点

nn.CrossEntropyLoss()使用是label参数的注意点

作者:互联网

遇到个离谱的事情,自定义数据集跑cross entropy loss的时候,

    loss1 = w_loss.loss(log_ps1, source_batch_labels)
    loss1.backward()

backward()这里总是报错,搞了半天最后发现是数据集设定的时候,给labels是int32,但是实际上得设置成int64

#  toy_source数据类型转化
source_datas_t = torch.tensor(source_datas, dtype=torch.float64)
source_labels_t = torch.tensor(source_labels, dtype=torch.int64)

# toy_target数据类型转化
target_datas_t = torch.tensor(target_datas, dtype=torch.float64)
target_labels_t = torch.tensor(target_labels, dtype=torch.int64)

label部分这样设置,CrossEntropyLoss()就不报错了

标签:torch,target,nn,dtype,labels,label,source,CrossEntropyLoss,tensor
来源: https://www.cnblogs.com/huzhengyu/p/16414883.html