昨天学习了一些tensorflow入门知识,经历各种奇葩错误,现在奉献一份安装tensorflow2就可以跑的demo
作者:互联网
本程序使用minist图像集合作为数据源,使用tensorflow内部的数据加载方式(如果没有数据集,会自动从网上下载).神经网络内层有三层,依靠纯手工搭建网络模式,比较贴近数学模型
1 #encoding: utf-8 2 import os 3 os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # 关闭低级的调试信息 4 5 import tensorflow as tf 6 from tensorflow import keras 7 from tensorflow.keras import datasets 8 9 lr = 0.001 10 (x,y), _ = datasets.mnist.load_data() 11 12 # x 归一化 13 x = tf.convert_to_tensor(x,dtype=tf.float32)/255. 14 y = tf.convert_to_tensor(y,dtype=tf.int32) 15 print(x.shape,y.shape,x.dtype,y.dtype) 16 print(tf.reduce_min(x),tf.reduce_max(x)) 17 print(tf.reduce_min(y),tf.reduce_max(y)) 18 19 # 将源数据分割,一次处理128条数据。分60000/128此处理结束,返回一个迭代器 20 train_db = tf.data.Dataset.from_tensor_slices((x,y)).batch(128) 21 train_iter = iter(train_db) 22 sample = next(train_iter) 23 print("batch:",sample[0].shape,sample[1].shape) 24 25 #生成参数矩阵 注意tf.Variable 大小写 26 w1 = tf.Variable(tf.random.truncated_normal([784,256],stddev=0.1)) 27 b1 = tf.Variable(tf.zeros([256])) 28 w2 = tf.Variable(tf.random.truncated_normal([256,128],stddev=0.1)) 29 b2 = tf.Variable(tf.zeros([128])) 30 w3 = tf.Variable(tf.random.truncated_normal([128,10],stddev=0.1)) 31 b3 = tf.Variable(tf.zeros([10])) 32 33 for epoch in range(30): #迭代一次即完成60k循环 34 for step,(x,y) in enumerate(train_db): #迭代一次即完成一次128次训练 35 # x =>128,28,28 需要转换为128,28*28 36 x = tf.reshape(x,[-1,28*28]) 37 # 注意加() 38 with tf.GradientTape() as tape: 39 h1 = x@w1 + b1 40 h1 = tf.nn.relu(h1) 41 h2 = h1@w2 + b2 42 h2 = tf.nn.relu(h2) 43 out = h2@w3 + b3 44 45 # 要对y进行one_hot编码,好处之一,使得不同结果之前的距离可以保持一致 46 y_one_hot = tf.one_hot(y,depth=10) 47 loss = tf.square(y_one_hot-out) 48 loss = tf.reduce_mean(loss) 49 50 grads = tape.gradient(loss,[w1,b1,w2,b2,w3,b3]) 51 # w1 = w1 - lr*grads[0],这个更新方式会更改原来的值,不再是tf.Variable 52 # 此处采用原地更新的方式,如下 (注意函数名字不要写错) 53 w1.assign_sub(lr * grads[0]) 54 b1.assign_sub(lr * grads[1]) 55 w2.assign_sub(lr * grads[2]) 56 b2.assign_sub(lr * grads[3]) 57 w3.assign_sub(lr * grads[4]) 58 b3.assign_sub(lr * grads[5]) 59 60 if step%100 == 0: 61 print(epoch,step,"loss:",float(loss))
标签:tensorflow2,demo,28,Variable,128,lr,tf,tensorflow,grads 来源: https://www.cnblogs.com/RYSBlog/p/13500051.html