GANs(生成对抗网络)浅析
作者:互联网
机器学习众多算法,经典如决策树,罗吉斯特回归等,以及近些年热火朝天的众多深度神经网络模型,其本质终究是分类器。2014年发表在人工智能顶会NIPS上的一篇文章《Generative Adversarial Networks》一举成名,在CV,NLP等众多领域都取得了突出性的成果。今天让我们一起来看下这一“神”算法究竟是如何工作的。
目录
1. 欢迎来“对抗”
世间万物,皆须以辩证的角度来细细思量。我们人学习的过程既是如此,从小我们学的便是哪些能做哪些不能做,好坏善恶需对比后方能体会到人间至真至善至美。要想让机器也能学习到数据世界的“真善美”,免不了也需要让它见见“世间险恶”,尤其是这种险恶越真实,它能学到的才更多。这时候小伙伴们可能要问了,这哪来的险恶呀,不就是一堆没有感情的数据吗! (震惊!都2021年了,不会还有人认为数据没有感情吧!)这些险恶在训练数据里面的表现方式当然就是:伪!造!数!据!
有了世间险恶,怎么从中学习呢?对比!俗话说的好,没有对比就没有伤害。(没有女朋友的你不吃几吨狗粮怎么会下定决心主动求偶呢!)我们的可爱机器需要在伪造数据和真实数据中机智的分辨出真假,才能变成一个合格的分类模型。所以,目前我们需要考虑的两个部分:1)生成假数据骗机器;2)训练识别器识别出假数据。(听起来是不是听残忍,一边骗你一边还要骂你没长脑子总被骗,真金都是要打磨出来的!)
1.1 生成假数据
现在有一个问题:生成的数据是看起来越假越好还是越真越好? 毫无疑问是越真越好,达到以假乱真的效果最好。
例:给一瓶可乐和一瓶开水让你辨别,你经过0.5s的简单思考回答出了有颜色的是可乐这能说你学会新知识了嘛;但如果给你一瓶可乐和一瓶酸梅汁,你可能需要靠闻它们的气味或者观察有没有气泡分辨哪个是可乐,这时候你学会了可乐和酸梅汁在气味和气泡这两个维度上有差别。
要想造出以假乱真的数据,就对生成器提出了很大的要求,所以生成器需要训练,那么训练总得有个目标呀,目标就是让生成器看不出来真假(即识别器有一半概率分辨不出来是假数据),换而言之,就是以降低识别器识别出假样本的准确率为生成器的优化目标。
生成器为了生成更真的样本,一般还需要将真样本输入给它学习,它对真样本加以噪声处理。
这一思想很像之前看过的赖账公司和讨债公司,两者互相博弈实现共赢。
1.2 鉴别假数据
对于识别器,给其输入的就是真假样本,其目标是为了更好的鉴别出真样本,这一点和我们之前接触过的普通分类器十分相似。
2. GAN的基本结构解析
2.1 基本结构
经过上面简单的介绍,下面来具体看下GAN的基本结构。以Minist手写数据集为例。
GAN包括一个生成器G和一个判别器D。如图所示:
- 生成器G接受一个随机噪声输入,输出一个图片样本,和真样本规模相等;
- 判别器D接受真样本和G生成的假样本,输出对该样本的判别结果,若判定为真样本则输出True,否则输出False。
2.2 训练方式
既然GAN由生成器和判别器两部分组成,且两部分训练目标不同,所以使用了交替训练的思想。基本流程如下:
- 初始化生成器参数 和判别器参数 ;
- 从先验噪声分布 中采样出 m 个样本 输入生成器生成样本,从真实样本分布 中采样 m 个样本 ;用生成样本和真实样本来更新判别器D的参数;
- 步骤 2 循环 k次;
- 从先验噪声分布 中采样出 m 个样本 ,更新生成器的参数;
- 重复上面2-4步直至收敛。
这里需要注意,对判别器循环k次后再更新生成器,原因是先拥有了一个足够可以的判别器之后,才有更新生成器的意义。就像只有受骗者对骗局已经能轻而易举的分辨出之后,骗子才会去发明新的骗术。
3. 相关理论推导
符号说明(任何不加符号说明的公式推导都是耍流氓):
- 判别器 , 表示判别器判定样本 x 为真实样本的概率为 p;
- 生成器 , 表示生成器根据输入 z 生成的样本;
- 表示噪声分布, 表示噪声 z 的概率; 表示真实样本分布, 表示样本 x 的概率;
- 表示求期望。
3.1 判别器
首先需要明确判别器实际上是一个二分类器,其优化目标就是为了其判别出真负样本的概率大,这里使用交叉熵函数作为其目标函数。
简单介绍下交叉熵以及其公式推导:
1. 为了搞清楚交叉熵先要引入相对熵的概念(又叫 KL散度)
对于两个分布 P(x) 和 Q(x),要用Q分布去贴近P分布,KL散度可以衡量Q分布到P分布的差异,计算公式:
2. 交叉熵
KL 散度可以改写成如下形式:
可以看到,拆成两部分后前面部分为分布 P 的熵,只要分布P不变,前面是不变的,后面一部分即为交叉熵,其衡量了 Q 与 P 的差异。
在分类任务中,可以看成是实际标签和预测值两个分布之间的差异。对分类任务交叉熵通常表示为:
, 其中 为真实标签, 为预测值
在此处,目标函数为:
由于样本和噪声都是采样得到的,用期望表示如下:
3.2 生成器
前面说了生成器的目的是为了和判别器“唱反调”,判别器的目标是最小化上述的 ,那么生成器的优化目标则是最大化 .
生成器目标是最小化上面 .当 D 固定时,第一项为常数, 于是loss变为:
3.3 最终目标
有了上面两个分析,得到最终的公式:
现在问题来了,提出了这种对抗的方法怎么能保证能使得生成器生成的样本分布贴近真实分布呢?请看原文中下述推导:
- 首先固定生成器G,若能得到一个最优的判别器使得交叉熵最小,那么有:
,即求 ,求得
-
将 带入原式,得到:
这样来看,如果对于每个G,都能取到最优的D,这样实际就是在最小化 两个分布之间的JS散度,如此一来理论上能够获得JS散度标准下最接近原分布的分布。
4. 问题
4.1 梯度消失问题
在生成器刚开始生成的样本还不够逼真的时候,判别器非常容易完全分辨出真样本和生成的样本,这就导致了 D(x) 在x为生成样本时的值非常小,导致了上面的 对G梯度的贡献十分小,G收敛的速度十分慢甚至会出现梯度消失的问题。
一种改进措施:用 代替 . 如下图,绿色为 -log(x),在接近0的地方梯度较大,而蓝色曲线在接近0的地方梯度十分小。
这种改进实际上是对上面的损失函数交换了真实样本和生成样本的分布。
改进前:
改进后:
4.2 模型坍塌
模型坍塌实际上就是生成模型只关注样本的某一小部分特征,对输入的噪声不在敏感,导致生成的样本失去了多样性。这里相关介绍参考博文 《Model Collapse in GANs》。
在很多分布中,通常有多了概率峰值,但是GANs有时候却只能专注于一个峰值,导致产生的样本只专注于某一区域。举例说明如下:
现在假设原数据分布为全球各地的气温分布,其中最热的地方为澳大利亚的 Alice Spring,最冷的地方为南极。分布在这两个地方达到两个峰值,如下:
我们希望学到的生成模型生成的温度在两个峰值温度都有相应的生成,可是如果出现模型坍塌,则可能模型只会生成在South Pole附近的温度,为什么会发生这种情况咧?
- 生成器发现生成south pole附件的温度值更能让判别器分辨不出,于是生成的样本更靠近south pole附近;
- 判别器发现Alice Spring附件的温度值大部分都是真样本,而South Pole附件的样本是真是假比较难分辨;
- 生成器放弃生成Alice Spring附件的温度值,集中在South Pole附近的温度值;
- 判别器判定Alice Spring附件的温度值都是真样本,South Pole附近的值都是假样本;
- 重复上述过程。
如此一来,最终生成器走向了“万劫不复”之路,导致最后的模型坍塌。
标签:GANs,判别,样本,生成器,生成,分布,数据,浅析,对抗 来源: https://blog.csdn.net/angus_huang_xu/article/details/116374775