Diffusion Models (DRAFT)

The overarching goal of diffusion models is to predict the noise that was added to an image at any time between t1t-1 and tt.

image

We want to add noise to the image such that the noisy image at XtX_t is Gaussian, i.e. N(0,1){N}(0, 1), which is required to create a learnable loss function and because standard Gaussians give us nice mathematical properties and, thus, are easy to work with.

Noise is added gradually, instead of one large addition of Gaussian noise. Just a bit in the beginning, then more towards the end when the image is almost fully Gaussian anyway. We do this, because we want to learn to undo the noise that was added between step tt and t1t-1.

image

The reasoning is that it’s easier to learn to undo a bit of noise, rather than a lot of noise across many time steps.

Forward Process

In order to get the noisy images to train, we need to generate them. This is naive, flawed approach:

X1=X0+βϵ1X2=X1+βϵ2X3=X2+βϵ3Xt=Xt1+βϵt \begin{align*} X_1 &= X_0 + \sqrt\beta \epsilon_1 \\ X_2 &= X_1 + \sqrt\beta \epsilon_2 \\ X_3 &= X_2 + \sqrt\beta \epsilon_3 \\ \dots \\ X_t &= X_{t-1} + \sqrt\beta \epsilon_t \\ \end{align*}

Here, ϵt\epsilon_t is the Gaussian noise added at time tt.

The β\beta term is added to scale down the Gaussian noise. This is required, because the input images are often normalised, usually between 0 and 1. Standard Gaussian noise can return numbers like 0.80.8, which will quickly overwhelm the numbers of the input image. Therefore, the Gaussian noise needs to be scaled down.

With this current setup, we need to have the previous noisy image Xt1X_{t-1} to generate the next noisy image XtX_t. But we can rewrite this to get any XtX_t just from the starting image X0X_0. E.g.:

X3=X2+βϵ3X3=X1+βϵ2+βϵ3X3=X0+βϵ1+βϵ2+βϵ3 \begin{align*} X_3 &= X_2 + \sqrt\beta \epsilon_3 \\ X_3 &= X_1 + \sqrt\beta \epsilon_2 + \sqrt\beta \epsilon_3 \\ X_3 &= X_0 + \sqrt\beta \epsilon_1 + \sqrt\beta \epsilon_2 + \sqrt\beta \epsilon_3 \\ \end{align*}

If you have zero mean, scaled Gaussians, you can sum the variances, which gives you a “new” Gaussian:

X3=X0+4βϵX3=N(X0,4β) \begin{align*} X_3 &= X_0 + 4\sqrt\beta \epsilon' \\ X_3 &= N(X_0, 4\beta) \\ \end{align*}

But unfortunately, Xt=N(X0,tβ)N(0,1)X_t = N(X_0, t\beta) \neq N(0,1) because X00X_0 \neq 0 and tβ1t\beta \neq 1.

We can even see this visually if we add the noise to an image in this iterative way and compare this to what standard Gaussian noise would look like.

image

This is what the iterative approach would look like (which also takes a lot of time to run). Using the direct approach, we can get the final noisy image faster:

image

But as can be seen, the noisy image looks much different that what true Gaussian noise would look like. This further shows that our naive approach does not converge to a standard Gaussian.

A better approach is needed. The authors came up with this formula:

Xt=1βtXt1+βtϵtX_t = \sqrt{1-\beta_t}X_{t-1} + \sqrt\beta_t\epsilon_t

where βt\beta_t is the scalar at time step tt (following a schedule e.g.). However, the issue still persists that we need to compute Xt1X_{t-1} to compute XtX_t. But we’d much rather have a function that takes the initial X0X_0 and tt as input at outputs the correct XtX_t directly. Let’s have a look at what Xt1X_{t-1} looks like.

Xt=1βtXt1+βtϵtXt1=1βt1Xt2+βt1ϵt1 \begin{align*} X_t &= \sqrt{1-\beta_t}X_{t-1} + \sqrt\beta_t\epsilon_t \\ X_{t-1} &= \sqrt{1-\beta_{t-1}}X_{t-2} + \sqrt\beta_{t-1}\epsilon_{t-1} \\ \end{align*}

If we substitute Xt1X_{t-1} into the XtX_t equation, we get

Xt=1βtXt1+βtϵtXt1=1βt1Xt2+βt1ϵt1Xt=1βt(1βt1Xt2+βt1ϵt1)+βtϵtXt=1βt1βt1Xt2+1βtβt1ϵt1+βtϵt \begin{align*} X_t &= \sqrt{1-\beta_t}X_{t-1} + \sqrt\beta_t\epsilon_t \\ X_{t-1} &= \sqrt{1-\beta_{t-1}}X_{t-2} + \sqrt\beta_{t-1}\epsilon_{t-1} \\ X_t &= \sqrt{1-\beta_t}(\sqrt{1-\beta_{t-1}}X_{t-2} + \sqrt\beta_{t-1}\epsilon_{t-1}) + \sqrt\beta_t\epsilon_t \\ X_t &= \sqrt{1-\beta_t}\sqrt{1-\beta_{t-1}}X_{t-2} + \sqrt{1-\beta_t} \sqrt\beta_{t-1}\epsilon_{t-1} + \sqrt\beta_t\epsilon_t \\ \end{align*}

Now, we say that αt=1βt\alpha_t = 1 - \beta_t to make this easier to read and have less clutter:

Xt=1βt1βt1Xt2+1βtβt1ϵt1+βtϵtXt=αtαt1Xt2+αt1αt1ϵt1+1αtϵtXt=αtαt1Xt2+αt(1αt1)ϵt1+1αtϵt \begin{align*} X_t &= \sqrt{1-\beta_t}\sqrt{1-\beta_{t-1}}X_{t-2} + \sqrt{1-\beta_t} \sqrt\beta_{t-1}\epsilon_{t-1} + \sqrt\beta_t\epsilon_t \\ X_t &= \sqrt{\alpha_t}\sqrt{\alpha_{t-1}}X_{t-2} + \sqrt{\alpha_t} \sqrt{1-\alpha_{t-1}}\epsilon_{t-1} + \sqrt{1-\alpha_t}\epsilon_t \\ X_t &= \sqrt{\alpha_t\alpha_{t-1}}X_{t-2} + \sqrt{\alpha_t(1-\alpha_{t-1})} \epsilon_{t-1} + \sqrt{1-\alpha_t}\epsilon_t \\ \end{align*}

Now we can re-use our neat trick from before where we said that variances of zero-mean Gaussians (ϵt1\epsilon_{t-1} and ϵt\epsilon_t in our case) can be summed up. Remember, that in this equation we are working with standard deviations (i.e. σ\sigma) but variances are the squares of standard deviations (i.e. σ2\sigma^2). This means, we need to square our standard deviations to get the variance first, i.e.:

var(αt(1αt1)ϵt1)+var(1αtϵt)var(\sqrt{\alpha_t(1-\alpha_{t-1})} \epsilon_{t-1}) + var(\sqrt{1-\alpha_t}\epsilon_t)

The variance of a Gaussian is 1, which gives us:

var(αt(1αt1)ϵt1)+var(1αtϵt)=αt(1αt1)+1αt=αtαtαt1+1αt=1αtαt1 \begin{align*} & var(\sqrt{\alpha_t(1-\alpha_{t-1})} \epsilon_{t-1}) + var(\sqrt{1-\alpha_t}\epsilon_t) \\ =& \alpha_t(1-\alpha_{t-1}) + 1-\alpha_t \\ =& \alpha_t -\alpha_t\alpha_{t-1} + 1-\alpha_t \\ =& 1 -\alpha_t\alpha_{t-1} \\ \end{align*}

This means that if we sum up the variances, we get a new Gaussian with mean 0 and variance 1αtαt11-\alpha_t\alpha_{t-1}. In other words:

Xt=1βt1βt1Xt2+1βtβt1ϵt1+βtϵtXt=αtαt1Xt2+αt1αt1ϵt1+1αtϵtXt=αtαt1Xt2+αt(1αt1)ϵt1+1αtϵtXt=αtαt1Xt2+1αtαt1ϵ \begin{align*} X_t &= \sqrt{1-\beta_t}\sqrt{1-\beta_{t-1}}X_{t-2} + \sqrt{1-\beta_t} \sqrt\beta_{t-1}\epsilon_{t-1} + \sqrt\beta_t\epsilon_t \\ X_t &= \sqrt{\alpha_t}\sqrt{\alpha_{t-1}}X_{t-2} + \sqrt{\alpha_t} \sqrt{1-\alpha_{t-1}}\epsilon_{t-1} + \sqrt{1-\alpha_t}\epsilon_t \\ X_t &= \sqrt{\alpha_t\alpha_{t-1}}X_{t-2} + \sqrt{\alpha_t(1-\alpha_{t-1})} \epsilon_{t-1} + \sqrt{1-\alpha_t}\epsilon_t \\ X_t &= \sqrt{\alpha_t\alpha_{t-1}}X_{t-2} + \sqrt{1-\alpha_t\alpha_{t-1}}\epsilon' \\ \end{align*}

I denoted the new Gaussian noise as ϵ\epsilon'. From here, we can notice a pattern! We managed to describe XtX_t as a combination from the noisy image two steps before and this has added another α\alpha term into the square roots. If we repeat this process until we get to X0X_0, we have this:

Xt=αtαt1αt2α1X0+1αtαt1αt2α1ϵ X_t = \sqrt{\alpha_t\alpha_{t-1}\alpha_{t-2}\dots\alpha_{1}}X_0 + \sqrt{1-\alpha_t\alpha_{t-1}\alpha_{t-2}\dots\alpha_{1}}\epsilon' \\

We can simplify this a bit further by saying αˉt=itαt\bar\alpha_t = \prod_i^t \alpha_t, which gives us our final statement:

Xt=αˉtX0+1αtˉϵ X_t = \sqrt{\bar\alpha_t}X_0 + \sqrt{1-\bar{\alpha_t}}\epsilon' \\

And there we go! Now we can generate the final noisy image at time step tt just by having the initial image X0X_0. Using this new formula, we can compare the generated noisy image from what we had generated before:

image

Now this looks much closer to Gaussian noise than what we had before. It’s a bit darker due to the larger value for β\beta that I used. However, mathematically, this does converge to a standard Gaussian

Xt=αˉtX0+1αtˉϵ X_t = \sqrt{\bar\alpha_t}X_0 + \sqrt{1-\bar{\alpha_t}}\epsilon' \\

Because as we increase tt, we keep multiplying more and more numbers that are between 0 and 1 (because remember that β\beta needs to be in that range), thus making the square root smaller and smaller, which in the limit becomes 0 and 0X0=00 X_0 = 0. Similarly, as time goes on and αtˉ\bar{\alpha_t} moves towards 0, the square root converges towards 1, which then just adds regular Gaussian noise via ϵ\epsilon'.

With this, we have the forward method covered. Now it’s time to derive the loss function and the learning objective.

Deriving the Loss Function

So far, we have described the forward process as a step-by-step sampling procedure. This is great for building intuition. However, to create a loss function, we need to shift from the language of single samples (XtX_t) to the language of probability distributions. This is important, because as we try to predict the noise added at a particular step XtX_t, we don’t actually know exactly what noise has to be added in general. Think of it this way:

You have some XtX_t and sample some noise to get ϵt\epsilon_t, you add that in and then get Xt+1X_{t+1}. But this amount of noise and this specific XtX_t are just some samples. There are many more XtX_t and ϵt\epsilon_t that could result in Xt+1X_{t+1}, it’s like a whole cloud. These specific values might have come from the edge of the cloud. If you just computed the MSE from that, you wouldn’t learn the true average, which you would find in the center of the cloud. By thinking about this as a probability distribution, you need to find the loss between your model and the center of the cloud.

The formal name for the distribution defined by our single-step sampling process is q(XtXt1)q(X_t|X_{t-1}).

The Markov Chain Property

A crucial property of our forward process is that it’s a Markov Chain. This simply means that the distribution for the next state, XtX_t, depends only on the immediately preceding state, Xt1X_{t-1}. It does not depend on any other previous states like Xt2X_{t-2}, Xt3X_{t-3}, or the original X0X_0.

The Joint Probability Distribution

We want to find the probability of an entire sequence of noisy images, X1,,XTX_1, \dots, X_T, given our starting image X0X_0. This is the joint probability distribution, q(X1:TX0)q(X_{1:T}|X_0).

Using the general chain rule of probability, we would write this as: q(X1,,XTX0)=q(X1X0)q(X2X1,X0)q(XTXT1,,X0)q(X_1, \dots, X_T | X_0) = q(X_1 | X_0) \cdot q(X_2 | X_1, X_0) \cdot \dots \cdot q(X_T | X_{T-1}, \dots, X_0) This looks complicated because each step seems to depend on the entire history.

However, we can now apply our Markov assumption. The assumption that XtX_t only depends on Xt1X_{t-1} allows us to simplify each term:

For the reverse process, we have our model pθp_\theta. To get a clean image back, we have to compute the joint probability, i.e.:

pθ(X0:T)=pθ(XT)t=1Tpθ(Xt1Xt) p_\theta(X_{0:T}) = p_\theta(X_T) \prod_{t=1}^T p_\theta(X_{t-1}|X_t)

From here we can start with the loss function derivation, because as stated before, we want to sample a clean image using our model. To do that, we need to minimise the negative log likelihood:

L=logpθ(X0)L = -\log p_\theta(X_0)

If we can minimise this, then we are learning the true data distribution, i.e. our dataset.

From here the first step is marginalisation using the Chain Rule of Probability:

p(A,B)=p(AB)p(B) p(A,B)=p(A∣B)p(B) L=logpθ(X0)pθ(X0)=pθ(X0X1:T)pθ(X1:T)dX1:Tpθ(X0)=pθ(X0:T)dX1:T \begin{align*} L &= -\log p_\theta(X_0) \\ p_\theta (X_0) &=\int p_\theta (X_0∣X_{1:T})p_\theta(X_{1:T})dX_{1:T} \\ p_\theta(X_0)&=\int p_\theta(X_{0:T})dX_{1:T} \\ \end{align*}

This loss function is numerically intractable, because we can’t integrate over all possible XX in existence. We need to insert something into this intractable integral to make it tractable. We can use the “divide-by-one” trick (I’m not sure if that’s what it’s called - I just like to call it that):

pθ(X0)=pθ(X0:T)dX1:T=q(X1:TX0)q(X1:TX0)pθ(X0:T)dX1:T=q(X1:TX0)pθ(X0:T)q(X1:TX0)dX1:T \begin{align*} p_\theta(X_0) &= \int p_\theta(X_{0:T})dX_{1:T} \\ &= \int \frac{q(X_{1:T} | X_0)}{q(X_{1:T} | X_0)} p_\theta(X_{0:T})dX_{1:T} \\ &= \int q(X_{1:T} | X_0)\frac{p_\theta(X_{0:T})}{q(X_{1:T} | X_0)} dX_{1:T} \\ \end{align*}

Now, this is in the form of an expectation, i.e.:

Ep(x)[f(x)]=p(x)f(x)dx E_{p(x)}[f(x)]=∫p(x)f(x)dx

Therefore, if we rewrite this into the form of an expectation, we get

pθ(X0)=q(X1:TX0)pθ(X0:T)q(X1:TX0)dX1:Tq(X1:TX0)qpθ(X0:T)ppθ(X0)=Eq[pq] \begin{align*} p_\theta(X_0) &= \int q(X_{1:T} | X_0)\frac{ p_\theta(X_{0:T})}{q(X_{1:T} | X_0)} dX_{1:T} \\ q(X_{1:T} | X_0) &\equiv q \\ p_\theta(X_{0:T}) &\equiv p \\ p_\theta(X_0) &= E_{q}[\frac{p}{q}] \end{align*}

I did this definition:

q(X1:TX0)qpθ(X0:T)p \begin{align*} q(X_{1:T} | X_0) &\equiv q \\ p_\theta(X_{0:T}) &\equiv p \\ \end{align*}

just so the math is a bit more concise.

Now we need to apply the logarithm:

L=logpθ(X0)q(X1:TX0)qpθ(X0:T)ppθ(X0)=Eq[pq]logpθ(X0)=logEq[pq] \begin{align*} L &= -\log p_\theta(X_0) \\ q(X_{1:T} | X_0) &\equiv q \\ p_\theta(X_{0:T}) &\equiv p \\ p_\theta(X_0) &= E_{q}[\frac{p}{q}] \\ \log p_\theta(X_0) &= \log E_{q}[\frac{p}{q}] \end{align*}

Because log\log is a concave function, we can apply Jensen’s Inequality, which states:

log(E[Y])E[log(Y)] \log(E[Y])\geq E[\log(Y)]

This gives us:

L=logpθ(X0)q(X1:TX0)qpθ(X0:T)plogpθ(X0)=logEq[pq]logpθ(X0)Eq[logpq]logpθ(X0)Eq[logpq] \begin{align*} L &= -\log p_\theta(X_0) \\ q(X_{1:T} | X_0) &\equiv q \\ p_\theta(X_{0:T}) &\equiv p \\ \log p_\theta(X_0) &= \log E_{q}[\frac{p}{q}] \\ \log p_\theta(X_0) \geq& E_{q}[\log \frac{p}{q}] \\ -\log p_\theta(X_0) \leq& -E_{q}[\log \frac{p}{q}] \\ \end{align*}

Now we have a lower bound. If we can make the expectation smaller, we will also minimise the negative log likelihood, which is exactly what we want. Thus, our new loss function to minimise is

Ln=Eq[logpq] L_n = -E_q[\log\frac{p}{q}]

We can apply this log rule to get rid of the negative sign log(A/B)=log(B/A)\log(A/B)=−\log(B/A).

Ln=Eq[logpq]=Eq[logqp] \begin{align*} L_n &= -E_q[\log\frac{p}{q}] \\ &= E_q[\log\frac{q}{p}] \\ \end{align*}

Now, we can use this log rule log(A/B)=logAlogB\log(A/B)=\log A−\log B

q(X1:TX0)=t=1Tq(XtXt1)q(X1:TX0)qpθ(X0:T)=pθ(XT)t=1Tpθ(Xt1Xt)pθ(X0:T)pLn=Eq[logpq]=Eq[logqp]=Eq[logqlogp]Ln=Eq[log(t=1Tq(XtXt1))log(pθ(XT)t=1Tpθ(Xt1Xt))] \begin{align*} q(X_{1:T} | X_0) &= \prod_{t=1}^{T} q(X_t | X_{t-1}) \\ q(X_{1:T} | X_0) &\equiv q \\ p_\theta(X_{0:T}) &= p_\theta(X_T) \prod_{t=1}^T p_\theta(X_{t-1}|X_t) \\ p_\theta(X_{0:T}) &\equiv p \\ L_n &= -E_q[\log\frac{p}{q}] \\ &= E_q[\log\frac{q}{p}] \\ &= E_q[\log{q} - \log p] \\ L_n &= E_q[\log (\prod_{t=1}^{T} q(X_t | X_{t-1})) - \log(p_\theta(X_T) \prod_{t=1}^T p_\theta(X_{t-1}|X_t))] \\ \end{align*}

Because log(AB)=logA+logB\log(A⋅B)=\log A+\log B, we can turn the products into sums and remove the brackets of the left log part:

Ln=Eq[t=1Tlogq(XtXt1)logpθ(XT)t=1Tlogpθ(Xt1Xt)]=Eq[logpθ(XT)+t=1Tlogq(XtXt1)pθ(Xt1Xt)]=Eq[logpθ(XT)+t=2Tlogq(XtXt1)pθ(Xt1Xt)+logq(X1X0)pθ(X0X1)] \begin{align*} L_n &= E_q[\sum_{t=1}^{T}\log q(X_t | X_{t-1}) - \log p_\theta(X_T) - \sum_{t=1}^T \log p_\theta(X_{t-1}|X_t)] \\ &= E_q[- \log p_\theta(X_T) + \sum_{t=1}^T \log\frac{q(X_t | X_{t-1})}{p_\theta(X_{t-1}|X_t)}] \\ &= E_q[- \log p_\theta(X_T) + \sum_{t=2}^T \log\frac{q(X_t | X_{t-1})}{p_\theta(X_{t-1}|X_t)} + \log \frac{q(X_1|X_0)}{p_\theta(X_0|X_1)}]` \end{align*}

You will also notice that I separated the t=1t=1 step. This is required in order to apply Bayes’ rule.

p(AB)=p(BA)p(A)p(B) p(A∣B)=\frac{p(B∣A)\cdot p(A)}{p(B)}

But if you instead did this:

p(AB,A)=p(BA,A)p(AA)p(BA) p(A∣B,A)=\frac{p(B∣A,A)\cdot p(A|A)}{p(B|A)}

You would get a so-called Dirac Delta function p(aa)=δ(0)p(a|a)=\delta(0), which is a point mass “probability distribution” which in turn isn’t really a probability distribution, because there is no uncertainty. It’s like saying “What card do I hold on my hand given that I hold the ace of hearts?”. It’s a tautology. This means you can’t apply Bayes’ rule here in a matter that makes sense, therefore you exclude this part.

Now we can rewrite the fraction in the sum using Bayes’ Rule (and the log rule where a product becomes a sum).

=Eq[logpθ(XT)+t=2T(logq(Xt1Xt,X0)pθ(Xt1Xt)+logq(XtX0)logq(Xt1X0))+logq(X1X0)pθ(X0X1)] \begin{align*} &= E_q\left[- \log p_\theta(X_T) + \sum_{t=2}^T \left(\log\frac{q(X_{t-1}|X_t, X_0)}{p_\theta(X_{t-1}|X_t)} + \log q(X_t|X_0) - \log q(X_{t-1}|X_0)\right) + \log\frac{q(X_1|X_0)}{p_\theta(X_0|X_1)}\right] \end{align*}

What we have now is a so called telescoping sum, which comes from this part:

logq(XtX0)logq(Xt1X0)\sum \log q(X_t | X_0) - \log q(X_{t-1} | X_0)

If you were to write out the terms you would get something like this (I will simplify this to just tuples, assume that the tuple (0,0) is equal to one, because log(qq)=log(1)=0\log(q|q)=\log(1)=0):

t=1T=(1,0)(0,0)+(2,0)(1,0)+(3,0)(2,0)+ \sum^T_{t=1} = (1,0) - (0,0) + (2,0) - (1,0) + (3,0) - (2,0) + \dots

The tuples being placeholders, e.g. (1,0)logq(X1X0)(1,0) \equiv \log q(X_1|X_0) or (2,1)logq(X2X1)(2,1) \equiv \log q(X_2|X_1) etc. As you can see, all the pairs cancel out except for the last one. This simplifies the sum to be this:

=Eq[logpθ(XT)+t=2T(logq(Xt1Xt,X0)pθ(Xt1Xt)+logq(XtX0)logq(Xt1X0))+logq(X1X0)pθ(X0X1)]=Eq[logpθ(XT)+t=2Tlogq(Xt1Xt,X0)pθ(Xt1Xt)+logq(XTX0)logq(X1X0)+logq(X1X0)logpθ(X0X1)]=Eq[logpθ(XT)+logq(XTX0)+t=2Tlogq(Xt1Xt,X0)pθ(Xt1Xt)logpθ(X0X1)]=Eq[logq(XTX0)p(XT)+t=2Tlogq(Xt1Xt,X0)pθ(Xt1Xt)logpθ(X0X1)] \begin{align*} &= E_q\left[- \log p_\theta(X_T) + \sum_{t=2}^T \left(\log\frac{q(X_{t-1}|X_t, X_0)}{p_\theta(X_{t-1}|X_t)} + \log q(X_t|X_0) - \log q(X_{t-1}|X_0)\right) + \log\frac{q(X_1|X_0)}{p_\theta(X_0|X_1)}\right] \\ &= E_q\left[- \log p_\theta(X_T) + \sum_{t=2}^T \log\frac{q(X_{t-1}|X_t, X_0)}{p_\theta(X_{t-1}|X_t)} + \log q(X_T|X_0) - \cancel{\log q(X_1|X_0)} + \cancel{\log q(X_1|X_0)} - \log p_\theta(X_0|X_1)\right] \\ &= E_q\left[- \log p_\theta(X_T) + \log q(X_T|X_0) + \sum_{t=2}^T \log\frac{q(X_{t-1}|X_t, X_0)}{p_\theta(X_{t-1}|X_t)} - \log p_\theta(X_0|X_1)\right]\\ &= E_q\left[\log\frac{q(X_T|X_0)}{p(X_T)} + \sum_{t=2}^T \log\frac{q(X_{t-1}|X_t, X_0)}{p_\theta(X_{t-1}|X_t)} - \log p_\theta(X_0|X_1)\right]\\ \end{align*}

Ok, so now we have 3 terms and whenever you see a log of a quotient, you have to think “KL Divergence”, which measures the distance between two probability distributions, i.e.:

DKL(PQ)=xP(x)logP(x)Q(x)D_{KL}(P || Q) = \sum_{x} P(x) \log \frac{P(x)}{Q(x)}

Or for continuous probabilities:

DKL(PQ)=P(x)logP(x)Q(x)dxD_{KL}(P || Q) = \int P(x) \log \frac{P(x)}{Q(x)} dx

Or if written in the form of an expectation:

DKL(PQ)=ExP[logP(x)Q(x)]=ExP[logP(x)logQ(x)]D_{KL}(P || Q) = E_{x \sim P}\left[\log \frac{P(x)}{Q(x)}\right] = E_{x \sim P}[\log P(x) - \log Q(x)]

This is exactly what we have in our loss term:

=Eq[DKL(q(XTX0)p(XT))+t=2TDKL(q(Xt1Xt,X0)pθ(Xt1Xt))logpθ(X0X1)]= E_q\left[D_{KL}(q(X_T|X_0) || p(X_T)) + \sum_{t=2}^T D_{KL}(q(X_{t-1}|X_t, X_0) || p_\theta(X_{t-1}|X_t)) - \log p_\theta(X_0|X_1)\right]

If we have a closer look, we can see that the left KL divergence has no trainable parameters θ\theta in it, so we can safely ignore it, as it will become 0 as soon as we compute the gradient. As for the rightmost part, the authors of the DDPM paper have seen that, empirically, it makes no difference to leave that part it, so we can also leave that part out. This leaves us with this:

=Eq[t=2TDKL(q(Xt1Xt,X0)pθ(Xt1Xt))]= E_q\left[\sum_{t=2}^T D_{KL}(q(X_{t-1}|X_t, X_0) || p_\theta(X_{t-1}|X_t))\right]

The good thing is that if your probability distributions are Gaussian, they simplify to a nice closed form. In general, for two univariate Gaussians, you have

DKL(PQ)=logσ2σ1+σ12+(μ1μ2)22σ2212D_{KL}(P || Q) = \log \frac{\sigma_2}{\sigma_1} + \frac{\sigma_1^2 + (\mu_1 - \mu_2)^2}{2\sigma_2^2} - \frac{1}{2}

and for multivariate Gaussians, you have this:

DKL(PQ)=12[logΣ2Σ1d+tr(Σ21Σ1)+(μ2μ1)TΣ21(μ2μ1)]D_{KL}(P || Q) = \frac{1}{2}\left[\log \frac{|\Sigma_2|}{|\Sigma_1|} - d + \text{tr}(\Sigma_2^{-1}\Sigma_1) + (\mu_2 - \mu_1)^T \Sigma_2^{-1}(\mu_2 - \mu_1)\right]

In the above formula, Σ|\Sigma| is the determinant, dd is the dimensionality and tr\text{tr} is the trace. Crucially, the authors set the covariances (variances) of both Gaussians to be equal, which means their determinants are also equal, which means two things:

First:

logΣΣ=0 \log \frac{|\Sigma|}{|\Sigma|} = 0

And second:

tr(Σ1Σ)=tr(I)=d\text{tr}(\Sigma^{-1}\Sigma) = \text{tr}(I) = d

And therefore, you have this 0d+d=00−d+d=0, which means that only the means survive:

DKL=12(μ2μ1)TΣ1(μ2μ1)D_{KL} = \frac{1}{2}(\mu_2 - \mu_1)^T \Sigma^{-1}(\mu_2 - \mu_1)

Because we set the variance to a constant, Σ1\Sigma^{-1} is just a constant factor, which plays no role when we minimise the loss function, so we can leave it out, which finally brings us to this part:

DKL(qpθ)=12(μ~tμθ)T(μ~tμθ)DKL(qpθ)μ~tμθ2Lt=2TEq[μ~t(Xt,X0)μθ(Xt,t)2]\begin{align*} D_{KL}(q||p_\theta) &= \frac{1}{2}(\tilde{\mu}_t - \mu_\theta)^T(\tilde{\mu}_t - \mu_\theta) \\ D_{KL}(q||p_\theta) & \propto ||\tilde{\mu}_t - \mu_\theta||^2\\ L &\approx \sum_{t=2}^T E_q[||\tilde{\mu}_t(X_t, X_0) - \mu_\theta(X_t, t)||^2] \end{align*}

There is one last step, which is the reparameterisation trick. This was our forward step:

Xt=αˉtX0+1αˉtϵX_t = \sqrt{\bar{\alpha}_t}X_0 + \sqrt{1-\bar{\alpha}_t}\epsilon

Which we can solve for X0X_0

X0=1αˉt(Xt1αˉtϵ)X_0 = \frac{1}{\sqrt{\bar{\alpha}_t}}(X_t - \sqrt{1-\bar{\alpha}_t}\epsilon)

The true posterior mean has this form:

μ~t(Xt,X0)=αˉt1βt1αˉtX0+αt(1αˉt1)1αˉtXt\tilde{\mu}_t(X_t, X_0) = \frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1-\bar{\alpha}_t}X_0 + \frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_t}X_t

If we substitute X0X_0, we get:

μ~t=αˉt1βt1αˉt1αˉt(Xt1αˉtϵ)+αt(1αˉt1)1αˉtXt\tilde{\mu}_t = \frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1-\bar{\alpha}_t} \cdot \frac{1}{\sqrt{\bar{\alpha}_t}}(X_t - \sqrt{1-\bar{\alpha}_t}\epsilon) + \frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_t}X_t

Doing a bit of algebra and collecting the XtX_t terms, we arrive at:

μ~t(Xt,ϵ)=1αt(Xtβt1αˉtϵ)\tilde{\mu}_t(X_t, \epsilon) = \frac{1}{\sqrt{\alpha_t}}\left(X_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\epsilon\right)

In this form, the true mean is now expressed as a combination of XtX_t and the true noise ϵ\epsilon. If we now change our neural network to output not the mean, but rather the noise directly, we get this:

μθ(Xt,t)=1αt(Xtβt1αˉtϵθ(Xt,t))\mu_\theta(X_t, t) = \frac{1}{\sqrt{\alpha_t}}\left(X_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\epsilon_\theta(X_t, t)\right)

If we insert this into the above loss expression, we get this:

μ~tμθ2=1αt(Xtβt1αˉtϵ)1αt(Xtβt1αˉtϵθ)2=1αt(Xtβt1αˉtϵXt+βt1αˉtϵθ)2=1αtβt1αˉt(ϵθϵ)2=(βtαt(1αˉt))ϵθϵ2\begin{align*} ||\tilde{\mu}_t - \mu_\theta||^2 &= \left|\left|\frac{1}{\sqrt{\alpha_t}}\left(X_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}t}}\epsilon\right) - \frac{1}{\sqrt{\alpha_t}}\left(X_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}t}}\epsilon_\theta\right)\right|\right|^2 \\ &= \left|\left|\frac{1}{\sqrt{\alpha_t}}\left(\cancel{X_t} - \frac{\beta_t}{\sqrt{1-\bar{\alpha}t}}\epsilon - \cancel{X_t} + \frac{\beta_t}{\sqrt{1-\bar{\alpha}t}}\epsilon_\theta\right)\right|\right|^2 \\ &= \left|\left|\frac{1}{\sqrt{\alpha_t}} \cdot \frac{\beta_t}{\sqrt{1-\bar{\alpha}t}}(\epsilon_\theta - \epsilon)\right|\right|^2 \\ &= \left(\frac{\beta_t}{\alpha_t(1-\bar{\alpha}t)}\right)||\epsilon_\theta - \epsilon||^2 \end{align*}

Which finally brings us to

L=Et,X0,ϵ[ϵϵθ(Xt,t)2]L = E_{t, X_0, \epsilon}[||\epsilon - \epsilon_\theta(X_t, t)||^2]