编程语言
首页 > 编程语言> > K邻近算法

K邻近算法

作者:互联网

1. k邻近算法概述:

k邻近算法简单直观,给定一个训练数据集,对新的输入实例,在训练数据集中找到与该实例最邻近的k个实例,这k个实例的多数属于某个类,就把该输入实例分为这个类

2. k邻近算法的模型复杂度体现在哪里?什么情况下会造成过拟合

k邻近算法模型复杂度体现在k值,k比较小时容易造成过拟合,k较大时容易造成欠拟合

3. 线性扫描算法

线性扫描算法步骤如下:

输入:训练数据集T={(x1,y1),(x2,y2),...(xn,yn)}

     待预测数据:(x_test)

     k值

(1)计算x_test与 xi的欧式距离

(2)欧式距离排序

(3)取前k个最小距离,对应训练数据点的类型y

(4)对k个y值进行统计

(5)返回频率出现最高的点

 1 import numpy as np
 2 from collection import Counter
 3 from draw import draw
 4 class KNN:
 5     def _init_(self,x_train,y_train,k=3):
 6         self.k=k
 7         self.x_train=x_train
 8         self.y_train=y_train
 9         
10     def predict(self,x_new):
11         #计算欧式距离
12         dist_list=[(np.linalg.norm(x_new=self.x_train[i],ord=2),self.y_train[i])
13                   for i in range(self.x_train.shape[0])]
14         dist_list.sort(key=lambda x: x[0])                  
15         y_list=[dist_list[i][-1] for i in range(self.k)]
16         #对上述k个点的分类进行统计
17         y_count=Counter(y_list).most_common()
18         return y_count[0][0]
19     
20 def main():
21     #首先输入训练数据
22     x_train=np.array([[5,4],
23                       [9,6],
24                       [4,7],
25                       [2,3],
26                       [8,1],
27                       [7,2]])
28     y_train=np.array([1,1,1,-1,-1,-1])
29     输入测试数据
30     x_new=np.array([5,3])
31     #绘图
32     draw(x_train,y_train,x_new)
33     for k in range(1,6,2):
34         #构建KNN实例
35         elf=KNN(x_train,y_train,k=k)
36         #对测试数据进行分类
37         y_predict=elf.predict(x_new)
38         print("k={},被分类为:{}").format(k,y_predict)
39 if __name__ == "main":
40     main()

 

标签:邻近,self,list,算法,train,np,new
来源: https://www.cnblogs.com/Cucucudeblog/p/10830615.html