The overarching goal of diffusion models is to predict the noise that was added to an image at any time between t−1 and t.
We want to add noise to the image such that the noisy image at Xt is Gaussian, i.e. N(0,1), which is required to create a learnable loss function and because standard Gaussians give us nice mathematical properties and, thus, are easy to work with.
Noise is added gradually, instead of one large addition of Gaussian noise. Just a bit in the beginning, then more towards the end when the image is almost fully Gaussian anyway. We do this, because we want to learn to undo the noise that was added between step t and t−1.
The reasoning is that it’s easier to learn to undo a bit of noise, rather than a lot of noise across many time steps.
Forward Process
In order to get the noisy images to train, we need to generate them. This is naive, flawed approach:
The β term is added to scale down the Gaussian noise. This is required, because the input images are often normalised, usually between 0 and 1. Standard Gaussian noise can return numbers like 0.8, which will quickly overwhelm the numbers of the input image. Therefore, the Gaussian noise needs to be scaled down.
With this current setup, we need to have the previous noisy image Xt−1 to generate the next noisy image Xt. But we can rewrite this to get any Xt just from the starting image X0. E.g.:
If you have zero mean, scaled Gaussians, you can sum the variances, which gives you a “new” Gaussian:
X3X3=X0+4βϵ′=N(X0,4β)
But unfortunately, Xt=N(X0,tβ)=N(0,1) because X0=0 and tβ=1.
We can even see this visually if we add the noise to an image in this iterative way and compare this to what standard Gaussian noise would look like.
This is what the iterative approach would look like (which also takes a lot of time to run). Using the direct approach, we can get the final noisy image faster:
But as can be seen, the noisy image looks much different that what true Gaussian noise would look like. This further shows that our naive approach does not converge to a standard Gaussian.
A better approach is needed. The authors came up with this formula:
Xt=1−βtXt−1+βtϵt
where βt is the scalar at time step t (following a schedule e.g.). However, the issue still persists that we need to compute Xt−1 to compute Xt. But we’d much rather have a function that takes the initial X0 and t as input at outputs the correct Xt directly. Let’s have a look at what Xt−1 looks like.
Now we can re-use our neat trick from before where we said that variances of zero-mean Gaussians (ϵt−1 and ϵt in our case) can be summed up. Remember, that in this equation we are working with standard deviations (i.e. σ) but variances are the squares of standard deviations (i.e. σ2). This means, we need to square our standard deviations to get the variance first, i.e.:
I denoted the new Gaussian noise as ϵ′. From here, we can notice a pattern! We managed to describe Xt as a combination from the noisy image two steps before and this has added another α term into the square roots. If we repeat this process until we get to X0, we have this:
Xt=αtαt−1αt−2…α1X0+1−αtαt−1αt−2…α1ϵ′
We can simplify this a bit further by saying αˉt=∏itαt, which gives us our final statement:
Xt=αˉtX0+1−αtˉϵ′
And there we go! Now we can generate the final noisy image at time step t just by having the initial image X0. Using this new formula, we can compare the generated noisy image from what we had generated before:
Now this looks much closer to Gaussian noise than what we had before. It’s a bit darker due to the larger value for β that I used. However, mathematically, this does converge to a standard Gaussian
Xt=αˉtX0+1−αtˉϵ′
Because as we increase t, we keep multiplying more and more numbers that are between 0 and 1 (because remember that β needs to be in that range), thus making the square root smaller and smaller, which in the limit becomes 0 and 0X0=0. Similarly, as time goes on and αtˉ moves towards 0, the square root converges towards 1, which then just adds regular Gaussian noise via ϵ′.
With this, we have the forward method covered. Now it’s time to derive the loss function and the learning objective.
Deriving the Loss Function
So far, we have described the forward process as a step-by-step sampling procedure. This is great for building intuition. However, to create a loss function, we need to shift from the language of single samples (Xt) to the language of probability distributions. This is important, because as we try to predict the noise added at a particular step Xt, we don’t actually know exactly what noise has to be added in general. Think of it this way:
You have some Xt and sample some noise to get ϵt, you add that in and then get Xt+1. But this amount of noise and this specific Xt are just some samples. There are many more Xt and ϵt that could result in Xt+1, it’s like a whole cloud. These specific values might have come from the edge of the cloud. If you just computed the MSE from that, you wouldn’t learn the true average, which you would find in the center of the cloud. By thinking about this as a probability distribution, you need to find the loss between your model and the center of the cloud.
The formal name for the distribution defined by our single-step sampling process is q(Xt∣Xt−1).
The Markov Chain Property
A crucial property of our forward process is that it’s a Markov Chain. This simply means that the distribution for the next state, Xt, depends only on the immediately preceding state, Xt−1. It does not depend on any other previous states like Xt−2, Xt−3, or the original X0.
The Joint Probability Distribution
We want to find the probability of an entire sequence of noisy images, X1,…,XT, given our starting image X0. This is the joint probability distribution, q(X1:T∣X0).
Using the general chain rule of probability, we would write this as:
q(X1,…,XT∣X0)=q(X1∣X0)⋅q(X2∣X1,X0)⋅⋯⋅q(XT∣XT−1,…,X0)
This looks complicated because each step seems to depend on the entire history.
However, we can now apply our Markov assumption. The assumption that Xt only depends on Xt−1 allows us to simplify each term:
q(X2∣X1,X0) simplifies to just q(X2∣X1).
q(XT∣XT−1,…,X0) simplifies to just q(XT∣XT−1).
Applying this simplification across the entire chain gives us a much cleaner result:
q(X1,…,XT∣X0)=q(X1∣X0)⋅q(X2∣X1)⋅⋯⋅q(XT∣XT−1)Finally, we can write this long product in a compact form using the product symbol, ∏: q(X1:T∣X0)=∏t=1Tq(Xt∣Xt−1)
This final equation is the formal definition of our entire forward process.
For the reverse process, we have our model pθ. To get a clean image back, we have to compute the joint probability, i.e.:
pθ(X0:T)=pθ(XT)t=1∏Tpθ(Xt−1∣Xt)
From here we can start with the loss function derivation, because as stated before, we want to sample a clean image using our model. To do that, we need to minimise the negative log likelihood:
L=−logpθ(X0)
If we can minimise this, then we are learning the true data distribution, i.e. our dataset.
From here the first step is marginalisation using the Chain Rule of Probability:
This loss function is numerically intractable, because we can’t integrate over all possible X in existence. We need to insert something into this intractable integral to make it tractable. We can use the “divide-by-one” trick (I’m not sure if that’s what it’s called - I just like to call it that):
Now we have a lower bound. If we can make the expectation smaller, we will also minimise the negative log likelihood, which is exactly what we want. Thus, our new loss function to minimise is
Ln=−Eq[logqp]
We can apply this log rule to get rid of the negative sign log(A/B)=−log(B/A).
You will also notice that I separated the t=1 step. This is required in order to apply Bayes’ rule.
p(A∣B)=p(B)p(B∣A)⋅p(A)
But if you instead did this:
p(A∣B,A)=p(B∣A)p(B∣A,A)⋅p(A∣A)
You would get a so-called Dirac Delta function p(a∣a)=δ(0), which is a point mass “probability distribution” which in turn isn’t really a probability distribution, because there is no uncertainty. It’s like saying “What card do I hold on my hand given that I hold the ace of hearts?”. It’s a tautology. This means you can’t apply Bayes’ rule here in a matter that makes sense, therefore you exclude this part.
Now we can rewrite the fraction in the sum using Bayes’ Rule (and the log rule where a product becomes a sum).
What we have now is a so called telescoping sum, which comes from this part:
∑logq(Xt∣X0)−logq(Xt−1∣X0)
If you were to write out the terms you would get something like this (I will simplify this to just tuples, assume that the tuple (0,0) is equal to one, because log(q∣q)=log(1)=0):
t=1∑T=(1,0)−(0,0)+(2,0)−(1,0)+(3,0)−(2,0)+…
The tuples being placeholders, e.g. (1,0)≡logq(X1∣X0) or (2,1)≡logq(X2∣X1) etc. As you can see, all the pairs cancel out except for the last one. This simplifies the sum to be this:
Ok, so now we have 3 terms and whenever you see a log of a quotient, you have to think “KL Divergence”, which measures the distance between two probability distributions, i.e.:
If we have a closer look, we can see that the left KL divergence has no trainable parameters θ in it, so we can safely ignore it, as it will become 0 as soon as we compute the gradient. As for the rightmost part, the authors of the DDPM paper have seen that, empirically, it makes no difference to leave that part it, so we can also leave that part out. This leaves us with this:
The good thing is that if your probability distributions are Gaussian, they simplify to a nice closed form. In general, for two univariate Gaussians, you have
In the above formula, ∣Σ∣ is the determinant, d is the dimensionality and tr is the trace. Crucially, the authors set the covariances (variances) of both Gaussians to be equal, which means their determinants are also equal, which means two things:
First:
log∣Σ∣∣Σ∣=0
And second:
tr(Σ−1Σ)=tr(I)=d
And therefore, you have this 0−d+d=0, which means that only the means survive:
DKL=21(μ2−μ1)TΣ−1(μ2−μ1)
Because we set the variance to a constant, Σ−1 is just a constant factor, which plays no role when we minimise the loss function, so we can leave it out, which finally brings us to this part:
Doing a bit of algebra and collecting the Xt terms, we arrive at:
μ~t(Xt,ϵ)=αt1(Xt−1−αˉtβtϵ)
In this form, the true mean is now expressed as a combination of Xt and the true noise ϵ. If we now change our neural network to output not the mean, but rather the noise directly, we get this:
μθ(Xt,t)=αt1(Xt−1−αˉtβtϵθ(Xt,t))
If we insert this into the above loss expression, we get this: