编程语言
首页 > 编程语言> > TransE的程序实现——学习,查阅,注释

TransE的程序实现——学习,查阅,注释

作者:互联网

代码来源:https://github.com/thunlp/KB2E/blob/master/TransE/

很久没碰C++了,为了实验成功,还是要仔仔细细抠一抠代码才行。

第一部分:头文件的引入以及常量定义

 1 #include<iostream>  /*输入输出流*/
 2 #include<cstring>   /*C字符串操作函数*/
 3 #include<cstdio>    /*标准输入输出的C++形式*/
 4 #include<map>   /*定义了一种关联容器,数据结构*/
 5 #include<vector>  /*顺序容器,常用于表示向量*/
 6 #include<string>  /*字符串操作函数*/
 7 #include<ctime>  /*日期时间结构体*/
 8 #include<cmath>  /*数学操作*/
 9 #include<cstdlib>  /*提供一些函数和符号常量,如第二部分出现的RAND_MAX*/
10 using namespace std; //添加命名空间
11 
12 #define pi 3.1415926535897932384626433832795
13 
14 bool L1_flag=1;  /*标识数为1时,该变量为L1范数;否则,则为L2范数*/

第二部分:简单计算函数的定义

 1 //normal distribution  正态分布
 2 double rand(double min, double max)  /*返回[min,max)之间的随机数*/
 3 {
 4     return min+(max-min)*rand()/(RAND_MAX+1.0);
 5 }
 6 double normal(double x, double miu,double sigma)  /*返回一个均值为miu标准差为sigma的正态分布函数在x处的函数值*/
 7 {
 8     return 1.0/sqrt(2*pi)/sigma*exp(-1*(x-miu)*(x-miu)/(2*sigma*sigma));
 9 }
10 double randn(double miu,double sigma, double min ,double max)  /*通过产生随机数的方式,返回满足某条件的自变量x*/
11 {
12     double x,y,dScope;
13     do{
14         x=rand(min,max);  /*调用自定义函数rand(min,max)*/
15         y=normal(x,miu,sigma);  /*调用自定义函数normal(x,miu,sigma)*/
16         dScope=rand(0.0,normal(miu,miu,sigma));  /*当x值取miu时,正态函数达到最大值,该变量在0与正态函数最大值取随机数*/
17     }while(dScope>y);  /*x处的函数值小于该随机数时,循环停止*/
18     return x;
19 }
20 
21 double sqr(double x)  /*平方函数*/
22 {
23     return x*x;
24 }
25 
26 double vec_len(vector<double> &a)  /*返回向量a的模,即L2范数*/
27 {
28     double res=0;
29     for (int i=0; i<a.size(); i++)  /*遍历向量长度*/
30         res+=a[i]*a[i];  
31     res = sqrt(res);
32     return res;
33 }

rand()返回一个0到最大随机数RAND_MAX(确定)的任意整数,RAND_MAX至少为32767。

第三部分:变量的定义

1 string version;
2 char buf[100000],buf1[100000];
3 int relation_num,entity_num;  /*定义关系数量,实体数量*/
4 map<string,int> relation2id,entity2id;  /*关系和实体使用关联容器,字符串作关键词,整型作关键字的值*/
5 map<int,string> id2entity,id2relation;  /*与上相反*/
6 
7 map<int,map<int,int> > left_entity,right_entity; 
8 map<int,double> left_num,right_num;

第7行,left_entity表示在此relation下头实体对应的尾实体的个数,3个int分别表示relation_id,headentity_id,个数。right_entity表示在此relation下尾实体对应的头实体的个数,3个int分别表示relation_id,tailentity_id,个数。主要用于计算采样概率p。

第8行,leftnum表示平均每个头实体对应多少个尾实体。rightnum表示平均每个尾实体对应多少头实体。int仍然表示relation_id。

第四部分:训练类的定义

  1 class Train{
  2 
  3 public:
  4     map<pair<int,int>, map<int,int> > ok;
  5     void add(int x,int y,int z)
  6     {
  7         fb_h.push_back(x);
  8         fb_r.push_back(z);
  9         fb_l.push_back(y);
 10         ok[make_pair(x,z)][y]=1;
 11     }
 12     void run(int n_in,double rate_in,double margin_in,int method_in)
 13     {
 14         n = n_in;
 15         rate = rate_in;
 16         margin = margin_in;
 17         method = method_in;
 18         relation_vec.resize(relation_num);
 19         for (int i=0; i<relation_vec.size(); i++)
 20             relation_vec[i].resize(n);
 21         entity_vec.resize(entity_num);
 22         for (int i=0; i<entity_vec.size(); i++)
 23             entity_vec[i].resize(n);
 24         relation_tmp.resize(relation_num);
 25         for (int i=0; i<relation_tmp.size(); i++)
 26             relation_tmp[i].resize(n);
 27         entity_tmp.resize(entity_num);
 28         for (int i=0; i<entity_tmp.size(); i++)
 29             entity_tmp[i].resize(n);
 30         for (int i=0; i<relation_num; i++)
 31         {
 32             for (int ii=0; ii<n; ii++)
 33                 relation_vec[i][ii] = randn(0,1.0/n,-6/sqrt(n),6/sqrt(n));
 34         }
 35         for (int i=0; i<entity_num; i++)
 36         {
 37             for (int ii=0; ii<n; ii++)
 38                 entity_vec[i][ii] = randn(0,1.0/n,-6/sqrt(n),6/sqrt(n));
 39             norm(entity_vec[i]);
 40         }
 41 
 42 
 43         bfgs();
 44     }
 45 
 46 private:
 47     int n,method;
 48     double res;//loss function value
 49     double count,count1;//loss function gradient
 50     double rate,margin;
 51     double belta;
 52     vector<int> fb_h,fb_l,fb_r;
 53     vector<vector<int> > feature;
 54     vector<vector<double> > relation_vec,entity_vec;
 55     vector<vector<double> > relation_tmp,entity_tmp;
 56     double norm(vector<double> &a)
 57     {
 58         double x = vec_len(a);
 59         if (x>1)
 60         for (int ii=0; ii<a.size(); ii++)
 61                 a[ii]/=x;
 62         return 0;
 63     }
 64     int rand_max(int x)
 65     {
 66         int res = (rand()*rand())%x;
 67         while (res<0)
 68             res+=x;
 69         return res;
 70     }
 71 
 72     void bfgs()
 73     {
 74         res=0;
 75         int nbatches=100;
 76         int nepoch = 1000;
 77         int batchsize = fb_h.size()/nbatches;
 78             for (int epoch=0; epoch<nepoch; epoch++)
 79             {
 80 
 81                 res=0;
 82                  for (int batch = 0; batch<nbatches; batch++)
 83                  {
 84                      relation_tmp=relation_vec;
 85                     entity_tmp = entity_vec;
 86                      for (int k=0; k<batchsize; k++)
 87                      {
 88                         int i=rand_max(fb_h.size());
 89                         int j=rand_max(entity_num);
 90                         double pr = 1000*right_num[fb_r[i]]/(right_num[fb_r[i]]+left_num[fb_r[i]]);
 91                         if (method ==0)
 92                             pr = 500;
 93                         if (rand()%1000<pr)
 94                         {
 95                             while (ok[make_pair(fb_h[i],fb_r[i])].count(j)>0)
 96                                 j=rand_max(entity_num);
 97                             train_kb(fb_h[i],fb_l[i],fb_r[i],fb_h[i],j,fb_r[i]);
 98                         }
 99                         else
100                         {
101                             while (ok[make_pair(j,fb_r[i])].count(fb_l[i])>0)
102                                 j=rand_max(entity_num);
103                             train_kb(fb_h[i],fb_l[i],fb_r[i],j,fb_l[i],fb_r[i]);
104                         }
105                         norm(relation_tmp[fb_r[i]]);
106                         norm(entity_tmp[fb_h[i]]);
107                         norm(entity_tmp[fb_l[i]]);
108                         norm(entity_tmp[j]);
109                      }
110                     relation_vec = relation_tmp;
111                     entity_vec = entity_tmp;
112                  }
113                 cout<<"epoch:"<<epoch<<' '<<res<<endl;
114                 FILE* f2 = fopen(("relation2vec."+version).c_str(),"w");
115                 FILE* f3 = fopen(("entity2vec."+version).c_str(),"w");
116                 for (int i=0; i<relation_num; i++)
117                 {
118                     for (int ii=0; ii<n; ii++)
119                         fprintf(f2,"%.6lf\t",relation_vec[i][ii]);
120                     fprintf(f2,"\n");
121                 }
122                 for (int i=0; i<entity_num; i++)
123                 {
124                     for (int ii=0; ii<n; ii++)
125                         fprintf(f3,"%.6lf\t",entity_vec[i][ii]);
126                     fprintf(f3,"\n");
127                 }
128                 fclose(f2);
129                 fclose(f3);
130             }
131     }
132     double res1;
133     double calc_sum(int e1,int e2,int rel)
134     {
135         double sum=0;
136         if (L1_flag)
137             for (int ii=0; ii<n; ii++)
138                 sum+=fabs(entity_vec[e2][ii]-entity_vec[e1][ii]-relation_vec[rel][ii]);
139         else
140             for (int ii=0; ii<n; ii++)
141                 sum+=sqr(entity_vec[e2][ii]-entity_vec[e1][ii]-relation_vec[rel][ii]);
142         return sum;
143     }
144     void gradient(int e1_a,int e2_a,int rel_a,int e1_b,int e2_b,int rel_b)
145     {
146         for (int ii=0; ii<n; ii++)
147         {
148 
149             double x = 2*(entity_vec[e2_a][ii]-entity_vec[e1_a][ii]-relation_vec[rel_a][ii]);
150             if (L1_flag)
151                 if (x>0)
152                     x=1;
153                 else
154                     x=-1;
155             relation_tmp[rel_a][ii]-=-1*rate*x;
156             entity_tmp[e1_a][ii]-=-1*rate*x;
157             entity_tmp[e2_a][ii]+=-1*rate*x;
158             x = 2*(entity_vec[e2_b][ii]-entity_vec[e1_b][ii]-relation_vec[rel_b][ii]);
159             if (L1_flag)
160                 if (x>0)
161                     x=1;
162                 else
163                     x=-1;
164             relation_tmp[rel_b][ii]-=rate*x;
165             entity_tmp[e1_b][ii]-=rate*x;
166             entity_tmp[e2_b][ii]+=rate*x;
167         }
168     }
169     void train_kb(int e1_a,int e2_a,int rel_a,int e1_b,int e2_b,int rel_b)
170     {
171         double sum1 = calc_sum(e1_a,e2_a,rel_a);
172         double sum2 = calc_sum(e1_b,e2_b,rel_b);
173         if (sum1+margin>sum2)
174         {
175             res+=margin+sum1-sum2;
176             gradient( e1_a, e2_a, rel_a, e1_b, e2_b, rel_b);
177         }
178     }
179 };

第五部分:类变量的定义和数据集准备

 1 Train train;
 2 void prepare()
 3 {
 4     FILE* f1 = fopen("../data/entity2id.txt","r");
 5     FILE* f2 = fopen("../data/relation2id.txt","r");
 6     int x;
 7     while (fscanf(f1,"%s%d",buf,&x)==2)
 8     {
 9         string st=buf;
10         entity2id[st]=x;
11         id2entity[x]=st;
12         entity_num++;
13     }
14     while (fscanf(f2,"%s%d",buf,&x)==2)
15     {
16         string st=buf;
17         relation2id[st]=x;
18         id2relation[x]=st;
19         relation_num++;
20     }
21     FILE* f_kb = fopen("../data/train.txt","r");
22     while (fscanf(f_kb,"%s",buf)==1)
23     {
24         string s1=buf;
25         fscanf(f_kb,"%s",buf);
26         string s2=buf;
27         fscanf(f_kb,"%s",buf);
28         string s3=buf;
29         if (entity2id.count(s1)==0)
30         {
31             cout<<"miss entity:"<<s1<<endl;
32         }
33         if (entity2id.count(s2)==0)
34         {
35             cout<<"miss entity:"<<s2<<endl;
36         }
37         if (relation2id.count(s3)==0)
38         {
39             relation2id[s3] = relation_num;
40             relation_num++;
41         }
42         left_entity[relation2id[s3]][entity2id[s1]]++;
43         right_entity[relation2id[s3]][entity2id[s2]]++;
44         train.add(entity2id[s1],entity2id[s2],relation2id[s3]);
45     }
46     for (int i=0; i<relation_num; i++)
47     {
48         double sum1=0,sum2=0;
49         for (map<int,int>::iterator it = left_entity[i].begin(); it!=left_entity[i].end(); it++)
50         {
51             sum1++;
52             sum2+=it->second;
53         }
54         left_num[i]=sum2/sum1;
55     }
56     for (int i=0; i<relation_num; i++)
57     {
58         double sum1=0,sum2=0;
59         for (map<int,int>::iterator it = right_entity[i].begin(); it!=right_entity[i].end(); it++)
60         {
61             sum1++;
62             sum2+=it->second;
63         }
64         right_num[i]=sum2/sum1;
65     }
66     cout<<"relation_num="<<relation_num<<endl;
67     cout<<"entity_num="<<entity_num<<endl;
68     fclose(f_kb);
69 }

第六部分:未知还没搞懂功能的函数

 1 int ArgPos(char *str, int argc, char **argv) {
 2   int a;
 3   for (a = 1; a < argc; a++) if (!strcmp(str, argv[a])) {
 4     if (a == argc - 1) {
 5       printf("Argument missing for %s\n", str);
 6       exit(1);
 7     }
 8     return a;
 9   }
10   return -1;
11 }

第七部分:主函数流程

 1 int main(int argc,char**argv)
 2 {
 3     srand((unsigned) time(NULL));
 4     int method = 1;
 5     int n = 100;
 6     double rate = 0.001;
 7     double margin = 1;
 8     int i;
 9     if ((i = ArgPos((char *)"-size", argc, argv)) > 0) n = atoi(argv[i + 1]);
10     if ((i = ArgPos((char *)"-margin", argc, argv)) > 0) margin = atoi(argv[i + 1]);
11     if ((i = ArgPos((char *)"-method", argc, argv)) > 0) method = atoi(argv[i + 1]);
12     cout<<"size = "<<n<<endl;
13     cout<<"learing rate = "<<rate<<endl;
14     cout<<"margin = "<<margin<<endl;
15     if (method)
16         version = "bern";
17     else
18         version = "unif";
19     cout<<"method = "<<version<<endl;
20     prepare();
21     train.run(n,rate,margin,method);
22 }

 

left_entity:在此relation下头实体对应的尾实体的个数,3个int分别表示relation_id,headentity_id,个数

标签:tmp,程序实现,double,entity,int,fb,relation,TransE,查阅
来源: https://www.cnblogs.com/real-zz-11/p/real_zz.html