变分自编码器(VAE)

发布时间:2024-01-10 11:00

在本篇文章中,我将从变分自编码器的来源出发,从两个角度分别引出其网络和目标函数

VAE的思想来源

如果我们有一批样本,然后想生成一个新样本,我们应该怎么做呢?

首先,最直接的想法是根据已有样本得到真实分布 P d a t a ( x ) P_{data}(x) Pdata(x),从而根据 P d a t a ( x ) P_{data}(x) Pdata(x)采样即可获取新样本

但是很可惜,我们很难获得 P d a t a ( x ) P_{data}(x) Pdata(x),并且通常 P d a t a ( x ) P_{data}(x) Pdata(x)是复杂的,不容易采样

于是我们有了一个新的想法,要是 P d a t a ( x ) P_{data}(x) Pdata(x)是一个我们已知的很简单的分布就好了,比如我们假设它就是一个高斯分布。这样我们可以通过极大似然估计求出高斯分布的参数。然而这种假设太强了,对于很多情况是不适用的。

于是我们再想到,既然这个假设太强的话,我们假设 P d a t a ( x ) P_{data}(x) Pdata(x)是多个高斯分布的叠加呢,即高斯混合模型(GMM)。我们假设有一个隐变量z服从多项式分布,表示样本来自于哪个高斯分布。我们只要先对z采样,得到相应的高斯分布P(x|z),然后对P(x|z)采样就能得到新样本了。

这么一来,显然模型的表达能力变强了

但是我们仍然不满意,高斯混合模型只能解决有限的多峰问题。

假如我们的样本的真实分布有无数个峰呢,高斯混合模型就不行了。

那我们再改进一下,我们假设 P d a t a ( x ) P_{data}(x) Pdata(x)是无数个高斯分布的叠加那不就解决了前面的问题了吗。

比如,只要z从原来的离散的多项式分布变为连续的高斯分布,那么P(x|z)就有无限多个,从而有
P d a t a ( x ) = ∫ z P ( x ∣ z ) P ( z ) d z P_{data}(x) = \int_z P(x|z)P(z)dz Pdata(x)=zP(xz)P(z)dz
这就是变分自编码器的根本思想,其本质上就是GMM的连续性扩展。其它一切都是在这个基础上推导而来的。

那么,问题来了,P(z)可以随意设定为一个连续性的分布,通常假设为N(0,1)的高斯分布,但是我们怎么得到P(x|z)呢?或者说,我们怎么得到P(x|z)的均值和方差?

从训练的角度

变分自编码器首先提出,用神经网络模型计算P(x|z)的均值和方差

在这里插入图片描述

但是我们知道,这个网络没法训练,因为这个网络输入的是一个在N(0,1)上随机采样的z,输出的是 μ \mu μ σ 2 \sigma^2 σ2或者说是一个高斯分布,我们并没有一个跟z相关的高斯分布的标签计算KL散度,然后让KL散度最小用于更新参数,或者如果有一个跟z对应的x,我们可以根据采用得到的x与标签x求mse,这样也能更新参数。

那么怎么办呢?既然没有标签,那我们就创造标签。

自编码器网络通过一个编码器和解码器创造了z和x的双向映射关系,那么我们可不可以借鉴这种思路呢?

自编码器的映射关系是确定性的,即z=f(x),x = g(z)

上面的网络已经创建了z->x的映射,输入z,输出的是p(x|z)这个高斯分布,采样得到x,我们很容易想到从x->z的映射就是,输入x,输出的是q(z|x)这个高斯分布,采样得到z

本质上,变分自编码器创造的是q(z|x)分布与P(x|z)分布的映射关系,与具体的值无关。

这么一来的话,我们的新网络结构如图:

在这里插入图片描述

我们分析一下这两个网络的最终目标,推断网络希望获得一个合适的后验分布,使得根据这个后验分布采样的z能表示代码x的编码,即可解码。生成网络则希望根据给定的编码,获得一个合适的似然分布,根据这个分布采样的x要尽可能与原来的x一致。

但是,我们之前已经假定了z服从N(0,1)这个高斯先验,而最终学出来的生成网络适应的z是q(z|x)生成的z,而我们是不知道q(z|x)的,因为我们得先有x,才有q(z|x),但是生成新样本要求我们先有z。

但是如果q(z|x)与p(z)先验很接近的话,那么我们直接使用p(z)近似q(z|x)进行采样也是可以的。

因此,总体来说,我们有两个目标,一个是让 x ^ \hat x x^ x x x尽可能接近,即重构损失 1 2 ( x − x ^ ) 2 \frac{1}{2}(x-\hat x)^2 21(xx^)2最小,另一个是q(z|x)与p(z)尽可能接近,我们用KL散度表示两个分布的接近程度 K L ( q ( z ∣ x ) ∣ ∣ p ( z ) ) KL(q(z|x)||p(z)) KL(q(zx)p(z))

尽管我们现在有了标签,并且定义好了损失函数我们的网络仍然无法训练,因为涉及了两次采样操作,采样是没办法计算梯度的。

所以有两个解决办法,重参数化和梯度估计。我们这里只介绍重参数化的方法。

重参数化

我们可以将z的随机性转移,假如有 ϵ ∼ N ( 0 , 1 ) \epsilon \sim N(0,1) ϵN(0,1),那么另 z = μ + ϵ ∗ σ z = \mu+\epsilon*\sigma z=μ+ϵσ,这样z仍然服从 N ( μ , σ 2 ) N(\mu,\sigma^2) N(μ,σ2),但是z关于 ϕ \phi ϕ的梯度就不再为0了。那么第一个采样问题解决了,第二个难道也这么做吗?

其实没有必要,因为第二次采样在训练网络时不是必要的,我们可以直接用 μ G \mu_G μG代替 x ^ \hat x x^,因为 μ G \mu_G μG是样本均值,更能代表普遍的采样结果。那为什么第一次采样不用 μ z \mu_z μz代替呢?这样做的话,就退化成自编码器了。没有噪声输入,生成网络的鲁棒性就不够强,其适应的是一批特定的编码,而非一个简单的分布

所以最终的损失函数就是
l o s s = 1 2 ( x − u G ) 2 + K L ( N ( u z , σ z 2 ) ∣ ∣ N ( 0 , 1 ) ) loss = \frac{1}{2}(x-u_G)^2+KL(N(u_z,\sigma_z^2)||N(0,1)) loss=21(xuG)2+KL(N(uz,σz2)N(0,1))
至于KL怎么算呢,有个公式
K L ( N ( u , σ 2 ) ∣ ∣ N ( 0 , 1 ) ) = 1 2 ∑ i = 1 d ( μ 2 + σ 2 − log ⁡ σ 2 − 1 ) KL(N(u,\sigma^2)||N(0,1)) = \frac{1}{2} \sum_{i=1}^{d}\left(\mu^{2}+\sigma^{2}-\log \sigma^{2}-1\right) KL(N(u,σ2)N(0,1))=21i=1d(μ2+σ2logσ21)
至于怎么推出来的,有点麻烦,在此省略,感兴趣的可参考两个多变量高斯分布之间的KL散度 - 知乎 (zhihu.com)

从EM算法角度

前面我们从模型训练的角度导出了变分自编码器的网络结构和目标函数,但是有些地方还是有些牵强,缺乏逻辑。所以我们还是从数学的角度推导出上述结构

首先,我们先回顾一下原始的EM算法的过程:

E步:固定 θ \theta θ,求 q ( z ∣ x ) = a r g m a x q ( z ∣ x ) E L B O = P θ ( z ∣ x ) q(z|x) = argmax_{q(z|x)} ELBO = P_\theta(z|x) q(zx)=argmaxq(zx)ELBO=Pθ(zx)

M步:固定 q ( z ∣ x ) q(z|x) q(zx),求 θ = a r g m a x θ E L B O \theta = argmax_\theta ELBO θ=argmaxθELBO

由于 P ( z ∣ x ; θ ) P(z|x;\theta) P(zx;θ)很多时候是无法得到或难以求解的,所以提出了变分推断的方法,变分推断不要求E步的时候计算出后验分布,而是只要给出一个接近后验分布的q(z|x)就够了。也就是说,E步这个过程我们可以使用基于迭代的方法求近似解。比如基于平均场理论的变分推断利用坐标上升法迭代求解,SGVI通过梯度上升法求解。

变分自编码器就是一个基于梯度的方法。

首先,E步由于找的是一个函数,ELBO没办法对函数求梯度,所以我们将q(z|x)参数化,基于我们前面提到的假设,q(z|x)是一个高斯分布,参数未知,我们假设为 ϕ \phi ϕ

于是变分EM算法如下:

E步:固定 θ \theta θ,求 ϕ = a r g m a x ϕ E z ∼ q ϕ ( z ∣ x ) l o g p θ ( x ∣ z ) p ( z ) q ϕ ( z ∣ x ) \phi = argmax_\phi E_{z\sim q_\phi(z|x)}log\frac{p_\theta(x|z)p(z)}{q_\phi(z|x)} ϕ=argmaxϕEzqϕ(zx)logqϕ(zx)pθ(xz)p(z)

M步:固定 ϕ \phi ϕ,求 θ = a r g m a x θ E z ∼ q ϕ ( z ∣ x ) l o g p θ ( x ∣ z ) p ( z ) q ϕ ( z ∣ x ) \theta = argmax_\theta E_{z\sim q_\phi(z|x)}log\frac{p_\theta(x|z)p(z)}{q_\phi(z|x)} θ=argmaxθEzqϕ(zx)logqϕ(zx)pθ(xz)p(z)

其中
E z ∼ q ϕ ( z ∣ x ) l o g p θ ( x ∣ z ) p ( z ) q ϕ ( z ∣ x ) = ∫ q ϕ ( z ∣ x ) l o g p θ ( x ∣ z ) p ( z ) q ϕ ( z ∣ x ) d z = ∫ q ϕ ( z ∣ x ) l o g   p θ ( x ∣ z ) d z + ∫ q ϕ ( z ∣ x ) l o g p ( z ) q ϕ ( z ∣ x ) d z = ∫ q ϕ ( z ∣ x ) l o g   p θ ( x ∣ z ) d z + K L ( q ϕ ( z ∣ x ) ∣ ∣ p ( z ) ) \begin{aligned} E_{z\sim q_\phi(z|x)}log\frac{p_\theta(x|z)p(z)}{q_\phi(z|x)} &= \int q_\phi(z|x)log\frac{p_\theta(x|z)p(z)}{q_\phi(z|x)}dz \\ &= \int q_\phi(z|x)log\ p_\theta(x|z)dz+\int q_\phi(z|x)log\frac{p(z)}{q_\phi(z|x)} dz\\ &= \int q_\phi(z|x)log\ p_\theta(x|z)dz+KL(q_\phi(z|x)||p(z)) \end{aligned} Ezqϕ(zx)logqϕ(zx)pθ(xz)p(z)=qϕ(zx)logqϕ(zx)pθ(xz)p(z)dz=qϕ(zx)log pθ(xz)dz+qϕ(zx)logqϕ(zx)p(z)dz=qϕ(zx)log pθ(xz)dz+KL(qϕ(zx)p(z))
前一项为在后验分布 q ϕ ( z ∣ x ) q_\phi(z|x) qϕ(zx)的期望,可以通过采样近似。后一项是两个高斯分布 N ( μ ϕ ( x ) , σ ϕ 2 ( x ) ) N(\mu_\phi(x),\sigma_\phi^2(x)) N(μϕ(x),σϕ2(x)) N ( 0 , 1 ) N(0,1) N(0,1)的KL距离

简单起见,我们假设q(z|x)的方差是对角矩阵,p(x|z)的方差是 λ I \lambda I λI

E z ∼ q ϕ ( z ∣ x ) l o g p θ ( x ∣ z ) p ( z ) q ϕ ( z ∣ x ) = 1 M ∑ i = 1 M l o g   p θ ( x ∣ z ( i ) ) + K L ( q ϕ ( z ∣ x ) ∣ ∣ p ( z ) ) = 1 M ∑ i = 1 M 1 2 π ( x − μ θ ( i ) ) 2 + 1 2 ∑ i = 1 d ( μ ϕ i 2 + σ ϕ i 2 − log ⁡ σ ϕ i 2 − 1 ) \begin{aligned} E_{z\sim q_\phi(z|x)}log\frac{p_\theta(x|z)p(z)}{q_\phi(z|x)} &= \frac{1}{M}\sum_{i=1}^M log\ p_\theta(x|z^{(i)})+KL(q_\phi(z|x)||p(z)) \\ &= \frac{1}{M}\sum_{i=1}^M \frac{1}{\sqrt{2\pi}}(x-\mu_\theta^{(i)})^2+\frac{1}{2} \sum_{i=1}^{d}\left(\mu_{\phi i}^{2}+\sigma_{\phi i}^{2}-\log \sigma_{\phi i}^{2}-1\right) \end{aligned} Ezqϕ(zx)logqϕ(zx)pθ(xz)p(z)=M1i=1Mlog pθ(xz(i))+KL(qϕ(zx)p(z))=M1i=1M2π 1(xμθ(i))2+21i=1d(μϕi2+σϕi2logσϕi21)
我们再简化一下,假设只采样一个z近似期望,那么
E z ∼ q ϕ ( z ∣ x ) l o g p θ ( x ∣ z ) p ( z ) q ϕ ( z ∣ x ) = [ 1 2 ( x − μ θ ( z ) ) 2 − l o g ( 2 π ) ] + 1 2 ∑ i = 1 d ( μ i 2 + σ ϕ i 2 − log ⁡ σ ϕ i 2 − 1 ) \begin{aligned} E_{z\sim q_\phi(z|x)}log\frac{p_\theta(x|z)p(z)}{q_\phi(z|x)} &= [\frac{1}{2}(x-\mu_\theta(z))^2-log(\sqrt{2\pi})]+\frac{1}{2} \sum_{i=1}^{d}\left(\mu_{i}^{2}+\sigma_{\phi i}^{2}-\log \sigma_{\phi i}^{2}-1\right) \end{aligned} Ezqϕ(zx)logqϕ(zx)pθ(xz)p(z)=[21(xμθ(z))2log(2π )]+21i=1d(μi2+σϕi2logσϕi21)
常数项我们可以忽略,因为对argmax不影响。且由于E步和M步的目标一致,我们可以同时优化
ϕ , θ = a r g m a x ϕ , θ 1 2 ( x − μ θ ( z ) ) 2 + 1 2 ∑ i = 1 d ( μ i 2 + σ ϕ i 2 − log ⁡ σ ϕ i 2 − 1 ) \phi,\theta = argmax_{\phi,\theta}\frac{1}{2}(x-\mu_\theta(z))^2+\frac{1}{2} \sum_{i=1}^{d}\left(\mu_{i}^{2}+\sigma_{\phi i}^{2}-\log \sigma_{\phi i}^{2}-1\right) ϕ,θ=argmaxϕ,θ21(xμθ(z))2+21i=1d(μi2+σϕi2logσϕi21)
我们使用神经网络来预测 θ \theta θ ϕ \phi ϕ,这样一来,将EM算法过程抽象为网络结构即为:

E步:输入x,通过神经网络求得 ϕ \phi ϕ,然后根据 q ϕ ( z ∣ x ) q_\phi(z|x) qϕ(zx)采样得到 z 1 , z 2 , . . . z M z_1,z_2,...z_M z1,z2,...zM

M步:输入z,通过神经网络求得 θ \theta θ

更新:计算ELBO的关于 θ \theta θ ϕ \phi ϕ的梯度,更新参数

到这里我们就可以看出,这里的E步就对应上面提到的推断网络,M步则对应生成网络,也解释了为什么只对z采样而不对x采样(因为EM过程不需要对x采样)。目标函数的前一项对应重构损失,后一项对应q(z|x)与p(z)的KL距离

到此,整个变分自编码器的介绍就结束了。

现在我们再看变分自编码器这个名字,所谓变分指的是该模型使用的近似后验分布q(z|x)代替p(z|x),而自编码器则是因为它借鉴了自编码器的思路从而EM过程抽象成了类似自编码器的编码+解码的结构。二者的差别就是变分自编码器的编码器和解码器都输出分布,而自编码器输出具体的编码。显然变分自编码器具有更强的鲁棒性。

ItVuer - 免责声明 - 关于我们 - 联系我们

本网站信息来源于互联网,如有侵权请联系:561261067@qq.com

桂ICP备16001015号