(y_hat.argmax(dim=1) ==lable).sum().cpu().item()
作者:互联网
tensor([4, 4, 5, 0, 4, 2, 8, 5, 8, 2, 4, 4, 4, 2, 8, 4, 1, 8, 2, 5, 0, 7, 4, 4,
6, 5, 6, 2, 5, 3, 5, 4, 4, 8, 4, 5, 2, 4, 2, 4, 6, 2, 5, 6, 5, 8, 4, 4,
4, 2], device='cuda:0')
第一个输出把每行最大的索引输出
tensor([False, False, False, False, False, False, False, False, True, False,
True, False, False, True, False, False, False, False, True, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, True, False, False, False, False, False,
False, True, False, False, False, False, False, False, False, False],
device='cuda:0')
第二个输出判断索引和lable是否相等,相等为true否则为false。
tensor(6, device='cuda:0')
第三个输出进行sum求和true算1,flase算0。
tensor(6)
第四个输出将cuda变为cpu
6
第五个item将tensor变为整形
标签:dim,False,lable,argmax,hat,True 来源: https://www.cnblogs.com/hahaah/p/15386837.html