VAE Explained

Explaining VAE without jumping steps

See my diffusion blog for a summarization of the VAE logic.

The starting point is that we want to generate data where \(D \sim p_{data}(x)\). The data distribution can be arbitrarily complicated, so we are motivated to approximate it using a few simpler distributions combined. A Gaussian mixture model is an example. In a Gaussian mixture model, to generate samples from \(p_{data}(x)\), we first sample \(Z \sim \text{category}(z_1, z_2, \cdots, z_n)\) and then sample \(x \sim X \mid Z = \mathcal{N}(\mu(Z), \sigma(Z))\). In this case, we actually sample from \(p(x, z)\). \(Z\) is called the latent variable.

The whole idea of VAE is to approximate \(p_{data}(x)\) using some latent variable \(Z\) as well. To generate a data sample \(x\), we first sample a \(z\) and then sample from \(X \mid Z\).

Let us parameterize our generator given the latent variable \(Z\) as \(p_{\theta}(x \mid z)\). One way to evaluate our generative model is to check the log-likelihood of the data point we generate:

\[\begin{equation} \log p_{\theta}(x) = \log \int p_{\theta}(x, z) \, dz = \log \int \frac{p_{\theta}(x, z)}{q(z)} q(z) \, dz = \log \mathbb{E}_{z \sim q(z)}\left[ \frac{p_{\theta}(x, z)}{q(z)} \right]. \end{equation}\]

We can read \(\mathbb{E}_{z \sim q(z)}\left[ \frac{p_{\theta}(x, z)}{q(z)} \right]\) as “weight \(p_{\theta}(x \mid z)\) according to the arbitrary prior \(q(z)\).” Now, to do gradient ascent to maximize the log-likelihood, we need to estimate \(\log \mathbb{E}_{z \sim q(z)}\left[ \frac{p_{\theta}(x, z)}{q(z)} \right]\) first. However, since the expectation operator is inside the log, we can’t directly use the Monte Carlo estimator. But according to Jensen’s inequality and the concavity of the \(\log\) function, we have:

\[\begin{equation} \log p_\theta(x) = \log \mathbb{E}_{z \sim q(z)}\left[ \frac{p_{\theta}(x, z)}{q(z)} \right] \geq \mathbb{E}_{z \sim q(z)}\left[ \log \frac{p_{\theta}(x, z)}{q(z)} \right]. \end{equation}\]

We can estimate the right-hand side of eq(2) using the Monte Carlo estimator:

\[\begin{equation} \mathbb{E}_{z \sim q(z)}\left[ \log \frac{p_{\theta}(x, z)}{q(z)} \right] \approx \frac{1}{K} \sum_{i=1}^{K}\log \frac{p_{\theta}(x, z_i)}{q(z_i)}. \end{equation}\]

Now, we can actually train something. We can choose an arbitrary prior \(q(z)\), and we can estimate a lower bound of the log-likelihood \(\mathbb{E}_{z \sim q(z)}\left[ \log \frac{p_{\theta}(x, z)}{q(z)} \right]\) by sampling \(z\) and \(x\). In this way, we can do gradient ascent on \(\theta\) to optimize our generator \(p_{\theta}(x \mid z)\). However, there are still two questions left:

  1. We now arbitrarily choose the prior \(q(z)\). How can we choose it wisely?
  2. How good is our lower bound? Intuitively, when we optimize a lower bound, we don’t necessarily optimize our true objective because the lower bound may not be good (tight). We now ask the question: how tight is our lower bound?

We now have no idea about question 1. But we can try to solve question 2:

\[\begin{equation} \begin{split} \log p_\theta(x) - \mathbb{E}_{z \sim q(z)}\left[ \log \frac{p_{\theta}(x, z)}{q(z)} \right] & = \mathbb{E}_{z \sim q(z)} \left[ \log p(x) - \log \frac{p_{\theta}(x, z)}{q(z)} \right] \\ & = \mathbb{E}_{z \sim q(z)} \left[ \log \frac{p_\theta(x) q(z)}{p_\theta(x, z)} \right] \\ & = \mathbb{E}_{z \sim q(z)} \left[ \frac{q(z)}{p_\theta(z \mid x)} \right] \\ & = D_{KL}(q(z) \ || \ p_\theta(z \mid x)). \end{split} \end{equation}\]

Fortunately, eq(4) solves both questions. The best guess for \(q(z)\) is \(p_\theta(z \mid x)\), and the closer \(q(z)\) is to \(p_\theta(z \mid x)\), the better our bound will be. Unfortunately, we don’t know \(p_\theta(z \mid x)\)! It would force us to ‘invert’ our generator, but if our generator is a neural network, we definitely don’t know the probability of different latent \(z\) underlying the result \(x\). However, fortunately, we can approximate \(p_\theta(z \mid x)\) using another neural network! Let’s use \(q_{\phi}(z \mid x)\) to approximate \(p_\theta(z \mid x)\). \(q_{\phi}(z \mid x)\) will take an input of \(x\) and output the guess of the underlying latent value \(z\) according to the generator. Note that \(\theta\) and $\phi$ are highly related: for different generators (parameterized by \(\theta\)), of course, we will have different guesses of the prior (parameterized by \(\phi\)) given the data evidence \(x\).

Since the parameters are highly related, we need to update them together. By updating \(\phi\), we want to get a tighter lower bound, and by updating \(\theta\), we want to increase the value of the lower bound and increase the log-likelihood (hopefully).

According to eq(2), our loss function is:

\[\begin{equation} \mathcal{L}(x; \theta, \phi) = \mathbb{E}_{z \sim q_{\phi}(z \mid x)}\left[ \log \frac{p_{\theta}(x, z)}{q_{\phi}(z \mid x)} \right]. \end{equation}\]

Freezing \(\phi\), obtaining the gradient w.r.t. \(\theta\) is easy:

\[\begin{equation} \begin{split} \nabla_{\theta} \mathcal{L}(x; \theta, \phi) & = \nabla_{\theta} \mathbb{E}_{z \sim q_{\phi}(z \mid x)}\left[ \log \frac{p_{\theta}(x, z)}{q_{\phi}(z \mid x)} \right] \\ & = \mathbb{E}_{z \sim q_{\phi}(z \mid x)} \left[ \nabla_{\theta} \log \frac{p_{\theta}(x, z)}{q_{\phi}(z \mid x)} \right] \\ & \approx \frac{1}{K} \sum_{i=1}^{K} \nabla_{\theta} \log \frac{p_{\theta}(x, z_i)}{q_\phi(z_i \mid x)}. \end{split} \end{equation}\]

However, freezing \(\theta\), obtaining the gradient w.r.t. \(\phi\) is not easy. The reason is that when we update \(\phi\), we change the potential sample results of $q_\phi(z \mid x)$. The gradient is the best update direction of the sample results, and we can’t estimate it through Monte Carlo because it is impossible to try every direction and find the best one. The root of the problem is that we can’t try all directions to find the best one. In other words, the sample result changes when parameters are updated, and we can’t know how to change the sample result because we can’t try all update directions. How can we solve this problem?

Here is the well-known reparameterization trick. See wikipedia for more explanation. Let \(\epsilon \sim p(\epsilon)\) and \(\mathbb{E}_{z \sim q_{\phi}(z \mid x)}\left[ \log \frac{p_{\theta}(x, z)}{q_{\phi}(z \mid x)} \right] = \mathbb{E}_{z \sim q_{\phi}(z \mid x)} \left[ f_{\phi}(x) \right]\). Now write \(f_\phi(x) = f(g_\phi(\epsilon))\) where \(f, g\) are deterministic. For instance, let \(\epsilon \sim \mathcal{N}(0, 1)\) and \(g_\phi(\epsilon) = \mu + \sigma \epsilon\). In this way, we know the result of any change to the parameters because the randomness of generating samples has nothing to do with the parameters we want to update. We can write

\[\begin{equation} \mathbb{E}_{z \sim q_{\phi}(z \mid x)}\left[ \log \frac{p_{\theta}(x, z)}{q_{\phi}(z \mid x)} \right] = \mathbb{E}_{z \sim q_{\phi}(z \mid x)} \left[ f_{\phi}(x) \right] = \mathbb{E}_{\epsilon \sim p(\epsilon)} f(g_\phi(\epsilon)), \end{equation}\]

and

\[\begin{equation} \begin{split} \nabla_\phi \mathbb{E}_{z \sim q_{\phi}(z \mid x)}\left[ \log \frac{p_{\theta}(x, z)}{q_{\phi}(z \mid x)} \right] & = \nabla_\phi \mathbb{E}_{\epsilon \sim p(\epsilon)} f(g_\phi(\epsilon)) \\ & = \mathbb{E}_{\epsilon \sim p(\epsilon)} \nabla_\phi \left[ \log \frac{p_{\theta}(x, z(\epsilon; \phi))}{z(\epsilon; \phi)} \right] \\ & \approx \frac{1}{K} \sum_{i=1}^{K} \nabla_\phi \left[ \log \frac{p_{\theta}(x, z(\epsilon_i; \phi))}{z(\epsilon_i; \phi)} \right]. \end{split} \end{equation}\]

Now we’re good. We have the loss function and gradient w.r.t. its parameters. For each batch, we sample \(X_i \sim p_{data}(x)\) and for each sample in the batch, we estimate the gradient using eq(6) and eq(8). Then we average the gradient across all data in the batch and update the parameters.