transformer代码笔记----transformer.py
作者:互联网
import torch.nn as nn from .decoder import Decoder from .encoder import Encoder class Transformer(nn.Module): #定义类,继承父类nn.Module """An encoder-decoder framework only includes attention. """ def __init__(self, encoder=None, decoder=None): #参数encoder和decoder设置默认值None super(Transformer, self).__init__() #继承父类__init__() if encoder is not None and decoder is not None: #判断decoder和encoder是否被重新赋值 self.encoder = encoder self.decoder = decoder for p in self.parameters(): #获取网络参数 if p.dim() > 1: nn.init.xavier_uniform_(p) #参数初始化,torch.nn.init.xavier_uniform_是一个服从均匀分布的Glorot初始化器 # else: # self.encoder = Encoder() #对全局变量赋值 # self.decoder = Decoder() def forward(self, padded_input, input_lengths, padded_target): #编码器中的前向传播 """ Args: padded_input: B x Ti x D 表示编码器输入时数据结构 其中B(一维向量):批量中每个音频的具体长度;Ti:该批量中音频的最大长度; input_lengths: B 每个音频的具体长度,假设批量大小为32,则B可表示为[2,3,45,6....],维度32 padded_targets: B x To 表示解码器的输入数据结构,这里的B和上面的B不同,因为编码器中是音频的输入,解码器中的输入是字符 """ encoder_padded_outputs, *_ = self.encoder(padded_input, input_lengths) # pred is score before softmax pred, gold, *_ = self.decoder(padded_target, encoder_padded_outputs, input_lengths) return pred, gold def recognize(self, input, input_length, char_list, args): #解码器中的识别过程 """Sequence-to-Sequence beam search, decode one utterence now. Args: input: T x D char_list: list of characters args: args.beam Returns: nbest_hyps: """ encoder_outputs, *_ = self.encoder(input.unsqueeze(0), input_length) nbest_hyps = self.decoder.recognize_beam(encoder_outputs[0], char_list, args) return nbest_hyps
标签:padded,transformer,nn,self,py,encoder,----,decoder,input 来源: https://www.cnblogs.com/Uriel-w/p/15426153.html