论文原文:Auto-Encoding Variational Bayes [OpenReview (ICLR 2014) | arXiv]
本文记录了我在学习 VAE 过程中的一些公式推导和思考。如果你希望从头开始学习 VAE,建议先看一下苏剑林的博客(本文末尾有链接)。
VAE 的整体框架VAE 认为,随机变量 \(\boldsymbol{x} \sim p(\boldsymbol{x})\) 由两个随机过程得到:
(资料图片)
根据先验分布 \(p(\boldsymbol{z})\) 生成隐变量 \(\boldsymbol{z}\)。根据条件分布 \(p(\boldsymbol{x} | \boldsymbol{z})\) 由 \(\boldsymbol{z}\) 得到 \(\boldsymbol{x}\)。于是 \(p(\boldsymbol{x}, \boldsymbol{z}) = p(\boldsymbol{z})p(\boldsymbol{x} | \boldsymbol{z})\) 就是我们所需要的生成模型。
一种朴素的想法是:先用随机数生成器生成隐变量 \(\boldsymbol{z}\),然后用 \(p(\boldsymbol{x} | \boldsymbol{z})\) 从 \(\boldsymbol{z}\) 中生成出(或者说重构出) \(\boldsymbol{x}\),通过最小化重构损失来训练模型。这个想法的问题在于:我们无法找到生成的样本与原始样本之间的对应关系,重构损失算不了,无法训练。
VAE 的做法是引入后验分布 \(p(\boldsymbol{z} | \boldsymbol{x})\),训练过程变为:
采样一批原始样本 \(\boldsymbol{x}\)。用 \(p(\boldsymbol{z} | \boldsymbol{x})\) 获得每个样本 \(\boldsymbol{x}\) 对应的隐变量 \(\boldsymbol{z}\)。用 \(p(\boldsymbol{x} | \boldsymbol{z})\) 从隐变量 \(\boldsymbol{z}\) 中重构出 \(\boldsymbol{x}\),通过最小化重构损失来训练模型。从这个角度来看,\(p(\boldsymbol{z} | \boldsymbol{x})\) 相当于编码器,\(p(\boldsymbol{x} | \boldsymbol{z})\) 相当于解码器,训练结束后只需要保留解码器 \(p(\boldsymbol{x} | \boldsymbol{z})\) 即可。
除了重构损失以外,VAE 还有一项 KL 散度损失,希望近似的后验分布 \(q(\boldsymbol{z} | \boldsymbol{x})\) 尽量接近先验分布 \(p(\boldsymbol{z})\),即最小化二者的 KL 散度。
变分下界的推导现有 \(N\) 个由分布 \(P(\boldsymbol{x}; \boldsymbol{\theta})\) 生成的样本 \(\boldsymbol{x}^{(1)}, \ldots, \boldsymbol{x}^{(N)}\),我们可以使用极大似然估计从这些样本中估计出分布的参数 \(\boldsymbol{\theta}\),即
\[\begin{aligned} \boldsymbol{\theta} & = \operatorname*{argmax}_{\boldsymbol{\theta}} p(\boldsymbol{x}^{(1)}; \boldsymbol{\theta}) \cdots p(\boldsymbol{x}^{(N)}; \boldsymbol{\theta}) \\ & = \operatorname*{argmax}_{\boldsymbol{\theta}} \ln(p(\boldsymbol{x}^{(1)}; \boldsymbol{\theta}) \cdots p(\boldsymbol{x}^{(N)}; \boldsymbol{\theta})) \\ & = \operatorname*{argmax}_{\boldsymbol{\theta}} \sum_{i=1}^n \ln p(\boldsymbol{x}^{(i)}; \boldsymbol{\theta}).\end{aligned}\]后验分布 \(p(\boldsymbol{z} | \boldsymbol{x}) = \frac{p(\boldsymbol{z})p(\boldsymbol{x} | \boldsymbol{z})}{p(\boldsymbol{x})} = \frac{p(\boldsymbol{z})p(\boldsymbol{x} | \boldsymbol{z})}{\int_{\boldsymbol{z}} p(\boldsymbol{x}, \boldsymbol{z}) \mathrm{d}\boldsymbol{z}}\) 是 intractable 的,因为分母处的边缘分布 \(p(\boldsymbol{x})\) 积不出来。具体来说,联合分布 \(p(\boldsymbol{x}, \boldsymbol{z}) = p(\boldsymbol{z})p(\boldsymbol{x} | \boldsymbol{z})\) 的表达式非常复杂,\(\int_{\boldsymbol{z}} p(\boldsymbol{x}, \boldsymbol{z}) \mathrm{d}\boldsymbol{z}\) 这个积分找不到解析解。
需要使用变分推断解决后验分布无法计算的问题。我们使用一个形式已知的分布 \(q(\boldsymbol{z}|\boldsymbol{x}^{(i)}; \boldsymbol{\phi})\) 来近似后验分布 \(p(\boldsymbol{z}|\boldsymbol{x}^{(i)}; \boldsymbol{\theta})\),于是有
\[\begin{aligned} \log p(\boldsymbol{x}^{(i)}) & = \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) - \log p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] \mathrm{d}\boldsymbol{z} + \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] \mathrm{d}\boldsymbol{z} + \log p(\boldsymbol{x}^{(i)}) \cdot 1 \\ & = \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})\log\frac{q(\boldsymbol{z}|\boldsymbol{x}^{(i)})}{p(\boldsymbol{z}|\boldsymbol{x}^{(i)})} \mathrm{d}\boldsymbol{z} + \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] \mathrm{d}\boldsymbol{z} + \log p(\boldsymbol{x}^{(i)}) \cdot \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})\mathrm{d}\boldsymbol{z} \\ & = \mathrm{KL}[q(\boldsymbol{z}|\boldsymbol{x}^{(i)}), p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] + \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] \mathrm{d}\boldsymbol{z} + \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})\log p(\boldsymbol{x}^{(i)}) \mathrm{d}\boldsymbol{z} \\ & = \mathrm{KL}[q(\boldsymbol{z}|\boldsymbol{x}^{(i)}), p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] + \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log p(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log p(\boldsymbol{x}^{(i)})] \mathrm{d}\boldsymbol{z} \\ & = \mathrm{KL}[q(\boldsymbol{z}|\boldsymbol{x}^{(i)}), p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] + \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log (p(\boldsymbol{z}|\boldsymbol{x}^{(i)})p(\boldsymbol{x}^{(i)}))] \mathrm{d}\boldsymbol{z} \\ & = \mathrm{KL}[q(\boldsymbol{z}|\boldsymbol{x}^{(i)}), p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] + \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log p(\boldsymbol{x}^{(i)}, \boldsymbol{z})] \mathrm{d}\boldsymbol{z} \\ & = \mathrm{KL}[q(\boldsymbol{z}|\boldsymbol{x}^{(i)}), p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] + \mathbb{E}_{\boldsymbol{z} \sim q(\boldsymbol{z}|\boldsymbol{x}^{(i)})}[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log p(\boldsymbol{x}^{(i)}, \boldsymbol{z})] \\ & = \mathrm{KL}[q(\boldsymbol{z}|\boldsymbol{x}^{(i)}), p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] + L(\boldsymbol{\theta}, \boldsymbol{\phi}; \boldsymbol{x}^{(i)}) \\ & \geq L(\boldsymbol{\theta}, \boldsymbol{\phi}; \boldsymbol{x}^{(i)}).\end{aligned}\]利用 KL 散度大于等于 0这一特性,我们得到了对数似然 \(\log p(\boldsymbol{x}^{(i)})\) 的一个下界\(L(\boldsymbol{\theta}, \boldsymbol{\phi}; \boldsymbol{x}^{(i)})\),于是可以将最大化对数似然改为最大化这个下界。
这个下界可以进一步写成
\[\begin{aligned} L(\boldsymbol{\theta}, \boldsymbol{\phi}; \boldsymbol{x}^{(i)}) & = \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log p(\boldsymbol{x}^{(i)}, \boldsymbol{z})] \mathrm{d}\boldsymbol{z} \\ & = \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log (p(\boldsymbol{z})p(\boldsymbol{x}^{(i)}|\boldsymbol{z}))] \mathrm{d}\boldsymbol{z} \\ & = \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log p(\boldsymbol{z}) + \log p(\boldsymbol{x}^{(i)}|\boldsymbol{z})] \mathrm{d}\boldsymbol{z} \\ & = -\int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) - \log p(\boldsymbol{z})] \mathrm{d}\boldsymbol{z} + \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})\log p(\boldsymbol{x}^{(i)}|\boldsymbol{z})] \mathrm{d}\boldsymbol{z} \\ & = -\mathrm{KL}[q(\boldsymbol{z}|\boldsymbol{x}^{(i)}), p(\boldsymbol{z})] + \mathbb{E}_{\boldsymbol{z} \sim q(\boldsymbol{z}|\boldsymbol{x}^{(i)})}[\log p(\boldsymbol{x}^{(i)}|\boldsymbol{z})]. \\\end{aligned}\]其中的第一项是 KL 散度损失,第二项是重构损失。
KL 散度损失使用标准正态分布作为先验分布,即 \(p(\boldsymbol{z}) = N(\boldsymbol{z}; \boldsymbol{0}, \boldsymbol{I})\)。
使用一个由 MLP 的输出来参数化的正态分布作为近似后验分布,即 \(q(\boldsymbol{z}|\boldsymbol{x}^{(i)}; \boldsymbol{\phi}) = N(\boldsymbol{z}; \boldsymbol{\mu}(\boldsymbol{x}^{(i)}; \boldsymbol{\phi}), \boldsymbol{\sigma}^2(\boldsymbol{x}^{(i)}; \boldsymbol{\phi})\boldsymbol{I})\)。
选择正态分布的好处在于 KL 散度的这个积分可以写出解析解,训练时直接按照公式计算即可,无需通过采样的方式来算积分。
由于我们选择的是各分量独立的多元正态分布,因此只需要推导一元正态分布的情形即可:
\[\begin{aligned} \mathrm{KL}[N(z; \mu, \sigma^2), N(z; 0, 1)] & = \int_z N(z; \mu, \sigma^2)\log\frac{N(z; \mu, \sigma^2)}{N(z; 0, 1)} \mathrm{d}z \\ & = \int_z N(z; \mu, \sigma^2) \log\frac{\frac{1}{\sqrt{2\pi}\sigma}\exp\left(-\frac{(z - \mu)^2}{2\sigma^2}\right)}{\frac{1}{\sqrt{2\pi}}\exp\left(-\frac{z^2}{2}\right)} \mathrm{d}z \\ & = \int_z N(z; \mu, \sigma^2) \log\left(\frac{1}{\sqrt{\sigma^2}}\exp\left(\frac{1}{2}\left(-\frac{(z - \mu^2)^2}{\sigma^2} + z^2\right)\right)\right) \mathrm{d}z \\ & = \frac{1}{2}\int_z N(z; \mu, \sigma^2) \left(-\log\sigma^2 - \frac{(z - \mu)^2}{\sigma^2} + z^2\right)\mathrm{d}z \\ & = \frac{1}{2}\left(-\log\sigma^2\int_z N(z; \mu, \sigma^2) \mathrm{d}z - \frac{1}{\sigma^2}\int_z N(z; \mu, \sigma^2)(z - \mu)^2\mathrm{d}z + \int_z N(z; \mu, \sigma^2)z^2\mathrm{d}z\right) \\ & = \frac{1}{2}\left(-\log\sigma^2 \cdot 1 - \frac{1}{\sigma^2} \cdot \sigma^2 + \mu^2 + \sigma^2\right) \\ & = \frac{1}{2}(-\log\sigma^2 - 1 + \mu^2 + \sigma^2).\end{aligned}\]解释一下倒数第三行的三个积分:
\(\int_z N(z; \mu, \sigma^2) \mathrm{d}z\) 是概率密度函数的积分,也就是 1。\(\int_z N(z; \mu, \sigma^2)(z - \mu)^2\mathrm{d}z\) 是方差的定义,也就是 \(\sigma^2\)。\(\int_z N(z; \mu, \sigma^2)z^2\mathrm{d}z\) 是正态分布的二阶矩,结果为 \(\mu^2 + \sigma^2\)。重构损失伯努利分布模型当 \(\boldsymbol{x}\) 是二值向量时,可以用伯努利分布(两点分布)来建模 \(p(\boldsymbol{x}|\boldsymbol{z})\),即认为向量 \(\boldsymbol{x}\) 的每个维度都服从对应的相互独立的伯努利分布。使用一个 MLP 来计算各维度所对应的伯努利分布的参数,第 \(i\) 维伯努利分布的参数为 \(y_i = \boldsymbol{y}(\boldsymbol{z})_i\),于是有
\[p(\boldsymbol{x}|\boldsymbol{z}) = \prod_{i=1}^D y_i^{x_i}(1 - y_i)^{1 - x_i},\]\[\log p(\boldsymbol{x}|\boldsymbol{z}) = \sum_{i=1}^D x_i\log y_i + (1 - x_i)\log(1 - y_i).\]其中 \(D\) 表示向量 \(\boldsymbol{x}\) 的维度。可见此时最大化 \(\log p(\boldsymbol{x}|\boldsymbol{z})\) 等价于最小化交叉熵损失。
正态分布模型当 \(\boldsymbol{x}\) 是实值向量时,可以用正态分布来建模 \(p(\boldsymbol{x}|\boldsymbol{z})\)。使用一个 MLP 来计算正态分布的参数,于是有
\[\begin{aligned} p(\boldsymbol{x}|\boldsymbol{z}) & = N(\boldsymbol{x}; \boldsymbol{\mu}, \boldsymbol{\sigma}^2\boldsymbol{I}) \\ & = \prod_{i=1}^D N(x_i; \mu_i, \sigma_i^2) \\ & = \left(\prod_{i=1}^D\frac{1}{\sqrt{2\pi}\sigma_i}\right)\exp\left(\sum_{i=1}^D-\frac{(x_i - \mu_i)^2}{2\sigma_i^2}\right),\end{aligned}\]\[\log p(\boldsymbol{x}|\boldsymbol{z}) = -\frac{D}{2}\log 2\pi - \frac{1}{2}\sum_{i=1}^D\log\sigma_i^2 - \frac{1}{2}\sum_{i=1}^D\frac{(x_i - \mu_i)^2}{\sigma_i^2}.\]很多时候我们会假设 \(\sigma_i^2\) 是一个常数,于是 MLP 只需要输出均值参数 \(\boldsymbol{\mu}\) 即可。此时有
\[\log p(\boldsymbol{x}|\boldsymbol{z}) \sim -\frac{1}{2}\sum_{i=1}^D(x_i - \mu_i)^2 = -\frac{1}{2}\|\boldsymbol{x} - \boldsymbol{\mu}(\boldsymbol{z})\|^2.\]可见此时最大化 \(\log p(\boldsymbol{x}|\boldsymbol{z})\) 等价于最小化 MSE 损失。
重参数化技巧需要使用重参数化技巧解决采样 \(z\) 时不可导的问题。解决的思路是先从无参数分布中采样一个 \(\varepsilon\),再通过变换得到 \(z\)。
从 \(N(\mu, \sigma^2)\) 中采样一个 \(z\),相当于先从 \(N(0, 1)\) 中采样一个 \(\varepsilon\),然后令 \(z = \mu + \varepsilon\cdot\sigma\)。
相关知识技巧,通过取对数把乘除变成加减:
\[\ln ab = \ln a + \ln b,\ \ln\frac{a}{b} = \ln a - \ln b.\]随机变量的函数的期望:
\[\mathbb{E}_{x \sim P(x)} g(x) = \int_x p(x)g(x) \mathrm{d}x,\]利用此公式可以将积分改写成期望的形式,这样就可以用采样的方式计算积分了(蒙特卡罗积分法)。
条件概率密度的定义:
\[p_{Y|X}(y|x) = \frac{p(x, y)}{p_X(x)},\]此处的 \(p\) 并不是概率而是概率密度函数,但是这个公式在形式上跟条件概率公式是一样的。
参考资料苏剑林的 VAE 系列博客:
变分自编码器(一):原来是这么一回事 - 科学空间变分自编码器(二):从贝叶斯观点出发 - 科学空间变分自编码器(三):这样做为什么能成? - 科学空间15 分钟了解变分推理:
【15分钟】了解变分推理 - 哔哩哔哩【15分钟】了解变分自编码器 - 哔哩哔哩