# -*- encoding: utf-8 -*-
@File    :   NearstNeighbour.py
@Time    :   2021/03/27 15:40:05
@Author  :   Wihau 
@Version :   1.0
@Desc    :   None

# here put the import lib
import gzip
import numpy as np
import struct
import operator
import time

from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score, confusion_matrix

train_images_idx3_ubyte_file = 'train-images-idx3-ubyte.gz'
train_labels_idx1_ubyte_file = 'train-labels-idx1-ubyte.gz'
test_images_idx3_ubyte_file = 't10k-images-idx3-ubyte.gz'
test_labels_idx1_ubyte_file = 't10k-labels-idx1-ubyte.gz'

def decode_idx3_ubyte(idx3_ubyte_file):
    :param idx3_ubyte_file: idx3文件路径
    :return: 数据集
    # 读取二进制数据
    bin_data = gzip.open(idx3_ubyte_file, 'rb').read()

    # 解析文件头信息,依次为魔数、图片数量、每张图片高、每张图片宽
    offset = 0
    fmt_header = '>IIII'

    # 解析数据集
    offset += struct.calcsize(fmt_header)
    fmt_image = '>784B'
    image_size = 100
    # 判断是否是训练集
    if 'train' in idx3_ubyte_file:
        image_size = 6000        
    images = np.empty((image_size, 784))
    for i in range(image_size):
        temp = struct.unpack_from(fmt_image, bin_data, offset)
        images[i] = np.reshape(temp, 784)
        offset += struct.calcsize(fmt_image)
    return images

def decode_idx1_ubyte(idx1_ubyte_file):
    :param idx1_ubyte_file: idx1文件路径
    :return: 数据集
    # 读取二进制数据
    bin_data = gzip.open(idx1_ubyte_file, 'rb').read()

    # 解析文件头信息,依次为魔数和标签数
    offset = 0
    fmt_header = '>II'

    # 解析数据集
    offset += struct.calcsize(fmt_header)
    fmt_label = '>B'
    label_size = 100
    # 判断是否是训练集
    if 'train' in idx1_ubyte_file:
        label_size = 6000        
    labels = np.empty(label_size, np.int)
    for i in range(label_size):
        labels[i] = struct.unpack_from(fmt_label, bin_data, offset)[0]
        offset += struct.calcsize(fmt_label)
    return labels

def load_train_images(idx_ubyte_file=train_images_idx3_ubyte_file):
    return decode_idx3_ubyte(idx_ubyte_file)

def load_train_labels(idx_ubyte_file=train_labels_idx1_ubyte_file):
    return decode_idx1_ubyte(idx_ubyte_file)

def load_test_images(idx_ubyte_file=test_images_idx3_ubyte_file):
    return decode_idx3_ubyte(idx_ubyte_file)

def load_test_labels(idx_ubyte_file=test_labels_idx1_ubyte_file):
    return decode_idx1_ubyte(idx_ubyte_file)

class NearstNeighbour:
    def __init__(self, k):
        self.k = k
    def train(self, X, y):
        self.Xtr = X
        self.ytr = y
        return self
    def predict(self, test_images):
        predictions = []

        # 当前运行的测试用例坐标
        for test_item in test_images:
            datasetsize = self.Xtr.shape[0]
            diffMat = np.tile(test_item, (datasetsize, 1)) - self.Xtr
            sqDiffMat = diffMat ** 2
            sqDistances = sqDiffMat.sum(axis = 1)
            distances = sqDistances ** 0.5
            # 距离从大到小排序,返回距离的序号
            sortedDistIndicies = distances.argsort()
            # 字典
            classCount = {}
            # 前K个距离最小的
            for i in range(self.k):
                # sortedDistIndicies[0]返回的是距离最小的数据样本的序号
                # labels[sortedDistIndicies[0]]距离最小的数据样本的标签
                voteIlabel = self.ytr[sortedDistIndicies[i]]
                # 若属于某类则权重加一
                classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
            # 排序
            sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)

        return predictions

train_images = load_train_images()
train_labels = load_train_labels()
test_images = load_test_images()
test_labels = load_test_labels()

k = 5
# 个人k近邻预测
print("-----Personal k nearest neighbour-----")
# 预测时间
start = time.time()
knn = NearstNeighbour(k)
predictions = knn.train(train_images, train_labels).predict(test_images)
end = time.time()
print("time of prediction:%.3f s" % (end-start))
# 准确率
accuracy = accuracy_score(test_labels, predictions)
print("accuracy score:", accuracy)
# 混淆矩阵
matrix = confusion_matrix(test_labels, predictions)

# sklearn的k近邻预测
print("-----Sklearn nearest neighbour-----")
# 预测时间
start = time.time()
sknn = KNeighborsClassifier(n_neighbors = k)
skpredictions = sknn.fit(train_images, train_labels).predict(test_images)
end = time.time()
print("time of prediction:%.3f s" % (end-start))
# 准确率
skaccuracy = accuracy_score(test_labels, skpredictions)
print("accuracy score:", skaccuracy)
# 混淆矩阵
skmatrix = confusion_matrix(test_labels, skpredictions)


k = 5 时
k = 10 时

