nn.CrossEntropyLoss()使用是label参数的注意点
作者:互联网
遇到个离谱的事情,自定义数据集跑cross entropy loss的时候,
loss1 = w_loss.loss(log_ps1, source_batch_labels)
loss1.backward()
backward()
这里总是报错,搞了半天最后发现是数据集设定的时候,给labels是int32,但是实际上得设置成int64
# toy_source数据类型转化
source_datas_t = torch.tensor(source_datas, dtype=torch.float64)
source_labels_t = torch.tensor(source_labels, dtype=torch.int64)
# toy_target数据类型转化
target_datas_t = torch.tensor(target_datas, dtype=torch.float64)
target_labels_t = torch.tensor(target_labels, dtype=torch.int64)
label部分这样设置,CrossEntropyLoss()就不报错了
标签:torch,target,nn,dtype,labels,label,source,CrossEntropyLoss,tensor 来源: https://www.cnblogs.com/huzhengyu/p/16414883.html