【Inference】变分推断以及VIEM
作者:互联网
在包含隐变量(latent variables)的推断问题中,针对连续性随机变量的情况,隐变量的高维以及被积函数(intergrand)的复杂度使积分(intergration)无法进行。而针对离散型随机变量,隐变量呈指数(exponentially)增长的隐状态使得精确计算的花费过高(prohibitively)。因此人们有了近似推断的想法。主要包括随机近似(stochastic)和确定性近似(deterministic)。变分推断就是确定性近似的一种。
文章目录
1 变分推断的基本思想
假设模型是联合概率分布 p ( x , z ) p(x,z) p(x,z),其中 x x x是观测变量, z z z是隐变量,包括参数。目标是学习模型的后验概率分布 p ( z ∣ x ) p(z|x) p(z∣x),用模型进行概率推理。但这是一个复杂的分布,直接估计分布的参数很困难。
所以考虑用概率分布 q ( z ) q(z) q(z)近似条件概率分布 p ( z ∣ x ) p(z|x) p(z∣x),用 K L KL KL散度 D ( q ( z ) ∣ ∣ p ( z ∣ x ) ) D(q(z)||p(z|x)) D(q(z)∣∣p(z∣x))计算两者的相似度, q ( z ) q(z) q(z)称为变分分布(variational distribution) 。如果能找到与 p ( z ∣ x ) p(z|x) p(z∣x)在 K L KL KL散度意义下最近的分布 q ∗ ( z ) q^{*}(z) q∗(z),则可以用这个分布近似 p ( z ∣ x ) p(z|x) p(z∣x)。
2 KL 散度以及问题转化
K L KL KL散度是用于比较两个分布之间的相似度的方法,分布越相似, K L KL KL散度越小。所以在变分推断中,我们目标就是让变分分布与目标分布的 K L KL KL散度最小。
K
L
KL
KL
d
i
v
e
r
g
e
n
c
e
divergence
divergence可以写作:
D
(
q
(
z
)
∣
∣
p
(
z
∣
x
)
)
=
E
q
(
z
)
[
l
o
g
q
(
z
)
]
−
E
q
(
z
)
[
l
o
g
p
(
z
∣
x
)
]
=
E
q
(
z
)
[
l
o
g
q
(
z
)
]
−
E
q
(
z
)
[
l
o
g
p
(
x
,
z
)
p
(
x
)
]
=
E
q
(
z
)
[
l
o
g
q
(
z
)
]
−
E
q
(
z
)
[
l
o
g
p
(
x
,
z
)
]
+
l
o
g
p
(
x
)
=
l
o
g
p
(
x
)
−
(
E
q
(
z
)
[
l
o
g
p
(
x
,
z
)
]
−
E
q
(
z
)
[
l
o
g
q
(
z
)
]
)
(
1.1
)
\begin{aligned} D\left ( q\left ( z \right )||p\left ( z|x \right ) \right ) &= E_{q(z)}\left [ log\, q\left ( z\right ) \right ] - E_{q(z)}\left [ log\, p\left ( z|x \right ) \right ] \\ &= E_{q(z)}\left [ log\, q\left ( z\right ) \right ] - E_{q(z)}\left [ log\, \frac{p\left ( x,z \right )}{p\left ( x \right )} \right ]\\ &= E_{q(z)}\left [ log\, q\left ( z\right ) \right ] - E_{q(z)}\left [ log\, p\left ( x,z\right ) \right ] + log \, p(x) \\ &= log \, p(x) - \left ( E_{q(z)}\left [ log\, p\left ( x,z\right ) \right ]- E_{q(z)}\left [ log\, q\left ( z\right ) \right ] \right ) (1.1) \end{aligned}
D(q(z)∣∣p(z∣x))=Eq(z)[logq(z)]−Eq(z)[logp(z∣x)]=Eq(z)[logq(z)]−Eq(z)[logp(x)p(x,z)]=Eq(z)[logq(z)]−Eq(z)[logp(x,z)]+logp(x)=logp(x)−(Eq(z)[logp(x,z)]−Eq(z)[logq(z)])(1.1)
其中
E
q
(
z
)
[
l
o
g
p
(
x
,
z
)
p
(
x
)
]
=
E
q
(
z
)
[
l
o
g
p
(
x
,
z
)
−
l
o
g
p
(
x
)
]
E_{q(z)}\left [ log\, \frac{p\left ( x,z \right )}{p\left ( x \right )} \right ]=E_{q(z)}\left [ log\, p\left ( x,z \right )-log \,p\left ( x \right ) \right ]
Eq(z)[logp(x)p(x,z)]=Eq(z)[logp(x,z)−logp(x)],因为
l
o
g
p
(
x
)
log \,p\left ( x \right )
logp(x)与
q
(
z
)
q(z)
q(z)无关,所以直接期望等于自身。
因为 K L KL KL散度大于等于零,当且仅当两个分布一-致时为零,所以任意情况下 l o g p ( x ) ≥ ( E q ( z ) [ l o g p ( x , z ) ] − E q ( z ) [ l o g q ( z ) ] ) = L ( q ) ( 1.2 ) log \, p(x) \ge \left ( E_{q(z)}\left [ log\, p\left ( x,z\right ) \right ]- E_{q(z)}\left [ log\, q\left ( z\right ) \right ] \right )=L(q)(1.2) logp(x)≥(Eq(z)[logp(x,z)]−Eq(z)[logq(z)])=L(q)(1.2)
我们可以理解为不等式右端是左端的下界,左端是右端的上界。只要让右端无限增大接近于左端,那么 K L KL KL散度就越接近于0。所以我们的目标就从找到 K L KL KL散度的最小值转化为求 L ( q ) L(q) L(q)的最大值。
补充:公式1.2中,左端一般称为证据(evidence),右端称为证据下界(evidence lower bound),简写为ELBO。根据公式1.1和公式1.2,易得
l
o
g
p
(
x
)
=
E
L
B
O
+
K
L
(
q
(
z
)
∣
∣
l
o
g
p
(
z
∣
x
)
)
=
(
E
q
(
z
)
[
l
o
g
p
(
x
,
z
)
]
−
E
q
(
z
)
[
l
o
g
q
(
z
)
]
)
+
E
q
(
z
)
[
l
o
g
q
(
z
)
]
−
E
q
(
z
)
[
l
o
g
p
(
z
∣
x
)
]
=
E
q
(
z
)
[
l
o
g
p
(
x
,
z
)
l
o
g
q
(
z
)
]
−
E
q
(
z
)
[
l
o
g
p
(
z
∣
x
)
l
o
g
q
(
z
)
]
=
∫
z
q
(
z
)
∗
l
o
g
p
(
x
,
z
)
q
(
z
)
d
z
+
∫
z
q
(
z
)
∗
l
o
g
p
(
z
∣
x
)
q
(
z
)
d
z
(
1.3
)
\begin{aligned} log\, p(x) &= ELBO+KL(q(z)||log\,p(z|x) )\\ &= \left ( E_{q(z)}\left [ log\, p\left ( x,z\right ) \right ]- E_{q(z)}\left [ log\, q\left ( z\right ) \right ] \right )+E_{q(z)}\left [ log\, q\left ( z\right ) \right ] - E_{q(z)}\left [ log\, p\left ( z|x \right ) \right ] \\ &=E_{q(z)}\left [ \frac{log\, p\left ( x,z\right ) }{log\, q\left ( z\right ) } \right ]-E_{q(z)}\left [ \frac{log\, p\left ( z|x\right ) }{log\, q\left ( z\right ) } \right ]\\ &=\int _{z}q(z)*log\, \frac{p(x,z)}{q(z)} dz+\int _{z}q(z)*log\, \frac{p(z|x)}{q(z)} dz(1.3) \end{aligned}
logp(x)=ELBO+KL(q(z)∣∣logp(z∣x))=(Eq(z)[logp(x,z)]−Eq(z)[logq(z)])+Eq(z)[logq(z)]−Eq(z)[logp(z∣x)]=Eq(z)[logq(z)logp(x,z)]−Eq(z)[logq(z)logp(z∣x)]=∫zq(z)∗logq(z)p(x,z)dz+∫zq(z)∗logq(z)p(z∣x)dz(1.3)
第一项就是
E
L
B
O
ELBO
ELBO。
3 平均场理论下变分推断的假设
在求解证据下界时,我们会引入平均场理论(mean field theroy)。
因为变分分布是我们自己找的一个简单的,易计算的分布。所以通常假设q(z)对z的所有分量都是互相独立的(实际是条件独立于参数),即满足
q
(
z
)
=
q
(
z
1
)
q
(
z
2
)
.
.
.
q
(
z
n
)
=
∏
n
=
1
M
q
(
z
M
)
(
1.4
)
q(z) = q(z_{1})q(z_{2})...q(z_{n})=\prod_{n=1}^{M} q(z_{M})(1.4)
q(z)=q(z1)q(z2)...q(zn)=n=1∏Mq(zM)(1.4)
这时的变分分布称为平均场。
K
L
KL
KL散度的最小化或证据下界最大化实际是在平均场的集合进行的。
接下来推导平均场理论下公式1.3依旧成立。
平均场理论下
L
(
q
)
=
E
L
B
O
=
∫
z
q
(
z
)
∗
l
o
g
p
(
x
,
z
)
q
(
z
)
d
z
=
∫
z
q
(
z
)
∗
l
o
g
p
(
x
,
z
)
d
z
−
∫
z
q
(
z
)
∗
l
o
g
q
(
z
)
d
z
=
∫
z
1
,
z
2
,
.
.
.
,
z
M
∏
i
=
1
M
q
i
(
z
i
)
∗
l
o
g
p
(
x
,
z
)
d
z
1
d
z
2
.
.
.
d
z
M
(
第
一
项
①
)
−
∫
z
1
,
z
2
,
.
.
.
,
z
M
∏
i
=
1
M
q
i
(
z
i
)
∗
∑
i
=
1
M
l
o
g
q
i
(
z
i
)
d
z
1
d
z
2
.
.
.
d
z
M
(
第
二
项
②
)
\begin{aligned} L(q) &= ELBO =\int _{z}q(z)*log\, \frac{p(x,z)}{q(z)} dz\\ &=\int _{z}q(z)*log\, p(x,z) dz-\int _{z}q(z)*log\, q(z) dz\\ &=\int _{z_{1},z_{2},...,z_{M}}\prod_{i=1}^{M} q_{i}(z_{i})*log\, p(x,z) dz_{1}\,dz_{2}...dz_{M}(第一项①)\\ &- \int _{z_{1},z_{2},...,z_{M}}\prod_{i=1}^{M} q_{i}(z_{i})*\sum_{i=1}^{M} log\, q_{i}(z_{i}) dz_{1}\,dz_{2}...dz_{M}(第二项②)\\ \end{aligned}
L(q)=ELBO=∫zq(z)∗logq(z)p(x,z)dz=∫zq(z)∗logp(x,z)dz−∫zq(z)∗logq(z)dz=∫z1,z2,...,zMi=1∏Mqi(zi)∗logp(x,z)dz1dz2...dzM(第一项①)−∫z1,z2,...,zMi=1∏Mqi(zi)∗i=1∑Mlogqi(zi)dz1dz2...dzM(第二项②)
下面推导中,我们假设
q
j
q_{j}
qj以外的平均场分量(
q
1
,
q
2
,
.
.
.
,
q
M
q_{1},q_{2},...,q_{M}
q1,q2,...,qM)均固定(其实
q
j
q_{j}
qj也确定了)。
第一项
∫
z
1
,
z
2
,
.
.
.
,
z
M
∏
i
=
1
M
q
i
(
z
i
)
∗
l
o
g
p
(
x
,
z
)
d
z
1
d
z
2
.
.
.
d
z
M
=
∫
z
j
q
i
(
z
j
)
d
z
j
∗
∫
z
i
≠
j
∏
i
≠
j
M
q
(
z
i
)
l
o
g
p
(
x
,
z
)
d
z
i
≠
j
=
∫
z
j
q
i
(
z
j
)
∗
E
∏
i
≠
j
M
q
(
z
i
)
[
l
o
g
p
(
x
,
z
)
d
z
i
≠
j
]
d
z
j
=
∫
z
j
q
i
(
z
j
)
∗
l
o
g
p
^
(
x
,
z
j
)
(
1.5
)
\begin{aligned} &\int _{z_{1},z_{2},...,z_{M}}\prod_{i=1}^{M} q_{i}(z_{i})*log\, p(x,z) dz_{1}\,dz_{2}...dz_{M}\\ &=\int _{z_{j}}q_{i}(z_{j})dz_{j}*\int _{z_{i\ne j}} \prod_{i\ne j}^{M}q(z_{i})log\, p(x,z)dz_{i\ne j}\\ &=\int _{z_{j}}q_{i}(z_{j})*E_{\prod_{i\ne j}^{M}q(z_{i})}[log\, p(x,z)dz_{i\ne j}]dz_{j}\\ &=\int _{z_{j}}q_{i}(z_{j})*log\, \widehat{p}(x,z_{j})(1.5) \end{aligned}
∫z1,z2,...,zMi=1∏Mqi(zi)∗logp(x,z)dz1dz2...dzM=∫zjqi(zj)dzj∗∫zi=ji=j∏Mq(zi)logp(x,z)dzi=j=∫zjqi(zj)∗E∏i=jMq(zi)[logp(x,z)dzi=j]dzj=∫zjqi(zj)∗logp
(x,zj)(1.5)
第二项
∫
z
1
,
z
2
,
.
.
.
,
z
M
∏
i
=
1
M
q
i
(
z
i
)
∗
∑
i
=
1
M
l
o
g
q
i
(
z
i
)
d
z
1
d
z
2
.
.
.
d
z
M
=
∫
z
1
,
z
2
,
.
.
.
,
z
M
∏
i
=
1
M
q
i
(
z
i
)
[
l
o
g
q
1
(
z
1
)
+
l
o
g
q
2
(
z
2
)
+
.
.
.
+
l
o
g
q
M
(
z
M
)
]
d
z
M
d
z
2
.
.
.
d
z
M
\begin{aligned} &\int _{z_{1},z_{2},...,z_{M}}\prod_{i=1}^{M} q_{i}(z_{i})*\sum_{i=1}^{M} log\, q_{i}(z_{i}) dz_{1}\,dz_{2}...dz_{M}\\ &=\int _{z_{1},z_{2},...,z_{M}}\prod_{i=1}^{M} q_{i}(z_{i})[log\, q_{1}(z_{1})+log\, q_{2}(z_{2})+...+log\, q_{M}(z_{M})]dz_{M}\,dz_{2}...dz_{M}\\ \end{aligned}
∫z1,z2,...,zMi=1∏Mqi(zi)∗i=1∑Mlogqi(zi)dz1dz2...dzM=∫z1,z2,...,zMi=1∏Mqi(zi)[logq1(z1)+logq2(z2)+...+logqM(zM)]dzMdz2...dzM
求和项展开第一项:
∫
z
1
,
z
2
,
.
.
.
,
z
M
∏
i
=
1
M
q
i
(
z
i
)
l
o
g
q
1
(
z
1
)
d
z
1
,
z
2
,
.
.
.
,
z
M
=
∫
z
1
,
z
2
,
.
.
.
,
z
M
q
1
(
z
1
)
l
o
g
q
1
(
z
1
)
q
2
(
z
2
)
q
3
(
z
3
)
.
.
.
q
M
(
z
M
)
d
z
1
,
z
2
,
.
.
.
,
z
M
=
∫
z
1
q
1
(
z
1
)
l
o
g
q
1
(
z
1
)
d
z
1
∗
∫
z
2
q
2
(
z
2
)
d
z
2
∗
∫
z
3
q
3
(
z
3
)
d
z
3
∗
.
.
.
∗
∗
∫
z
M
q
M
(
z
M
)
d
z
M
=
∫
z
1
q
1
(
z
1
)
l
o
g
q
1
(
z
1
)
d
z
1
\begin{aligned} & \int _{z_{1},z_{2},...,z_{M}}\prod_{i=1}^{M} q_{i}(z_{i})\ log\, q_{1}(z_{1})dz_{1},z_{2},...,z_{M}\\ & = \int _{z_{1},z_{2},...,z_{M}}q_{1}(z_{1})log\, q_{1}(z_{1})\ q_{2}(z_{2})\ q_{3}(z_{3})...\ q_{M}(z_{M})dz_{1},z_{2},...,z_{M}\\ & = \int _{z_{1}}q_{1}(z_{1})log\, q_{1}(z_{1})dz_{1}* \int _{z_{2}}q_{2}(z_{2})dz_{2}* \int _{z_{3}}q_{3}(z_{3})dz_{3}*...** \int _{z_{M}}q_{M}(z_{M})dz_{M}\\ & = \int _{z_{1}}q_{1}(z_{1})log\, q_{1}(z_{1})dz_{1} \end{aligned}
∫z1,z2,...,zMi=1∏Mqi(zi) logq1(z1)dz1,z2,...,zM=∫z1,z2,...,zMq1(z1)logq1(z1) q2(z2) q3(z3)... qM(zM)dz1,z2,...,zM=∫z1q1(z1)logq1(z1)dz1∗∫z2q2(z2)dz2∗∫z3q3(z3)dz3∗...∗∗∫zMqM(zM)dzM=∫z1q1(z1)logq1(z1)dz1
所以求和项展开的第m项:
∫
z
m
q
m
(
z
m
)
l
o
g
q
m
(
z
m
)
d
z
m
\int _{z_{m}}q_{m}(z_{m})log\, q_{m}(z_{m})dz_{m}
∫zmqm(zm)logqm(zm)dzm
因为除了
q
j
q_{j}
qj都确定了,所以最终第二项等于
②
=
∑
i
=
1
M
∫
z
i
q
i
(
z
i
)
l
o
g
q
i
(
z
i
)
d
z
i
=
∫
z
j
q
j
(
z
j
)
l
o
g
q
j
(
z
j
)
d
z
j
+
C
(
常
数
)
(
1.6
)
\begin{aligned} ②=\sum_{i=1}^{M}\int _{z_{i}}q_{i}(z_{i})log\, q_{i}(z_{i})dz_{i}=\int _{z_{j}}q_{j}(z_{j})log\, q_{j}(z_{j})dz_{j}+C(常数)(1.6) \end{aligned}
②=i=1∑M∫ziqi(zi)logqi(zi)dzi=∫zjqj(zj)logqj(zj)dzj+C(常数)(1.6)
结合等式1.5和等式1.6,
L
(
q
)
=
∫
z
j
q
i
(
z
j
)
∗
l
o
g
p
^
(
x
,
z
j
)
−
∫
z
j
q
j
(
z
j
)
l
o
g
q
j
(
z
j
)
d
z
j
+
C
=
∫
z
j
q
i
(
z
j
)
∗
l
o
g
p
^
(
x
,
z
j
)
q
i
(
z
j
)
d
z
j
+
C
=
−
K
L
(
q
i
(
z
j
)
∣
∣
p
^
(
x
,
z
j
)
)
+
C
\begin{aligned} L(q)& = \int _{z_{j}}q_{i}(z_{j})*log\, \widehat{p}(x,z_{j})-\int _{z_{j}}q_{j}(z_{j})log\, q_{j}(z_{j})dz_{j}+C\\ & = \int _{z_{j}}q_{i}(z_{j})*log\, \frac{\widehat{p}(x,z_{j})}{q_{i}(z_{j})}dz_{j}+C \\ & = -KL(q_{i}(z_{j})||\widehat{p}(x,z_{j}))+C \end{aligned}
L(q)=∫zjqi(zj)∗logp
(x,zj)−∫zjqj(zj)logqj(zj)dzj+C=∫zjqi(zj)∗logqi(zj)p
(x,zj)dzj+C=−KL(qi(zj)∣∣p
(x,zj))+C
因为
K
L
KL
KL散度恒大于0,所以
L
(
q
)
L(q)
L(q)恒小于
C
C
C。且当
L
(
q
)
L(q)
L(q)接近于
C
C
C时,
K
L
KL
KL散度无限小,q(z)与
4 变分推断的步骤
根据以上的假设和铺垫,变分推断共有以下几个步骤:
步骤1:定义变分分布 q ( z ) q(z) q(z);
步骤2:推导其证据下界表达式;
步骤3:用最优化方法对证据下界进行优化,如坐标上升,得到最优分布 q ∗ ( z ) q^{*}(z) q∗(z),作为后验分布 p ( z ∣ x ) p(z|x) p(z∣x)的近似。
这里最困难的就是证据下界最大化的问题。而EM算法(Expectation Maximization Algorithm),本身也是利用了下界最大化直至收敛的想法来进行优化的,所以我们可以应用EM算法处理第三步。
5 EM在变分推断中的应用(VIEM)
假设模型是联合概率分布
p
(
x
,
z
∣
0
)
p(x, z|0)
p(x,z∣0),其中
x
x
x是观测变量,
z
z
z是隐变量,
θ
\theta
θ是参数。
目标是通过观测数据的概率(证据)
l
o
g
p
(
x
∣
θ
)
log\, p(x|\theta)
logp(x∣θ)的最大化,估计模型的参数
θ
\theta
θ。使用变分推理,导入平均场
q
(
z
)
=
∏
n
=
1
M
q
(
z
M
)
q(z) =\prod_{n=1}^{M} q(z_{M})
q(z)=∏n=1Mq(zM), 定义证据下界
L
(
q
,
θ
)
=
E
q
[
l
o
g
p
(
x
,
z
∣
θ
)
]
−
E
q
[
l
o
g
q
(
z
)
]
\begin{aligned} L(q,\theta)=E_{q}[log\,p(x,z|\theta)]-E_{q}[log\,q(z)] \end{aligned}
L(q,θ)=Eq[logp(x,z∣θ)]−Eq[logq(z)]
通过迭代,分别以
q
q
q和
θ
\theta
θ为变量对证据下界进行最大化,就得到变分
E
M
EM
EM算法。
算法的目标函数
q
^
=
a
r
g
m
i
n
q
K
L
(
q
(
z
)
∣
∣
p
(
z
∣
x
)
)
=
a
r
g
m
a
x
q
L
(
q
,
θ
)
\widehat q= argmin_{q}\ KL(q(z)||p(z|x))=argmax_{q}\ L(q,\theta)
q
=argminq KL(q(z)∣∣p(z∣x))=argmaxq L(q,θ)
在第t次迭代中:
(1) E E E步:固定 θ ( t − 1 ) \theta^{(t-1)} θ(t−1),求 L ( q ( t ) , θ ( t − 1 ) ) L(q^{(t)},\theta^{(t-1)}) L(q(t),θ(t−1))对 q q q的最大化。
(2)
M
M
M步:固定
q
(
t
)
q^{(t)}
q(t),求
L
(
q
(
t
)
,
θ
(
t
)
)
L(q^{(t)},\theta^{(t)})
L(q(t),θ(t))对
θ
θ
θ的最大化。
变分EM推断中,以下关系成立:
l
o
g
p
(
x
∣
θ
t
−
1
)
=
L
(
q
(
t
)
,
θ
(
t
−
1
)
)
≤
L
(
q
(
t
)
,
θ
(
t
)
)
=
l
o
g
p
(
x
∣
θ
t
)
\begin{aligned} log\, p(x|\theta^{t-1})= L(q^{(t)},\theta^{(t-1)})\le L(q^{(t)},\theta^{(t)})=log\, p(x|\theta^{t}) \end{aligned}
logp(x∣θt−1)=L(q(t),θ(t−1))≤L(q(t),θ(t))=logp(x∣θt)
根据公式1.3
L
(
q
,
θ
)
=
E
L
B
O
=
E
q
(
z
)
[
l
o
g
p
θ
(
x
,
z
)
l
o
g
q
(
z
)
]
=
E
q
(
z
)
[
l
o
g
p
θ
(
x
,
z
)
]
+
H
[
q
(
z
)
]
(
常
数
)
)
L(q,\theta)=ELBO=E_{q(z)}\left [ \frac{log\, p_{\theta}\left ( x,z\right ) }{log\, q\left ( z\right ) } \right ]=E_{q(z)}\left [ log\, p_{\theta}\left ( x,z\right ) \right ]+H[q(z)](常数))
L(q,θ)=ELBO=Eq(z)[logq(z)logpθ(x,z)]=Eq(z)[logpθ(x,z)]+H[q(z)](常数))
根据之前的假设,对每一个
q
j
q_{j}
qj,都是
q
i
≠
j
q_{i\ne j}
qi=j固定其余的 。所以
l
o
g
q
j
(
z
j
)
=
E
∏
i
≠
j
q
i
(
z
i
)
[
l
o
g
p
θ
(
x
,
z
)
]
+
C
=
∫
q
1
∫
q
2
.
.
.
∫
q
j
−
1
∫
q
j
+
1
.
.
.
∫
q
M
q
1
q
2
.
.
.
q
j
−
1
q
j
+
1
.
.
.
q
M
[
l
o
g
p
(
x
,
z
)
]
d
q
1
d
q
2
.
.
.
d
q
j
−
1
d
q
j
+
1
.
.
.
d
q
M
\begin{aligned} log\, q_{j}(z_{j}) &=E_{\prod_{i\ne j}q_{i}(z_{i})}\left [ log\, p_{\theta}\left ( x,z\right ) \right ]+C \\ &=\int_{q_{1}}\int_{q_{2}}...\int_{q_{j-1}}\int_{q_{j+1}}...\int_{q_{M}}q_{1}q_{2}...q_{j-1}q_{j+1}...q_{M}[log\, p(x,z)]dq_{1}dq_{2}...dq_{j-1}dq_{j+1}...dq_{M} \end{aligned}
logqj(zj)=E∏i=jqi(zi)[logpθ(x,z)]+C=∫q1∫q2...∫qj−1∫qj+1...∫qMq1q2...qj−1qj+1...qM[logp(x,z)]dq1dq2...dqj−1dqj+1...dqM
可以使用坐标上升的方法进行迭代求解。即在每一轮迭代的M步中:
{
q
1
^
(
z
1
)
=
∫
q
2
∫
q
3
.
.
.
∫
q
M
q
2
q
3
.
.
.
q
M
[
l
o
g
p
θ
(
x
,
z
)
]
d
q
2
d
q
3
.
.
.
d
q
M
q
2
^
(
z
2
)
=
∫
q
1
^
∫
q
3
.
.
.
∫
q
M
q
1
^
q
3
.
.
.
q
M
[
l
o
g
p
θ
(
x
,
z
)
]
d
q
1
^
d
q
3
.
.
.
d
q
M
q
3
^
(
z
3
)
=
∫
q
1
^
∫
q
2
^
.
.
.
∫
q
M
q
1
^
q
2
^
.
.
.
q
M
[
l
o
g
p
θ
(
x
,
z
)
]
d
q
1
^
d
q
2
^
.
.
.
d
q
M
.
.
.
.
.
.
.
.
.
.
.
.
q
M
^
(
z
M
)
=
∫
q
1
^
∫
q
2
^
.
.
.
∫
q
M
−
1
^
q
1
^
q
2
^
.
.
.
q
M
−
1
^
[
l
o
g
p
θ
(
x
,
z
)
]
d
q
1
^
d
q
2
^
.
.
.
d
q
M
−
1
^
\begin{aligned} \left\{\begin{matrix} \widehat{q_{1}}(z_{1})=\int _{q_{2}}\int _{q_{3}}...\int _{q_{M}}\, q_{2}\, q_{3}...q_{M}[log\, p_{\theta}(x,z)]dq_{2}\, dq_{3}...dq_{M} \\ \widehat{q_{2}}(z_{2})=\int _{\widehat{q_{1}}}\int _{q_{3}}...\int _{q_{M}}\, \widehat{q_{1}}\, q_{3}...q_{M}[log\, p_{\theta}(x,z)]d\widehat{q_{1}}\, dq_{3}...dq_{M} \\ \widehat{q_{3}}(z_{3})=\int _{\widehat{q_{1}}}\int _{\widehat{q_{2}}}...\int _{q_{M}}\, \widehat{q_{1}}\, \widehat{q_{2}}...q_{M}[log\, p_{\theta}(x,z)]d\widehat{q_{1}}\, d\widehat{q_{2}}...dq_{M}\\ ...... \\ ...... \\ \widehat{q_{M}}(z_{M})=\int _{\widehat{q_{1}}}\int _{\widehat{q_{2}}}...\int _{\widehat{q_{M-1}}}\, \widehat{q_{1}}\, \widehat{q_{2}}...\widehat{q_{M-1}}[log\, p_{\theta}(x,z)]d\widehat{q_{1}}\, d\widehat{q_{2}}...d\widehat{q_{M-1}} \end{matrix}\right. \end{aligned}
⎩⎪⎪⎪⎪⎪⎪⎪⎨⎪⎪⎪⎪⎪⎪⎪⎧q1
(z1)=∫q2∫q3...∫qMq2q3...qM[logpθ(x,z)]dq2dq3...dqMq2
(z2)=∫q1
∫q3...∫qMq1
q3...qM[logpθ(x,z)]dq1
dq3...dqMq3
(z3)=∫q1
∫q2
...∫qMq1
q2
...qM[logpθ(x,z)]dq1
dq2
...dqM............qM
(zM)=∫q1
∫q2
...∫qM−1
q1
q2
...qM−1
[logpθ(x,z)]dq1
dq2
...dqM−1
当然,平均场理论是一个非常强的假设,像神经网络就不适合平均场理论。因此SGVI(Stochastic Gradient Variational Inference)随机梯度变分推断就出现了。
(后续更新)
参考 reference
[1]David Bellot. Learning Probabilistic Graphical Models in R. Packt Publishing, 2016
[2]李航.《统计学习方法》(第二版).清华大学出版社, 2019
[3]参考视频: 【机器学习】【白板推导系列】【合集 1~23】_哔哩哔哩_bilibili.
PS:视频里的推导是反着推的,从 l o g p ( x ) log\, p(x) logp(x)推出等于 E L B O ELBO ELBO和 K L d i v e g e n c e KL\, divegence KLdivegence。再从 E L B O ELBO ELBO推出等于 − K L d i v e r g e n c e + C o n s t a n t -KL\, divergence +Constant −KLdivergence+Constant(其实就是 l o g p ( x ) log\, p(x) logp(x),和 q ( z ) q(z) q(z)无关所以是常数)。但是反推的过程较难,比较锻炼读者的推导能力,推荐观看。
标签:...,right,Inference,log,int,变分,dz,left,VIEM 来源: https://blog.csdn.net/qq_43749398/article/details/121989797