统计学习方法学习笔记第二章(感知机)
作者:互联网
感知机是一个二分类的线性分类模型,是神经网络和支持向量机的基础。
考虑统计学习方法三要素:
模型:f(x) = sign(w*x+b)
策略:收敛前提条件:数据集是线性可分的
学习策略:考虑每一个点到超平面的距离:(二维的点到平面距离公式),对于分类错误的数据,yi*(w*xi_b)<0,则令损失函数为-Σyi*(w*xi+b)
算法:SGD,若分类错误则计算梯度然后更新。
算法的收敛性:
收敛次数k有上界(R/γ)^2,其中R是x1到xn中二阶范式的最大值,γ是满足对任意i,yi*(wopt * xi + bopt) >= γ的一个常数(由于判别是否分类正确的条件是yi*(wopt * xi + bopt) > 0, 所以肯定可以找到这样一个γ)
对偶形式的算法:
观察更新的过程可以发现,令αi = ni * η,则w = ∑ αi * yi * xi, b = ∑ αi * yi, 其中ni是算法运行过程中数据i被误分类的次数。
实际处理的过程中,可以预处理出Gram矩阵(作用是保存任意两个特征向量的叉积)
两种不同是算法对应的支持向量机的原始形式和对偶形式。
算法测试代码:
import random import numpy as np import math import torch import matplotlib.pyplot as plt n = 10 k = -1 bias = 1 lr = 1 def line(x): return k * x + bias def learn(x, y): w = np.random.randn(2) b = 0 epoch = 0 while True: flag = False count = 0 for i in range(len(x)): if (y[i] * (np.dot(w, x[i]) + b)) <= 0: # print('updated') w = w + lr * y[i] * x[i] b = b + lr * y[i] flag = True count += 1 if not flag: break epoch += 1 print('epoch count', epoch, count) return w, b Gram = np.zeros((n, n)) def dual_learn(x, y): alpha = np.zeros(n) b = 0 epoch = 0 while True: flag = False count = 0 for i in range(len(x)): tot = 0 for j in range(len(x)): tot += alpha[j] * y[j] * Gram[j][i] tot = y[i] * (tot + b) if tot <= 0: print('updated') alpha[i] += lr b += lr * y[i] flag = True count += 1 if not flag: break epoch += 1 print('epoch count', epoch, count) return alpha, b if __name__ == '__main__': x = np.random.uniform(0, 1, (n, 2)) y = np.zeros(n) for i in range(n): for j in range(n): Gram[i][j] = np.dot(x[i], x[j]) print(Gram) for i in range(len(y)): if x[i][1] > line(x[i][0]): y[i] = 1 else: y[i] = -1 # plt.show() print(x, y) alpha, b = dual_learn(x, y) print(alpha) w = np.zeros(2) for i in range(n): w += alpha[i] * y[i] * x[i] vector = [0, 0] ax = plt.gca() ax.plot(x) plt.show() for i in range(10): vector[0] = float(input()) vector[1] = float(input()) # print(w) print(vector) print(np.dot(w, np.array(vector).astype('float64')) + b)
以y=-x+1(0 <= x <= 1, 0 <= y <= 1)为分界线,随机生成一些点进行标记,通过SGD进行更新。
标签:yi,print,学习,感知机,算法,vector,np,import,第二章 来源: https://www.cnblogs.com/pkgunboat/p/15754780.html