机器学习实战_k近邻算法识别手写数字
作者:互联网
代码如下:
import numpy as np
import operator
from os import listdir
def classify0(inX, dataSet, labels, k):
dataSetSize = dataSet.shape[0]
diffMat = np.tile(inX, (dataSetSize, 1)) - dataSet
sqDiffMat = diffMat ** 2
sqDistances = sqDiffMat.sum(axis=1)
distances = sqDistances ** 0.5
sortedDistIndices = distances.argsort()
classCount = {}
for i in range(k):
voteIlabel = labels[sortedDistIndices[i]]
classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
def img2vector(filename):
returnVect = np.zeros((1, 1024))
fr = open(filename)
for i in range(32):
lineStr = fr.readline()
for j in range(32):
returnVect[0, 32 * i + j] = int(lineStr[j])
return returnVect
def handwritingClassTest():
hwLabels = []
trainingFileList = listdir('trainingDigits')
m = len(trainingFileList)
trainingMat = np.zeros((m, 1024))
for i in range(m):
fileNameStr = trainingFileList[i]
classNumber = int(fileNameStr.split('_')[0])
hwLabels.append(classNumber)
trainingMat[i, :] = img2vector('trainingDigits/%s' % (fileNameStr))
testFileList = listdir('testDigits')
errorCount = 0.0
mTest = len(testFileList)
for i in range(mTest):
fileNameStr = testFileList[i]
classNumber = int(fileNameStr.split('_')[0])
vectorUnderTest = img2vector('testDigits/%s' % (fileNameStr))
classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
print("the classifier came back with: %d\t the real answer is: %d" % (classifierResult, classNumber))
if (classifierResult != classNumber):
errorCount += 1.0
print("the total number of errors is: %d\nthe total error rate is %f%%" % (errorCount, errorCount / mTest))
if __name__ == '__main__':
handwritingClassTest()
标签:__,classCount,classNumber,近邻,fileNameStr,range,算法,np,手写 来源: https://blog.csdn.net/fukangwei_lite/article/details/123122531