其他分享
首页 > 其他分享> > 模型压缩 -- 知识蒸馏

模型压缩 -- 知识蒸馏

作者:互联网

轻量化网络已经是一个热点,主要的技术路线如下:

追求模型在参数量、计算量、内存访问量等方面耗时、能耗尽可能的减小,以便在边缘设备上进行部署


本节来了解下蒸馏的开山之作:知识蒸馏是由Hinton老爷子在2015年首次在《Distilling the Knowledge in a Neural Network》论文中提出的,论文地址:https://arxiv.org/pdf/1503.0253

知识蒸馏就是把一个大模型或者多个模型ensemble学到的知识迁移到另一个轻量级的单模型上,最主要的目的是为了方便线上部署。知识蒸馏主要有两个方面:第一个是将大而深的模型迁移到一个轻量级的小模型上。这就像我们线上把大而深的BERT模型学到的知识迁移到轻量级的TextCNN小模型上;另一个就是将多个模型ensemble学到的知识迁移到单个轻量级的模型。

问题:

  1. 知识蒸馏要解决什么问题
  2. 什么是知识蒸馏?它具有什么优势?和其他网络的区别?
  3. 具体的创新点在哪里
  4. 效果如何

现在来回到上面这几个问题:

1、知识蒸馏要解决什么问题?或者为什么要提出知识蒸馏?

简单来说,它就是为了实现更加轻量化的网络而提出的。

在训练过程中,我们需要使用复杂的模型,大量的计算资源,以便从非常大、高度冗余的数据集中提取出信息。在实验中,效果最好的模型往往规模很大,甚至由多个模型集成得到。而大模型不方便部署到服务中去,常见的瓶颈如下:
• 推断速度慢
• 对部署资源要求高(内存,显存等) • 在部署时,我们对延迟以及计算资源都有着严格的限制。

因此,模型压缩(在保证性能的前提下减少模型的参数量)成为了一个重要的问题。而”模型蒸馏“属于模型压缩的一种方法

一般地,大模型往往都是单个复杂网络或者是若干网络的集合,具有良好的性能和泛化能力;而小模型因为网络规模较小,表达能力有限。因此,可以利用大模型学习到的知识去指导小模型训练,使得小模型具有与大模型相当的性能,但是参数数量大幅降低,从而实现模型压缩与加速,这就是知识蒸馏与迁移学习在模型优化中的应用。

举个例子:
BERT这一类模型优点在于效果好,但是如果用于线上推理就比较麻烦了,因为基础版本的BERT模型接近330M包含一亿的参数,你想让一个一亿参数的模型完成线上10ms内的线上推理基本有点不现实。而传统的文本分类算法比如TextCNN可以轻松满足线上推理的需求,但是效果相比BERT还是有点不如人意。知识蒸馏通俗的理解就是BERT当老师,TextCNN当学生,让BERT这个老师把学到的知识传授给TextCNN这个学生,这样就能让TextCNN达到和BERT媲美的效果,最后我们线上去部署TextCNN,就能做到模型效果和线上推理速度兼得。这就是知识蒸馏的作用。

2、什么是知识蒸馏?它具有什么优势?它为什么有效?

知识蒸馏本质上是一种映射关系,将老师学到的东西传递给学生网络

Hinton等人在《Distilling the Knowledge in a Neural Network》一文中提出的知识蒸馏概念,其核心思想是先训练一个复杂网络模型,然后使用这个复杂网络的输出和数据的真实标签去训练一个更小的网络,因此知识蒸馏框架通常包含了一个复杂模型(被称为Teacher模型)和一个小模型(被称为Student模型)。
image

2.1 它具有什么优势?

一个好的模型最重要的是通过训练数据获得一定的泛化能力,不仅仅是拟合训练数据,最重要的是在新数据集上能有一定的泛化识别能力。而知识蒸馏的目的是让学生去学习老师的这种泛化能力,所以从理论上来说学生比老师单纯的去拟合训练数据能获得更多的知识。

(1)提升模型精度

如果对目前的网络模型A的精度不是很满意,那么可以先训练一个更高精度的teacher模型B(通常参数量更多,时延更大),然后用这个训练好的teacher模型B对student模型A进行知识蒸馏,得到一个更高精度的A模型。

(2)降低模型时延,压缩网络参数

如果对目前的网络模型A的时延不满意,可以先找到一个时延更低,参数量更小的模型B,通常来讲,这种模型精度也会比较低,然后通过训练一个更高精度的teacher模型C来对这个参数量小的模型B进行知识蒸馏,使得该模型B的精度接近最原始的模型A,从而达到降低时延的目的。

(3)标签之间的域迁移

假如使用狗和猫的数据集训练了一个teacher模型A,使用香蕉和苹果训练了一个teacher模型B,那么就可以用这两个模型同时蒸馏出一个可以识别狗、猫、香蕉以及苹果的模型,将两个不同域的数据集进行集成和迁移。

2.2 为什么有效

image
1) 对于复杂的模型,理论搜索空间要大于较小网络的搜索空间,但如果使用较小的网络可以实现相同(甚至相似)的收敛,则教师网络的收敛空间应与学生网络的解空间重叠,如上图。

但仅此一项并不能保证学生网络在同一位置收敛,学生网络的收敛可能与教师网络的收敛大不相同。但如果指导学生网络复制教师网络的行为(教师网络已经在更大的解空间中进行搜索了),则可以预期其收敛空间与原始教师网络收敛空间重叠。

2) softmax层的输出,除了正例之外,负标签也带有大量的信息,比如某些负标签对应的概率远远大于其他负标签(宝马,兔子和垃圾车)。而在传统的训练过程(hard target)中,所有负标签都被统一对待。也就是说,KD的训练方式使得每个样本给Net-S带来的信息量大于传统的训练方式。

3、创新体的创新点在哪里

3.0 Hard-target和Soft-target

在Hinton这篇论文中,Hinton将问题限定在分类问题,分类问题的共同点是模型最后会有一个softmax层,其输出值对应了相应类别的概率值。在知识蒸馏时,由于我们已经有了一个泛化能力较强的Teacher模型,我们在利用Teacher模型来蒸馏训练Student模型时,可以直接让Student模型去学习Teacher模型的泛化能力。一个很直白且高效的迁移泛化能力的方法就是:使用softmax层输出的类别的概率来作为“Soft-target” 。

传统的神经网络训练方法是定义一个损失函数,目标是使预测值尽可能接近于真实值(Hard- target),损失函数就是使神经网络的损失值和尽可能小。这种训练过程是对ground truth求极大似然。在知识蒸馏中,是使用大模型的类别概率作为Soft-target的训练过程。
image

知识蒸馏用Teacher模型预测的 Soft-target 来辅助 Hard-target 训练 Student模型的方式为什么有效呢?

softmax层的输出,除了正例之外,负标签也带有Teacher模型归纳推理的大量信息,比如某些负标签对应的概率远远大于其他负标签,则代表 Teacher模型在推理时认为该样本与该负标签有一定的相似性。而在传统的训练过程(Hard-target)中,所有负标签都被统一对待。也就是说,知识蒸馏的训练方式使得每个样本给Student模型带来的信息量大于传统的训练方式。

如在MNIST数据集中做手写体数字识别任务,假设某个输入的“2”更加形似"3",softmax的输出值中"3"对应的概率会比其他负标签类别高;而另一个"2"更加形似"7",则这个样本分配给"7"对应的概率会比其他负标签类别高。这两个"2"对应的Hard-target的值是相同的,但是它们的Soft-target却是不同的,由此我们可见Soft-target蕴含着比Hard-target更多的信息。

在使用 Soft-target 训练时,Student模型可以很快学习到 Teacher模型的推理过程;而传统的 Hard-target 的训练方式,所有的负标签都会被平等对待。因此,Soft-target 给 Student模型带来的信息量要大于 Hard-target,并且Soft-target分布的熵相对高时,其Soft-target蕴含的知识就更丰富。同时,使用 Soft-target 训练时,梯度的方差会更小,训练时可以使用更大的学习率,所需要的样本也更少。这也解释了为什么通过蒸馏的方法训练出的Student模型相比使用完全相同的模型结构和训练数据只使用Hard-target的训练方法得到的模型,拥有更好的泛化能力。

3.1 带“温度”的softmax

将要训练的小模型称为新模型,将以及训练的大模型称为原模型。

我们的目标是让新模型与原模型的softmax输出的分布充分接近。直接这样做是有问题的:在一般的softmax函数中,自然指数image先拉大logits之间的差距,然后作归一化,最终得到的分布是一个arg max的近似 ,其输出是一个接近one-hot的向量,其中一个值很大,其他的都很小。这种情况下,前面说到的「可能是垃圾车,但绝不是萝卜」这种知识的体现是非常有限的。相较类似one-hot这样的硬性输出,我们更希望输出更「软」一些。一种方法是直接比较logits来避免这个问题。具体地,对于每一条数据,记原模型产生的某个logits是image,新模型产生的logits是 image,我们需要最小化
image

该《Distilling the Knowledge in a Neural Network》论文中提出了带有“温度”的softmax,作为一种广义是softmax函数。
原始的softmax函数:
image
但要是直接使用softmax层的输出值作为soft target, 这又会带来一个问题: 当softmax输出的概率分布熵相对较小时,负标签的值都很接近0,对损失函数的贡献非常小,小到可以忽略不计。因此"温度"这个变量就派上了用场
image
其中 T 是温度,这是从统计力学中的玻尔兹曼分布中借用的概念。容易证明,当温度 T 趋向于0时,softmax输出将收敛为一个one-hot向量;温度 T 趋向于无穷时,softmax的输出则更「软」。因此,在训练新模型的时候,可以使用较高的 T 使得softmax产生的分布足够软,这时让新模型(同样温度下)的softmax输出近似原模型;在训练结束以后再使用正常的温度 T = 1 来预测。具体地,在训练时我们需要最小化两个分布的交叉熵(Cross-entropy),记新模型利用公式 (2) 产生的分布是 q ,原模型产生的分布是 p ,则我们需要最小化
image

在化学中,蒸馏是一个有效的分离沸点不同的组分的方法,大致步骤是先升温使低沸点的组分汽化,然后降温冷凝,达到分离出目标物质的目的。在前面提到的这个过程中,我们先让温度 T 升高,然后在测试阶段恢复「低温」,从而将原模型中的知识提取出来,因此将其称为是蒸馏,实在是妙。

当然,如果转移时使用的是有标签的数据,那么也可以将标签与新模型softmax分布的交叉熵加入到损失函数中去。这里需要将式 (3) 乘上一个\(T^2\),这是为了让损失函数的两项的梯度大致在一个数量级上(参考公式 (9) ),实验表明这将大大改善新模型的表现(考虑到加入了更多的监督信号)。

与直接优化logits差异相比
由公式(2)(3),对于交叉熵损失来说,其对于新模型的某个logit 的梯度是 \(z_i\)的梯度是
image
由于\(e^2 - 2\)与 \(x\)是等价无穷小(\(x -> 0\) 时),易知,当 \(T\) 充分大时,有
image

假设所有logits对每个样本都是零均值化的,即\(\sum_{j}z_j=\sum_j v_j = 0\),则有
image

PS:
温度:
原来的softmax函数是T = 1的特例。 T越高,softmax的output probability distribution越趋于平滑,其分布的熵越大,负标签携带的信息会被相对地放大,模型训练将更加关注负标签
image

3.2 通用的知识蒸馏方法

知识蒸馏训练的具体方法如下图所示,主要包括以下几个步骤:

image

知识蒸馏示意图(来自:https://nervanasystems.github.io/distiller/knowledge_distillation.html)

Net-T和Net-S同时输入transfer set(这里可以这里可以直接复用训练Net-T用到的training set),用Net-T产生的softmax distribution(with high temperature)来作为soft target,Net-S在相同温度T条件下的softmax输出和soft target的cross entropy就是Loss函数的一部分$$L_{soft}$$

Net-S在T=1的条件下的softmax输出和ground truth的cross entropy就是Loss函数的第二部分$$L_{hard}$$

第二部分Loss必要性其实很好理解:Net-T也有一定的错误率,使用round truth可以有效降低错误被传播给Net-S的可能。打个比方,老师虽然学识远远超过学生,但是他仍然有出错的可能,而这时候如果学生在老师的教授之外,可以同时参考到标准答案,

【注意】
在Net-S训练完毕后,做inference时其Softmax的温度T要恢复到1

image

image

image

最后,\(\alpha\)和\(\beta\) 是关于\({L_{soft}}\) 和 \({L_{hard}}\)的权重,实验发现,当\({L_{hard}}\) 权重较小时,能产生最好的效果,这是一个经验性的结论。直接给出结论:由于\({L_{soft}}\) 贡献的梯度大约为\({L_{hard}}\)的\(\frac{1}{T^2}\),因此在同时使用Soft-target和Hard-target的时候,需要在\({L_{soft}}\)的权重上乘以\({T^2}\)的系数,这样才能保证Soft-target和Hard-target贡献的梯度量基本一致。

4、效果如何

Hinton等人做了三组实验,其中两组都验证了知识蒸馏方法的有效性。在MNIST数据集上的实验表明,即便有部分类别的样本缺失,新模型也可以表现得很不错,只需要修改相应的偏置项,就可以与原模型表现相当。在语音任务的实验也表明,蒸馏得到的模型比从头训练的模型捕捉了更多数据集中的有效信息,表现仅比集成模型低了0.3个百分点。总体来说知识蒸馏是一个简单而有效的模型压缩/训练方法。这大体上是因为原模型的softmax提供了比one-hot标签更多的监督信号

参考:
https://www.bilibili.com/video/BV1gS4y1k7vj/?spm_id_from=333.788
https://zhuanlan.zhihu.com/p/90049906

标签:target,训练,--,压缩,知识,softmax,模型,蒸馏
来源: https://www.cnblogs.com/whiteBear/p/16344812.html