win10系统anaconda的notebook的cifar10离线下载、数据加载及CNN训练
作者:互联网
1、官网数据下载
有时会受到网络限制不能直接加载cifar10数据,需要下载离线数据包,官方网址如下:
https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
2、压缩包重命名与解压
将压缩包放置user/xxx/.keras/datasets下,将cifar-10-batches-py.tar.gz直接解压,在datasets目录下新建文件夹cifar-10-batches-py,将解压的全部文件(不包括文件夹)拷贝至这个文件夹下。
3、加载数据
导入from tensorflow.keras import datasets
读取数据:
(x_train,y_train), (x_test,y_test) = datasets.cifar10.load_data()
x_train,x_test = x_train/255.0, x_test/255.0
4、cifar10数据的CNN训练(代码主要来自https://blog.csdn.net/yanghe4405/article/details/107521797)
import tensorflow as tf import os import numpy as np from matplotlib import pyplot as plt from tensorflow.keras import datasets from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Dropout, Flatten, Dense from tensorflow.keras import Model np.set_printoptions(threshold=np.inf) #cifar10 = tf.keras.datasets.cifar10 #(x_train,y_train), (x_test,y_test) = cifar10.load_data() (x_train,y_train), (x_test,y_test) = datasets.cifar10.load_data() x_train,x_test = x_train/255.0, x_test/255.0 class Baseline(Model): def __init__(self): #'在此准备出搭建神经网络要用的每一层结构,即CBAPD' super(Baseline, self).__init__() self.c1 = Conv2D(filters=6, kernel_size=(5, 5), padding='same') self.b1 = BatchNormalization() # BN层 self.a1 = Activation('relu') # 激活层 self.p1 = MaxPool2D(pool_size=(2, 2), strides=2, padding='same') # 池化层 self.d1 = Dropout(0.2) # dropout层 self.flatten = Flatten() self.f1 = Dense(128, activation='relu') self.d2 = Dropout(0.2) self.f2 = Dense(10, activation='softmax') def call(self, x): x = self.c1(x) x = self.b1(x) x = self.a1(x) x = self.p1(x) x = self.d1(x) x = self.flatten(x) x = self.f1(x) x = self.d2(x) y = self.f2(x) return y model = Baseline() model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), metrics=['sparse_categorical_accuracy']) checkpoint_save_path = "./checkpoint/Baseline.ckpt" if os.path.exists(checkpoint_save_path + '.index'): print('-------------load the model-----------------') model.load_weights(checkpoint_save_path) cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path, save_weights_only=True, save_best_only=True) history = model.fit(x_train, y_train, batch_size=32, epochs=20, validation_data=(x_test, y_test), validation_freq=1, callbacks=[cp_callback]) model.summary() # print(model.trainable_variables) file = open('./weights.txt', 'w') for v in model.trainable_variables: file.write(str(v.name) + '\n') file.write(str(v.shape) + '\n') file.write(str(v.numpy()) + '\n') file.close() acc = history.history['sparse_categorical_accuracy'] val_acc = history.history['val_sparse_categorical_accuracy'] loss = history.history['loss'] val_loss = history.history['val_loss'] plt.subplot(1, 2, 1) plt.plot(acc, label='Training Accuracy') plt.plot(val_acc, label='Validation Accuracy') plt.title('Training and Validation Accuracy') plt.legend() plt.subplot(1, 2, 2) plt.plot(loss, label='Training Loss') plt.plot(val_loss, label='Validation Loss') plt.title('Training and Validation Loss') plt.legend() plt.show()
标签:plt,cifar10,self,keras,离线,test,train,CNN,history 来源: https://www.cnblogs.com/zzx1905/p/load_cifar10_cnn.html