VAE and Diffusion Explained

A comprehensive explanation of Variational Autoencoders and Diffusion models.

Part 1: Variational Autoencoder (VAE)

The starting point is that we want to generate data where . The data distribution can be arbitrarily complicated, so we are motivated to approximate it using a few simpler distributions combined. A Gaussian mixture model is an example. In a Gaussian mixture model, to generate samples from , we first sample and then sample . In this case, we actually sample from . is called the latent variable.

The whole idea of VAE is to approximate using some latent variable as well. To generate a data sample , we first sample a and then sample from .

Let us parameterize our generator given the latent variable as . One way to evaluate our generative model is to check the log-likelihood of the data point we generate:

We can read as “weight according to the arbitrary prior .” Now, to do gradient ascent to maximize the log-likelihood, we need to estimate first. However, since the expectation operator is inside the log, we can’t directly use the Monte Carlo estimator. But according to Jensen’s inequality and the concavity of the function, we have:

We can estimate the right-hand side using the Monte Carlo estimator:

Now, we can actually train something. We can choose an arbitrary prior , and we can estimate a lower bound of the log-likelihood by sampling and . In this way, we can do gradient ascent on to optimize our generator . However, there are still two questions left:

  1. We now arbitrarily choose the prior . How can we choose it wisely?
  2. How good is our lower bound? Intuitively, when we optimize a lower bound, we don’t necessarily optimize our true objective because the lower bound may not be good (tight). We now ask the question: how tight is our lower bound?

We now have no idea about question 1. But we can try to solve question 2:

This solves both questions. The best guess for is , and the closer is to , the better our bound will be. Unfortunately, we don’t know ! It would force us to ‘invert’ our generator, but if our generator is a neural network, we definitely don’t know the probability of different latent underlying the result . However, we can approximate using another neural network! Let’s use to approximate . will take an input of and output the guess of the underlying latent value according to the generator. Note that and are highly related: for different generators (parameterized by ), of course, we will have different guesses of the prior (parameterized by ) given the data evidence .

Since the parameters are highly related, we need to update them together. By updating , we want to get a tighter lower bound, and by updating , we want to increase the value of the lower bound and increase the log-likelihood (hopefully).

Our loss function is:

Freezing , obtaining the gradient w.r.t. is easy:

However, freezing , obtaining the gradient w.r.t. is not easy. The reason is that when we update , we change the potential sample results of . The gradient is the best update direction of the sample results, and we can’t estimate it through Monte Carlo because it is impossible to try every direction and find the best one. The root of the problem is that we can’t try all directions to find the best one. In other words, the sample result changes when parameters are updated, and we can’t know how to change the sample result because we can’t try all update directions.

Here is the well-known reparameterization trick. Let and . Now write where are deterministic. For instance, let and . In this way, we know the result of any change to the parameters because the randomness of generating samples has nothing to do with the parameters we want to update. We can write:

and

Now we’re good. We have the loss function and gradient w.r.t. its parameters. For each batch, we sample and for each sample in the batch, we estimate the gradient. Then we average the gradient across all data in the batch and update the parameters.

Part 2: Hierarchical VAE

Previously, we only have a single latent variable and is generated given . Now we consider the case with two latent variables . The generative model is now . Notice that here we have two unknown priors: and . The generative model is only responsible for observing and generating accordingly. We can read it as “first sample from prior of , then sample given , then sample (generate) given “. Here we assume the Markov property: is independent of given .

Similar to the single-latent VAE, we evaluate the model by expectation of log-likelihood , where .

Since we don’t know the prior, we can do importance sampling by guessing two tractable priors and . You can think of them as a joint prior :

Similarly, according to Jensen’s inequality:

Similarly, we can derive the best prior guess as follows:

Therefore, the best prior is by guessing the latent variable according to the generator and the data evidence , that is, . And we will parametrize and learn an encoder to approximate , that is, . This can be read as “first encode as according to , then encode according to “.

Part 3: Diffusion as Hierarchical VAE

The starting point of diffusion is to diffuse the data using multiple-step random noise. The noise scheduling policy is how we add noise at each timestep and is an independent and extensively-studied problem. Suppose we use the naive noise scheduling policy, then we have . It is equivalent to , where . Here we suppose we add noise times, so .

So, the encoder is given or scheduled, but not learned as in VAE. What are the consequences of a given encoder?

From a training perspective, we don’t need to parametrize and learn the ‘best’ encoder with respect to the decoder. We only have a given encoder. So we don’t need to co-optimize and and can lead to more stable and efficient training.

From a math perspective, given the unperturbed data , corresponding perturbed data distributions of any timestep are known and can be written as Gaussian.

Note that our goal is still to generate data given latent variables and evaluate our generative model through log-likelihood. Now our latent variables are . So we can write the log-likelihood same as before:

Now given , we know all distributions of each timestep, so we can do importance sampling according to our known :

So we have , and:

Therefore, the expected log-likelihood (objective function) is:

In theory, this can be used to train something. The training process will be first sample a data point , then diffuse times and get probability paths . For each path, can be easily calculated and .

However, the issue is that this requires us to sample the whole probability path for each , and the probability path can be very long if is big (typically, ). But you can think of the sample process: we have a generator for each timestep and we have a given diffusion encoder each timestep as well. Our intuition is that we can write the huge equation into some small pieces and each small piece is responsible for one timestep match. If we can do so, we only need sample two timesteps rather than the whole path, and thus make our training loop more light-weighted. Let’s write in finer granularity:

The result matches our intuition: says that the timestep of the generative path does not go wrong. says the last generative step does not go wrong. And says the first generative step (sample ) does not go wrong. Each term is responsible for one timestep!

We know that the KL divergence of two Gaussians can be written explicitly. And we take that advantage by parametrizing each generative step as a Gaussian:

And the training loss can be written as:

Conclusion: Comparing Diffusion with VAE

Aside from the tedious symbol manipulation skills, what can we learn by comparing VAE and diffusion model? In VAE, we parametrize and learn the best prior, that is, the inverse of the generative model. This requires us to co-optimize two networks, one to push up the bound and one to tighten the bound. However, in diffusion model, we fix the encoding process (diffusion process, noise scheduling policy,…), and we learn a generative model to inverse that process. From this perspective, the decoder is more like a data augmentation process rather than a decoder which can catch semantic stuff and learn representations. More importantly, we don’t need to co-optimize two networks. Also, the lower bound is tightened during our training process as well, because a tight bound needs a good inverse decoder, which is exactly the training objective. That being said, we align two goals of pushing up the bound and tightening the bound.