其他分享
首页 > 其他分享> > 数字识别手写

数字识别手写

作者:互联网

数字识别手写

import numpy as np
# bmp 图片后缀
import matplotlib.pyplot as plt
%matplotlib inline

from sklearn.neighbors import KNeighborsClassifier

# 提炼样本数据
img = plt.imread('./data/3/3_33.bmp')
plt.imshow(img)
img.shape
# (28, 28)

# 把5000张的图片全部读取出来

feature = []
target = []
#./data/3/3_33.bmp
for i in range(0,10):
    for j in range(1,501):
        img_path = './data/'+str(i)+'/'+str(i)+'_'+str(j)+'.bmp'    # 拼接路径
        img_arr = plt.imread(img_path)
        feature.append(img_arr)
        target.append(i)
        

# 
feature = np.array(feature)
target = np.array(target)


feature.shape  #特征是三维 (5000,28,28)
# 需要变形成二维
feature = feature.reshape(5000,784)
feature.shape   # (5000, 784)
# target.shape  target是几维的无所谓
# 将样本打乱
np.random.seed(3)  # 加seed是保证每次random的时候一样,保证一致性
np.random.shuffle(feature)
np.random.seed(3)
np.random.shuffle(target)


# 获取训练数据和测试数据
x_train = feature[:4950]
y_train = target[:4950]

x_test = feature[4950:]
y_test = target[4950:]


# 实例化模型对象,训练
knn = KNeighborsClassifier(n_neighbors=15)
knn.fit(x_train,y_train)
knn.score(x_test,y_test)
# 0.96
#保存训练好的模型
from sklearn.externals import joblib
joblib.dump(knn,'./digist_knn.m')

# 将保存的模型文件load加载出来
knn = joblib.load('./digist_knn.m')

# 查看knn
knn
KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
           metric_params=None, n_jobs=None, n_neighbors=15, p=2,
           weights='uniform')



#使用测试数据测试模型的精准度
print('已知分类:',y_test)
print('模型分类结果:',knn.predict(x_test))

已知分类: [4 5 7 9 7 5 7 6 8 6 4 1 3 4 8 4 2 0 1 2 0 5 8 6 5 9 3 9 1 8 9 6 4 1 5 2 8
 7 7 2 5 3 5 5 6 1 1 3 6 3]
模型分类结果: [4 5 7 9 7 5 7 6 8 6 1 1 3 4 8 4 1 0 1 2 0 5 8 6 5 9 3 9 1 8 9 6 4 1 5 2 8
 7 7 2 5 3 5 5 6 1 1 3 6 3]
#将外部图片带入模型进行识别
img_arr = plt.imread('./数字.jpg')
plt.imshow(img_arr)

# 将8切出来
eight = img_arr[180:230,95:125]
plt.imshow(eight)

eight.shape # (50, 30, 3) 是三维的

# 降维操作
eight = eight.mean(axis=2)
eight.shape  # (50, 30)

# 像素的等比例压缩
import scipy.ndimage as ndimage
eight = ndimage.zoom(eight,zoom = (28/50,28/30))
eight.shape # (28, 28)

plt.imshow(eight)

eight = eight.reshape((1,784))
eight.shape  # (1, 784)


knn.predict(eight)  # array([4])  识别的有可能会有错误

标签:knn,数字,img,feature,shape,eight,手写,识别,target
来源: https://www.cnblogs.com/Quantum-World/p/11353720.html