其他分享
首页 > 其他分享> > (y_hat.argmax(dim=1) ==lable).sum().cpu().item()

(y_hat.argmax(dim=1) ==lable).sum().cpu().item()

作者:互联网

        print(y_hat.argmax(dim=1))         print(y_hat.argmax(dim=1) ==lable)         print((y_hat.argmax(dim=1) ==lable).sum())         print((y_hat.argmax(dim=1) ==lable).sum().cpu())              print((y_hat.argmax(dim=1) ==lable).sum().cpu().item()) 输出:

tensor([4, 4, 5, 0, 4, 2, 8, 5, 8, 2, 4, 4, 4, 2, 8, 4, 1, 8, 2, 5, 0, 7, 4, 4,
6, 5, 6, 2, 5, 3, 5, 4, 4, 8, 4, 5, 2, 4, 2, 4, 6, 2, 5, 6, 5, 8, 4, 4,
4, 2], device='cuda:0')

第一个输出把每行最大的索引输出
tensor([False, False, False, False, False, False, False, False, True, False,
True, False, False, True, False, False, False, False, True, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, True, False, False, False, False, False,
False, True, False, False, False, False, False, False, False, False],
device='cuda:0')

第二个输出判断索引和lable是否相等,相等为true否则为false。
tensor(6, device='cuda:0')

第三个输出进行sum求和true算1,flase算0。
tensor(6)

第四个输出将cuda变为cpu
6

第五个item将tensor变为整形

标签:dim,False,lable,argmax,hat,True
来源: https://www.cnblogs.com/hahaah/p/15386837.html