编程语言
首页 > 编程语言> > [机器学习]三角不等式加速K均值聚类及C++实现

[机器学习]三角不等式加速K均值聚类及C++实现

作者:互联网

本博客涉及代码可在GitHub下载:传送门

K均值聚类

K均值聚类是常用的欧式距离聚类算法,即认为两个目标向量的差的模长越小,两个目标越可能是一类的。

通俗理解:牧师-村民模型

有四个牧师去郊区布道,一开始牧师们随意选了几个布道点,并且把这几个布道点的情况公告给了郊区所有的村民,于是每个村民到离自己家最近的布道点去听课。听课之后,大家觉得距离太远了,于是每个牧师统计了一下自己的课上所有的村民的地址,搬到了所有地址的中心地带,并且在海报上更新了自己的布道点的位置。牧师每一次移动不可能离所有人都更近,有的人发现A牧师移动以后自己还不如去B牧师处听课更近,于是每个村民又去了离自己最近的布道点……就这样,牧师每个礼拜更新自己的位置,村民根据自己的情况选择布道点,最终稳定了下来。

可以发现该牧师的目的是为了让每个村民到其最近中心点的距离和最小。

数学模型

参考K-Means聚类算法原理

收敛性证明

参考k-means聚类算法的收敛性证明与应用

三角不等式加速

定理1及优化

定理1:x是一个数据点,b和c是中点,如果 \(|bc|\geq 2|xb|\) 则有 \(|xc|\geq |xb|\)
使用三角不等式(两边之和大于第三边)即可证明。有了这条性质,我们如此考虑,若数据点x此时锚定了中点b,且我们知道 \(|xb|\) 或 \(|xb|\) 的一个上界,而 \(|bc|\geq 2|xb|\),那么我们不用考虑x点会从b点转移到c点去,节约了计算 \(|xc|\) 的时间。

定理2及优化

定理2:x是一个数据点,b和c是中点,有 \(|xc|\geq max\{0,|xb|-|bc|\}\)。
使用三角不等式(两边之和大于第三边)即可证明。
对于锚定了中点b的数据点x,考虑是否转移到另一个中点c,若我们知道 \(|xb|\) 的上界和 \(|xc|\) 的下界,且上界小于这个下界,那一定不考虑。在算法开始的时候上界和下界可以直接赋值为距离,在后续的迭代中,需要保证上下界的合法。

我们不维护每一对点的距离的上界,只维护一个数据点到它的锚定点的距离的上界u(x)。一开始数据点到锚定点的距离是确定的,上界也确定,若该点的锚定点发生了位移,根据定理1则 u(x)+=dis(m(c(x)),c(x))m(c)表示c位移后的位置(代码中为mean),c(x)表示x数据点点锚定的中点。同时当我们计算x到它锚定点的距离的时候,我们顺手更新一下这个上界为x到它当前的锚定点的距离,让它不会一直增大以至于算法后期失去约束能力。同时,我们可以记录一下这个点它的上界是否仍然是c到x的距离,如果是的话,我们又能省去一次计算距离。

我们维护每一个数据点x到中点c的距离的下界l(x,c),一开始赋值为距离,迭代的时候,根据定理2,l(x,c)=max{l(x,c)-dis(c,m(c)),0}

伪代码

C++实现思路

KMean

Point

某种意义上是个抽象类

Data_point: 继承Point

Center: 继承Point

代码

头文件

#include <cmath>
#include <cstdio>
#include <ctime>
#include <iostream>
#include <vector>
using namespace std;    

KMeans

class KMeans{
    class Point;
    class Data_point;
    class Center;
public:
    vector<int> result;
    int iterator_times;
    KMeans(vector<vector<double>> dataset, vector<int> label, int cluster_number){
        data_size = dataset.size();
        dimension = dataset[0].size();
        cluster = cluster_number;
        for (auto it = dataset.begin(); it != dataset.end();it++){
            vector<double> pos;
            for (int i = 0; i < dimension;i++)
                pos.push_back((*it)[i]);
            point.push_back(Data_point(pos));
        }
        for (auto it = label.begin(); it != label.end();it++)
            this->label.push_back(*it);
        //output(); 
        iterator_times = 0;
        mainThread();
       
    }
private:
    int data_size, dimension, cluster;
    vector<Data_point> point;
    vector<Center> center;
    vector<int> label;

point

    class Point{
    public:
        vector<double> pos;
        void build(vector<double> arr){
            pos.clear();
            for (auto it = arr.begin(); it != arr.end();it++)
                pos.push_back(*it);
        }
        void build(Point &rhs){
            pos.clear();
            for (auto it = rhs.pos.begin(); it != rhs.pos.end();it++)
                pos.push_back(*it);
        }
        void output(){
            cout << "point data:";
            for (auto it = pos.begin(); it != pos.end();it++)
                cout << (*it) << " ";
            cout << endl;
        }
    private:
        void clear(){
            pos.clear();
        }
    };

Data_point

    class Data_point : public Point{
    public:
        int center;
        vector<double> lowerBound;
        double upperBound;
        bool uOutofDate, transed;
        Data_point(vector<double> arr){
            this->build(arr);
            center = 0;
            uOutofDate = transed = false;
        }
    };

Center

    class Center : public Point{
    public:
        vector<double> mean;
        int cnt;
        vector<double> disc;
        double s;//s(c)=1/2 min(c,c')(c!=c')
        Center(Data_point &rhs){
            this->build(rhs);
            cnt = 0;
        }

        void trans(){
            int len = mean.size();
            for (int i = 0; i < len;i++){
                pos[i]=mean[i];
                mean[i] = 0;
            }
            cnt = 0;
        }

        void add_mean(Data_point &data){
            int len = mean.size();
            for (int i = 0; i < len;i++)
                mean[i] += data.pos[i];
            cnt++;
        }

        void cal_mean(){
            int len = mean.size();
            for (int i = 0; i < len;i++)
                mean[i] /= cnt;
        }

        void output(){
            int len = pos.size();
            cout << "center position:";
            for (int i = 0; i < pos.size();i++)
                cout << pos[i] << " ";
            cout << endl;
        }
    };

KMeans内成员函数


    double cal_dis(Point &a, Point &b){
        double dis = 0;
        for (int i = 1; i <= this->dimension; i++){
            dis += (a.pos[i] - b.pos[i]) * (a.pos[i] - b.pos[i]);
        }
        return sqrt(dis);
    }
    
    double cal_dis(Point &a, vector<double> &b){
        double dis = 0;
        for (int i = 1; i <= this->dimension; i++){
            dis += (a.pos[i] - b[i]) * (a.pos[i] - b[i]);
        }
        return sqrt(dis);
    }

    void output(){
        cout << "Data size: " << data_size << endl;
        cout << "Dimension: " << dimension << endl;
        cout << "Data:" << endl;
        // for (auto it = point.begin(); it != point.end();it++)
        //     it->output();
    }

    //Calculate distance between centers and s(c)
    void cal_disc(){
        for (int c1 = 0; c1 < cluster;c1++)
            for (int c2 = c1+1; c2 < cluster;c2++){
                double dis = cal_dis(center[c1], center[c2]);
                center[c1].disc[c2] = dis;
                center[c2].disc[c1] = dis;
            }
        for (int c1 = 0; c1 < cluster;c1++){
            double s = center[c1].disc[(c1 + 1) % cluster];
            for (int c2 = 0; c2 < cluster;c2++){
                if(c1==c2)
                    continue;
                double s2 = center[c1].disc[c2];
                if(s2<s)
                    s = s2;
            }
            center[c1].s = s / 2;
        }
    }

    void init(){
        //get initial centers
        vector<int> num;
        for (int i = 0; i < data_size;i++)
            num.push_back(i);
        srand(time(NULL));
        for (int i = 0; i < data_size; i++)
            swap(num[i],num[rand() % data_size]);
        for (int i = 0; i < cluster; i++)
            center.push_back(Center(point[num[i]]));
        num.clear();

        //Initial the center
        for (auto c = center.begin(); c != center.end();c++){
            for (int i = 0; i < cluster;i++)
                c->disc.push_back(0);
            for (int i = 0; i < dimension;i++)
                c->mean.push_back(0);
        }
        cal_disc();
        //Calculate initial c(x) and l(x,c) and u(x)
        for (auto x = point.begin(); x != point.end();x++){
            for (int i = 0; i < cluster;i++)
                x->lowerBound.push_back(0);
            double dis = cal_dis(*x, center[0]);
            x->center = 0;
            int len = center.size();
            for (int c = 1; c < len;c++){
                if(dis<=center[x->center].disc[c])
                    continue;
                double dis2 = cal_dis(*x, center[c]);
                x->lowerBound[c] = dis2;
                if(dis2<dis){
                    x->center = c;
                    dis = dis2;
                }
            }
            x->upperBound = dis;
            center[x->center].add_mean(*x);
        }
        for (auto c = center.begin(); c != center.end();c++)
            c->cal_mean();

        //Move the centers
        // cout << "first time center:" << endl;
        // for (auto c = center.begin(); c != center.end();c++)
        //     c->output();

        for (auto c = center.begin(); c != center.end();c++)
            c->trans();

        // cout << "second time center:" << endl;
        // for (auto c = center.begin(); c != center.end();c++)
        //     c->output();
    }   

    //Repeat steps
    bool repeat(){
        bool flag = false;
        cal_disc();
        for (auto x = point.begin(); x != point.end();x++){
            x->transed = false;
            double disc;
            //Ignore points which will not trans
            if(x->upperBound<=center[x->center].s)
                continue;
            //Update the upperBound
            if(x->uOutofDate){
                x->uOutofDate=false;
                disc = cal_dis(*x, center[x->center]);
                x->upperBound = disc;
            }else
                disc = x->upperBound;
            //trans points to closer centers
            for (int c = 0; c < cluster;c++){
                if(disc>x->lowerBound[x->center]||disc>0.5*center[x->center].disc[c]){
                    double disc2 = cal_dis(*x, center[c]);
                    //cout << "iter note:"<<disc<<" "<<disc2 << endl;
                    if(disc2<disc){
                        //cout << "yes" << endl;
                        //cout << "trans note:" << disc2 << " " << disc << endl;
                        x->center = c;
                        x->transed = true;
                        flag = true;
                    }
                }
            }
            //Add to the center
            center[x->center].add_mean(*x);
        }
        //Calculate the mean of centers
        for (auto c = center.begin(); c!=center.end();c++)
            c->cal_mean();
        
        //Update the lowerBound and upperBound
        for (auto x = point.begin(); x != point.end();x++){
            if(!x->transed)
                continue;
            for (int c = 0; c < cluster;c++)
                x->lowerBound[c] = max(x->lowerBound[c] - cal_dis(center[c], center[c].mean), 0.0);
            x->upperBound += cal_dis(center[x->center], center[x->center].mean);
            x->uOutofDate = true;
        }

        //Move the centers
        for (auto c = center.begin(); c != center.end();c++)
            c->trans();

        iterator_times++;
        return flag;
    }
    //Get result and save to `result`
    void getResult(){
        for (auto x = point.begin(); x != point.end();x++)
            result.push_back((*x).center);
    }

    void mainThread(){
        init();
        while(true){
            if(!repeat())
                break;
        }
        getResult();
    }   
    
};

测试代码

int main(){
    // Get the data
    freopen("seeds_dataset.txt", "r", stdin);
    vector<vector<double>> dataset;
    vector<int> label;
    double tmp;
    int tmp2;
    for (int i = 1; i <= 210; i++){
        vector<double> point;
        for (int j = 1; j <= 7; j++){
            cin >> tmp;
            point.push_back(tmp);
        }
        cin >> tmp2;
        label.push_back(tmp2);
        dataset.push_back(point);
    }
    //KMeans(vector<vector<double>> dataset, vector<int> label, int number of cluster)
    KMeans model(dataset, label, 3);

    cout << "result:" << endl;
    for (auto it = model.result.begin(); it != model.result.end();it++)
        cout << (*it) << " ";
    cout << endl;

    cout << "iteration times: " << model.iterator_times << endl;
}

标签:center,不等式,point,int,C++,++,vector,聚类,dis
来源: https://www.cnblogs.com/sherrlock/p/16265212.html