其他分享
首页 > 其他分享> > fashion_mnist多分类训练,两种模型的保存与加载

fashion_mnist多分类训练,两种模型的保存与加载

作者:互联网

from tensorflow.python.keras.preprocessing.image import load_img,img_to_array
from tensorflow.python.keras.models import Sequential,Model
from tensorflow.python.keras.layers import Dense,Flatten,Input
import tensorflow as tf
from tensorflow.python.keras.losses import sparse_categorical_crossentropy
from tensorflow.python import keras
import os
import numpy as np

class SingleNN(object):

    #建立神经网络模型
    model = keras.Sequential([
        keras.layers.Flatten(input_shape=(28,28)),
        keras.layers.Dense(128,activation=tf.nn.relu),
        keras.layers.Dense(10,activation=tf.nn.softmax)
    ])

    def __init__(self):
        (self.x_train,self.y_train),(self.x_test,self.y_test) = keras.datasets.fashion_mnist.load_data()
        #归一化
        self.x_train = self.x_train/255.0
        self.x_test = self.x_test/255.0

    def singlenn_compile(self):
        '''
        编译模型优化器、损失、准确率
        :return:
        '''
        SingleNN.model.compile(
            optimizer=keras.optimizers.SGD(lr=0.01),
            loss=keras.losses.sparse_categorical_crossentropy,
            metrics=['accuracy']
        )

    def singlenn_fit(self):
        """
        进行fit训练
        :return: 
        """
        SingleNN.model.fit(self.x_train,self.y_train,epochs=5)

    def single_evalute(self):
        '''
        模型评估
        :return: 
        '''
        test_loss,test_acc = SingleNN.model.evaluate(self.x_test,self.y_test)
        print(test_loss,test_acc)

    def single_predict(self):
        '''
        预测结果
        :return: 
        '''
        # if os.path.exists("./ckpt/checkpoink"):
        #     SingleNN.model.load_weights("./ckpt/SingleNN")

        if os.path.exists("./ckpt/SingleNN.h5"):
            SingleNN.model.load_weights("./ckpt/SingleNN.h5")

        predictions = SingleNN.model.predict(self.x_test)

        return predictions

if __name__ == '__main__':
    snn = SingleNN()
    # snn.singlenn_compile()
    # snn.singlenn_fit()
    # snn.single_evalute()
    # # SingleNN.model.save_weights("./ckpt/SingleNN")
    # SingleNN.model.save_weights("./ckpt/SingleNN.h5")
    predictions = snn.single_predict()
    print(predictions)
    result = np.argmax(predictions,axis=1)
    print(result)

  

标签:fashion,keras,self,SingleNN,test,import,model,mnist,加载
来源: https://www.cnblogs.com/LiuXinyu12378/p/12250596.html