线性判别准则与线性分类编程实践
作者:互联网
一、线性判别分析介绍
线性判别分析(Linear Discriminant Analysis,简称 L D A LDALDA)是一种经典的线性学习方法,亦称"Fisher 判别分析"。
线性判别分析思想:给定训练样本集,设法将样例投影到一条直线上。使得同类样例的投影点尽可能接近、异类样例的投影点尽可能远;在对新样本进行分类时,将其投影到该直线上,再根据投影点的位置来确定新样本的类别。
二、线性判别分析原理
1. 类内散度矩阵(within-class scatter matrix)
类内散度矩阵
用来判断同类样例的投影点之间的距离。
2. 类间散度矩阵(between-class scatter matrix)
类间散度矩阵
用来判断异类样例的投影点之间的距离。
3. 广义瑞利商(generalized Rayleigh quotiet)
广义瑞利商(generalized Rayleigh quotiet)就是 L D A LDALDA欲最大化的目标,使用类内散度矩阵和类间散度矩阵将最大化目标改写为:
LDA可从贝叶斯决策理论的角度来阐释,并可证明,当两类数据同先验、满足高斯分布且协方差相等时,LDA可达到最优分类。
三、sklearn库实现线性判别分析LDA
- 数据生成
#生成200个三个维度样本 import numpy as np import pandas as pd import matplotlib.pyplot as plt from matplotlib.colors import ListedColormap from sklearn.datasets import make_classification x, y = make_classification(n_samples=200, n_features=2, n_redundant=0, n_classes=2, n_informative=2,n_clusters_per_class=2,class_sep =1, random_state =0) fig = plt.figure() plt.scatter(x[:, 0], x[:, 1], c=y)
- 数据处理
#设置分类平滑度 h = .01 #设置X和Y的边界值 x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1 y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1 #使用meshgrid函数返回X和Y两个坐标向量矩阵 xx, yy = np.meshgrid(np.arange(x_min, x_max,h), np.arange(y_min, y_max,h)) Z = lda.predict(np.c_[xx.ravel(), yy.ravel()])
- 数据集划分
from sklearn.model_selection import train_test_split x_train,x_test,y_train,y_test = train_test_split(x, y, random_state=33, test_size=0.25)
- LDA分类
#使用LDA进行降维 from sklearn.discriminant_analysis import LinearDiscriminantAnalysis from sklearn.linear_model import LogisticRegression lda = LinearDiscriminantAnalysis(n_components=1) x_train_lda = lda.fit_transform(x_train, y_train) # LDA是有监督方法,需要用到标签 x_test_lda = lda.fit_transform(x_test, y_test) # 预测时候特征向量正负问题,乘-1反转镜像
- 绘制训练集分类图像
#设置colormap颜色 cm_bright = ListedColormap(['#D9E021', '#0D8ECF']) #绘制数据点 plt.scatter(x_train[:, 0], x_train[:, 1], c=y_train, cmap=cm_bright) plt.title('Linear Discriminant Analysis Classifiers') plt.axis('tight') plt.show()
- 绘制测试集分类图
plt.title('Linear Discriminant Analysis Classifiers') plt.scatter(x_test[:, 0], x_test[:, 1], c=y_test, cmap=cm_bright) plt.show()
四、总结
LDA算法既可以用来降维,也可以用来分类,但是目前来说,主要还是用于降维,和PCA类似,LDA降维基本也不用调参,只需要指定降维到的维数即可。
五、参考
【机器学习】机器学习之线性判别分析(LDA)_YangMax1的博客-CSDN博客
标签:LDA,plt,判别,编程,判别分析,test,train,线性,import 来源: https://blog.csdn.net/w798214705/article/details/121167607