其他分享
首页 > 其他分享> > TinyBERT

TinyBERT

作者:互联网

TinyBERT

TinyBERT: Distilling BERT for Natural Language Understanding

Transformer蒸馏

img

Embedding-layer Distillation

L e m b d = M S E ( E S W e , E T ) E S ∈ R l × d 0 , E T ∈ R l × d l : s e q u e n c e l e n g t h d 0 : s t u d e n t e m b e d d i n g 维 度 d : t e a c h e r e m b e d d i n g 维 度 W e : d 0 × d 可 训 练 的 线 性 变 换 矩 阵 \mathcal{L}_{embd}=MSE(E^{S}W_e,E_T)\\ E^S\in R^{l \times d_0},E^T \in R^{l\times d}\\ l:sequence \quad length\\ d0:student\quad embedding维度\\ d:teacher\quad embedding维度\\ W_e:d_0\times d可训练的线性变换矩阵 Lembd​=MSE(ESWe​,ET​)ES∈Rl×d0​,ET∈Rl×dl:sequencelengthd0:studentembedding维度d:teacherembedding维度We​:d0​×d可训练的线性变换矩阵

Transformer-layer Distillation

Attention based loss

L a t t n = 1 h ∑ i = 1 h M S E ( A i S , A i T ) A i ∈ R l × l h : a t t e n t i o n 的 头 数 l : 输 入 长 度 A i S : s t u d e n t 网 络 第 i 个 a t t e n t i o n 头 的 a t t e n t i o n s c o r e 矩 阵 A i T : t e a c h e r 网 络 第 i 个 a t t e n t i o n 头 的 a t t e n t i o n s c o r e 矩 阵 \mathcal{L}_{attn}=\frac{1}{ h}\sum_{i=1}^h MSE(A_i^S,A_i^T)\\ A_i\in R^{l\times l}\\ h:attention的头数\\ l:输入长度\\ A_i^S:student网络第i个attention头的attention\quad score矩阵\\ A_i^T:teacher网络第i个attention头的attention\quad score矩阵 Lattn​=h1​i=1∑h​MSE(AiS​,AiT​)Ai​∈Rl×lh:attention的头数l:输入长度AiS​:student网络第i个attention头的attentionscore矩阵AiT​:teacher网络第i个attention头的attentionscore矩阵

hidden states based distillation

L h i d n = M S E ( H S W h , H T ) H S ∈ R l × d 0 H T ∈ R l × d H S : s t u d e n t t r a n s f o r m e r 的 隐 藏 层 输 出 H T : t e a c h e r t r a n s f o r m e r 的 隐 藏 层 输 出 W h : 投 射 矩 阵 \mathcal{L}_{hidn}=MSE(H^SW_h,H^T)\\ H^S\in R^{l\times d_0}\quad H^T\in R^{l\times d}\\ H^S:student\quad transformer的隐藏层输出\\ H^T:teacher\quad transformer的隐藏层输出\\ W_h:投射矩阵 Lhidn​=MSE(HSWh​,HT)HS∈Rl×d0​HT∈Rl×dHS:studenttransformer的隐藏层输出HT:teachertransformer的隐藏层输出Wh​:投射矩阵

Prediction-layer DIstillation

计算 teacher 输出的概率分布和 student 输出的概率分布的 softmax 交叉熵,这里用来模拟teacher在predict层的表现
L p r e d = − s o f t m a x ( z T ) ⋅ l o g _ s o f t m a x ( z S / t ) \mathcal{L}_{pred}=-softmax(z^T)\cdot log\_softmax(z^S/t) Lpred​=−softmax(zT)⋅log_softmax(zS/t)

loss总结

L m o d e l = ∑ m = 0 M + 1 λ m L l a y e r ( S m , T g ( m ) ) L l a y e r ( S m , T g ( m ) ) = { L e m b d ( S 0 , T 0 ) m = 0 L h i d n ( S m , T g ( m ) ) + L a t t n ( S m , T g ( m ) ) M ≥ m > 0 L p r e d ( S M + 1 , T N + 1 ) m = M + 1 \mathcal{L}_{model}=\sum_{m=0}^{M+1}\lambda_m \mathcal{L}_{layer}(S_m,T_{g(m)})\\ \mathcal{L}_{layer}(S_m,T_{g(m)})= \left\{ \begin{aligned} \mathcal{L}_{embd}(S_0,T_0)&& m=0\\ \mathcal{L}_{hidn}(S_m,T_{g(m)})+\mathcal{L}_{attn}(S_m,T_{g(m)}) && M\geq m > 0\\ \mathcal{L}_{pred}(S_{M+1},T_{N+1}) && m=M+1 \end{aligned} \right. Lmodel​=m=0∑M+1​λm​Llayer​(Sm​,Tg(m)​)Llayer​(Sm​,Tg(m)​)=⎩⎪⎨⎪⎧​Lembd​(S0​,T0​)Lhidn​(Sm​,Tg(m)​)+Lattn​(Sm​,Tg(m)​)Lpred​(SM+1​,TN+1​)​​m=0M≥m>0m=M+1​

two-step

效果

img

虽然效果略微下降,但影响不大

推理速度有了明显的提升

标签:attention,矩阵,student,quad,mathcal,teacher,TinyBERT
来源: https://blog.csdn.net/doyouseeman/article/details/113833217