其他分享
首页 > 其他分享> > 深度学习-----从零开始实现识别手写字体任务(六)计算测试集的准确率和Tensorflow的执行阶段

深度学习-----从零开始实现识别手写字体任务(六)计算测试集的准确率和Tensorflow的执行阶段

作者:互联网

计算测试集的准确率

def compute_accuracy(v_xs, v_ys):
    global prediction
    # y_pre将v_xs输入模型后得到的预测值 (10000,10)
    y_pre = sess.run(prediction, feed_dict={xs: v_xs, keep_prob: 1})
    # argmax(axis) axis = 1 返回结果为:数组中每一行最大值所在“列”索引值
    # tf.equal返回布尔值,correct_prediction (10000,1)
    correct_prediction = tf.equal(tf.argmax(y_pre, 1), tf.argmax(v_ys, 1))
    # tf.cast将bool转成float32, tf.reduce_mean求均值,作为accuracy值(0到1)
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    result = sess.run(accuracy, feed_dict={xs: v_xs, ys: v_ys, keep_prob: 1})
    return result

该函数有两个参数,v_xs为测试集,v_ys为测试集的标签。

y_pre将测试集输入后的模型转化为预测值

argmax函数将数组每一行的最大值所在的列返回出来

使用tf.equal比较测试集和标签的值是否相等,若相等返回ture

tf.cast将布尔值转化为浮点数,通过求均值来计算准确率

最后返回准确率。

Tensorflow的执行阶段

TensorFlow 程序通常被组织成一个构建阶段和一个执行阶段,之前的都是构建阶段,现在是执行阶段,需要创立一个session对象来一遍遍执行上述程序。

Session对象在使用完后需要关闭以释放资源. 除了显式调用 close 外, 也可以使用 "with" 代码块 来自动完成关闭动作。

keep_prob_rate = 0.6
with tf.Session() as sess:
    # 初始化图中所有Variables
    init = tf.global_variables_initializer()
    sess.run(init)
    # 总迭代次数(batch)为max_epoch=1000,每次取100张图做batch梯度下降
    print("step 0, test accuracy %g" % (compute_accuracy(
        mnist.test.images, mnist.test.labels)))
    for i in range(max_epoch):
        # mnist.train.next_batch 默认shuffle=True,随机读取,batch大小为100
        batch_xs, batch_ys = mnist.train.next_batch(100)
        # 此batch是个2维tuple,batch[0]是(100,784)的样本数据数组,batch[1]是(100,10)的样本标签数组,分别赋值给batch_xs, batch_ys
        sess.run(train_step, feed_dict={xs: batch_xs, ys: batch_ys, keep_prob: keep_prob_rate})
        # 暂时不进行赋值的元素叫占位符(如xs、ys),run需要它们时得赋值,feed_dict就是用来赋值的,格式为字典型
        if (i + 1) % 50 == 0:
            print("step %d, test accuracy %g" % (i + 1, compute_accuracy(
                mnist.test.images, mnist.test.labels)))

第一步我们要先初始化所有的variables 

先输出模型一开始的准确率

然后进行迭代

经过训练通过train_step改变所有variables的值,从而提高准确率

最后输出准确率

标签:batch,准确率,从零开始,-----,xs,tf,Tensorflow,ys,accuracy
来源: https://blog.csdn.net/qq_57173265/article/details/118424139