TransE的程序实现——学习,查阅,注释
作者:互联网
代码来源:https://github.com/thunlp/KB2E/blob/master/TransE/
很久没碰C++了,为了实验成功,还是要仔仔细细抠一抠代码才行。
第一部分:头文件的引入以及常量定义
1 #include<iostream> /*输入输出流*/ 2 #include<cstring> /*C字符串操作函数*/ 3 #include<cstdio> /*标准输入输出的C++形式*/ 4 #include<map> /*定义了一种关联容器,数据结构*/ 5 #include<vector> /*顺序容器,常用于表示向量*/ 6 #include<string> /*字符串操作函数*/ 7 #include<ctime> /*日期时间结构体*/ 8 #include<cmath> /*数学操作*/ 9 #include<cstdlib> /*提供一些函数和符号常量,如第二部分出现的RAND_MAX*/ 10 using namespace std; //添加命名空间 11 12 #define pi 3.1415926535897932384626433832795 13 14 bool L1_flag=1; /*标识数为1时,该变量为L1范数;否则,则为L2范数*/
第二部分:简单计算函数的定义
1 //normal distribution 正态分布 2 double rand(double min, double max) /*返回[min,max)之间的随机数*/ 3 { 4 return min+(max-min)*rand()/(RAND_MAX+1.0); 5 } 6 double normal(double x, double miu,double sigma) /*返回一个均值为miu标准差为sigma的正态分布函数在x处的函数值*/ 7 { 8 return 1.0/sqrt(2*pi)/sigma*exp(-1*(x-miu)*(x-miu)/(2*sigma*sigma)); 9 } 10 double randn(double miu,double sigma, double min ,double max) /*通过产生随机数的方式,返回满足某条件的自变量x*/ 11 { 12 double x,y,dScope; 13 do{ 14 x=rand(min,max); /*调用自定义函数rand(min,max)*/ 15 y=normal(x,miu,sigma); /*调用自定义函数normal(x,miu,sigma)*/ 16 dScope=rand(0.0,normal(miu,miu,sigma)); /*当x值取miu时,正态函数达到最大值,该变量在0与正态函数最大值取随机数*/ 17 }while(dScope>y); /*x处的函数值小于该随机数时,循环停止*/ 18 return x; 19 } 20 21 double sqr(double x) /*平方函数*/ 22 { 23 return x*x; 24 } 25 26 double vec_len(vector<double> &a) /*返回向量a的模,即L2范数*/ 27 { 28 double res=0; 29 for (int i=0; i<a.size(); i++) /*遍历向量长度*/ 30 res+=a[i]*a[i]; 31 res = sqrt(res); 32 return res; 33 }
rand()返回一个0到最大随机数RAND_MAX(确定)的任意整数,RAND_MAX至少为32767。
第三部分:变量的定义
1 string version; 2 char buf[100000],buf1[100000]; 3 int relation_num,entity_num; /*定义关系数量,实体数量*/ 4 map<string,int> relation2id,entity2id; /*关系和实体使用关联容器,字符串作关键词,整型作关键字的值*/ 5 map<int,string> id2entity,id2relation; /*与上相反*/ 6 7 map<int,map<int,int> > left_entity,right_entity; 8 map<int,double> left_num,right_num;
第7行,left_entity表示在此relation下头实体对应的尾实体的个数,3个int分别表示relation_id,headentity_id,个数。right_entity表示在此relation下尾实体对应的头实体的个数,3个int分别表示relation_id,tailentity_id,个数。主要用于计算采样概率p。
第8行,leftnum表示平均每个头实体对应多少个尾实体。rightnum表示平均每个尾实体对应多少头实体。int仍然表示relation_id。
第四部分:训练类的定义
1 class Train{ 2 3 public: 4 map<pair<int,int>, map<int,int> > ok; 5 void add(int x,int y,int z) 6 { 7 fb_h.push_back(x); 8 fb_r.push_back(z); 9 fb_l.push_back(y); 10 ok[make_pair(x,z)][y]=1; 11 } 12 void run(int n_in,double rate_in,double margin_in,int method_in) 13 { 14 n = n_in; 15 rate = rate_in; 16 margin = margin_in; 17 method = method_in; 18 relation_vec.resize(relation_num); 19 for (int i=0; i<relation_vec.size(); i++) 20 relation_vec[i].resize(n); 21 entity_vec.resize(entity_num); 22 for (int i=0; i<entity_vec.size(); i++) 23 entity_vec[i].resize(n); 24 relation_tmp.resize(relation_num); 25 for (int i=0; i<relation_tmp.size(); i++) 26 relation_tmp[i].resize(n); 27 entity_tmp.resize(entity_num); 28 for (int i=0; i<entity_tmp.size(); i++) 29 entity_tmp[i].resize(n); 30 for (int i=0; i<relation_num; i++) 31 { 32 for (int ii=0; ii<n; ii++) 33 relation_vec[i][ii] = randn(0,1.0/n,-6/sqrt(n),6/sqrt(n)); 34 } 35 for (int i=0; i<entity_num; i++) 36 { 37 for (int ii=0; ii<n; ii++) 38 entity_vec[i][ii] = randn(0,1.0/n,-6/sqrt(n),6/sqrt(n)); 39 norm(entity_vec[i]); 40 } 41 42 43 bfgs(); 44 } 45 46 private: 47 int n,method; 48 double res;//loss function value 49 double count,count1;//loss function gradient 50 double rate,margin; 51 double belta; 52 vector<int> fb_h,fb_l,fb_r; 53 vector<vector<int> > feature; 54 vector<vector<double> > relation_vec,entity_vec; 55 vector<vector<double> > relation_tmp,entity_tmp; 56 double norm(vector<double> &a) 57 { 58 double x = vec_len(a); 59 if (x>1) 60 for (int ii=0; ii<a.size(); ii++) 61 a[ii]/=x; 62 return 0; 63 } 64 int rand_max(int x) 65 { 66 int res = (rand()*rand())%x; 67 while (res<0) 68 res+=x; 69 return res; 70 } 71 72 void bfgs() 73 { 74 res=0; 75 int nbatches=100; 76 int nepoch = 1000; 77 int batchsize = fb_h.size()/nbatches; 78 for (int epoch=0; epoch<nepoch; epoch++) 79 { 80 81 res=0; 82 for (int batch = 0; batch<nbatches; batch++) 83 { 84 relation_tmp=relation_vec; 85 entity_tmp = entity_vec; 86 for (int k=0; k<batchsize; k++) 87 { 88 int i=rand_max(fb_h.size()); 89 int j=rand_max(entity_num); 90 double pr = 1000*right_num[fb_r[i]]/(right_num[fb_r[i]]+left_num[fb_r[i]]); 91 if (method ==0) 92 pr = 500; 93 if (rand()%1000<pr) 94 { 95 while (ok[make_pair(fb_h[i],fb_r[i])].count(j)>0) 96 j=rand_max(entity_num); 97 train_kb(fb_h[i],fb_l[i],fb_r[i],fb_h[i],j,fb_r[i]); 98 } 99 else 100 { 101 while (ok[make_pair(j,fb_r[i])].count(fb_l[i])>0) 102 j=rand_max(entity_num); 103 train_kb(fb_h[i],fb_l[i],fb_r[i],j,fb_l[i],fb_r[i]); 104 } 105 norm(relation_tmp[fb_r[i]]); 106 norm(entity_tmp[fb_h[i]]); 107 norm(entity_tmp[fb_l[i]]); 108 norm(entity_tmp[j]); 109 } 110 relation_vec = relation_tmp; 111 entity_vec = entity_tmp; 112 } 113 cout<<"epoch:"<<epoch<<' '<<res<<endl; 114 FILE* f2 = fopen(("relation2vec."+version).c_str(),"w"); 115 FILE* f3 = fopen(("entity2vec."+version).c_str(),"w"); 116 for (int i=0; i<relation_num; i++) 117 { 118 for (int ii=0; ii<n; ii++) 119 fprintf(f2,"%.6lf\t",relation_vec[i][ii]); 120 fprintf(f2,"\n"); 121 } 122 for (int i=0; i<entity_num; i++) 123 { 124 for (int ii=0; ii<n; ii++) 125 fprintf(f3,"%.6lf\t",entity_vec[i][ii]); 126 fprintf(f3,"\n"); 127 } 128 fclose(f2); 129 fclose(f3); 130 } 131 } 132 double res1; 133 double calc_sum(int e1,int e2,int rel) 134 { 135 double sum=0; 136 if (L1_flag) 137 for (int ii=0; ii<n; ii++) 138 sum+=fabs(entity_vec[e2][ii]-entity_vec[e1][ii]-relation_vec[rel][ii]); 139 else 140 for (int ii=0; ii<n; ii++) 141 sum+=sqr(entity_vec[e2][ii]-entity_vec[e1][ii]-relation_vec[rel][ii]); 142 return sum; 143 } 144 void gradient(int e1_a,int e2_a,int rel_a,int e1_b,int e2_b,int rel_b) 145 { 146 for (int ii=0; ii<n; ii++) 147 { 148 149 double x = 2*(entity_vec[e2_a][ii]-entity_vec[e1_a][ii]-relation_vec[rel_a][ii]); 150 if (L1_flag) 151 if (x>0) 152 x=1; 153 else 154 x=-1; 155 relation_tmp[rel_a][ii]-=-1*rate*x; 156 entity_tmp[e1_a][ii]-=-1*rate*x; 157 entity_tmp[e2_a][ii]+=-1*rate*x; 158 x = 2*(entity_vec[e2_b][ii]-entity_vec[e1_b][ii]-relation_vec[rel_b][ii]); 159 if (L1_flag) 160 if (x>0) 161 x=1; 162 else 163 x=-1; 164 relation_tmp[rel_b][ii]-=rate*x; 165 entity_tmp[e1_b][ii]-=rate*x; 166 entity_tmp[e2_b][ii]+=rate*x; 167 } 168 } 169 void train_kb(int e1_a,int e2_a,int rel_a,int e1_b,int e2_b,int rel_b) 170 { 171 double sum1 = calc_sum(e1_a,e2_a,rel_a); 172 double sum2 = calc_sum(e1_b,e2_b,rel_b); 173 if (sum1+margin>sum2) 174 { 175 res+=margin+sum1-sum2; 176 gradient( e1_a, e2_a, rel_a, e1_b, e2_b, rel_b); 177 } 178 } 179 };
第五部分:类变量的定义和数据集准备
1 Train train; 2 void prepare() 3 { 4 FILE* f1 = fopen("../data/entity2id.txt","r"); 5 FILE* f2 = fopen("../data/relation2id.txt","r"); 6 int x; 7 while (fscanf(f1,"%s%d",buf,&x)==2) 8 { 9 string st=buf; 10 entity2id[st]=x; 11 id2entity[x]=st; 12 entity_num++; 13 } 14 while (fscanf(f2,"%s%d",buf,&x)==2) 15 { 16 string st=buf; 17 relation2id[st]=x; 18 id2relation[x]=st; 19 relation_num++; 20 } 21 FILE* f_kb = fopen("../data/train.txt","r"); 22 while (fscanf(f_kb,"%s",buf)==1) 23 { 24 string s1=buf; 25 fscanf(f_kb,"%s",buf); 26 string s2=buf; 27 fscanf(f_kb,"%s",buf); 28 string s3=buf; 29 if (entity2id.count(s1)==0) 30 { 31 cout<<"miss entity:"<<s1<<endl; 32 } 33 if (entity2id.count(s2)==0) 34 { 35 cout<<"miss entity:"<<s2<<endl; 36 } 37 if (relation2id.count(s3)==0) 38 { 39 relation2id[s3] = relation_num; 40 relation_num++; 41 } 42 left_entity[relation2id[s3]][entity2id[s1]]++; 43 right_entity[relation2id[s3]][entity2id[s2]]++; 44 train.add(entity2id[s1],entity2id[s2],relation2id[s3]); 45 } 46 for (int i=0; i<relation_num; i++) 47 { 48 double sum1=0,sum2=0; 49 for (map<int,int>::iterator it = left_entity[i].begin(); it!=left_entity[i].end(); it++) 50 { 51 sum1++; 52 sum2+=it->second; 53 } 54 left_num[i]=sum2/sum1; 55 } 56 for (int i=0; i<relation_num; i++) 57 { 58 double sum1=0,sum2=0; 59 for (map<int,int>::iterator it = right_entity[i].begin(); it!=right_entity[i].end(); it++) 60 { 61 sum1++; 62 sum2+=it->second; 63 } 64 right_num[i]=sum2/sum1; 65 } 66 cout<<"relation_num="<<relation_num<<endl; 67 cout<<"entity_num="<<entity_num<<endl; 68 fclose(f_kb); 69 }
第六部分:未知还没搞懂功能的函数
1 int ArgPos(char *str, int argc, char **argv) { 2 int a; 3 for (a = 1; a < argc; a++) if (!strcmp(str, argv[a])) { 4 if (a == argc - 1) { 5 printf("Argument missing for %s\n", str); 6 exit(1); 7 } 8 return a; 9 } 10 return -1; 11 }
第七部分:主函数流程
1 int main(int argc,char**argv) 2 { 3 srand((unsigned) time(NULL)); 4 int method = 1; 5 int n = 100; 6 double rate = 0.001; 7 double margin = 1; 8 int i; 9 if ((i = ArgPos((char *)"-size", argc, argv)) > 0) n = atoi(argv[i + 1]); 10 if ((i = ArgPos((char *)"-margin", argc, argv)) > 0) margin = atoi(argv[i + 1]); 11 if ((i = ArgPos((char *)"-method", argc, argv)) > 0) method = atoi(argv[i + 1]); 12 cout<<"size = "<<n<<endl; 13 cout<<"learing rate = "<<rate<<endl; 14 cout<<"margin = "<<margin<<endl; 15 if (method) 16 version = "bern"; 17 else 18 version = "unif"; 19 cout<<"method = "<<version<<endl; 20 prepare(); 21 train.run(n,rate,margin,method); 22 }
left_entity:在此relation下头实体对应的尾实体的个数,3个int分别表示relation_id,headentity_id,个数
标签:tmp,程序实现,double,entity,int,fb,relation,TransE,查阅 来源: https://www.cnblogs.com/real-zz-11/p/real_zz.html