其他分享
首页 > 其他分享> > 深度学习_用LSTM构建单词纠错神器(3)

深度学习_用LSTM构建单词纠错神器(3)

作者:互联网

六、模型预测

由于预测的word修正不知道何时结束, 所以我们需要对输入的值进行不断的修正,直到预测到末尾符为止。
即预测时候输入的input2为仅有一个起始符的全都为0的初始向量,然后每次预测都更新下字母位置的值,直到遇到末尾符。


## 由于预测的word 不知道何时结束, 所以我们需要对输入的值进行不断的修正,直到预测到末尾符为止
## 所以我们对于一个全新的输入,进行预测的时候需要先使得input2 输入为 一个全空的单词矩阵
def predict(m, input1_test):
    input2_orign = np.zeros((1, 36, 39))
    input2_orign[:, 0, g.char2int['\t']] = 1

    input_word = ''
    pred_word = ''

    for idx in range(input2_orign.shape[1] - 1): # max_encode_len
        p_tmp =  m.predict([tf.constant(input1_test), tf.constant(input2_orign)])
        # update input
        input2_w_idx = np.argmax(p_tmp[:, idx, :], axis=1)[0]
        # input2_orign[:, idx+1, :] = p_tmp[:, idx, :]
        input2_orign[:, idx+1, input2_w_idx] = 1
        
        input1_w_idx = np.argmax(input1_test[:, idx, :], axis=1)[0]
        pred_word += g.int2char[input2_w_idx]
        input_word += g.int2char[input1_w_idx]

        if (pred_word[-1] == '\n'):
            break
    print(f'[{idx}] input_word: {input_word[:-1]},  pred_word : {pred_word}' )
    return pred_word



def word2tensor(word):
    """
    当没有提供embedding的方法的时候,
    采用最简单的字母位置及出现则标记为1, 否则标记为0。 便于后面一个一个字母预测的时候抽取字母
    """
    char_set = [chr(i) for i in range(ord('a'), ord('z')+1)] + '0 1 2 3 4 5 6 7 8 9'.split() + ['\t', '\n', '#']
    char2int = dict(zip(char_set, range(len(char_set))))
    # int2char = dict(zip(range(len(char_set)), char_set))
    input1_encode_data = np.zeros((1, 34, len(char_set)), dtype='float64')

    # 将矩阵填充上数据 某个字母出现一次则标记增加1
    for w_idx, chr_tmp in enumerate(list(word)):
        if w_idx == 34:
            break
        input1_encode_data[0, w_idx, char2int[chr_tmp]] = 1

    return input1_encode_data


def word_correct(m, word):
    input1_encode_data = word2tensor(word)
    return predict(m, input1_encode_data)

预测测试

for  i in range(1, 50):
    print('--'*23, f'[i]', '--'*23)
    input1_test = input1_encode_data[i:i+1, :, :]
    predict(m, input1_test)


word_correct('hellp')

"""
---------------------------------------------- [12] ----------------------------------------------
[11] input_word: applicat5on,  pred_word : application

---------------------------------------------- [13] ----------------------------------------------
[11] input_word: requrements,  pred_word : requrements

---------------------------------------------- [14] ----------------------------------------------
[11] input_word: advertisrng,  pred_word : advertisrng

---------------------------------------------- [15] ----------------------------------------------
[12] input_word: construction,  pred_word : construction

---------------------------------------------- [16] ----------------------------------------------
[11] input_word: engineer5ng,  pred_word : engineering
"""

标签:input2,input1,word,idx,pred,神器,input,LSTM,纠错
来源: https://blog.csdn.net/Scc_hy/article/details/120794296