7 Fashion 数据识别
作者:互联网
利用上篇训练好的模型,如何在实际中使用。
import os
from PIL import Image
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Flatten ,Dense,Conv1D,MaxPool1D
class Mymodel(Model):
def __init__(self):
super(Mymodel, self).__init__()
self.c1 = Conv1D(filters=256,kernel_size=2,padding='same',activation='relu')
self.m1 = MaxPool1D(pool_size=2,strides=2,padding='same')
self.flatten = Flatten()
self.d1 = Dense(units=128,activation=tf.keras.activations.relu)
self.d2 = Dense(units=10,activation=tf.keras.activations.softmax)
def call(self,x):
x = self.c1(x)
x = self.m1(x)
x = self.flatten(x)
x = self.d1(x)
y = self.d2(x)
return y
model = Mymodel()
checkpoint_save_path = './models/fashionconv/fashion.ckpt'
if os.path.exists(checkpoint_save_path+'.index'):
print('-------------load model-----------')
model.load_weights(checkpoint_save_path)
while 1 :
perNum = int(input('请输入要识别的图片:'))
predict_label = ['t-shirt','trouser','pullover','dress','coat','sandal','shirt','sneaker','bag','ankle_boot']
img_path = './data/class4/FASHION_FC/{}.jpeg'.format(perNum)
img= Image.open(img_path)
img = img.resize((28,28),Image.ANTIALIAS)
img_arr = np.array(img.convert('L'))
for i in range(28):
for j in range(28):
if img_arr[i][j]< 200:
img_arr[i][j] = 255
else:
img_arr[i][j] = 0
img_arr = img_arr / 255.0
x_predict = img_arr[tf.newaxis,...]
result = model.predict(x_predict)
pred = tf.argmax(result,axis=1)
print('\n')
i = tf.gather(pred,0)
print(predict_label[i])
标签:arr,Fashion,img,self,path,tf,import,识别,数据 来源: https://blog.csdn.net/qq_35779738/article/details/110227413