其他分享
首页 > 其他分享> > 知识图到文本的生成——陆

知识图到文本的生成——陆

作者:互联网

2021SC@SDUSC

mkiters函数也是dataset类中的一个重要的类函数,我的队友已经在她的博客中详细分析过这个函数,此处不再赘述。

  def mktestset(self, args):
    path = args.path.replace("train",'test')
    fields=self.fields
    ds = data.TabularDataset(path=path, format='tsv',fields=fields)
    ds.fields["rawent"] = data.RawField()
    for x in ds:
      x.rawent = x.ent.split(" ; ")
      x.ent = self.vec_ents(x.ent,self.ENT)
      x.rel = self.mkGraphs(x.rel,len(x.ent[1]))
      if args.sparse:
        x.rel = (self.adjToSparse(x.rel[0]),x.rel[1])
      x.tgt = x.out
      x.out = [y.split("_")[0]+">" if "_" in y else y for y in x.out]
      x.sordertgt = torch.LongTensor([int(y)+3 for y in x.sorder.split(" ")])
      x.sorder = [[int(z) for z in y.strip().split(" ")] for y in x.sorder.split("-1")[:-1]]
    ds.fields["tgt"] = self.TGT
    ds.fields["rawent"] = data.RawField()
    ds.fields["sordertgt"] = data.RawField()
    dat_iter = data.Iterator(ds,1,device=args.device,sort_key=lambda x:len(x.src), train=False, sort=False)
    return dat_iter

mktestset函数是dataset类中一个用来形成测试集的函数,对数据集进行遍历之后,返回一个迭代器。其余的函数都是对数据集进行一些修饰工作,不再一一展开详细分析。

让我们回到最最开始的地方, 继续分析train.py程序。之前我们详细分析了dataset类,而pargs.py由我的队友来着重分析,那么我们就继续看:

m = model(args)

我们开始分析model类。

class model(nn.Module):

首先我们看model类的init函数。

  def __init__(self,args):
    super().__init__()
    self.args = args
    cattimes = 3 if args.title else 2
    self.emb = nn.Embedding(args.ntoks,args.hsz)
    self.lstm = nn.LSTMCell(args.hsz*cattimes,args.hsz)
    self.out = nn.Linear(args.hsz*cattimes,args.tgttoks)
    self.le = list_encode(args)
    self.entout = nn.Linear(args.hsz,1)
    self.switch = nn.Linear(args.hsz*cattimes,1)

这个model类继承了torch.nn,其中的参数都是调用了torch.nn中的函数。cattimes是分类的次数,如果有标题,就设置为3,如果没有标题,就为2。emb为用args.ntoks和args.hsz组成的矩阵(args.ntoks是输出的vocab长度,会在pargs.py的代码分析中详细介绍)。lstm是用hsz和分类次数乘积作为构建LSTM中的一个Cell的输入特征维度,hsz作为构建LSTM中的一个Cell的隐状态的维度,torch.nn中的LSTM和LSTMCell的操作如下图:

 

 out、entout、switch都是调用了nn.Linear()函数,其中的参数都是指维度,对二维变量进行线性变换,如图所示。

    self.attn = MultiHeadAttention(args.hsz,args.hsz,args.hsz,h=4,dropout_p=args.drop)
    self.mattn = MatrixAttn(args.hsz*cattimes,args.hsz)
    self.graph = (args.model in ['graph','gat','gtrans'])
    print(args.model)

MultiHeadAttention()是attention.py中的类,继承Module,这里的操作是返回一个连接后的4*4维度的attn。MatrixAttn()也是attention.py中的类,继承Module,这里是对hsz和分类次数的乘积和hsz作线性变换(如上图所示)。graph则是模型生成的图,然后在终端打印出来。

标签:nn,图到,fields,self,args,生成,hsz,文本,ds
来源: https://blog.csdn.net/qq_50729659/article/details/121542713