其他分享
首页 > 其他分享> > transformer代码笔记----decoder.py

transformer代码笔记----decoder.py

作者:互联网

import torch
import torch.nn as nn
import torch.nn.functional as F

from config import IGNORE_ID
from .attention import MultiHeadAttention
from .module import PositionalEncoding, PositionwiseFeedForward
from .utils import get_attn_key_pad_mask, get_attn_pad_mask, get_non_pad_mask, get_subsequent_mask, pad_list


# filename = 'bigram_freq.pkl'
# print('loading {}...'.format(filename))
# with open(filename, 'rb') as file:
#     bigram_freq = pickle.load(file)


class Decoder(nn.Module):
    ''' A decoder model with self attention mechanism. '''

    def __init__(
            self, sos_id=0, eos_id=1,
            n_tgt_vocab=4335, d_word_vec=512,
            n_layers=6, n_head=8, d_k=64, d_v=64,
            d_model=512, d_inner=2048, dropout=0.1,
            tgt_emb_prj_weight_sharing=True,
            pe_maxlen=5000):
        super(Decoder, self).__init__()
        # parameters 参数实例化
        self.sos_id = sos_id  # Start of Sentence
        self.eos_id = eos_id  # End of Sentence
        self.n_tgt_vocab = n_tgt_vocab
        self.d_word_vec = d_word_vec
        self.n_layers = n_layers
        self.n_head = n_head
        self.d_k = d_k
        self.d_v = d_v
        self.d_model = d_model
        self.d_inner = d_inner
        self.dropout = dropout
        self.tgt_emb_prj_weight_sharing = tgt_emb_prj_weight_sharing
        self.pe_maxlen = pe_maxlen

        self.tgt_word_emb = nn.Embedding(n_tgt_vocab, d_word_vec)
        self.positional_encoding = PositionalEncoding(d_model, max_len=pe_maxlen)
        self.dropout = nn.Dropout(dropout)

        self.layer_stack = nn.ModuleList([
            DecoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
            for _ in range(n_layers)])   #解码器个数

        self.tgt_word_prj = nn.Linear(d_model, n_tgt_vocab, bias=False)  #线性变换
        nn.init.xavier_normal_(self.tgt_word_prj.weight)  #初始化

        if tgt_emb_prj_weight_sharing:  #默认为true
            # Share the weight matrix between target word embedding & the final logit dense layer
            self.tgt_word_prj.weight = self.tgt_word_emb.weight  #将目标词嵌入权重共享给线性函数的权重
            self.x_logit_scale = (d_model ** -0.5) #?
        else:
            self.x_logit_scale = 1.

    def preprocess(self, padded_input): #预处理
        """Generate decoder input and output label from padded_input
        Add <sos> to decoder input, and add <eos> to decoder output label
        """
        ys = [y[y != IGNORE_ID] for y in padded_input]  # parse padded ys  IGNOR_ID=-1
        # prepare input and output word sequences with sos/eos IDs
        eos = ys[0].new([self.eos_id]) #定义新的零阶tensor
        # .new():创建一个新的Tensor,该Tensor的type和device都和原有Tensor一致,且无内容。
        sos = ys[0].new([self.sos_id])
        ys_in = [torch.cat([sos, y], dim=0) for y in ys] #合并两个tensor,添加起始标签
        ys_out = [torch.cat([y, eos], dim=0) for y in ys] #添加结束标签
        # padding for ys with -1
        # pys: utt x olen
        ys_in_pad = pad_list(ys_in, self.eos_id) #ys_in:填充对象;self.eos_id:填充值
        ys_out_pad = pad_list(ys_out, IGNORE_ID)
        assert ys_in_pad.size() == ys_out_pad.size() #assert判断后面代码的布尔值,若为假就报错
        return ys_in_pad, ys_out_pad  #返回添加标签和填充后的数据

    def forward(self, padded_input, encoder_padded_outputs,
                encoder_input_lengths, return_attns=False):
        """
        Args:
            padded_input: N x To
            encoder_padded_outputs: N x Ti x H
        Returns:
        """
        dec_slf_attn_list, dec_enc_attn_list = [], [] #定义解码器注意力和编码解码注意力列表

        # Get Deocder Input and Output
        ys_in_pad, ys_out_pad = self.preprocess(padded_input)  #提取预处理后的数据

        # Prepare masks
        non_pad_mask = get_non_pad_mask(ys_in_pad, pad_idx=self.eos_id) #对输入mask

        slf_attn_mask_subseq = get_subsequent_mask(ys_in_pad) #对目标序列mask
        slf_attn_mask_keypad = get_attn_key_pad_mask(seq_k=ys_in_pad,
                                                     seq_q=ys_in_pad,
                                                     pad_idx=self.eos_id) #对key mask
        slf_attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0) #自注意力mask

        output_length = ys_in_pad.size(1)
        dec_enc_attn_mask = get_attn_pad_mask(encoder_padded_outputs,
                                              encoder_input_lengths,
                                              output_length) #编码解码注意力mask

        # Forward
        dec_output = self.dropout(self.tgt_word_emb(ys_in_pad) * self.x_logit_scale +
                                  self.positional_encoding(ys_in_pad)) #输入等词嵌入加位置编码

        for dec_layer in self.layer_stack: #进入decoder层
            dec_outpsk=slf_aut, dec_slf_attn, dec_enc_attn = dec_layer(
                dec_output, encoder_padded_outputs,
                non_pad_mask=non_pad_mask,
                slf_attn_mattn_mask,
                dec_enc_attn_mask=dec_enc_attn_mask)

            if return_attns: #默认False
                dec_slf_attn_list += [dec_slf_attn]
                dec_enc_attn_list += [dec_enc_attn]

        # before softmax
        seq_logit = self.tgt_word_prj(dec_output)#编码器的输出放入线性网络中

        # Return
        pred, gold = seq_logit, ys_out_pad #得到目标值和预测值

        if return_attns:
            return pred, gold, dec_slf_attn_list, dec_enc_attn_list
        return pred, gold

    def recognize_beam(self, encoder_outputs, char_list, args):
        """Beam search, decode one utterence now.
        Args:
            encoder_outputs: T x H
            char_list: list of character
            args: args.beam
        Returns:
            nbest_hyps:
        """
        # search params
        beam = args.beam_size
        nbest = args.nbest
        if args.decode_max_len == 0:
            maxlen = encoder_outputs.size(0)
        else:
            maxlen = args.decode_max_len

        encoder_outputs = encoder_outputs.unsqueeze(0) #unsqueeze(0)对零维添加一个维度

        # prepare sos
        # 在数据中添加起始标志
        ys = torch.ones(1, 1).fill_(self.sos_id).type_as(encoder_outputs).long()
        #.ones(size):生成一个全是1的tensor;a.type_as(b):将a的数据类型转换为b的数据类型;
        #a.fill_(b):将a中的数据替换为b;long():数据类型

        # yseq: 1xT
        hyp = {'score': 0.0, 'yseq': ys}
        hyps = [hyp]
        ended_hyps = []

        for i in range(maxlen):
            hyps_best_kept = []
            for hyp in hyps:
                ys = hyp['yseq']  # 1 x i
                # last_id = ys.cpu().numpy()[0][-1]
                # freq = bigram_freq[last_id]
                # freq = torch.log(torch.from_numpy(freq))
                # # print(freq.dtype)
                # freq = freq.type(torch.float).to(device)
                # print(freq.dtype)
                # print('freq.size(): ' + str(freq.size()))
                # print('freq: ' + str(freq))
                # -- Prepare masks
                non_pad_mask = torch.ones_like(ys).float().unsqueeze(-1)  # 1xix1
                slf_attn_mask = get_subsequent_mask(ys)

                # -- Forward
                dec_output = self.dropout(
                    self.tgt_word_emb(ys) * self.x_logit_scale +
                    self.positional_encoding(ys))

                for dec_layer in self.layer_stack:
                    dec_output, _, _ = dec_layer(
                        dec_output, encoder_outputs,
                        non_pad_mask=non_pad_mask,
                        slf_attn_mask=slf_attn_mask,
                        dec_enc_attn_mask=None)

                seq_logit = self.tgt_word_prj(dec_output[:, -1])
                # local_scores = F.log_softmax(seq_logit, dim=1)
                local_scores = F.log_softmax(seq_logit, dim=1)
                # print('local_scores.size(): ' + str(local_scores.size()))
                # local_scores += freq
                # print('local_scores: ' + str(local_scores))

                # topk scores
                local_best_scores, local_best_ids = torch.topk(
                    local_scores, beam, dim=1)

                for j in range(beam):
                    new_hyp = {}
                    new_hyp['score'] = hyp['score'] + local_best_scores[0, j]
                    new_hyp['yseq'] = torch.ones(1, (1 + ys.size(1))).type_as(encoder_outputs).long()
                    new_hyp['yseq'][:, :ys.size(1)] = hyp['yseq']
                    new_hyp['yseq'][:, ys.size(1)] = int(local_best_ids[0, j])
                    # will be (2 x beam) hyps at most
                    hyps_best_kept.append(new_hyp)

                hyps_best_kept = sorted(hyps_best_kept,
                                        key=lambda x: x['score'],
                                        reverse=True)[:beam]
            # end for hyp in hyps
            hyps = hyps_best_kept

            # add eos in the final loop to avoid that there are no ended hyps
            if i == maxlen - 1:
                for hyp in hyps:
                    hyp['yseq'] = torch.cat([hyp['yseq'],
                                             torch.ones(1, 1).fill_(self.eos_id).type_as(encoder_outputs).long()],
                                            dim=1)

            # add ended hypothes to a final list, and removed them from current hypothes
            # (this will be a probmlem, number of hyps < beam)
            remained_hyps = []
            for hyp in hyps:
                if hyp['yseq'][0, -1] == self.eos_id:
                    ended_hyps.append(hyp)
                else:
                    remained_hyps.append(hyp)

            hyps = remained_hyps
            # if len(hyps) > 0:
            #     print('remeined hypothes: ' + str(len(hyps)))
            # else:
            #     print('no hypothesis. Finish decoding.')
            #     break
            #
            # for hyp in hyps:
            #     print('hypo: ' + ''.join([char_list[int(x)]
            #                               for x in hyp['yseq'][0, 1:]]))
        # end for i in range(maxlen)
        nbest_hyps = sorted(ended_hyps, key=lambda x: x['score'], reverse=True)[
                     :min(len(ended_hyps), nbest)]
        # compitable with LAS implementation
        for hyp in nbest_hyps:
            hyp['yseq'] = hyp['yseq'][0].cpu().numpy().tolist()
        return nbest_hyps


class DecoderLayer(nn.Module):
    ''' Compose with three layers '''

    def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
        super(DecoderLayer, self).__init__()
        self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
        self.enc_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
        self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)

    def forward(self, dec_input, enc_output, non_pad_mask=None, slf_attn_mask=None, dec_enc_attn_mask=None):
        dec_output, dec_slf_attn = self.slf_attn(
            dec_input, dec_input, dec_input, mask=slf_attn_mask)
        dec_output *= non_pad_mask

        dec_output, dec_enc_attn = self.enc_attn(
            dec_output, enc_output, enc_output, mask=dec_enc_attn_mask)
        dec_output *= non_pad_mask

        dec_output = self.pos_ffn(dec_output)
        dec_output *= non_pad_mask

        return dec_output, dec_slf_attn, dec_enc_attn

 

标签:transformer,self,py,mask,----,pad,attn,ys,dec
来源: https://www.cnblogs.com/Uriel-w/p/15426155.html