其他分享
首页 > 其他分享> > Tensorflow2.x代码实现计算Top-k Accuracy

Tensorflow2.x代码实现计算Top-k Accuracy

作者:互联网

图像分类或是识别任务中,一般要求计算top-1,top-2,tor-5等准确率,下面是用Tensorflow2实现这一功能的基本代码,可以根据要求改代码分别计算:

def accuracy(output,target,topk(1,)):
    maxk=max(topk)
    batch_size=target.shape[0]
    
    pred=tf.math.top_k(output,maxk).indices
    pred=tf.transpose(pred,perm=[1,0])
    target_=tf.broadcast_to(target,pred.shape)
    correct=tf.equal(target_,pred)

    res=[]
    for k in topk:
        correct_k=tf.cast(tf.reshape(correct[:k],[-1]),dtype=tf.float32)
        correct_k=tf.reduce_sum(correct_k)
        acc=float(correct_k/batch_size)
        res.append(acc)
    return res

 

标签:Tensorflow2,target,pred,Top,top,topk,tf,correct,Accuracy
来源: https://blog.51cto.com/u_15242250/2870181