其他分享
首页 > 其他分享> > 笔记3:Tensorflow2.0实战之MNSIT数据集

笔记3:Tensorflow2.0实战之MNSIT数据集

作者:互联网

最近Tensorflow相继推出了alpha和beta两个版本,这两个都属于tensorflow2.0版本;早听说新版做了很大的革新,今天就来用一下看看
这里还是使用MNSIT数据集进行测试

导入必要的库

import tensorflow as tf
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics

数据的准备

(xs, ys),_ = datasets.mnist.load_data()
print('datasets:', xs.shape, ys.shape, xs.min(), xs.max())
xs = tf.convert_to_tensor(xs, dtype=tf.float32) / 255.
db = tf.data.Dataset.from_tensor_slices((xs,ys))
db = db.batch(32).repeat(10)

网络结构和优化器准备

network = Sequential([layers.Dense(256, activation='relu'),
                     layers.Dense(256, activation='relu'),
                     layers.Dense(256, activation='relu'),
                     layers.Dense(10)])
network.build(input_shape=(None, 28*28))
network.summary()
optimizer = optimizers.SGD(lr=0.01)
acc_meter = metrics.Accuracy()

对数据集进行迭代

for step, (x,y) in enumerate(db):
    with tf.GradientTape() as tape:
        # [b, 28, 28] => [b, 784]
        x = tf.reshape(x, (-1, 28*28))
        # [b, 784] => [b, 10]
        out = network(x)
        # [b] => [b, 10]
        y_onehot = tf.one_hot(y, depth=10)
        # [b, 10]
        loss = tf.square(out-y_onehot)
        # [b]
        loss = tf.reduce_sum(loss) / 32
    acc_meter.update_state(tf.argmax(out, axis=1), y)
    grads = tape.gradient(loss, network.trainable_variables)
    optimizer.apply_gradients(zip(grads, network.trainable_variables))
    if step % 200==0:
        print(step, 'loss:', float(loss), 'acc:', acc_meter.result().numpy())
        acc_meter.reset_states()

最终的训练结果

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                multiple                  200960    
_________________________________________________________________
dense_1 (Dense)              multiple                  65792     
_________________________________________________________________
dense_2 (Dense)              multiple                  65792     
_________________________________________________________________
dense_3 (Dense)              multiple                  2570      
=================================================================
Total params: 335,114
Trainable params: 335,114
Non-trainable params: 0
_________________________________________________________________
0 loss: 1.6250096559524536 acc: 0.125
200 loss: 0.4169953465461731 acc: 0.68828124
400 loss: 0.3840298056602478 acc: 0.84796876
600 loss: 0.3416569232940674 acc: 0.8721875
800 loss: 0.26919665932655334 acc: 0.898125
1000 loss: 0.3009098470211029 acc: 0.8985937
1200 loss: 0.2646633982658386 acc: 0.9103125
1400 loss: 0.22533166408538818 acc: 0.91765624
1600 loss: 0.18222002685070038 acc: 0.9165625
1800 loss: 0.18201570212841034 acc: 0.9290625
2000 loss: 0.19700995087623596 acc: 0.94484377
2200 loss: 0.14550982415676117 acc: 0.93125
2400 loss: 0.20129463076591492 acc: 0.9265625
2600 loss: 0.20377759635448456 acc: 0.93921876
2800 loss: 0.1372058093547821 acc: 0.93859375
3000 loss: 0.2261861264705658 acc: 0.93359375
3200 loss: 0.1720336377620697 acc: 0.940625
3400 loss: 0.12401969730854034 acc: 0.93953127
3600 loss: 0.10386896133422852 acc: 0.9401562
3800 loss: 0.16028286516666412 acc: 0.95734376
4000 loss: 0.17534957826137543 acc: 0.95171875
4200 loss: 0.14097453653812408 acc: 0.940625
4400 loss: 0.14199058711528778 acc: 0.9490625
4600 loss: 0.19402430951595306 acc: 0.94671875
4800 loss: 0.15967118740081787 acc: 0.94625
5000 loss: 0.1375979483127594 acc: 0.953125
5200 loss: 0.22316312789916992 acc: 0.9453125
5400 loss: 0.21779394149780273 acc: 0.9501563
5600 loss: 0.08099132776260376 acc: 0.9632813
5800 loss: 0.15826722979545593 acc: 0.9603125
6000 loss: 0.11169645190238953 acc: 0.95140624
6200 loss: 0.16848763823509216 acc: 0.95203125
6400 loss: 0.10312280058860779 acc: 0.9571875
6600 loss: 0.12469235062599182 acc: 0.9521875
6800 loss: 0.1130545362830162 acc: 0.9557812
7000 loss: 0.10152068734169006 acc: 0.95921874
7200 loss: 0.2921682596206665 acc: 0.9484375
7400 loss: 0.12305493652820587 acc: 0.9625
7600 loss: 0.13934454321861267 acc: 0.97
7800 loss: 0.07814794033765793 acc: 0.9571875
8000 loss: 0.16721022129058838 acc: 0.95921874
8200 loss: 0.0795113667845726 acc: 0.9632813
8400 loss: 0.07537689059972763 acc: 0.9557812
8600 loss: 0.11197802424430847 acc: 0.9585937
8800 loss: 0.12252026051282883 acc: 0.96203125
9000 loss: 0.13232024013996124 acc: 0.9578125
9200 loss: 0.08134811371564865 acc: 0.9590625
9400 loss: 0.07900199294090271 acc: 0.973125
9600 loss: 0.18293496966362 acc: 0.9671875
9800 loss: 0.04652725160121918 acc: 0.9609375
10000 loss: 0.16443191468715668 acc: 0.9615625
10200 loss: 0.1063765436410904 acc: 0.9640625
10400 loss: 0.11965180188417435 acc: 0.95890623
10600 loss: 0.07027075439691544 acc: 0.96796876
10800 loss: 0.15515664219856262 acc: 0.961875
11000 loss: 0.07982166111469269 acc: 0.9584375
11200 loss: 0.09907206147909164 acc: 0.9715625
11400 loss: 0.10707013309001923 acc: 0.9734375
11600 loss: 0.13629364967346191 acc: 0.9646875
11800 loss: 0.09981581568717957 acc: 0.96515626
12000 loss: 0.07304492592811584 acc: 0.96734375
12200 loss: 0.065020851790905 acc: 0.96375
12400 loss: 0.13698126375675201 acc: 0.96671873
12600 loss: 0.17877255380153656 acc: 0.96546876
12800 loss: 0.10601920634508133 acc: 0.9625
13000 loss: 0.0823143720626831 acc: 0.96765625
13200 loss: 0.15889227390289307 acc: 0.974375
13400 loss: 0.10676314681768417 acc: 0.9715625
13600 loss: 0.08510202169418335 acc: 0.9684375
13800 loss: 0.08974016457796097 acc: 0.969375
14000 loss: 0.05767156183719635 acc: 0.9678125
14200 loss: 0.21300190687179565 acc: 0.96796876
14400 loss: 0.08092547953128815 acc: 0.97078127
14600 loss: 0.17201107740402222 acc: 0.96515626
14800 loss: 0.07620300352573395 acc: 0.965625
15000 loss: 0.09742768108844757 acc: 0.9767187
15200 loss: 0.09786351025104523 acc: 0.97484374
15400 loss: 0.09377723932266235 acc: 0.9689062
15600 loss: 0.08926095068454742 acc: 0.970625
15800 loss: 0.08965814113616943 acc: 0.9703125
16000 loss: 0.1047641858458519 acc: 0.96734375
16200 loss: 0.0918944925069809 acc: 0.97328126
16400 loss: 0.08902822434902191 acc: 0.9709375
16600 loss: 0.06524112075567245 acc: 0.96421874
16800 loss: 0.09127143025398254 acc: 0.975
17000 loss: 0.09830828011035919 acc: 0.9784375
17200 loss: 0.030347194522619247 acc: 0.97125
17400 loss: 0.07646052539348602 acc: 0.9728125
17600 loss: 0.09391230344772339 acc: 0.9717187
17800 loss: 0.05943562090396881 acc: 0.9714062
18000 loss: 0.08575735241174698 acc: 0.9734375
18200 loss: 0.09358179569244385 acc: 0.973125
18400 loss: 0.0565657876431942 acc: 0.97015625
18600 loss: 0.04309214651584625 acc: 0.96984375

总结
整体上使用起来比1.X版本有顺手多了,感觉和pytorch差不多,使用起来更加的丝滑流畅,真的是对这个版本爱不释手了,如果要是再有一张性能好一些的显卡就更好了

标签:acc,loss,MNSIT,Dense,28,笔记,Tensorflow2.0,tf,xs
来源: https://www.cnblogs.com/taotaoName/p/16258530.html