其他分享
首页 > 其他分享> > torch.topk【按维度求前k个最值以及其索引】以及top1和top5

torch.topk【按维度求前k个最值以及其索引】以及top1和top5

作者:互联网

分类中常用到top1和top5的指标:
top1是指 预测错误样本数/总样本数,此处的预测错误的样本是指 预测的最大概率对应的类别与真实类别不同;
top5是指 预测错误样本数/总样本数,此处的预测错误的样本是指 预测的排序前五的概率对应的类别中没有真实类别。

一般来讲,top5 error要小于top1 error,排序前五的概率中没有真实标签几率小。
有的博客解释需要top5 error是因为人工标记可能出现偏差,因此要放宽标准。

【torch.topk函数】
目的:求取tensor中某个dim的排序前k个值(val) 以及其索引(index).

torch.topk(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor)
#input:tensor数据
#k:排序后的前k个数据
#dim:沿着某个维度
#largest:True是指从大到小,否则,从小到大
#sorted:True返回的结果
import torch

if __name__=="__main__":
pred = torch.rand((4, 5))
print(pred)
print("------------k=1------------------")
vals, indices = pred.topk(k=1, dim=1, largest=True, sorted=True)
print(indices)
print("------------k=2------------------")
vals, indices = pred.topk(k=2, dim=1, largest=True, sorted=True)
print(indices)

#output:
tensor([[0.2219, 0.9817, 0.7909, 0.7659, 0.1657],
        [0.6779, 0.9653, 0.1959, 0.3108, 0.1755],
        [0.9107, 0.5243, 0.2525, 0.1543, 0.4314],
        [0.5417, 0.0409, 0.5777, 0.3693, 0.2606]])
------------k=1------------------
tensor([[1],
        [1],
        [0],
        [2]])
------------k=2------------------
tensor([[1, 2],
        [1, 0],
        [0, 1],
        [2, 0]])

标签:dim,tensor,求前,top5,torch,print,True
来源: https://blog.csdn.net/YJYS_ZHX/article/details/113457825