其他分享
首页 > 其他分享> > VAE-变分推断

VAE-变分推断

作者:互联网

1.推荐材料

1.PRML 第十章节 变分推断
2.B站 白板推导 这部分讲解的很详细
https://www.bilibili.com/video/BV1aE411o7qd?p=70
https://www.bilibili.com/video/BV1aE411o7qd?p=71
https://www.bilibili.com/video/BV1aE411o7qd?p=72
https://www.bilibili.com/video/BV1aE411o7qd?p=73
https://www.bilibili.com/video/BV1aE411o7qd?p=74
3.鲁鹏老师 计算机视觉与深度学习 这个讲的比较浅显易懂
https://www.bilibili.com/video/BV1V54y1B7K3?p=14
4.知乎 - 有不少笔误,需要详细斟酌
https://zhuanlan.zhihu.com/p/94638850
5.邱锡鹏 - 蒲公英书 第13.2章节 变分自编码器
6.https://blog.csdn.net/lrt366/article/details/83154048
这篇文章解释了两个我一直困惑的点

2.VAE的目标

对隐藏变量\(Z\)的特征提取(鲁鹏视频)


正常数据很难真的达到高斯分布,一般都是由多个高斯分布叠加组成,俗称GMM(混合高斯模型)

为了达到上面的目的需要完成下列两个事情
1.隐变量\(Z\)的真实概率分布\(P(Z)\)
2.求出生成概率模型\(P(X|Z)\)的参数\(\theta\)

\(问题1通过神经网络学习 输入X,输出P(Z),神经网络参数为\phi,这是VAE的前半段,称为推断网络\)
\(问题2通过神经网络血虚,输入Z,输出\hat X,神经网络参数为\theta,这是VAE的后半段,称为生成网络\)

模型结构

注意

3.流程梳理

1.梳理的第一步-主要网络结构

回到第一张图,我们先不看\(X\to Z\)这根虚线,\(Z\to X\)这条线是很明确的,就是为了实现2.2 这个小目标(邱锡鹏蒲公英书)
由贝叶斯公式得到\(P(X)=\frac{P(X|Z)P(Z)}{P(Z|X)}\)
\(\ln P(X)=\ln P(Z,X) - \ln P(Z|X) = \ln \frac{P(Z,X)}{q(Z)} - \ln \frac{P(Z|X)}{q(Z)} -\color{red}{公式1}\)
这里引入了一个新的分布\(q(Z)\),为什么要引入它?因为真实分布\(P(Z)\)无法获得,引入\(q(Z)\),希望\(q(Z)\)可以无限逼近\(P(Z)\)
那么这个\(q(Z)\)是怎么算出来的呢?这里就用到了VAE的前半段,也就是推断网络学习出来的
至此先梳理清楚VAE就是通过一个推断网络+生成网络组成的
另外一个很重要的点,前半段的生成网络\(q(Z)\)是通过学习样本\(X\)获得的,所以应该写成q(Z|X)的形式,并且我们令推断网络的参数为\(\phi\),生成网络的参数为\(\theta\),避免搞混
所以最终梳理一下标记
\(由样本X通过推断网络f_{I}(X,\phi)学习得到近似隐变量真实分布P(Z)的近似分布q(Z|X;\phi)\)
\(通过q(Z|X;\phi) \color{red}{采样}后得到一个样本数据Z,然后通过生成网络f_{G}(Z,\theta)学习得到X经过\color{red}{数据降维}后的数据\hat X,并且\hat X的数据分布满足P(X|Z;\theta)\)
\(采样:通过采样能够学习到所有数据概率值情况下的数据,比如输入一张满月的照片,一张半月(1/2个月亮),则通过采样,就可以获取1/2月亮-全月亮所有可能形状的月亮(鲁鹏视频),这个在图形处理中很有用,可以用于图像增强\)

\(数据降维:一般来说Z是比X低维的数据,这样就能对主要特征进行抽取\)

2.梳理的第二步-重要的假设

回到这张图

还有这张

假设1 - 可以看到通常我们假设这些隐变量都是服从正态分布的,也就是GMM模型,现在假设\(P(Z)\)服从标准正态分布\(\sim \mathcal{N}(0,1)\)(\color{red}{为什么这么假设?先放一放,后面再解释,或者直接看我推荐的材料6})

3.梳理的第三步-推导目标函数

假设2 - 假设我们已知真实分布\(P(Z)\),这样我们先处理VAE的后半段-生成模型,那么为了求解生成模型中的参数\(\theta\),我们用最常用的最大似然法,通过贝叶斯方法求解边缘分布\(P(X)\),使得对数似然函数\(\log \prod P(X;\theta)\)最大即可
\(\ln P(X;\theta)=\ln P(Z,X;\theta) - \ln P(Z|X;\phi) = \ln \frac{P(Z,X;\theta)}{q(Z|X,\phi)} - \ln \frac{P(Z|X;\theta)}{q(Z|X,\phi)} -\color{red}{公式1}\)
大家可以看到这个公式无意中已经把前半段的推断网络也牵扯进来了,不再是单纯的求生成网络
还有一点,根据模型结构中的图1,其实也就是一个概率图模型,\(\theta\)同时决定了\(Z\)和\(X\),所以P(Z,X)添加了一个解释说明的参数,P(Z,X;\theta)代表P(Z,X)这个模型中的参数也是\(\theta\)
继续推导上面的公式
\(\ln P(X;\theta)=\ln P(Z,X;\theta) - \ln P(Z|X;\phi) = \ln \frac{P(Z,X;\theta)}{q(Z|X,\phi)} - \ln \frac{P(Z|X;\theta)}{q(Z|X,\phi)}\)
\(两边对q(Z;\phi)求期望\)
\(左边=\int \ln P(X;\theta)q(Z|X;\phi)dZ=\ln P(X;\theta),不变\)
\(右边=\int[\ln \frac{P(Z,X;\theta)}{q(Z|X,\phi)} - \ln \frac{P(Z|X;\theta)}{q(Z|X,\phi)}]q(Z|X;\phi)dZ\)
\(=\int\ln \frac{P(Z,X;\theta)}{q(Z|X,\phi)}q(Z|X;\phi)dZ - \int\ln \frac{P(Z|X;\theta)}{q(Z|X,\phi)}q(Z|X;\phi)dZ\)
\(这个式子的前一半称为\color{red}{ELBO}(evidence\ lower\ bounds),后一半是KL散度,\color{red}{KL(q(Z|X;\phi) || P(Z|X;\theta))}\)


小结一下
\(\ln P(X;\theta) =\int\ln \frac{P(Z,X;\theta)}{q(Z|X,\phi)}q(Z|X;\phi)dZ - \int\ln \frac{P(Z|X;\theta)}{q(Z|X,\phi)}q(Z|X;\phi)dZ\)
\(=ELBO + KL(q(Z|X;\phi) || P(Z|X;\theta))\)

4.梳理第四步-详解ELBO

ELBO = \(\int\ln \frac{P(Z,X;\theta)}{q(Z|X,\phi)}q(Z|X;\phi)dZ\)
\(=\mathbb{E}_{Z\sim q(Z|X;\phi)}[\ln \frac{P(Z,X;\theta)}{q(Z|X,\phi)}]\)
\(=\mathbb{E}_{Z\sim q(Z|X;\phi)}[\ln \frac{P(Z|X;\theta)P(X;\theta)}{q(Z|X,\phi)}]\)
\(=\mathbb{E}_{Z\sim q(Z|X;\phi)}[\ln P(Z|X;\theta)]-KL(q(Z|X;\phi) || P(Z|X;\theta))\)
先处理后一半
\(其中 P(Z|X;\theta) 是先验,之前我们假设了是标准正态分布\)
\(q(Z|X;\phi)是后验,我们仍然假设是正态分布(GMM模型),不过参数未知\sim \mathcal{N}(\mu_I,\sigma_I^2 I),下标I代表是从推断网络得到的\)
\(两个正态分布的KL散度可以直接求出,不推导了,直接看邱锡鹏老师的蒲公英书 公式13.24\)
\(KL(q(Z|X;\theta) || P(Z;\theta))=\frac{1}{2}(tr(\sigma^2_II)+\mu_I^T\mu_I-d-\log(|\sigma_I^2I|))\)
\(d是维度\)

5.梳理第四步-重参数化技巧

第四步中的前一半\(\mathbb{E}_{Z\sim q(Z|X;\phi)}[\ln P(Z|X;\theta)]\)可以通过收集数据后取平均的方式,但是最大的问题是Z是通过采样得到的,没有确定的表达式,没办法求梯度
所以这里引入了重参数化技巧

重参数化

引入公式
\(Z=\mu_I +\sigma_I \times \epsilon\)
\(\epsilon \sim \mathcal{N}(0,I)\)
\(这样\mathbb{E}_{Z\sim q(Z|X;\phi)}[\ln P(Z|X;\theta)] 可以转化为 \mathbb{E}_{\epsilon\sim p(\epsilon}[\ln P(Z|g(\phi,\epsilon);\theta)]\)
这样我们就改写了网络结构

6.目标函数总结

\(目标函数最终定义为\)
\(L(\phi,\theta|X)=\sum_{n=1}^{N}(\frac{1}{M}\sum_{n=1}^{N}\log p(x^{(n)}|z^{(n,m)};\theta) - KL(q(z|x^{(n);\theta},N(z;0,I)))) - 蒲公英书 13.27\)
\(\color{red}{这一步的推导看起来很自然,但总感觉不知道怎么推导出来的}\)
\(\color{red}{另外还有一个点\mu_G书中介绍的是生产网络的输出,我看了一些代码的确也是这么写的,但我不太明白的是为什么用一个均值\mu_G的符号表示,且在蒲公英书的 13.18 公式中使用过\mu_G,明确写着这是一个均值符号,这有什么意义吗?}\)
\(L(\phi,\theta|X)=-\frac{1}{2}||x-\mu_G||^2 -\lambda KL(q(z|x^{(n);\theta},N(z;0,I)))) - 蒲公英书 13.27\)
\(今天重温了下概率论与数理统计,居然发现邱锡鹏老师的公式推导是完全正确的\)
\(有一个隐藏的等式要引入,因为在实际使用的正态分布的似然函数推导的时候用的是误差服从正态分布的推导,也就是\epsilon = x-\mu =x-\hat x,\epsilon\sim N(0,1),这样就完全串联起来了\)

4.其他变分方法

PRML书上介绍了除了本章节要讲的方法,变分方法包括有限元方法,最大熵方法
除了基于梯度的变分方法,PRML书中还着重说明了基于平均场理论的变分方法

5.其他

为什么要求先验\(P(Z)是标准正态分布?\)

https://blog.csdn.net/lrt366/article/details/83154048

标签:phi,frac,ln,变分,VAE,推断,theta,KL
来源: https://www.cnblogs.com/boyknight/p/16290582.html