其他分享
首页 > 其他分享> > bayes-tan

bayes-tan

作者:互联网

#ifndef TAN_RAND_H
#define TAN_RAND_H

#include "incrementalLearner.h"
#include "xxyDist.h"
#include <limits>

class tan_rand: public IncrementalLearner {
public:
	tan_rand();
	tan_rand(char* const *& argv, char* const * end);
	~tan_rand(void);

	void reset(InstanceStream &is);   ///< reset the learner prior to training
	void initialisePass(); ///< must be called to initialise a pass through an instance stream before calling train(const instance). should not be used with train(InstanceStream)
	void train(const instance &inst); ///< primary training method. train from a single instance. used in conjunction with initialisePass and finalisePass
	void finalisePass(); ///< must be called to finalise a pass through an instance stream using train(const instance). should not be used with train(InstanceStream)
	bool trainingIsFinished(); ///< true iff no more passes are required. updated by finalisePass()
	void getCapabilities(capabilities &c);

	virtual void classify(const instance &inst, std::vector<double> &classDist);

private:
	unsigned int noCatAtts_;          ///< the number of categorical attributes.
	unsigned int noClasses_;                          ///< the number of classes

	InstanceStream* instanceStream_;
	std::vector<CategoricalAttribute> parents_;
	xxyDist xxyDist_;

	bool trainingIsFinished_; ///< true iff the learner is trained

	const static CategoricalAttribute NOPARENT = 0xFFFFFFFFUL; //使用printf("%d",0xFFFFFFFFUL);输出是-1 cannot use std::numeric_limits<categoricalAttribute>::max() because some compilers will not allow it here
};


#endif // TAN_RAND_H
#include "tan_rand.h"
#include "utils.h"
#include "correlationMeasures.h"
#include <assert.h>
#include <math.h>
#include <set>
#include <stdlib.h>
#include <queue>

struct node
{
    CategoricalAttribute x, fa;
    CatValue val;

    bool operator <(const node &v) const{
        return val < v.val;
    }
};

tan_rand::tan_rand(char* const *&, char* const *)
{
    name_ = "tan_rand";
    //ctor
}

tan_rand::~tan_rand()
{
    //dtor
}

void tan_rand::getCapabilities(capabilities &c)
{
    c.setCatAtts(true); // only categorical attributes are supported at the moment
}

bool tan_rand::trainingIsFinished()
{
    return trainingIsFinished_ ;
}

void tan_rand::reset(InstanceStream &is)
{
    instanceStream_ = &is;
    const unsigned int noCatAtts = is.getNoCatAtts();
    noCatAtts_ = noCatAtts;
    noClasses_ = is.getNoClasses();

    trainingIsFinished_ = false;

    //safeAlloc(parents, noCatAtts_);
    parents_.resize(noCatAtts);
    for (CategoricalAttribute a = 0; a < noCatAtts_; a++)
    {
        parents_[a] = NOPARENT;
    }

    xxyDist_.reset(is);
}


    void tan_rand::train(const instance &inst) // 加载数据
    {
        xxyDist_.update(inst);
    }


void tan_rand::initialisePass()
{
    assert(trainingIsFinished_ == false);
}

void tan_rand::finalisePass() {
    //printf("finalisePass\n");
    assert(trainingIsFinished_ == false);

    crosstab<float> cmi = crosstab<float>(noCatAtts_);
    getCondMutualInf(xxyDist_, cmi);

    CategoricalAttribute firstAtt = 0;
    parents_[firstAtt] = NOPARENT;

    std::vector<double>dis;
    dis.resize(noCatAtts_);
    bool vis[noCatAtts_];
    memset(vis, 0, sizeof(vis));

    std::priority_queue<node>que;

    for(CategoricalAttribute i = 1; i < noCatAtts_; i++){
        dis[i] = cmi[firstAtt][i];
        que.push({i, firstAtt, dis[i]});
    }

    vis[firstAtt] = 1;
    while(!que.empty()){
        node v = que.top(); que.pop();

        if (vis[v.x]) continue;
        vis[v.x] = true;
        parents_[v.x] = v.fa;

        for(CategoricalAttribute i = 0; i < noCatAtts_; i++){
            if (!vis[i] && cmi[v.x][i] > dis[i]){
                dis[i] = cmi[v.x][i];
                que.push({i, v.x, dis[i]});
            }
        }
    }

    trainingIsFinished_ = true;
}

void tan_rand::classify(const instance &inst, std::vector<double> &classDist)
{
    for (CatValue y = 0; y < noClasses_; y++)
    {
        classDist[y] = xxyDist_.xyCounts.p(y);
    }

    for (unsigned int x1 = 0; x1 < noCatAtts_; x1++)
    {
        const CategoricalAttribute parent = parents_[x1];

        if (parent == NOPARENT)
        {
            for (CatValue y = 0; y < noClasses_; y++)
            {
                classDist[y] *= xxyDist_.xyCounts.p(x1, inst.getCatVal(x1), y);
            }
        } else
        {
            for (CatValue y = 0; y < noClasses_; y++)
            {
                classDist[y] *= xxyDist_.p(x1, inst.getCatVal(x1), parent, inst.getCatVal(parent), y);
            }
        }
    }

    normalise(classDist);
}

  

标签:include,const,void,instance,bayes,tan,x1,xxyDist
来源: https://www.cnblogs.com/ccut-ry/p/13577980.html