其他分享
首页 > 其他分享> > TensorFlow MNIST手写数字识别学习笔记(二)

TensorFlow MNIST手写数字识别学习笔记(二)

作者:互联网

接下来我们具体解析一下mnist.py这个文件
我们先看下初始的参数定义

"""MNIST数据集有10类, 分别是从0到9的数字."""
NUM_CLASSES = 10
"""MNIST数据集的图片都是28*28的像素."""
IMAGE_SIZE = 28
IMAGE_PIXELS = IMAGE_SIZE * IMAGE_SIZE

定义 inference函数,具体设置一下网络的结构

def inference(images, hidden1_units, hidden2_units):
  """隐藏层1"""
  with tf.name_scope('hidden1'):
  """从输入图片的向量28*28以及hidden1节点计算截断的正态分布中输出随机值,产生正太分布的值如果与均值的差值大于两倍的标准差,那就重新生成"""
    weights = tf.Variable(
        tf.truncated_normal([IMAGE_PIXELS, hidden1_units],
                            stddev=1.0 / math.sqrt(float(IMAGE_PIXELS))),
        name='weights')
    """生成为零的hidden1_units大小的矩阵"""
    biases = tf.Variable(tf.zeros([hidden1_units]),
                         name='biases')
    """计算hidden1的RELU函数值,也就是image*weights+biases"""                   
    hidden1 = tf.nn.relu(tf.matmul(images, weights) + biases)
  """ 隐藏层2""" 
  with tf.name_scope('hidden2'):
    """从hidden1节点以及hidden2节点计算截断的正态分布中输出随机值,产生正太分布的值如果与均值的差值大于两倍的标准差,那就重新生成"""
    weights = tf.Variable(
        tf.truncated_normal([hidden1_units, hidden2_units],
                            stddev=1.0 / math.sqrt(float(hidden1_units))),
        name='weights')
    """生成为零的hidden2_units大小的矩阵"""
    biases = tf.Variable(tf.zeros([hidden2_units]),
                         name='biases')
    """计算hidden2的RELU函数值,也就是hidden1*weights+biases"""             
    hidden2 = tf.nn.relu(tf.matmul(hidden1, weights) + biases)
  """线性函数Softmax"""
  with tf.name_scope('softmax_linear'):
  """从hidden2节点以及NUM_CLASSES计算截断的正态分布中输出随机值,产生正太分布的值如果与均值的差值大于两倍的标准差,那就重新生成"""
    weights = tf.Variable(
        tf.truncated_normal([hidden2_units, NUM_CLASSES],
                            stddev=1.0 / math.sqrt(float(hidden2_units))),
        name='weights')
    """生成为零的NUM_CLASSES大小的矩阵"""
    biases = tf.Variable(tf.zeros([NUM_CLASSES]),
                         name='biases')
    """定义网络的逻辑"""                
    logits = tf.matmul(hidden2, weights) + biases
  return logits

定义 loss函数,计算损失

def loss(logits, labels):
  """把labels转换为64位整数"""
  labels = tf.to_int64(labels)
  """计算交叉熵并返回"""
  return tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)

定义 training函数,计算出一个最小的梯度

def training(loss, learning_rate):
  """添加summary.scalar函数用来显示标量信息,计算损失."""
  tf.summary.scalar('loss', loss)
  """创建一个梯度下降优化器来计算梯度."""
  optimizer = tf.train.GradientDescentOptimizer(learning_rate)
  """创建一个variable去计算global_step."""
  global_step = tf.Variable(0, name='global_step', trainable=False)
  """用优化器去求出一个最小的梯度""" 
  train_op = optimizer.minimize(loss, global_step=global_step)
  return train_op

定义evaluation函数,

def evaluation(logits, labels):
  """计算logits预测的结果和labels实际结果的是否相等 """
  correct = tf.nn.in_top_k(logits, labels, 1)
  """返回correct的结果转换成int32并进行压缩求和."""
  return tf.reduce_sum(tf.cast(correct, tf.int32))

最后我们看下input_data.py这个文件

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# pylint: disable=unused-import
import gzip
import os
import tempfile

import numpy
from six.moves import urllib
from six.moves import xrange  # pylint: disable=redefined-builtin
import tensorflow as tf
"""从网上下载要用的数据并读取"""
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets

标签:hidden2,hidden1,units,biases,import,tf,TensorFlow,手写,MNIST
来源: https://blog.csdn.net/qq_42616932/article/details/98085507