使用tensorflow实现cnn进行mnist识别
作者:互联网
第一个CNN代码,暂时对于CNN的BP还不熟悉。但是通过这个代码对于tensorflow的运行机制有了初步的理解
1 ''' 2 softmax classifier for mnist 3 4 created on 2019.9.28 5 author: vince 6 ''' 7 import math 8 import logging 9 import numpy 10 import random 11 import matplotlib.pyplot as plt 12 import tensorflow as tf 13 from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets 14 from sklearn.metrics import accuracy_score 15 16 def weight_bais_variable(shape): 17 init = tf.random.truncated_normal(shape = shape, stddev = 0.01); 18 return tf.Variable(init); 19 20 def bais_variable(shape): 21 init = tf.constant(0.1, shape=shape); 22 return tf.Variable(init); 23 24 def conv2d(x, w): 25 return tf.nn.conv2d(x, w, [1, 1, 1, 1], padding = "SAME"); 26 27 def max_pool_2x2(x): 28 return tf.nn.max_pool2d(x, ksize = [1, 2, 2, 1], strides = [1, 2, 2, 1], padding = "SAME"); 29 30 def cnn(x, rate): 31 with tf.name_scope('reshape'): 32 x_image = tf.reshape(x, [-1, 28, 28, 1]); 33 34 #first layer, conv & pool 35 with tf.name_scope('conv1'): 36 w_conv1 = weight_bais_variable([5, 5, 1, 32]); 37 b_conv1 = bais_variable([32]); 38 h_conv1 = tf.nn.relu(conv2d(x_image, w_conv1) + b_conv1); #28 * 28 * 32 39 with tf.name_scope('pool1'): 40 h_pool1 = max_pool_2x2(h_conv1); #14 * 14 * 32 41 42 #second layer, conv & pool 43 with tf.name_scope('conv2'): 44 w_conv2 = weight_bais_variable([5, 5, 32, 64]); 45 b_conv2 = bais_variable([64]); 46 h_conv2 = tf.nn.relu(conv2d(h_pool1, w_conv2) + b_conv2); #14 * 14 * 64 47 with tf.name_scope('pool2'): 48 h_pool2 = max_pool_2x2(h_conv2); #7 * 7 * 64 49 50 #first full connect layer, feature graph -> feature vector 51 with tf.name_scope('fc1'): 52 w_fc1 = weight_bais_variable([7 * 7 * 64, 1024]); 53 b_fc1 = bais_variable([1024]); 54 h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64]); 55 h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, w_fc1) + b_fc1); 56 with tf.name_scope("dropout1"): 57 h_fc1_drop = tf.nn.dropout(h_fc1, rate); 58 59 #second full connect layer, 60 with tf.name_scope('fc2'): 61 w_fc2 = weight_bais_variable([1024, 10]); 62 b_fc2 = bais_variable([10]); 63 #h_fc2 = tf.matmul(h_fc1_drop, w_fc2) + b_fc2; 64 h_fc2 = tf.matmul(h_fc1, w_fc2) + b_fc2; 65 return h_fc2; 66 67 68 def main(): 69 logging.basicConfig(level = logging.INFO, 70 format = '%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s', 71 datefmt = '%a, %d %b %Y %H:%M:%S'); 72 73 mnist = read_data_sets('../data/MNIST',one_hot=True) # MNIST_data指的是存放数据的文件夹路径,one_hot=True 为采用one_hot的编码方式编码标签 74 75 x = tf.placeholder(tf.float32, [None, 784]); 76 y_real = tf.placeholder(tf.float32, [None, 10]); 77 rate = tf.placeholder(tf.float32); 78 79 y_pre = cnn(x, rate); 80 81 sess = tf.InteractiveSession(); 82 sess.run(tf.global_variables_initializer()); 83 84 loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = y_pre, labels = y_real)); 85 train_op = tf.train.GradientDescentOptimizer(0.5).minimize(loss); 86 87 correct_prediction = tf.equal(tf.argmax(y_pre, 1), tf.argmax(y_real, 1)); 88 prediction_op= tf.reduce_mean(tf.cast(correct_prediction, tf.float32)); 89 for _ in range(300): 90 batch_xs, batch_ys = mnist.train.next_batch(128); 91 sess.run(train_op, feed_dict = {x : batch_xs, y_real : batch_ys, rate: 0.5}); 92 if _ % 10 == 0: 93 accuracy = sess.run(prediction_op, feed_dict = {x : mnist.test.images, y_real : mnist.test.labels, rate: 0.0 }); 94 logging.info("%s : %s" % (_, accuracy)); 95 96 if __name__ == "__main__": 97 main();
标签:fc1,fc2,bais,name,tf,variable,cnn,tensorflow,mnist 来源: https://www.cnblogs.com/thsss/p/11695315.html