java手撕KMeans算法实现手写数字聚类(失败案例)
作者:互联网
最近几天刚刚接触机器学习,学完K-Means聚类算法。正好又赶上一个课程项目是识别“手写数字”,因为KMeans能够实现聚类,因此自然而然地想要通过KMeans来实现。
前排提示:这是kmeans聚类的一个失败案例,没有成功聚类,仅供参考。
一,什么是KMeans聚类算法??
非常传统的聚类算法,目的是将一堆数据进行分类。
它的思想很朴素:假设这里有一群点,要将这些点分成两类。要是分成的类很合理的话,那不同类之间的中心点相聚是不是应该足够大,中心点附近的同一类的点是不是应该足够多?
举个例子:
a表示的是一堆原始点,没有处理。要将a聚类成两类,先随便找到两个点,计算所有点到这两个点的距离(欧式距离,曼哈顿距离,闵式距离等等都可以),根据距离最近的原则分配成两类。这时候是不是就能够得到两类的中心点,然后再次重复操作,直到最后聚出来的类不会发生变化。
so easy 是不是
二,使用的手写数字测试集??
我们在这里使用的是mnist测试集。这家伙的知名程度在机器学习中相当于是hello world了。不知道的小伙伴可以去查查。
但是一定有人会问到,mnist测试集应该怎么通过java使用呢?
不用担心,我用Python通过TensorFlow将mnist测试集打包成了txt文件,用java的文件操作直接调用就可以了。
具体效果像这样:
这是28 * 28的二维int数组,每个值介于0到255之间,熟悉图像处理的小伙伴一定知道这是灰度值,0表示最黑,255表示最亮,因此这是黑纸白字的测试集,大家要是自己写测试数据的使用要记着对图片进行预处理,要不然可能会出错。
我将txt命名为:数字名-标号的形式,方便之后训练和测试。
三,java手撕KMeans算法
先摆上一个算法流程图
1.首先定义:
训练图片(50000 * 28 * 28 的三维数组)
聚类中心(10 * 28 * 28的三维数组)
每张图片到聚类中心的距离(50000 * 10 的二维数组)
旧的类和新的类(ArrayList[] 数组,因为不知道一个类中到底会有多少个图片)
static float[][][] num = new float[50000][28][28];
static float[][][] center = new float[10][28][28];// 聚类中心
static long[][] distance = new long[num.length][10];
static ArrayList<Integer>[] oldKinds = new ArrayList[10];// 旧的聚类
static ArrayList<Integer>[] newKinds = new ArrayList[10];
2.定义方法:
从Txt文件导入测试数据的方法
public static void getTXT(String path,int img,int x,int y) throws IOException {
File file = new File(path);
FileInputStream fis = new FileInputStream(file);
InputStreamReader isr = new InputStreamReader(fis);
BufferedReader br = new BufferedReader(isr);
String line;
while((line = br.readLine()) != null){
boolean isNum = false;
for(int i = 0;i < line.length();i ++){
if(line.charAt(i) != ' ' && !isNum){
// 如果遇到数字
isNum = true;
float tempNum = 0;
// 取数字
while(i < line.length() && line.charAt(i) != ' '){
tempNum = tempNum * 10 + line.charAt(i) - '0';
i++;
}
isNum = false;
if(y < 28){
}
else{
y = 0;
x ++;
}
num[img][x][y] = tempNum;
y++;
}
}
}
br.close();
}
获得图片到聚类中心距离的方法
// 得到距离
public static long getDistance(float[][] n,float[][] k){
long ret = 0;
for (int i = 0; i < 28;i ++){
for (int j = 0; j < 28; j ++){
ret += Math.pow((n[i][j] - k[i][j]),2);
}
}
return ret;
}
得到图片距离最近聚类中心索引的方法
// 获得数组元素最小值对应的下标
public static int getMinIndex(long dis[]){
int index = -1;
long min = Integer.MAX_VALUE;
for(int i = 0; i < 10;i ++){
if(dis[i] < min){
index = i;
min = dis[i];
}
}
return index;
}
比较旧的聚类和新的聚类是否相同的方法
public static boolean isSame(){
for(int i = 0; i < 10 ;i ++){
for(int j = 0; j < newKinds[i].size();j ++){
if(newKinds[i].size() != oldKinds[i].size()) return false;
if (newKinds[i].get(j).intValue() != oldKinds[i].get(j).intValue() ) {
return false;
}
}
}
return true;
}
需要注意的是!!!
两个Integer的比较需要通过.intValue()的方法先转换成为int!!!再进行比较,否则会因为内存什么什么奇奇怪怪的原因导致出现130 != 130这种很天真的错误。
我在这里被坑了一次,希望看到这片文章的人能够避一下坑。
3.开始while(true)死循环,直到旧类和新类相等不发生改变
int kindTime = 0;
while(true){
// 3.计算每个文件和当前类中心之间的距离
for (int i = 0; i < num.length; i++){
for (int j = 0; j < 10; j++){
distance[i][j] = getDistance(num[i],center[j]);
}
}
// 更新旧类
for(int i = 0;i < 10;i ++){
oldKinds[i].clear();
for(int j = 0 ; j < newKinds[i].size();j ++){
oldKinds[i].add(newKinds[i].get(j));
}
}
// 更新新类
for (int i = 0; i < 10 ; i ++){
newKinds[i].clear();
}
for (int i = 0; i < num.length; i ++){
// 获得距离最小值,将其放到对应的类中
newKinds[getMinIndex(distance[i])].add(i);
}
// 4.更新聚类中心
for(int i = 0; i < 10; i ++){
for(int x = 0; x < 28; x++){
for(int y = 0; y < 28;y ++){
center[i][x][y] = getAverage(newKinds[i],x,y);
}
}
}
// 5.重复步骤,直到类不再发生改变
if(isSame()){
break;
}
System.out.println("第"+kindTime+"次聚类");
kindTime++;
}
4.保存类中心点
因为如果训练数据不变的话,聚类聚出的中心是不会变化的,所以为了避免之后聚类的重复操作,我们还是将得到的聚类中心点保存成为txt文件放到电脑上比较好。
// 保存聚类中心点
public static void saveKind(int index){
FileWriter out = null;
String path = "D:\\java\\workSpace\\KMeans\\" + index + "kinds.txt";
File file = new File(path);
try {
out = new FileWriter(file);
//二维数组按行存入到文件中
for (int i = 0; i < center[index].length; i++) {
for (int j = 0; j < center[index][i].length; j++) {
//将每个元素转换为字符串
String content = String.valueOf(center[index][i][j]) + " ";
out.write(content + "\t");
}
out.write("\r\n");
}
out.close();
} catch (IOException e) {
e.printStackTrace();
}
}
到现在,所有kmeans要求的操作我们都已经实现了。我们看看效果怎么样吧
1.我从test测试集(刚刚是train训练集)中导入了8000张图片,0到9每个数字各800张。
导入的方式和上文中的相同,这里就不在赘述了。
然后通过刚刚聚出来的类中心对测试数据进行聚类。(因为kmeans是无监督聚类吗,所以我也不知道每个类中心代表的哪个数字)
这是最后聚出来的结果:
发现大问题!!!我将每个类聚到的数字分别列出来。比如第0类,聚到4个数字0,3个数字1……
最后得到的结果,很!不!理!想!
通过分析可以看到,数字1的聚类效果最好,800张图片中有787张被聚到第7类中了,但是第7类也混入了不少其他数字,还有129张2是什么鬼?!
其他的类就更不用说了,混杂了很多数字。
经过缜密思考之后,我认为是k的数值设置的问题,因为我们想要聚类出10个数字,所以很主观地将k设置成为了10,没有思考相同数字,因为书写原因而出现的数字内部聚类的问题。
就像数字0,分别被聚到了第1类和第4类中,这两类很少有其他数字。因此是将数字0进行了分类,把高的0矮的0胖的0瘦的0分开了!而不是将0之外的数字分开。
或许可以通过改变k的值进行改进呢!
这片文章才差不多就是这样了。最后贴上代码。
如果有朋友想要mnist手写数字数据集的txt文件,可以给我留言邮箱信息哦,我抽时间会发送的。
欢迎大佬们批评指正!
// 首先是kmeans聚类的代码
import java.io.*;
import java.util.ArrayList;
public class KMeans {
// KMeans算法实现手写数字聚类
static float[][][] num = new float[50000][28][28];
static float[][][] center = new float[10][28][28];// 聚类中心
static long[][] distance = new long[num.length][10];
static ArrayList<Integer>[] oldKinds = new ArrayList[10];// 旧的聚类
static ArrayList<Integer>[] newKinds = new ArrayList[10];
public static void main(String[] args) throws IOException {
// 1.读取文件
System.out.println("导入文件中……");
for (int i = 0;i < num.length;i ++){
getTXT("D:\\Python\\jupyter\\trains2\\" + Integer.toString(i/5000) + "-" + Integer.toString(i%5000 + 1) + ".txt",i,0,0);
if(i % 1000 == 0) System.out.println("已导入文件:" + i);
}
System.out.println("导入文件成功!!!");
// 随机选择聚类中心
for(int i = 0; i < 10; i ++){
oldKinds[i] = new ArrayList<>();
}
for(int i = 0 ; i < 10;i ++) {
transTwoArray(num[i], center[i]);
newKinds[i] = new ArrayList<>();
newKinds[i].add(i);
}
int kindTime = 0;
while(true){
// 3.计算每个文件和当前类中心之间的距离
for (int i = 0; i < num.length; i++){
for (int j = 0; j < 10; j++){
distance[i][j] = getDistance(num[i],center[j]);
}
}
// 更新旧类
for(int i = 0;i < 10;i ++){
oldKinds[i].clear();
for(int j = 0 ; j < newKinds[i].size();j ++){
oldKinds[i].add(newKinds[i].get(j));
}
}
// 更新新类
for (int i = 0; i < 10 ; i ++){
newKinds[i].clear();
}
for (int i = 0; i < num.length; i ++){
// 获得距离最小值,将其放到对应的类中
newKinds[getMinIndex(distance[i])].add(i);
}
// 4.更新聚类中心
for(int i = 0; i < 10; i ++){
for(int x = 0; x < 28; x++){
for(int y = 0; y < 28;y ++){
center[i][x][y] = getAverage(newKinds[i],x,y);
}
}
}
// 5.重复步骤,直到类不再发生改变
if(isSame()){
break;
}
System.out.println("第"+kindTime+"次聚类");
kindTime++;
}
// 保存聚类中心
System.out.println("聚类成功!!!");
System.out.println("-------------------------");
System.out.println("保存类中心点中……");
for(int i = 0; i < 10;i ++){
saveKind(i);
}
System.out.println("保存类中心点成功!!!");
}
// 读取文件
public static void getTXT(String path,int img,int x,int y) throws IOException {
File file = new File(path);
FileInputStream fis = new FileInputStream(file);
InputStreamReader isr = new InputStreamReader(fis);
BufferedReader br = new BufferedReader(isr);
String line;
while((line = br.readLine()) != null){
boolean isNum = false;
for(int i = 0;i < line.length();i ++){
if(line.charAt(i) != ' ' && !isNum){
// 如果遇到数字
isNum = true;
float tempNum = 0;
// 取数字
while(i < line.length() && line.charAt(i) != ' '){
tempNum = tempNum * 10 + line.charAt(i) - '0';
i++;
}
isNum = false;
if(y < 28){
}
else{
y = 0;
x ++;
}
num[img][x][y] = tempNum;
y++;
}
}
}
br.close();
}
// 转移两个数组
public static void transTwoArray(float[][] array1,float[][] array2){
for(int i = 0; i < 28;i ++){
for (int j = 0; j < 28;j ++){
array2[i][j] = array1[i][j];
}
}
}
// 得到距离
public static long getDistance(float[][] n,float[][] k){
long ret = 0;
for (int i = 0; i < 28;i ++){
for (int j = 0; j < 28; j ++){
ret += Math.pow((n[i][j] - k[i][j]),2);
}
}
return ret;
}
// 获得数组元素最小值对应的下标
public static int getMinIndex(long dis[]){
int index = -1;
long min = Integer.MAX_VALUE;
for(int i = 0; i < 10;i ++){
if(dis[i] < min){
index = i;
min = dis[i];
}
}
return index;
}
// 计算均值
public static float getAverage(ArrayList<Integer> arr,int x,int y){
float ret = 0;
for(int i = 0; i < arr.size(); i ++){
ret += num[arr.get(i)][x][y];// 将同一类中所有相同位置元素相加
}
return ret / arr.size();
}
// 保存聚类中心点
public static void saveKind(int index){
FileWriter out = null;
String path = "D:\\java\\workSpace\\KMeans\\" + index + "kinds.txt";
File file = new File(path);
try {
out = new FileWriter(file);
//二维数组按行存入到文件中
for (int i = 0; i < center[index].length; i++) {
for (int j = 0; j < center[index][i].length; j++) {
//将每个元素转换为字符串
String content = String.valueOf(center[index][i][j]) + " ";
out.write(content + "\t");
}
out.write("\r\n");
}
out.close();
} catch (IOException e) {
e.printStackTrace();
}
}
// 是否相等
public static boolean isSame(){
for(int i = 0; i < 10 ;i ++){
for(int j = 0; j < newKinds[i].size();j ++){
if(newKinds[i].size() != oldKinds[i].size()) return false;
if (newKinds[i].get(j).intValue() != oldKinds[i].get(j).intValue() ) {
return false;
}
}
}
return true;
}
}
测试聚类中心的代码
import java.io.*;
import java.util.ArrayList;
public class myKMeansTest {
static float[][][] kMeans = new float[10][28][28];
static float[][][] test = new float[8000][28][28];// 测试数据,每个数字有800张
static long[][] distance = new long[8000][10];// 每张图片聚类类中心的距离
static ArrayList<Integer>[] kinds = new ArrayList[10];// 每个类中包含的图片索引
public static void main(String[] args) throws IOException {
System.out.println("-----获取文件中-----");
// 读取聚类中心文件
for(int i = 0; i < 10;i ++){
String img = "D:\\java\\workSpace\\KMeans\\" + i + "kinds.txt";
getKMeansTxt(img,i);
}
// 读取测试文件
for(int i = 0;i < 8000;i ++){
String img = "D:\\Python\\jupyter\\test\\" + i/800 + "-" + (i%800 + 1) + ".txt";
getTestTxt(img,i,0,0);
if(i % 800 == 0) System.out.println("已导入数据:"+i);
}
System.out.println("获取文件成功!!");
// 进行测试
System.out.println("开始聚类……");
for(int i = 0; i < 10;i ++){
kinds[i] = new ArrayList<>();
}
for(int i = 0; i < 8000;i ++){
for (int j = 0; j < 10;j ++){
distance[i][j] = GoodKMeans.getDistance(kMeans[j],test[i]);// 获得每张图片对应聚类中心的距离
}
}
for(int i= 0;i< 8000;i++){
kinds[GoodKMeans.getMinIndex(distance[i])].add(i);// 将图片归为最小距离的类中
}
System.out.println("聚类成功!!");
int[][] ans = new int[10][10];
for(int i = 0; i < 10;i ++){
for(int j = 0; j < kinds[i].size();j ++){
if(kinds[i].get(j) < 800) ans[i][0]++;
else if(kinds[i].get(j) >= 800 && kinds[i].get(j) < 1600) ans[i][1]++;
else if(kinds[i].get(j) >= 1600 && kinds[i].get(j)< 2400) ans[i][2]++;
else if(kinds[i].get(j) >= 2400 && kinds[i].get(j)< 3200) ans[i][3]++;
else if(kinds[i].get(j) >= 3200 && kinds[i].get(j)< 4000) ans[i][4]++;
else if(kinds[i].get(j) >= 4000 && kinds[i].get(j)< 4800) ans[i][5]++;
else if(kinds[i].get(j) >= 4800 && kinds[i].get(j)< 5600) ans[i][6]++;
else if(kinds[i].get(j) >= 5600 && kinds[i].get(j)< 6400) ans[i][7]++;
else if(kinds[i].get(j) >= 6400 && kinds[i].get(j)< 7200) ans[i][8]++;
else if(kinds[i].get(j) >= 7200 && kinds[i].get(j)< 8000) ans[i][9]++;
}
}
for (int i = 0; i < 10;i ++){
System.out.print("第"+i+"类中:");
for (int j = 0; j < 10;j ++){
System.out.print(j+":");
System.out.printf("%3d",ans[i][j]);
System.out.print("\t");
}
System.out.println();
}
}
// 获得聚类中心文件
public static void getKMeansTxt(String img,int index) throws IOException {
File file = new File(img);
FileInputStream fis = new FileInputStream(file);
InputStreamReader isr = new InputStreamReader(fis);
BufferedReader br = new BufferedReader(isr);
int x = 0;
int y = 0;
String line;
while((line = br.readLine()) != null){
boolean isNum = false;
for(int i = 0;i < line.length();i ++){
if(line.charAt(i)-'0' <10 && line.charAt(i)-'0' >=0 && !isNum){
// 如果遇到数字
isNum = true;
// 取数字
int j = i + 1;
while(j < line.length() && line.charAt(j) != ' '){
j++;
}
isNum = false;
if(y < 28){
}
else{
y = 0;
x ++;
}
kMeans[index][x][y] = Float.valueOf(line.substring(i,j)).floatValue();
i = j;
y++;
}
}
}
br.close();
}
// 获得测试文件
public static void getTestTxt(String path,int img,int x,int y) throws IOException {
File file = new File(path);
FileInputStream fis = new FileInputStream(file);
InputStreamReader isr = new InputStreamReader(fis);
BufferedReader br = new BufferedReader(isr);
String line;
while((line = br.readLine()) != null){
boolean isNum = false;
for(int i = 0;i < line.length();i ++){
if(line.charAt(i) != ' ' && !isNum){
// 如果遇到数字
isNum = true;
float tempNum = 0;
// 取数字
while(i < line.length() && line.charAt(i) != ' '){
tempNum = tempNum * 10 + line.charAt(i) - '0';
i++;
}
isNum = false;
if(y < 28){
}
else{
y = 0;
x ++;
}
test[img][x][y] = tempNum;
y++;
}
}
}
br.close();
}
}
标签:10,java,int,28,KMeans,++,聚类,new 来源: https://blog.csdn.net/m0_51418456/article/details/123601187