基于k近邻的MNIST图像分类对比
作者:互联网
数据集读取
由于数据来源网站不稳定,个人将数据集下载到本地后进行读取
网上多数都是将数据集读取为三维数组方便进行显示,但因计算方便和用sklearn时都是二维数组,所以个人后来修改了下
def decode_idx3_ubyte(idx3_ubyte_file):
"""
解析idx3文件的通用函数
:param idx3_ubyte_file: idx3文件路径
:return: 数据集
"""
# 读取二进制数据
bin_data = gzip.open(idx3_ubyte_file, 'rb').read()
# 解析文件头信息,依次为魔数、图片数量、每张图片高、每张图片宽
offset = 0
fmt_header = '>IIII'
# 解析数据集
offset += struct.calcsize(fmt_header)
fmt_image = '>784B'
image_size = 100
# 判断是否是训练集
if 'train' in idx3_ubyte_file:
image_size = 6000
images = np.empty((image_size, 784))
for i in range(image_size):
temp = struct.unpack_from(fmt_image, bin_data, offset)
images[i] = np.reshape(temp, 784)
offset += struct.calcsize(fmt_image)
return images
def decode_idx1_ubyte(idx1_ubyte_file):
"""
解析idx1文件的通用函数
:param idx1_ubyte_file: idx1文件路径
:return: 数据集
"""
# 读取二进制数据
bin_data = gzip.open(idx1_ubyte_file, 'rb').read()
# 解析文件头信息,依次为魔数和标签数
offset = 0
fmt_header = '>II'
# 解析数据集
offset += struct.calcsize(fmt_header)
fmt_label = '>B'
label_size = 100
# 判断是否是训练集
if 'train' in idx1_ubyte_file:
label_size = 6000
labels = np.empty(label_size, np.int)
for i in range(label_size):
labels[i] = struct.unpack_from(fmt_label, bin_data, offset)[0]
offset += struct.calcsize(fmt_label)
return labels
这里控制了读取的数量,只使用了原数据集的十分之一
实现k近邻算法
class NearstNeighbour:
def __init__(self, k):
self.k = k
def train(self, X, y):
self.Xtr = X
self.ytr = y
return self
def predict(self, test_images):
predictions = []
# 这段代码借鉴https://github.com/Youngphone/KNN-MNIST/blob/master/KNN-MNIST.ipynb
# 当前运行的测试用例坐标
for test_item in test_images:
datasetsize = self.Xtr.shape[0]
#距离计算公式
diffMat = np.tile(test_item, (datasetsize, 1)) - self.Xtr
sqDiffMat = diffMat ** 2
sqDistances = sqDiffMat.sum(axis = 1)
distances = sqDistances ** 0.5
# 距离从大到小排序,返回距离的序号
sortedDistIndicies = distances.argsort()
# 字典
classCount = {}
# 前K个距离最小的
for i in range(self.k):
# sortedDistIndicies[0]返回的是距离最小的数据样本的序号
# labels[sortedDistIndicies[0]]距离最小的数据样本的标签
voteIlabel = self.ytr[sortedDistIndicies[i]]
# 若属于某类则权重加一
classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
# 排序
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
predictions.append(sortedClassCount[0][0])
return predictions
与sklearn的k近邻对比
# -*- encoding: utf-8 -*-
'''
@File : NearstNeighbour.py
@Time : 2021/03/27 15:40:05
@Author : Wihau
@Version : 1.0
@Desc : None
'''
# here put the import lib
import gzip
import numpy as np
import struct
import operator
import time
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score, confusion_matrix
train_images_idx3_ubyte_file = 'train-images-idx3-ubyte.gz'
train_labels_idx1_ubyte_file = 'train-labels-idx1-ubyte.gz'
test_images_idx3_ubyte_file = 't10k-images-idx3-ubyte.gz'
test_labels_idx1_ubyte_file = 't10k-labels-idx1-ubyte.gz'
def decode_idx3_ubyte(idx3_ubyte_file):
"""
解析idx3文件的通用函数
:param idx3_ubyte_file: idx3文件路径
:return: 数据集
"""
# 读取二进制数据
bin_data = gzip.open(idx3_ubyte_file, 'rb').read()
# 解析文件头信息,依次为魔数、图片数量、每张图片高、每张图片宽
offset = 0
fmt_header = '>IIII'
# 解析数据集
offset += struct.calcsize(fmt_header)
fmt_image = '>784B'
image_size = 100
# 判断是否是训练集
if 'train' in idx3_ubyte_file:
image_size = 6000
images = np.empty((image_size, 784))
for i in range(image_size):
temp = struct.unpack_from(fmt_image, bin_data, offset)
images[i] = np.reshape(temp, 784)
offset += struct.calcsize(fmt_image)
return images
def decode_idx1_ubyte(idx1_ubyte_file):
"""
解析idx1文件的通用函数
:param idx1_ubyte_file: idx1文件路径
:return: 数据集
"""
# 读取二进制数据
bin_data = gzip.open(idx1_ubyte_file, 'rb').read()
# 解析文件头信息,依次为魔数和标签数
offset = 0
fmt_header = '>II'
# 解析数据集
offset += struct.calcsize(fmt_header)
fmt_label = '>B'
label_size = 100
# 判断是否是训练集
if 'train' in idx1_ubyte_file:
label_size = 6000
labels = np.empty(label_size, np.int)
for i in range(label_size):
labels[i] = struct.unpack_from(fmt_label, bin_data, offset)[0]
offset += struct.calcsize(fmt_label)
return labels
def load_train_images(idx_ubyte_file=train_images_idx3_ubyte_file):
return decode_idx3_ubyte(idx_ubyte_file)
def load_train_labels(idx_ubyte_file=train_labels_idx1_ubyte_file):
return decode_idx1_ubyte(idx_ubyte_file)
def load_test_images(idx_ubyte_file=test_images_idx3_ubyte_file):
return decode_idx3_ubyte(idx_ubyte_file)
def load_test_labels(idx_ubyte_file=test_labels_idx1_ubyte_file):
return decode_idx1_ubyte(idx_ubyte_file)
class NearstNeighbour:
def __init__(self, k):
self.k = k
def train(self, X, y):
self.Xtr = X
self.ytr = y
return self
def predict(self, test_images):
predictions = []
# 当前运行的测试用例坐标
for test_item in test_images:
datasetsize = self.Xtr.shape[0]
#距离计算公式
diffMat = np.tile(test_item, (datasetsize, 1)) - self.Xtr
sqDiffMat = diffMat ** 2
sqDistances = sqDiffMat.sum(axis = 1)
distances = sqDistances ** 0.5
# 距离从大到小排序,返回距离的序号
sortedDistIndicies = distances.argsort()
# 字典
classCount = {}
# 前K个距离最小的
for i in range(self.k):
# sortedDistIndicies[0]返回的是距离最小的数据样本的序号
# labels[sortedDistIndicies[0]]距离最小的数据样本的标签
voteIlabel = self.ytr[sortedDistIndicies[i]]
# 若属于某类则权重加一
classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
# 排序
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
predictions.append(sortedClassCount[0][0])
return predictions
train_images = load_train_images()
train_labels = load_train_labels()
test_images = load_test_images()
test_labels = load_test_labels()
k = 5
# 个人k近邻预测
print("-----Personal k nearest neighbour-----")
# 预测时间
start = time.time()
knn = NearstNeighbour(k)
predictions = knn.train(train_images, train_labels).predict(test_images)
end = time.time()
print("time of prediction:%.3f s" % (end-start))
# 准确率
accuracy = accuracy_score(test_labels, predictions)
print("accuracy score:", accuracy)
# 混淆矩阵
matrix = confusion_matrix(test_labels, predictions)
print(matrix)
# sklearn的k近邻预测
print("-----Sklearn nearest neighbour-----")
# 预测时间
start = time.time()
sknn = KNeighborsClassifier(n_neighbors = k)
skpredictions = sknn.fit(train_images, train_labels).predict(test_images)
end = time.time()
print("time of prediction:%.3f s" % (end-start))
# 准确率
skaccuracy = accuracy_score(test_labels, skpredictions)
print("accuracy score:", skaccuracy)
# 混淆矩阵
skmatrix = confusion_matrix(test_labels, skpredictions)
print(skmatrix)
结果如下
k = 5 时
k = 10 时
标签:近邻,labels,test,train,file,图像,images,ubyte,MNIST 来源: https://blog.csdn.net/qq_39376697/article/details/115283738