深度学习_用LSTM构建单词纠错神器(3)
作者:互联网
六、模型预测
由于预测的word修正不知道何时结束, 所以我们需要对输入的值进行不断的修正,直到预测到末尾符为止。
即预测时候输入的input2为仅有一个起始符的全都为0的初始向量,然后每次预测都更新下字母位置的值,直到遇到末尾符。
## 由于预测的word 不知道何时结束, 所以我们需要对输入的值进行不断的修正,直到预测到末尾符为止
## 所以我们对于一个全新的输入,进行预测的时候需要先使得input2 输入为 一个全空的单词矩阵
def predict(m, input1_test):
input2_orign = np.zeros((1, 36, 39))
input2_orign[:, 0, g.char2int['\t']] = 1
input_word = ''
pred_word = ''
for idx in range(input2_orign.shape[1] - 1): # max_encode_len
p_tmp = m.predict([tf.constant(input1_test), tf.constant(input2_orign)])
# update input
input2_w_idx = np.argmax(p_tmp[:, idx, :], axis=1)[0]
# input2_orign[:, idx+1, :] = p_tmp[:, idx, :]
input2_orign[:, idx+1, input2_w_idx] = 1
input1_w_idx = np.argmax(input1_test[:, idx, :], axis=1)[0]
pred_word += g.int2char[input2_w_idx]
input_word += g.int2char[input1_w_idx]
if (pred_word[-1] == '\n'):
break
print(f'[{idx}] input_word: {input_word[:-1]}, pred_word : {pred_word}' )
return pred_word
def word2tensor(word):
"""
当没有提供embedding的方法的时候,
采用最简单的字母位置及出现则标记为1, 否则标记为0。 便于后面一个一个字母预测的时候抽取字母
"""
char_set = [chr(i) for i in range(ord('a'), ord('z')+1)] + '0 1 2 3 4 5 6 7 8 9'.split() + ['\t', '\n', '#']
char2int = dict(zip(char_set, range(len(char_set))))
# int2char = dict(zip(range(len(char_set)), char_set))
input1_encode_data = np.zeros((1, 34, len(char_set)), dtype='float64')
# 将矩阵填充上数据 某个字母出现一次则标记增加1
for w_idx, chr_tmp in enumerate(list(word)):
if w_idx == 34:
break
input1_encode_data[0, w_idx, char2int[chr_tmp]] = 1
return input1_encode_data
def word_correct(m, word):
input1_encode_data = word2tensor(word)
return predict(m, input1_encode_data)
预测测试
for i in range(1, 50):
print('--'*23, f'[i]', '--'*23)
input1_test = input1_encode_data[i:i+1, :, :]
predict(m, input1_test)
word_correct('hellp')
"""
---------------------------------------------- [12] ----------------------------------------------
[11] input_word: applicat5on, pred_word : application
---------------------------------------------- [13] ----------------------------------------------
[11] input_word: requrements, pred_word : requrements
---------------------------------------------- [14] ----------------------------------------------
[11] input_word: advertisrng, pred_word : advertisrng
---------------------------------------------- [15] ----------------------------------------------
[12] input_word: construction, pred_word : construction
---------------------------------------------- [16] ----------------------------------------------
[11] input_word: engineer5ng, pred_word : engineering
"""
标签:input2,input1,word,idx,pred,神器,input,LSTM,纠错 来源: https://blog.csdn.net/Scc_hy/article/details/120794296