其他分享
首页 > 其他分享> > R2D2:基于可微分树的预训练模型

R2D2:基于可微分树的预训练模型

作者:互联网

https://arxiv.org/abs/2107.00967
在一次分享中看到这篇论文,感觉有意思细读了一下
主要是讲基于可微分树的递归transformer来实现具有强解释性的层次预训练语言模型

论文主要章节涉及了三个方面

相关背景知识

  1. 基于CKY算法的语法分析 介绍 博客
乔姆斯基范式(CNF,Chomsky Normal Form)

任何语法都可以转化成一个弱等价的CNF形式,CNF语法都是二分叉
在这里插入图片描述

CYK算法

CYK算法(也称为Cocke–Younger–Kasami算法)是一种用来对 上下文无关文法(CFG,Context Free Grammar)进行语法分析(parsing)的算法。该算法最早由John Cocke, Daniel Younger and Tadao Kasami分别独立提出,其中John Cocke还是1987年度的图灵奖得主。CYK算法是基于动态规划思想设计的一种自底向上语法分析算法。
看过最易懂的博文
代码实现
2. Gumbel-Softmax estimation
在自底向上的计算过程中,每个格子会有多种组合方式,在各种组合方式中,选择概率最大的组合,即argmax函数。但是argmax函数是不可导的,没有办法反向传播。
通过reparameterization对logits的输出拟合为onehot,同时保证梯度可以反向传播
对离散变量再参数化
4. 基于大语料的预训练语言模型的大概套路

模型结构设计
Differentiable Tree

数据结构图
该论文定义了一个类似于CKY形式的可微二叉树解析器
句子 S={s1,s2,s3,…sn}
如上图,每一个格子 T ( i , j ) = < e i , j , p i , j , p ~ i , j > \Tau(i,j)=<e_{i,j},p_{i,j},\tilde{p}_{i,j}> T(i,j)=<ei,j​,pi,j​,p~​i,j​>
e i , j e_{i,j} ei,j​ 是向量表征
p i , j p_{i,j} pi,j​ 是每一个步所有组合的概率
p ~ i , j \tilde{p}_{i,j} p~​i,j​是在[ s i s_i si​, s j s_j sj​]的子树的概率
树的末端节点是 T i , i \Tau_{i,i} Ti,i​, e i , i e_{i,i} ei,i​以当前输入 s i s_i si​的向量初始化, p i , j p_{i,j} pi,j​ 和 p ~ i , j \tilde{p}_{i,j} p~​i,j​初始化为1。
在这里插入图片描述

上述公式的k是指( s i s_i si​, s j − 1 s_{j-1} sj−1​)之间的某一分割点(分割点不同,会对应出不同的组合)
第一个公式
f ( . ) f(.) f(.)是我们下一节Recursive Transformer定义的函数, p i , j k p_{i,j}^k pi,jk​ 和 p ~ i , j k \tilde{p}_{i,j}^k p~​i,jk​分别指一步中组合的概率和其子树的概率
第二个公式
以K为分割点的子树的概率,是当前组合的概率和左右子树概率的乘积,这个和CKY算法是一致的
第三个公式
这里放一个链接 Straight Through Gumbel-Softmax ,通过一定方式实现argmax函数的可微??
p i , j p_{i,j} pi,j​ 和 p ~ i , j \tilde{p}_{i,j} p~​i,j​是基于所有分割点得到的 p i , j k p_{i,j}^k pi,jk​ 和 p ~ i , j k \tilde{p}_{i,j}^k p~​i,jk​的组合
output: 计算得出权重
第四个公式
通过当前组合与权重系数的乘积计算出 e i , i e_{i,i} ei,i​
第五个公式
通过概率向量与权重系数的乘积计算出新的概率向量

Recursive Transformer

Recursive Transformer-based encoder
这个图对应了上一节第一个公式。
中间shape的转换过程看图,不想转述了,最终输出的 p i , j p_{i,j} pi,j​是 R 1 R^1 R1, c i , j k c_{i,j}^k ci,jk​是 R d R^d Rd

Tree Recovery

通过Straight-Through Gumbel-Softmax在每一个cell选择最佳的分割点,Tree( T 1 , n \Tau_{1,n} T1,n​), 从树的根节点自顶向下递归操作,选择的最佳分割点还原树的结构,类似于CKY算法最后的回溯过程

Complexity Optimization 复杂度优化

上述的 f ( . ) f(.) f(.)是整个模型的核心计算部分,我们可以通过树的剪枝归并算法来实现对 f ( . ) f(.) f(.)O( n 3 n^3 n3)
复杂度到线性复杂度

算法

在这里插入图片描述

寻找最佳的合并点

在这里插入图片描述

example

在这里插入图片描述
这张图展示了长度为6的句子的处理过程。
m表示设定的剪枝的阈值 T \Tau T 是一个二维数组,用来盛放自底向上计算的所有cell。
上上述图示的三个function:
TREEINDUCTION 是前向计算的过程,调用PRUNING进行剪枝,PRUNING调用FIND寻找最佳消并点。
计算m之下的cell,如上图(b)显示。
当cell的row大于等于m时,还原所有以第m行的节点为root节点的子树,调用PRUNING进行剪枝操作,
剪枝的第一步是找到局部最佳的merge点(上图c),剪掉部分的cell(上图d),返回一个新的 T \Tau T(上图e)
在FIND中,最佳分割点的候选集合需要满足两个条件
(1)在 T \Tau T的第二行
(2)在以第m行的节点为root节点的子树中有被使用到
然后在候选集合中选择(x.p *pl *pr)最高的cell T i , j \Tau_{i,j} Ti,j​做为最佳merge点,对应的将 T i , ∗ \Tau_{i,*} Ti,∗​和 T ∗ , j \Tau_{*,j} T∗,j​剪掉,得到 T 3 \Tau^3 T3

实验

预训练目标:

  1. 学习词汇表征,在实际实验中是对于word piece的表征,选择WikiText-2数据集,长度在128以内的句子,mask词汇,输入左子树和右子树的embedding进行词汇预测
    因为剪枝操作,存在左子树或者右子树为空,以临近的最长子树来替代
    在这里插入图片描述

  2. 无监督成分句法分析
    在 WSJ and CTB 测试集计算F1
    在这里插入图片描述

基于word-piece的word、NP等的召回
在这里插入图片描述

标签:Tau,剪枝,模型,微分,cell,算法,R2D2,pi,复杂度
来源: https://blog.csdn.net/qq_27965129/article/details/120953040