Explaining Diffusion without jumping steps
See my VAE blog for a detailed explanation of the VAE logic.
We have a generative model \(p_{\theta}(x \mid z)\) to generate data given a latent variable, and we evaluate the model by checking the expectation of log-likelihood \(\mathbb{E}_{x \sim p_{data}} \log p_{\theta}(x)\), where \(\log p_{\theta}(x) = \int_Z p(x,z) dz\). However, we definitely don’t know the prior distribution \(q_{prior}(z)\) of the latent variable \(z\), so we definitely can’t evaluate \(p(x,z) = p(z \mid z)q_{prior}(z)\). How to solve this problem? We can only assume an arbitrary tractable prior \(q(z)\). For instance, \(q(z)\) can be \(\mathcal{N}(0,1)\). Now, we can use the technique called importance sampling because we can only sample from \(q(z)\) but we want to evaluate the property related to \(q_{prior}(z)\), which is exactly the purpose of importance sampling:
\[\log p_{\theta}(x) = \log \int_Z p(x,z) dz = \log \int_Z p(x,z)\frac{q(z)}{q(z)} = \log \mathbb{E}_{z \sim q(z)} \left[ \frac{p(x,z)}{q(z)} \right]\]Recall that then we use Jensen’s equality to make our objective function estimable through Monte Carlo. Also, recall that we can then get the best prior guess is \(p_{\theta}(z \dim x)\). Finally, we parametrize the intractable \(p_{\theta}(z \mid x)\) by \(q_\phi (z \mid x)\), and we co-optimize \(\theta\) and \(\phi\) to push up the lower bound and tighten the lower bound simultaneously.
The whole logical starting point is to generate \(x\) from latent, and the latent distribution is unknown.
Previously, we only have a single latent variable \(z\) and \(x\) is generated given \(z\). Now we consider the case with two latent variables \(z_1, z_2\). The generative model is now \(p_{\theta}(x,z_1,z_2) = p_{\theta}(x \mid z_1)q_{prior}(z_1 \mid z_2)q_{prior}(z_2)\). Notice that here we have two unknown priors: \(q_{prior}(z_1 \mid z_2)\) and \(q_{prior}(z_2)\). The generative model \(p_{\theta}(x \mid z_1)\) is only responsible for observing \(z_1\) and generate \(x\) accordingly. We can read it as “first sample from prior of \(z_2\), then sample \(z_1\) given \(z_2\), then sample(generate) \(x\) given \(z_1\)”. As you can see, here we assume the Markov property, that is, x is independent of \(z_2\) given \(z_1\).
Similar to the single-latent VAE, we evaluate the model by expectation of log-likelihood \(\mathbb{E}_{x \sim p_{data}} \log p_{\theta}(x)\), where \(\log p_{\theta}(x) = \int_{Z_1,Z_2} p(x,z_1,z_2) dz_1dz_2\).
Since we don’t know the prior, we can do importance sampling by guessing two tractable priors \(q(z_1 \mid z_2)\) and \(q(z_2)\). You can think them as a joint prior \(q(z_1 \mid z_2)q(z_2) = q(z_1,z_2)\):
\[\begin{equation} \begin{split} p_{\theta}(x) &= \int_{Z_1,Z_2} p(x,z_1,z_2) dz_1dz_2 \\ &= \int_{Z_1,Z_2} \frac{p(x,z_1,z_2)q(z_1,z_2)}{q(z_1,z_2)} dz_1dz_2 \\ &= \mathbb{E}_{Z_1,Z_2 \sim q(z_1,z_2)} \left[ \frac{p(x,z_1,z_2)}{q(z_1,z_2)} \right]. \end{split} \end{equation}\]Similarly, according to Jensen’s inequality:
\[\begin{equation} \log p_{\theta}(x) = \log \mathbb{E}_{Z_1,Z_2 \sim q(z_1,z_2)} \left[ \frac{p(x,z_1,z_2)}{q(z_1,z_2)}\right] \geq \mathbb{E}_{Z_1,Z_2 \sim q(z_1,z_2)} \left[ \log \frac{p(x,z_1,z_2)}{q(z_1,z_2)} \right]. \end{equation}\]Similarly, we can derive the best prior guess as follows:
\[\begin{equation} \log p_{\theta}(x) - \mathbb{E}_{Z_1,Z_2 \sim q(z_1,z_2)} \left[ \log \frac{p(x,z_1,z_2)}{q(z_1,z_2)} \right] = D_{KL} (q(z_1,z_2) \| p_{\theta}(z_1,z_2 \mid x)). \end{equation}\]Therefore, the best prior is by guessing the latent variable according to the generator and the data evidence \(x\), that is, \(p_{\theta}(z_1,z_2 \mid x)\). And we will parametrize and learn and encoder to approximate \(p_{\theta}(z_1,z_2 \mid x)\), that is, \(q_{\phi}(z_1,z_2 \mid x)\). This can be read as “first encode x as \(z_1\) according to \(q_\phi(z_1 \mid x)\), then encode \(z_1\) according to \(q_\phi(z_2 \mid z_1)\)”.
The starting point of diffusion is to diffuse the data using multiple-step random noise. The noise scheduling policy is how we add noise at each timestep and is an independent and extensively-studied problem. Suppose we use the naive noise scheduling policy, then we have \(x_t \sim \mathcal{N}(x_{t};\sqrt{\alpha_t}x_{t-1}, \sqrt{1-\alpha_t}\mathbb{I})\). It is equivalent that \(x_t = \sqrt{\alpha_t}x_{t-1} + \sqrt{1-\alpha_t}\epsilon_{t-1},\ \epsilon_{t-1} \sim \mathcal{N}(0,\mathbb{I})\). Here we suppose we add noise \(T\) times, so \(t \in [1,T]\).
So, the encoder is given or scheduled, but not learned as in VAE. What are the consequences of a given encoder?
From a training perspective, we don’t need to parametrize and learn the ‘best’ encoder with respect to the decoder. We only have a given encoder. So we don’t need to co-optimize \(\phi\) and \(\theta\) and can lead to more stable and efficient training.
From a math perspective, given the unperturbed data \(x_0\), corresponding perturbed data distributions of any timestep \(X_t \mid X_0=x_0, \forall t \in [1,T]\) are known and can be written as Gaussian.
Note that our goal is still to generate data \(x\) given latent variables and evaluate our generative model through log-likelihood. Now our latent variables are \(x_{1:T}\). So we can write the log-likelihood same as before:
\[\begin{equation} \log p_{\theta}(x_0) = \log \int \cdots \int p_{\theta}(x_0,x_1,\cdots,x_T)dx_{1}\cdots dx_T. \end{equation}\]Now given \(x_0\), we know all distributions of each timestep, so we can do importance sampling according to our known \(q(x_{1:T} \mid x_0)\):
\[\begin{equation} \begin{split} \int \cdots \int p(x_0,x_1,\cdots,x_T)dx_{1}\cdots dx_T &= \int \cdots \int \frac{p_{\theta}(x_0,x_1,\cdots,x_T)q(x_{1:T} \mid x_0)}{q(x_{1:T} \mid x_0)} dx_{1} \cdots dx_T \\ &= \mathbb{E}_{x_{1:T} \sim q(X_{1:T} \mid X_0=x_0)} \left[ \frac{p_{\theta}(x_0,x_1,\cdots,x_T)}{q(x_{1:T} \mid x_0)} \right]. \end{split} \end{equation}\]So we have \(\log p_{\theta}(x_0) \geq \mathbb{E}_{x_{1:T} \sim q(X_{1:T} \mid X_0=x_0)} \left[ \log \frac{p_{\theta}(x_0,x_1,\cdots,x_T)}{q(x_{1:T} \mid x_0)} \right]\), and
\[\begin{equation} \begin{split} &\log p_{\theta}(x_0) - \mathbb{E}_{x_{1:T} \sim q(X_{1:T} \mid X_0=x_0)} \left[ \log \frac{p_{\theta}(x_0,x_1,\cdots,x_T)}{q(x_{1:T} \mid X_0=x_0)} \right] \\ &= D_{KL}\left( q(x_{1:T} \mid X_0=x_0) \| p_{\theta}(x_{1:T} \mid X_0=x_0) \right). \end{split} \end{equation}\]Therefore, the expected log-likelihood(objective function) is
\[\begin{equation} \begin{split} \mathbb{E}_{x_0} \log p_{\theta}(x_0) &\geq \mathbb{E}_{x_0 \sim p_{data}} \left\{\mathbb{E}_{x_{1:T} \sim q(X_{1:T} \mid X_0=x_0)} \left[ \log \frac{p_{\theta}(x_0,x_1,\cdots,x_T)}{q(x_{1:T} \mid x_0)} \right] \right\} = \mathcal{L}(\theta) \end{split} \end{equation}\]In theory, the eq(7) can be used to train something. The training process will be first sample a data point \(x_0\), then diffuse \(x_0\) \(K\) times and get \(K\) probability paths \(x_{1:T} \sim q(X_{1:T} \mid X_0=x_0)\). For each path, \(q(x_{1:T} \mid x_0)\) can be easily calculated and \(p_{\theta}(x_0,x_1,\cdots,x_T) = p_\theta(X_0=x_0 \mid x_1)\left[ \prod_{t=2}^{T}{p_\theta(x_{t-1} \mid x_t)} \right] p_\theta(x_T)\). Doing this iteratively by sampling a lot of \(x_0\) we can get a batch of training data and do backprop. Therefore, for one \(x_0\),
\[\begin{equation} \begin{split} \log p_{\theta}(x_0) &= \mathbb{E}_{x_{1:T} \sim q(X_{1:T} \mid X_0=x_0)} \left[ \log \frac{p_{\theta}(x_0,x_1,\cdots,x_T)}{q(x_{1:T} \mid x_0)} \right] \\ &= \frac{1}{K} \sum_{i=1}^K \left[ \log \frac{p_{\theta}(x_0,x^i_1,\cdots,x^i_T)}{q(x^i_{1:T} \mid x_0)} \right] \\ &= \frac{1}{K} \sum_{i=1}^K \left[ \frac{p_\theta(X_0=x_0 \mid x^i_1)\left[ \prod_{t=2}^{T}{p_\theta(x^i_{t-1} \mid x^i_t)} \right] p_\theta (x^i_T)}{q(x^i_{1:T} \mid x_0)} \right]. \end{split} \end{equation}\]However, the issue is that this requires us to sample the whole probability path for each \(x_0\), and the probability path can be very long if \(T\) is big (typically, \(T=1000\)). But you can think of the sample process: we have a generator for each timestep \(p_\theta(x_{t-1} \mid x_t)\) and we have a given diffusion encoder \(q(x_t \mid x_{t-1})\) each timestep as well. The above eq(8) is just saying that the whole probability path of the generative model should not go wrong at any timestep, or the \(\log p_{\theta}(x_0)\) will be low. So, our intuition is that we can write the huge equation into some small pieces and each small piece is responsible for one timestep match. If we can do so, we only need sample two timesteps rather than the whole path, and thus make our training loop more light-weighted. Let’s write \(\log p_{\theta}(x_0)\) in finer granularity:
\[\begin{equation} \begin{split} \log p_{\theta}(x_0) &= \mathbb{E}_{x_{1:T} \sim q(X_{1:T} \mid X_0=x_0)} \left[ \log \frac{p_{\theta}(x_0,x_1,\cdots,x_T)}{q(x_{1:T} \mid x_0)} \right] \\ &= \mathbb{E}_q \left[ \log \frac{p_\theta(X_0=x_0 \mid x_1)\left[ \prod_{t=2}^{T}{p_\theta(x_{t-1} \mid x_t)} \right] p_\theta (x_T)}{q(x_{1:T} \mid x_0)} \right] \\ &= \mathbb{E}_q \left[ \log \frac{p_\theta(X_0=x_0 \mid x_1)\left[ \prod_{t=2}^{T}{p_\theta(x_{t-1} \mid x_t)} \right] p_\theta (x_T)}{q(x_1 \mid x_0)\prod_{t=2}^T q(x_t \mid x_{t-1},x_0)} \right] \\ &= \mathbb{E}_q \left[ \log \frac{p_\theta(X_0=x_0 \mid x_1)p_\theta (x_T)}{q(x_1 \mid x_0)} \left[ \prod_{t=2}^{T}\frac{p_\theta(x_{t-1} \mid x_t)}{q(x_t \mid x_{t-1},x_0)} \right] \right] \\ &=\mathbb{E}_{q(\mathbf{x}_1\mid\mathbf{x}_0)}\left[\log p_\theta(\mathbf{x}_0\mid\mathbf{x}_1)\right] - \mathcal{D}_{\mathrm{KL}}(q(\mathbf{x}_T\mid\mathbf{x}_0) \| p(\mathbf{x}_T)) \\ &\quad - \sum_{t=2}^T \mathbb{E}_{q(\mathbf{x}_t\mid\mathbf{x}_0)} \left[\mathcal{D}_{\mathrm{KL}}(q(\mathbf{x}_{t-1}\mid\mathbf{x}_t, \mathbf{x}_0) \| p_\theta(\mathbf{x}_{t-1}\mid\mathbf{x}_t))\right]. \end{split} \end{equation}\]Above is just some tedious manipulation of symbols. But the result does match our intuition: \(q(\mathbf{x}_{t-1}\mid\mathbf{x}_t, \mathbf{x}_0) \| p_\theta(\mathbf{x}_{t-1}\mid\mathbf{x}_t)\) says that the timestep \(t, t\in[2:T-1]\) of the generative path does not go wrong. \(\mathbb{E}_{q(\mathbf{x}_1\mid\mathbf{x}_0)}\left[\log p_\theta(\mathbf{x}_0\mid\mathbf{x}_1)\right]\) says the last generative path does not go wrong. And \(\mathcal{D}_{\mathrm{KL}}(q(\mathbf{x}_T\mid\mathbf{x}_0) \| p(\mathbf{x}_T))\) says the first generative path (sample \(x_T\)) does not go wrong. Each term is responsible for on timestep!!!
We know that the KL divergence of two Gaussian can be written explicitly. And we take that advantage by parametrizing each generative step as a Gaussian:
\[\begin{equation} p_\theta(\mathbf{x}_{t-1} \mid \mathbf{x}_t) = \mathcal{N}(\mathbf{x}_{t-1}; \mu_\theta(\mathbf{x}_t, t), \Sigma_\theta(\mathbf{x}_t, t)). \end{equation}\]And the training loss can be written as:
\[\begin{equation} \begin{split} L_t^{\text{simple}} &= \mathbb{E}_{t \sim [1, T], \mathbf{x}_0, \boldsymbol{\epsilon}_t} \left[ \left\| \boldsymbol{\epsilon}_t - \boldsymbol{\epsilon}_\theta(\mathbf{x}_t, t) \right\|^2 \right] \\ &= \mathbb{E}_{t \sim [1, T], \mathbf{x}_0, \boldsymbol{\epsilon}_t} \left[ \left\| \boldsymbol{\epsilon}_t - \boldsymbol{\epsilon}_\theta\left(\sqrt{\bar{\alpha}_t} \mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_t} \boldsymbol{\epsilon}_t, t \right) \right\|^2 \right]. \end{split} \end{equation}\]Aside from the tedious symbol manipulation skills, what can we learn by comparing VAE and diffusion model? In VAE, we parametrize and learn the best prior, that is, the inverse of the generative model. This requires us to co-optimize two networks, one to push up the bound and one to tighten the bound. However, in diffusion model, we fix the encoding process (diffusion process, noise scheduling policy,…), and we learn a generative model to inverse that process. From this perspective, the decoder is more like a data augmentation process rather than a decoder which can catch sematic stuff and learn representations. More importantly, we don’t need to co-optimize two network. Also, the lower bound is tightened during our training process as well, because in eq(6) we see that a tight bound needs a good inverse decoder, which is exactly the training objective. That being said, we align two goals of pushing up the bound and tightening the bound.