其他分享
首页 > 其他分享> > tf2卷积神经网络mnist手写数字识别

tf2卷积神经网络mnist手写数字识别

作者:互联网

import tensorflow.compat.v1 as tf
import tensorflow as tf2
tf.disable_v2_behavior()
import numpy as np
from tensorflow.keras.datasets import mnist
import matplotlib.pyplot as plt

def weight_variable(shape):
    inital=tf2.random.truncated_normal(shape,stddev=.1)
    return tf.Variable(inital)

def bias_variable(shape):
    inital=tf2.constant(.1,shape=shape)
    return tf.Variable(inital)

def conv2d(x,W):            #x:输入的数据,W:权重
    return tf.nn.conv2d(x,W,strides=[1,1,1,1],padding="SAME")

#max_pool的padding的行为和conv2d的padding不一样,
#VALID在剩余行列数小于池化窗口大小时,将最右边和最下面的列或行抛弃,只保留有效值;
#SAME在剩余行列数不足时补充0来满足池化窗口的大小,保持窗口被池化区域相同;
def max_pool_2x2(x):        #x:输入的数据
    return tf.nn.max_pool(x,ksize=[1,2,2,1],strides=[1,2,2,1],padding="SAME")

#add_layer被拆成3部分:weight_variable,bias_variable,和run里matmul处的代码
# def add_layer(inputs,in_size,out_size,func=None):
#     W=tf.Variable(tf2.random.normal([in_size,out_size]))
#     b=tf.Variable(tf2.zeros([1,out_size])+.1)
#     Wx_b=tf2.matmul(inputs,W)+b
#     if not func:
#         outputs=Wx_b
#     else:
#         outputs=func(Wx_b)
#     return outputs

def model(x_data,y_data):
    global pred,xs,ys,keep_prob
    xs=tf.placeholder(tf2.float32,[None,x_data.shape[-1]])
    ys=tf.placeholder(tf2.float32,[None,10])
    keep_prob=tf.placeholder(tf.float32)

    x_image=tf2.reshape(xs,[-1,28,28,1])

    #卷积层通常缩小数据的宽,通过池化增加数据的层数
    #conv1 layer
    #5*5*1的卷积核32个,参数1和2自己定义,参数3等于输入数据的层数,参数4选择2的若干次幂
    W_conv1=weight_variable([5,5,1,32])
    b_conv1=bias_variable([32])
    h_conv1=tf.nn.relu(conv2d(x_image,W_conv1)+b_conv1)
    h_pool1=max_pool_2x2(h_conv1)
    # conv1 layer输出的单条数据的shape是[14,14,32]

    #conv2 layer
    W_conv2=weight_variable([5,5,32,64])
    b_conv2=bias_variable([64])
    h_conv2=tf.nn.relu(conv2d(h_pool1,W_conv2)+b_conv2)
    h_pool2=max_pool_2x2(h_conv2)
    # conv2 layer输出的单条数据的shape是[7,7,64]

    #fc1 layer(fully connected layer)
    W_fc1=weight_variable([7*7*64,1024])
    b_fc1 = bias_variable([1024])
    #要把单条的数据转成一维
    h_pool2_flat=tf.reshape(h_pool2,[-1,7*7*64])
    h_fc1=tf.nn.relu(tf2.matmul(h_pool2_flat,W_fc1)+b_fc1)
    #tf2.nn.dropout和tf.nn.dropout行为不一样,这里不能改成tf2的
    #dropout防止过拟合
    h_fc1_drop=tf.nn.dropout(h_fc1,keep_prob)

    #fc2 layer
    W_fc2=weight_variable([1024,10])
    b_fc2 = bias_variable([10])
    pred=tf.nn.softmax(tf2.matmul(h_fc1_drop,W_fc2)+b_fc2)
    # 这里可以不dropout
    # pred=tf.nn.dropout(pred,keep_prob)

    return pred

def run(x_data, y_data):
    pred =model(x_data, y_data)
    cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(tf.clip_by_value(pred, 1e-10, 1.0)), reduction_indices=1))

    train_step=tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)

    se.run(tf.global_variables_initializer())

    results=[]
    for i in range(1000):
        #随机取数据进行训练
        random_index=np.random.choice(x_train.shape[0], 10, replace=False)
        batch_xs,batch_ys=x_train[random_index],y_train[random_index]

        se.run(train_step,feed_dict={xs:batch_xs,ys:batch_ys,keep_prob:.5})

        if i%50==0:
            acc=compute_accuracy(x_test,y_test,1)
            results.append(acc)
            print(i,acc)

    plt.scatter([50*i for i in range(len(results))],results)
    y_major_locator = plt.MultipleLocator(.1)
    ax = plt.gca()
    ax.yaxis.set_major_locator(y_major_locator)
    plt.ylim(0, 1)
    plt.show()

def compute_accuracy(v_xs,v_ys,v_kp=1):
    y_pre=se.run(pred,feed_dict={xs:v_xs,keep_prob:1})
    correct_pred=tf.equal(tf.argmax(y_pre,1),tf.argmax(v_ys,1))
    accuracy=tf.reduce_mean(tf.cast(correct_pred,tf2.float32))
    result=se.run(accuracy,feed_dict={xs:v_xs,ys:v_ys,keep_prob:1})
    return result


if __name__=="__main__":
    (x_train, y_train), (x_test, y_test)=mnist.load_data()
    mym=mnist.load_data()

    x_train=x_train.reshape(x_train.shape[0],-1)/255
    x_test=x_test.reshape(x_test.shape[0],-1)/255
    se =tf.Session()

    y_train=np.eye(10)[y_train]
    y_test=np.eye(10)[y_test]

    run(x_test,y_test)

在这里插入图片描述

标签:卷积,test,variable,shape,train,tf,mnist,tf2
来源: https://blog.csdn.net/weixin_43292547/article/details/117115962