编程语言
首页 > 编程语言> > java手撕KMeans算法实现手写数字聚类(失败案例)

java手撕KMeans算法实现手写数字聚类(失败案例)

作者:互联网

最近几天刚刚接触机器学习,学完K-Means聚类算法。正好又赶上一个课程项目是识别“手写数字”,因为KMeans能够实现聚类,因此自然而然地想要通过KMeans来实现。

前排提示:这是kmeans聚类的一个失败案例,没有成功聚类,仅供参考。

一,什么是KMeans聚类算法??

非常传统的聚类算法,目的是将一堆数据进行分类。

它的思想很朴素:假设这里有一群点,要将这些点分成两类。要是分成的类很合理的话,那不同类之间的中心点相聚是不是应该足够大,中心点附近的同一类的点是不是应该足够多?

举个例子:

a表示的是一堆原始点,没有处理。要将a聚类成两类,先随便找到两个点,计算所有点到这两个点的距离(欧式距离,曼哈顿距离,闵式距离等等都可以),根据距离最近的原则分配成两类。这时候是不是就能够得到两类的中心点,然后再次重复操作,直到最后聚出来的类不会发生变化。

so easy 是不是

二,使用的手写数字测试集??

我们在这里使用的是mnist测试集。这家伙的知名程度在机器学习中相当于是hello world了。不知道的小伙伴可以去查查。

但是一定有人会问到,mnist测试集应该怎么通过java使用呢?

不用担心,我用Python通过TensorFlow将mnist测试集打包成了txt文件,用java的文件操作直接调用就可以了。

具体效果像这样:

 这是28 * 28的二维int数组,每个值介于0到255之间,熟悉图像处理的小伙伴一定知道这是灰度值,0表示最黑,255表示最亮,因此这是黑纸白字的测试集,大家要是自己写测试数据的使用要记着对图片进行预处理,要不然可能会出错。

我将txt命名为:数字名-标号的形式,方便之后训练和测试。

 三,java手撕KMeans算法

先摆上一个算法流程图

 1.首先定义:

           训练图片(50000 * 28 * 28 的三维数组)

           聚类中心(10 * 28 * 28的三维数组)

           每张图片到聚类中心的距离(50000 * 10 的二维数组)

           旧的类和新的类(ArrayList[] 数组,因为不知道一个类中到底会有多少个图片)

    static float[][][] num = new float[50000][28][28];
    static float[][][] center = new float[10][28][28];// 聚类中心
    static long[][] distance = new long[num.length][10];
    static ArrayList<Integer>[] oldKinds = new ArrayList[10];// 旧的聚类
    static ArrayList<Integer>[] newKinds = new ArrayList[10];

 2.定义方法:

        从Txt文件导入测试数据的方法

public static void getTXT(String path,int img,int x,int y) throws IOException {
        File file = new File(path);
        FileInputStream fis = new FileInputStream(file);
        InputStreamReader isr = new InputStreamReader(fis);
        BufferedReader br = new BufferedReader(isr);

        String line;
        while((line = br.readLine()) != null){
            boolean isNum = false;
            for(int i = 0;i < line.length();i ++){
                if(line.charAt(i) != ' ' && !isNum){
                    // 如果遇到数字
                    isNum = true;
                    float tempNum = 0;
                    // 取数字
                    while(i < line.length() && line.charAt(i) != ' '){
                        tempNum = tempNum * 10 + line.charAt(i) - '0';
                        i++;
                    }
                    isNum = false;
                    if(y < 28){
                    }
                    else{
                        y = 0;
                        x ++;
                    }
                    num[img][x][y] = tempNum;
                    y++;
                }
            }
        }
        br.close();
    }

        获得图片到聚类中心距离的方法

    // 得到距离
    public static long getDistance(float[][] n,float[][] k){
        long ret = 0;
        for (int i = 0; i < 28;i ++){
            for (int j = 0; j < 28; j ++){
                ret += Math.pow((n[i][j] - k[i][j]),2);
            }
        }
        return ret;
    }

        得到图片距离最近聚类中心索引的方法

    // 获得数组元素最小值对应的下标
    public static int getMinIndex(long dis[]){
        int index = -1;
        long min = Integer.MAX_VALUE;
        for(int i = 0; i < 10;i ++){
            if(dis[i] < min){
                index = i;
                min = dis[i];
            }
        }
        return index;
    }

        比较旧的聚类和新的聚类是否相同的方法

    public static boolean isSame(){
        for(int i = 0; i < 10 ;i ++){
            for(int j = 0; j < newKinds[i].size();j ++){
                if(newKinds[i].size() != oldKinds[i].size()) return false;
                if (newKinds[i].get(j).intValue() != oldKinds[i].get(j).intValue() ) {
                    return false;
                }
            }
        }
        return true;
    }

需要注意的是!!!

两个Integer的比较需要通过.intValue()的方法先转换成为int!!!再进行比较,否则会因为内存什么什么奇奇怪怪的原因导致出现130 != 130这种很天真的错误。

我在这里被坑了一次,希望看到这片文章的人能够避一下坑。

3.开始while(true)死循环,直到旧类和新类相等不发生改变

        int kindTime = 0;
        while(true){
            // 3.计算每个文件和当前类中心之间的距离
            for (int i = 0; i < num.length; i++){
                for (int j = 0; j < 10; j++){
                    distance[i][j] = getDistance(num[i],center[j]);
                }
            }
            // 更新旧类
            for(int i = 0;i < 10;i ++){
                oldKinds[i].clear();
                for(int j = 0 ; j < newKinds[i].size();j ++){
                    oldKinds[i].add(newKinds[i].get(j));
                }
            }
            // 更新新类
            for (int i = 0; i < 10 ; i ++){
                newKinds[i].clear();
            }
            for (int i = 0; i < num.length; i ++){
                // 获得距离最小值,将其放到对应的类中
                newKinds[getMinIndex(distance[i])].add(i);
            }
            // 4.更新聚类中心
            for(int i = 0; i < 10; i ++){
                for(int x = 0; x < 28; x++){
                    for(int y = 0; y < 28;y ++){
                        center[i][x][y] = getAverage(newKinds[i],x,y);
                    }
                }
            }
            // 5.重复步骤,直到类不再发生改变
            if(isSame()){
                break;
            }
            System.out.println("第"+kindTime+"次聚类");
            kindTime++;
        }

4.保存类中心点

因为如果训练数据不变的话,聚类聚出的中心是不会变化的,所以为了避免之后聚类的重复操作,我们还是将得到的聚类中心点保存成为txt文件放到电脑上比较好。

    // 保存聚类中心点
    public static void saveKind(int index){
        FileWriter out = null;
        String path = "D:\\java\\workSpace\\KMeans\\" + index + "kinds.txt";
        File file = new File(path);
        try {
            out = new FileWriter(file);
            //二维数组按行存入到文件中
            for (int i = 0; i < center[index].length; i++) {
                for (int j = 0; j < center[index][i].length; j++) {
                    //将每个元素转换为字符串
                    String content = String.valueOf(center[index][i][j]) + " ";
                    out.write(content + "\t");
                }
                out.write("\r\n");
            }
            out.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

到现在,所有kmeans要求的操作我们都已经实现了。我们看看效果怎么样吧

1.我从test测试集(刚刚是train训练集)中导入了8000张图片,0到9每个数字各800张。

导入的方式和上文中的相同,这里就不在赘述了。

然后通过刚刚聚出来的类中心对测试数据进行聚类。(因为kmeans是无监督聚类吗,所以我也不知道每个类中心代表的哪个数字)

这是最后聚出来的结果:

发现大问题!!!我将每个类聚到的数字分别列出来。比如第0类,聚到4个数字0,3个数字1……

最后得到的结果,很!不!理!想!

 通过分析可以看到,数字1的聚类效果最好,800张图片中有787张被聚到第7类中了,但是第7类也混入了不少其他数字,还有129张2是什么鬼?!

其他的类就更不用说了,混杂了很多数字。

经过缜密思考之后,我认为是k的数值设置的问题,因为我们想要聚类出10个数字,所以很主观地将k设置成为了10,没有思考相同数字,因为书写原因而出现的数字内部聚类的问题。

就像数字0,分别被聚到了第1类和第4类中,这两类很少有其他数字。因此是将数字0进行了分类,把高的0矮的0胖的0瘦的0分开了!而不是将0之外的数字分开。

或许可以通过改变k的值进行改进呢!

这片文章才差不多就是这样了。最后贴上代码。

如果有朋友想要mnist手写数字数据集的txt文件,可以给我留言邮箱信息哦,我抽时间会发送的。

欢迎大佬们批评指正!

// 首先是kmeans聚类的代码
import java.io.*;
import java.util.ArrayList;

public class KMeans {
    // KMeans算法实现手写数字聚类
    static float[][][] num = new float[50000][28][28];
    static float[][][] center = new float[10][28][28];// 聚类中心
    static long[][] distance = new long[num.length][10];
    static ArrayList<Integer>[] oldKinds = new ArrayList[10];// 旧的聚类
    static ArrayList<Integer>[] newKinds = new ArrayList[10];

    public static void main(String[] args) throws IOException {
        // 1.读取文件
        System.out.println("导入文件中……");
        for (int i = 0;i < num.length;i ++){
            getTXT("D:\\Python\\jupyter\\trains2\\" + Integer.toString(i/5000) + "-" + Integer.toString(i%5000 + 1) + ".txt",i,0,0);
            if(i % 1000 == 0) System.out.println("已导入文件:" + i);
        }
        System.out.println("导入文件成功!!!");
        // 随机选择聚类中心
        for(int i = 0; i < 10; i ++){
            oldKinds[i] = new ArrayList<>();
        }
        for(int i = 0 ; i < 10;i ++) {
            transTwoArray(num[i], center[i]);
            newKinds[i] = new ArrayList<>();
            newKinds[i].add(i);
        }

        int kindTime = 0;
        while(true){
            // 3.计算每个文件和当前类中心之间的距离
            for (int i = 0; i < num.length; i++){
                for (int j = 0; j < 10; j++){
                    distance[i][j] = getDistance(num[i],center[j]);
                }
            }
            // 更新旧类
            for(int i = 0;i < 10;i ++){
                oldKinds[i].clear();
                for(int j = 0 ; j < newKinds[i].size();j ++){
                    oldKinds[i].add(newKinds[i].get(j));
                }
            }
            // 更新新类
            for (int i = 0; i < 10 ; i ++){
                newKinds[i].clear();
            }
            for (int i = 0; i < num.length; i ++){
                // 获得距离最小值,将其放到对应的类中
                newKinds[getMinIndex(distance[i])].add(i);
            }
            // 4.更新聚类中心
            for(int i = 0; i < 10; i ++){
                for(int x = 0; x < 28; x++){
                    for(int y = 0; y < 28;y ++){
                        center[i][x][y] = getAverage(newKinds[i],x,y);
                    }
                }
            }
            // 5.重复步骤,直到类不再发生改变
            if(isSame()){
                break;
            }
            System.out.println("第"+kindTime+"次聚类");
            kindTime++;
        }
        // 保存聚类中心
        System.out.println("聚类成功!!!");
        System.out.println("-------------------------");
        System.out.println("保存类中心点中……");
        for(int i = 0; i < 10;i ++){
            saveKind(i);
        }
        System.out.println("保存类中心点成功!!!");
    }


    // 读取文件
    public static void getTXT(String path,int img,int x,int y) throws IOException {
        File file = new File(path);
        FileInputStream fis = new FileInputStream(file);
        InputStreamReader isr = new InputStreamReader(fis);
        BufferedReader br = new BufferedReader(isr);

        String line;
        while((line = br.readLine()) != null){
            boolean isNum = false;
            for(int i = 0;i < line.length();i ++){
                if(line.charAt(i) != ' ' && !isNum){
                    // 如果遇到数字
                    isNum = true;
                    float tempNum = 0;
                    // 取数字
                    while(i < line.length() && line.charAt(i) != ' '){
                        tempNum = tempNum * 10 + line.charAt(i) - '0';
                        i++;
                    }
                    isNum = false;
                    if(y < 28){
                    }
                    else{
                        y = 0;
                        x ++;
                    }
                    num[img][x][y] = tempNum;
                    y++;
                }
            }
        }
        br.close();
    }
    // 转移两个数组
    public static void transTwoArray(float[][] array1,float[][] array2){
        for(int i = 0; i < 28;i ++){
            for (int j = 0; j < 28;j ++){
                array2[i][j] = array1[i][j];
            }
        }
    }
    // 得到距离
    public static long getDistance(float[][] n,float[][] k){
        long ret = 0;
        for (int i = 0; i < 28;i ++){
            for (int j = 0; j < 28; j ++){
                ret += Math.pow((n[i][j] - k[i][j]),2);
            }
        }
        return ret;
    }
    // 获得数组元素最小值对应的下标
    public static int getMinIndex(long dis[]){
        int index = -1;
        long min = Integer.MAX_VALUE;
        for(int i = 0; i < 10;i ++){
            if(dis[i] < min){
                index = i;
                min = dis[i];
            }
        }
        return index;
    }
    // 计算均值
    public static float getAverage(ArrayList<Integer> arr,int x,int y){
        float ret = 0;
        for(int i = 0; i < arr.size(); i ++){
            ret += num[arr.get(i)][x][y];// 将同一类中所有相同位置元素相加
        }
        return ret / arr.size();
    }
    // 保存聚类中心点
    public static void saveKind(int index){
        FileWriter out = null;
        String path = "D:\\java\\workSpace\\KMeans\\" + index + "kinds.txt";
        File file = new File(path);
        try {
            out = new FileWriter(file);
            //二维数组按行存入到文件中
            for (int i = 0; i < center[index].length; i++) {
                for (int j = 0; j < center[index][i].length; j++) {
                    //将每个元素转换为字符串
                    String content = String.valueOf(center[index][i][j]) + " ";
                    out.write(content + "\t");
                }
                out.write("\r\n");
            }
            out.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }
    // 是否相等
    public static boolean isSame(){
        for(int i = 0; i < 10 ;i ++){
            for(int j = 0; j < newKinds[i].size();j ++){
                if(newKinds[i].size() != oldKinds[i].size()) return false;
                if (newKinds[i].get(j).intValue() != oldKinds[i].get(j).intValue() ) {
                    return false;
                }
            }
        }
        return true;
    }
}

测试聚类中心的代码

import java.io.*;
import java.util.ArrayList;

public class myKMeansTest {
    static float[][][] kMeans = new float[10][28][28];
    static float[][][] test = new float[8000][28][28];// 测试数据,每个数字有800张
    static long[][] distance = new long[8000][10];// 每张图片聚类类中心的距离
    static ArrayList<Integer>[] kinds = new ArrayList[10];// 每个类中包含的图片索引

    public static void main(String[] args) throws IOException {
        System.out.println("-----获取文件中-----");
        // 读取聚类中心文件
        for(int i = 0; i < 10;i ++){
            String img = "D:\\java\\workSpace\\KMeans\\" + i + "kinds.txt";
            getKMeansTxt(img,i);
        }
        // 读取测试文件
        for(int i = 0;i < 8000;i ++){
            String img = "D:\\Python\\jupyter\\test\\" + i/800 + "-" + (i%800 + 1) + ".txt";
            getTestTxt(img,i,0,0);
            if(i % 800 == 0) System.out.println("已导入数据:"+i);
        }
        System.out.println("获取文件成功!!");
        // 进行测试
        System.out.println("开始聚类……");
        for(int i = 0; i < 10;i ++){
            kinds[i] = new ArrayList<>();
        }
        for(int i = 0; i < 8000;i ++){
            for (int j = 0; j < 10;j ++){
                distance[i][j] = GoodKMeans.getDistance(kMeans[j],test[i]);// 获得每张图片对应聚类中心的距离
            }
        }
        for(int i= 0;i< 8000;i++){
            kinds[GoodKMeans.getMinIndex(distance[i])].add(i);// 将图片归为最小距离的类中
        }
        System.out.println("聚类成功!!");

        int[][] ans = new int[10][10];
        for(int i = 0; i < 10;i ++){
            for(int j = 0; j < kinds[i].size();j ++){
                if(kinds[i].get(j) < 800) ans[i][0]++;
                else if(kinds[i].get(j) >= 800 && kinds[i].get(j) < 1600) ans[i][1]++;
                else if(kinds[i].get(j) >= 1600 && kinds[i].get(j)< 2400) ans[i][2]++;
                else if(kinds[i].get(j) >= 2400 && kinds[i].get(j)< 3200) ans[i][3]++;
                else if(kinds[i].get(j) >= 3200 && kinds[i].get(j)< 4000) ans[i][4]++;
                else if(kinds[i].get(j) >= 4000 && kinds[i].get(j)< 4800) ans[i][5]++;
                else if(kinds[i].get(j) >= 4800 && kinds[i].get(j)< 5600) ans[i][6]++;
                else if(kinds[i].get(j) >= 5600 && kinds[i].get(j)< 6400) ans[i][7]++;
                else if(kinds[i].get(j) >= 6400 && kinds[i].get(j)< 7200) ans[i][8]++;
                else if(kinds[i].get(j) >= 7200 && kinds[i].get(j)< 8000) ans[i][9]++;
            }
        }
        for (int i = 0; i < 10;i ++){
            System.out.print("第"+i+"类中:");
            for (int j = 0; j < 10;j ++){
                System.out.print(j+":");
                System.out.printf("%3d",ans[i][j]);
                System.out.print("\t");
            }
            System.out.println();
        }
    }

    // 获得聚类中心文件
    public static void getKMeansTxt(String img,int index) throws IOException {
        File file = new File(img);
        FileInputStream fis = new FileInputStream(file);
        InputStreamReader isr = new InputStreamReader(fis);
        BufferedReader br = new BufferedReader(isr);

        int x = 0;
        int y = 0;
        String line;
        while((line = br.readLine()) != null){
            boolean isNum = false;
            for(int i = 0;i < line.length();i ++){
                if(line.charAt(i)-'0' <10 && line.charAt(i)-'0' >=0 && !isNum){
                    // 如果遇到数字
                    isNum = true;
                    // 取数字
                    int j = i + 1;
                    while(j < line.length() && line.charAt(j) != ' '){
                        j++;
                    }
                    isNum = false;
                    if(y < 28){
                    }
                    else{
                        y = 0;
                        x ++;
                    }
                    kMeans[index][x][y] = Float.valueOf(line.substring(i,j)).floatValue();
                    i = j;
                    y++;
                }
            }
        }
        br.close();
    }
    // 获得测试文件
    public static void getTestTxt(String path,int img,int x,int y) throws IOException {
        File file = new File(path);
        FileInputStream fis = new FileInputStream(file);
        InputStreamReader isr = new InputStreamReader(fis);
        BufferedReader br = new BufferedReader(isr);

        String line;
        while((line = br.readLine()) != null){
            boolean isNum = false;
            for(int i = 0;i < line.length();i ++){
                if(line.charAt(i) != ' ' && !isNum){
                    // 如果遇到数字
                    isNum = true;
                    float tempNum = 0;
                    // 取数字
                    while(i < line.length() && line.charAt(i) != ' '){
                        tempNum = tempNum * 10 + line.charAt(i) - '0';
                        i++;
                    }
                    isNum = false;
                    if(y < 28){
                    }
                    else{
                        y = 0;
                        x ++;
                    }
                    test[img][x][y] = tempNum;
                    y++;
                }
            }
        }
        br.close();
    }
}

标签:10,java,int,28,KMeans,++,聚类,new
来源: https://blog.csdn.net/m0_51418456/article/details/123601187