RNN模型中输入的重要性的评估
作者:互联网
I. Saliency Maps for RNN
RNN是很多序列任务的不二法门,比如文本分类任务的常用方法就是“词向量+LSTM+全连接分类器”。如下图
假如这样的一个模型可以良好地工作,那么现在考虑一个任务是:如何衡量输入 w1,…,wn
对最终的分类结果的影响的重要程度(Saliency)呢?例如假设这是一个情感分类任务,那么怎么找出是哪些词对最终的分类有较为重要的影响呢?本文给出了一个较为直接的思路。
思路的原理很简单,因为我们是将RNN最后一步的状态向量(也就是绿色阴影所代表的向量)传递给后面的分类器进行分类的,因此最后一步的状态向量 hn
就是一个目标向量。而RNN是一个递推的过程, h0,h1,…,hn−1 (一般 h0 就是全零初始化)是逐步逼近 hn的过程。
所以我们可以依次考虑中间向量到目标向量的距离
而由 hi 到 hi+1 ,是因为多考虑了词 wi+1 ,造成的后果是: 本来 hi 与目标向量的距离是 ∥hn−hi∥ ,现在距离变成了 ∥hn−hi+1∥ ,所以我们可以用差值
∥hn−hi∥−∥hn−hi+1∥
来衡量词 wi+1 对最终分类所造成的影响。它可以是正的,意味着 wi+1 的引入缩小了与目标的距离,因此它对完成正确分类有着促进作用;反之它也可以是负的,代表着对分类有反作用;而它的值越大,就表示作用的程度越大,所以可以用这个指标来降序排列,得到各个词的重要程度了。当然,也可以通过除以目标向量的范数,来排除量纲的影响:
∥hn−hi∥∥hn∥−∥hn−hi+1∥∥hn∥
II. 简单的实验
除了理论说通,还要有实验才有说服力。这里还是以文本情感分类为例,下面的代码修改自文章《文本情感分类(三):分词 OR 不分词》。
#! -*- coding:utf-8 -*-
#实验环境:tensorflow 1.2 + Keras 2.0.6
import numpy as np
import pandas as pd
import jieba
pos = pd.read_excel('pos.xls', header=None)
neg = pd.read_excel('neg.xls', header=None)
pos['words'] = pos[0].apply(jieba.lcut)
neg['words'] = neg[0].apply(jieba.lcut)
words = {}
for l in pos['words'].append(neg['words']): #统计得到词表
for w in l:
if w in words:
words[w] += 1
else:
words[w] = 1
min_count = 10 #词频低于min_count的舍弃
maxlen = 100 #句子截断为100字
words = {i:j for i,j in words.items() if j >= min_count}
id2word = {i+1:j for i,j in enumerate(words)} #id映射到词,未登录词全部用0表示
word2id = {j:i for i,j in id2word.items()} #词映射到id
def doc2num(s):
s = [word2id.get(i,0) for i in s[:maxlen]]
return s + [0]*(maxlen-len(s))
pos['id'] = pos['words'].apply(doc2num)
neg['id'] = neg['words'].apply(doc2num)
x = np.vstack([np.array(list(pos['id'])), np.array(list(neg['id']))])
y = np.array([[1]]*len(pos)+[[0]]*len(neg))
#手动打乱数据
idx = range(len(x))
np.random.shuffle(idx)
x = x[idx]
y = y[idx]
from keras.models import Model
from keras.layers import Input, Dense, Dropout, Embedding, Lambda
from keras.layers import LSTM
from keras import backend as K
#建立模型
input = Input(shape=(None,))
input_vecs = Embedding(len(words)+1, 128, mask_zero=True)(input) #用了mask_zero,填充部分自动为0
lstm = LSTM(128, return_sequences=True, return_state=True)(input_vecs) #返回一个list
lstm_state = Lambda(lambda x: x[1])(lstm) #list的第二个元素就是lstm最后的状态
dropout = Dropout(0.5)(lstm_state)
predict = Dense(1, activation='sigmoid')(dropout)
#list的第一个元素就是lstm的状态向量序列,先补充一个0向量(h_0),然后与
lstm_sequence = Lambda(lambda x: K.concatenate([K.zeros_like(x[0])[:,:1], x[0]], 1))(lstm)
lstm_dist = Lambda(lambda x: K.sqrt(K.sum((x[0]-K.expand_dims(x[1], 1))**2, 2)/K.sum(x[1]**2,1,keepdims=True)))([lstm_sequence, lstm_state])
model = Model(inputs=input, outputs=predict) #文本情感分类模型
model.compile(loss='binary_crossentropy',
optimizer='adam',
metrics=['accuracy'])
model_dist = Model(inputs=input, outputs=lstm_dist) #计算权重的模型
model_dist.compile(loss='mse',
optimizer='adam')
batch_size = 128
train_num = 15000
model.fit(x[:train_num], y[:train_num], batch_size = batch_size, epochs=5, validation_data=(x[train_num:],y[train_num:]))
import uniout #这个库使得python交互界面中可以直接现实unicode字符而非原始编码
def saliency(s): #简单的按saliency排序输出的函数
ws = jieba.lcut(s)[:maxlen]
x_ = np.array([[word2id.get(w,0) for w in ws]])
score = np.diff(model_dist.predict(x_)[0])
idxs = score.argsort()
return [(i,ws[i],-score[i]) for i in idxs] #输出结果为:(词位置、词、词权重)
III. 一些结果
5个epoch后,模型的valid准确率为90%上下。然后开始测试我们的saliency函数:
>>> s = u'发货很快,物流也及时,热水器包装很好,已经打电话找师傅安装好了,用了加热挺快的,非常好的热水器'
>>> saliency(s)
[(1, u'很快', 0.27192765), (23, u'挺快', 0.24280858), (5, u'及时', 0.212547), (17, u'好', 0.20004171), (10, u'好', 0.18096809), (6, u',', 0.091015637), (27, u'好', 0.089454532), (29, u'热水器', 0.075981312), (9, u'很', 0.073918283), (19, u',', 0.056054845), (2, u',', 0.054350078), (25, u',', 0.05254671), (16, u'安装', 0.047204345), (4, u'也', 0.032162607), (20, u'用', 0.032066017), (7, u'热水器', 0.020733953), (28, u'的', 0.015783943), (11, u',', 0.0068980753), (24, u'的', 0.0040297061), (21, u'了', -0.0028512329), (8, u'包装', -0.0061192214), (12, u'已经', -0.0071464926), (26, u'非常', -0.016652927), (3, u'物流', -0.034254551), (18, u'了', -0.041289926), (13, u'打电话', -0.067411274), (14, u'找', -0.076897413), (0, u'发货', -0.12187159), (15, u'师傅', -0.13093886), (22, u'加热', -0.25505972)]>>> s = u'用过了才来评价,挺好用的和我在商场买的一样,应该是正品,烧水也快。五分。'
>>> saliency(s)
[(13, u'商场', 0.18483591), (20, u'正品', 0.18375245), (7, u'好', 0.10582721), (6, u'挺', 0.091109157), (27, u'。', 0.078018099), (17, u',', 0.068635911), (21, u',', 0.060936138), (8, u'用', 0.05191648), (16, u'一样', 0.045136154), (18, u'应该', 0.038288802), (25, u'。', 0.036342919), (12, u'在', 0.033013284), (26, u'五分', 0.032422632), (23, u'也', 0.030365154), (15, u'的', 0.028619051), (10, u'和', 0.026087582), (9, u'的', 0.02335906), (22, u'烧水', 0.021587744), (5, u',', 0.020465493), (11, u'我', 0.012083113), (3, u'来', 0.011826038), (1, u'了', 0.010910749), (24, u'快', 0.0078535229), (4, u'评价', 0.0051870346), (19, u'是', -0.0072228611), (14, u'买', -0.045606673), (0, u'用过', -0.048195302), (2, u'才', -0.10755491)]>>> s = u'新购入好吃的噻...左边的是光明酸奶保质期很贴心的是150天哦..盒子也很心水.味道超级棒哦...当然价钱也贵一点...'
>>> saliency(s)
[(9, u'光明', 0.23620945), (2, u'好吃', 0.20977235), (0, u'新', 0.098622561), (19, u'..', 0.082070053), (24, u'.', 0.051661924), (10, u'酸奶', 0.051210225), (26, u'超级', 0.050095648), (33, u'贵', 0.045776516), (34, u'一点', 0.043491587), (5, u'...', 0.042084634), (28, u'哦', 0.041375414), (12, u'很', 0.039076477), (17, u'天', 0.036537394), (3, u'的', 0.035894275), (13, u'贴心', 0.035299033), (14, u'的', 0.034476012), (32, u'也', 0.034225687), (35, u'...', 0.03399314), (29, u'...', 0.030394047), (31, u'价钱', 0.022595212), (22, u'很', 0.017542139), (15, u'是', 0.017276704), (30, u'当然', 0.016185507), (11, u'保质期', 0.015443504), (7, u'的', 0.013372183), (16, u'150', 0.0039001554), (4, u'噻', -0.0), (23, u'心水', -0.0), (1, u'购入', -0.0), (8, u'是', -0.0022000074), (21, u'也', -0.0026017427), (20, u'盒 子', -0.019797415), (6, u'左边', -0.026007473), (25, u'味道', -0.088489026), (18, u'哦', -0.092244744), (27, u'棒', -0.10724148)]>>> s = u'安装我自己花了500多,美的够黑心的,真的是烦心,安装的售后叼的要死!差评!!!!!'
>>> saliency(s)
[(10, u'黑心', 0.42798662), (22, u'要死', 0.27835941), (3, u'花', 0.19574976), (23, u'!', 0.1150609), (25, u'!', 0.099666387), (26, u'!', 0.086483672), (27, u'!', 0.075068578), (28, u'!', 0.065046221), (5, u'500', 0.064557791), (29, u'!', 0.056182593), (19, u'售后', 0.053749442), (8, u'美的', 0.033454597), (4, u'了', 0.011106014), (24, u'差评', -0.0), (15, u'烦心', -0.0), (20, u'叼', -0.0), (21, u' 的', -0.0019193888), (12, u',', -0.007355392), (14, u'是', -0.0092425346), (11, u'的', -0.010129571), (13, u'真的', -0.018997073), (18, u'的', -0.023140371), (16, u',', -0.025478244), (6, u'多', -0.028243661), (1, u'我', -0.028675437), (7, u',', -0.032953143), (2, u'自己', -0.040592432), (0, u'安装', -0.086136341), (17, u'安装', -0.11075348), (9, u'够', -0.1388548)]>>> s = u'作者的文笔一般,观点也是和市面上的同类书大同小异,不推荐读者购买。'
>>> saliency(s)
[(12, u'书', 0.21486527), (3, u'一般', 0.21002132), (2, u'文笔', 0.1983965), (15, u'不', 0.13881186), (5, u'观点', 0.10552621), (19, u'。', 0.10454651), (17, u'读者', 0.1005006), (11, u'同类', 0.067096055), (16, u'推荐', 0.034221202), (8, u'和', 0.0024021864), (4, u',', 0.0023730397), (9, u'市面上', 6.4194202e-05), (13, u'大同小异', -0.0), (0, u'作者', -0.0029646158), (18, u'购买', -0.0097944811), (7, u'是', -0.018220723), (14, u',', -0.029574722), (1, u'的', -0.031183362), (6, u'也', -0.03150624), (10, u'的', -0.055580795)]>>> s = u'总的来说有点乱七八糟的感觉。重复又重复。'
>>> saliency(s)
[(2, u'乱七八糟', 0.48706663), (1, u'有点', 0.25306857), (5, u'。', 0.22749269), (9, u'。', 0.1819987), (4, u'感觉', 0.091413289), (6, u'重复', 0.042275667), (3, u'的', 0.035568655), (8, u'重复', 0.035231754), (7, u'又', 0.0035995394), (0, u'总的来说', -0.35771555)]>>> s = u'太离谱了,现在什么年代了,都不能下载铃声!我几乎找遍了都不支持!有时还会突然、死机,真是后悔了!'
>>> saliency(s)
[(19, u'不', 0.24356699), (1, u'离谱', 0.22871125), (11, u'下载', 0.20161489), (0, u'太', 0.14231563), (27, u'死机', 0.12861626), (10, u'不能', 0.1119701), (5, u'什么', 0.095636666), (13, u'!', 0.092601627), (21, u'!', 0.078016117), (32, u'!', 0.070392698), (6, u'年代', 0.048469543), (22, u'有时', 0.046845198), (30, u'后悔', 0.044225916), (31, u'了', 0.041907579), (25, u'突然', 0.024277031), (28, u',', 0.019716278), (24, u'会', 0.0099375397), (12, u'铃声', 0.002569586), (16, u'找遍', -0.0), (9, u'都', -0.0023496151), (2, u'了', -0.0070439577), (4, u'现在', -0.013223112), (7, u'了', -0.014739335), (18, u'都', -0.018258482), (3, u',', -0.020016789), (8, u',', -0.025391042), (17, u'了', -0.032811344), (29, u'真是', -0.053523436), (23, u'还', -0.072828561), (15, u'几乎', -0.073504835), (26, u'、', -0.081194341), (14, u'我', -0.089508265), (20, u'支持', -0.12699783)]
效果可见一斑,排在前面的词语基本上是情感倾向比较强烈的词语。值得指出的是,这种重要性的评估方案还会自动地考虑词语的位置所造成的影响,假如一个情感词在句子中重复出现,那么后出现的词语一般来说权重会更低(因为前面的已经能让我们完成分类了,后面的权重就下降了),比如
>>> s = u'很糟糕,没有比这更加糟糕的了,真是太糟糕了'
>>> saliency(s)
[(1, u'糟糕', 0.35864168), (3, u'没有', 0.20132971), (7, u'糟糕', 0.20045981), (11, u'真是太', 0.18850464), (13, u'了', 0.15443647), (6, u'更加', 0.087445587), (4, u'比', 0.086736709), (5, u'这', 0.033832848), (2, u',', -0.013917089), (8, u'的', -0.021302044), (9, u'了', -0.04099898), (12, u'糟糕', -0.051040836), (10, u',', -0.055929884), (0, u'很', -0.12819862)]>>> s = u'在快乐的地方做着快乐的事情,拥有快乐的心情'
>>> saliency(s)
[(1, u'快乐', 0.27486306), (6, u'快乐', 0.20637855), (10, u'拥有', 0.11172111), (11, u'快乐', 0.092637397), (12, u'的', 0.065387651), (9, u',', 0.062397212), (5, u'着', 0.059462607), (7, u'的', 0.053004891), (2, u'的', 0.04632777), (13, u'心情', 0.044671077), (0, u'在', 0.023381054), (4, u'做', 0.014107406), (8, u'事情', -0.014113277), (3, u'地方', -0.040226519)]
IV. 自我的评价
自我感觉这是评估RNN模型的输入重要性的一种简单明快的方案,而且不需要太多的数学知识,欢迎读者在更多的任务上试用。
标签:10,RNN,saliency,neg,重要性,hn,words,lstm,输入 来源: https://blog.51cto.com/u_14540820/2759449