bayes-tan
作者:互联网
#ifndef TAN_RAND_H #define TAN_RAND_H #include "incrementalLearner.h" #include "xxyDist.h" #include <limits> class tan_rand: public IncrementalLearner { public: tan_rand(); tan_rand(char* const *& argv, char* const * end); ~tan_rand(void); void reset(InstanceStream &is); ///< reset the learner prior to training void initialisePass(); ///< must be called to initialise a pass through an instance stream before calling train(const instance). should not be used with train(InstanceStream) void train(const instance &inst); ///< primary training method. train from a single instance. used in conjunction with initialisePass and finalisePass void finalisePass(); ///< must be called to finalise a pass through an instance stream using train(const instance). should not be used with train(InstanceStream) bool trainingIsFinished(); ///< true iff no more passes are required. updated by finalisePass() void getCapabilities(capabilities &c); virtual void classify(const instance &inst, std::vector<double> &classDist); private: unsigned int noCatAtts_; ///< the number of categorical attributes. unsigned int noClasses_; ///< the number of classes InstanceStream* instanceStream_; std::vector<CategoricalAttribute> parents_; xxyDist xxyDist_; bool trainingIsFinished_; ///< true iff the learner is trained const static CategoricalAttribute NOPARENT = 0xFFFFFFFFUL; //使用printf("%d",0xFFFFFFFFUL);输出是-1 cannot use std::numeric_limits<categoricalAttribute>::max() because some compilers will not allow it here }; #endif // TAN_RAND_H
#include "tan_rand.h" #include "utils.h" #include "correlationMeasures.h" #include <assert.h> #include <math.h> #include <set> #include <stdlib.h> #include <queue> struct node { CategoricalAttribute x, fa; CatValue val; bool operator <(const node &v) const{ return val < v.val; } }; tan_rand::tan_rand(char* const *&, char* const *) { name_ = "tan_rand"; //ctor } tan_rand::~tan_rand() { //dtor } void tan_rand::getCapabilities(capabilities &c) { c.setCatAtts(true); // only categorical attributes are supported at the moment } bool tan_rand::trainingIsFinished() { return trainingIsFinished_ ; } void tan_rand::reset(InstanceStream &is) { instanceStream_ = &is; const unsigned int noCatAtts = is.getNoCatAtts(); noCatAtts_ = noCatAtts; noClasses_ = is.getNoClasses(); trainingIsFinished_ = false; //safeAlloc(parents, noCatAtts_); parents_.resize(noCatAtts); for (CategoricalAttribute a = 0; a < noCatAtts_; a++) { parents_[a] = NOPARENT; } xxyDist_.reset(is); } void tan_rand::train(const instance &inst) // 加载数据 { xxyDist_.update(inst); } void tan_rand::initialisePass() { assert(trainingIsFinished_ == false); } void tan_rand::finalisePass() { //printf("finalisePass\n"); assert(trainingIsFinished_ == false); crosstab<float> cmi = crosstab<float>(noCatAtts_); getCondMutualInf(xxyDist_, cmi); CategoricalAttribute firstAtt = 0; parents_[firstAtt] = NOPARENT; std::vector<double>dis; dis.resize(noCatAtts_); bool vis[noCatAtts_]; memset(vis, 0, sizeof(vis)); std::priority_queue<node>que; for(CategoricalAttribute i = 1; i < noCatAtts_; i++){ dis[i] = cmi[firstAtt][i]; que.push({i, firstAtt, dis[i]}); } vis[firstAtt] = 1; while(!que.empty()){ node v = que.top(); que.pop(); if (vis[v.x]) continue; vis[v.x] = true; parents_[v.x] = v.fa; for(CategoricalAttribute i = 0; i < noCatAtts_; i++){ if (!vis[i] && cmi[v.x][i] > dis[i]){ dis[i] = cmi[v.x][i]; que.push({i, v.x, dis[i]}); } } } trainingIsFinished_ = true; } void tan_rand::classify(const instance &inst, std::vector<double> &classDist) { for (CatValue y = 0; y < noClasses_; y++) { classDist[y] = xxyDist_.xyCounts.p(y); } for (unsigned int x1 = 0; x1 < noCatAtts_; x1++) { const CategoricalAttribute parent = parents_[x1]; if (parent == NOPARENT) { for (CatValue y = 0; y < noClasses_; y++) { classDist[y] *= xxyDist_.xyCounts.p(x1, inst.getCatVal(x1), y); } } else { for (CatValue y = 0; y < noClasses_; y++) { classDist[y] *= xxyDist_.p(x1, inst.getCatVal(x1), parent, inst.getCatVal(parent), y); } } } normalise(classDist); }
标签:include,const,void,instance,bayes,tan,x1,xxyDist 来源: https://www.cnblogs.com/ccut-ry/p/13577980.html