Decoding the math behind Diffusion Models: A breakthrough in Generative AI

Diffusion models are a new class of state-of-the-art generative models that generate diverse high-resolution images. There are already a bunch of different diffusion models that include Open AI’s DALL-E 2 and GLIDE, Google’s Imagen, and Stability AI’s Stable Diffusion. In this blog post, we will dig our way up from the basic principles described in the most prominent one, which is the Denoising Diffusion Probabilistic Models (DDPM) as initialized by Sohl-Dickstein et al in 2015 and then improved by Ho. et al in 2020.

Images produced by Dall-E 2

The basic idea behind diffusion models is rather simple. It takes an input image x0 and gradually adds Gaussian noise to it through a series of T time steps. We will call this the forward process. A network is then trained to recover the original image by reversing the noising process. By being able to model the reverse process, we can start from random noise and denoise it step-by-step to generate new data.


Forward Diffusion Process

Consider an image x0 sampled from the real data distribution (or the training set). The subscript denotes the number of time step. The forward process denoted by q is modeled as a Markov chain, where the distribution at a particular time step depends only on the sample from the previous step. The distribution of corrupted samples can be written as, (1)q(x1:T|x0)=t=1Tq(xt|xt1)

At each step of the Markov chain, we add Gaussian noise to xt1 producing a new latent variable xt. The transition distribution forms a unimodal diagonal Gaussian as, (2)q(xt|xt1)=N(xt;μt=1βtxt1,Σt=βtI) where βt is the variance of Gaussian at a time step t. It is a hyperparameter that follows a fixed schedule such that it increases with time and lies in the range [0,1].

Ho et al. sets a linear schedule for the variance starting from β1=104 to βT=0.02, and T=1000.



A latent variable xt can be sampled from the distribution q(xt|xt1) by using the reparameterization trick is as, (3)xt=1βtxt1+βtϵt where ϵtN(0,1).

Equation 3 shows that we need to compute all the previous samples xt1,...,x0 in order to obtain xt, making it expensive. To solve this problem, we define, αt=(1βt),αt¯=s=0Tαs and rewrite equation 3 in a recursive manner, xt=αtxt1+1αtϵt=αt[αt1xt2+1αt1ϵt]+1αtϵt=αtαt1xt2+(αt)(1αt1)+(1αt)ϵt=αtαt1xt2+1αtαt1ϵt...=αtαt1...α0x0+1αtαt1...α0ϵt(4)=αt¯x0+1αt¯ϵt

The close-form sampling at any arbitrary timestep can be carried out using the following distribution, (5)xtq(xt|x0)=N(xt;μt=αt¯x0,Σt=(1αt¯)I) Since βt is a hyperparameter that is fixed beforehand, we can precompute αt and αt¯ for all timesteps and use Equation 4 to sample the latent variable xt in one go.



Reverse Diffusion Process

As T, αt¯0, the distribution q(xT|x0)N(0,I) (also called isotropic Gaussian distribution), losing all information about the original sample. Therefore if we manage to learn the reverse distribution, we can sample xTN(0,I), and run the denoising process step-wise to generate a new sample.

With a small enough step size (βt1), the reverse process has the same functional form as the forward process. Therefore, the reverse distribution can also be modeled as a unimodal diagonal Gaussian. Unfortunately, it is not straightforward to estimate q(xt1|xt), as it needs to use the entire dataset (It's intractable since it requires knowing the distribution of all possible images in order to calculate this conditional probability).

Hence, we use a network to learn this Gaussian by parameterizing the mean and variance, (6)pθ(xt1|xt)=N(xt1;μθ(xt,t),Σθ(xt,t)) Apart from the latent sample xt, the model also takes time step t as input. Different time steps are associated with different noise levels, and the model learns to undo these individually.


Like the forward process, the reverse process can also be set up as a Markov chain. We can write the joint probability of the sequence of samples as, (7)pθ(x0:T)=p(xT)t=1Tpθ(xt1|xt) Here, p(xT)=N(0,I) as we start training with a sample from pure noise distribution.



Training Objective (Loss function)

The forward process is fixed and it's the reverse process that we solely focus on learning. Diffusion models can be seen as latent variable models, and are similar to variational autoencoders (VAEs), where x0 is an observed variable and x1,x2,...,xT are latent variables.

Maximizing the variational lower bound (also called evidence lower bound ELBO) on the marginal log-likelihood forms the objective in VAEs. For an observed variable x and latent variable z, this lower bound can be written as, logpθ(x)Eq(z|x)[logpθ(x|z)]DKL(q(z|x)||pθ(z))

Rewriting it in the diffusion model framework we get, ELBO=Eq(x1:T|x0)[logpθ(x0|x1:T)]DKL(q(x1:T|x0)||pθ(x1:T))=Eq(x1:T|x0)[logpθ(x0|x1:T)]Eq(x1:T|x0)[logq(x1:T|x0)pθ(x1:T)]=Eq(x1:T|x0)[logpθ(x0|x1:T)pθ(x1:T)q(x1:T|x0)]=Eq(x1:T|x0)[logpθ(x0:T)q(x1:T|x0)]Using equation 1 and 5,=Eq(x1:T|x0)[logp(xT)+t=1Tlogpθ(xt1|xt)q(xt|xt1)]Taking the edge case t=1 out,=Eq(x1:T|x0)[logp(xT)+logpθ(x0|x1)q(x1|x0)+t=2Tlogpθ(xt1|xt)q(xt|xt1)]Using Markov property and then Bayes’ rule we can write,q(xt|xt1)=q(xt|xt1,x0)=q(xt1|xt,x0)q(xt|x0)q(xt1,x0) and replace,=Eq(x1:T|x0)[logp(xT)+logpθ(x0|x1)q(x1|x0)+t=2Tlogpθ(xt1|xt)q(xt1|x0)q(xt1|xt,x0)q(xt|x0)]=Eq(x1:T|x0)[logp(xT)+logpθ(x0|x1)q(x1|x0)+t=2Tlogpθ(xt1|xt)q(xt1|xt,x0)+t=2Tlogq(xt1|x0)q(xt|x0)]=Eq(x1:T|x0)[logp(xT)+logpθ(x0|x1)q(x1|x0)+t=2Tlogpθ(xt1|xt)q(xt1|xt,x0)+logq(x1|x0)q(x2|x0)...q(xT1|x0)q(x2|x0)q(x3|x0)...q(xT|x0)]=Eq(x1:T|x0)[logp(xT)+logpθ(x0|x1)q(x1|x0)+t=2Tlogpθ(xt1|xt)q(xt1|xt,x0)+logq(x1|x0)q(xT|x0)]=Eq(x1:T|x0)[logpθ(x0|x1)+logp(xT)q(xT|x0)+t=2Tlogpθ(xt1|xt)q(xt1|xt,x0)]=Eq(x1|x0)[logpθ(x0|x1)]+Eq(xT|x0)[logp(xT)q(xT|x0)]+Eq(xt1,xt|x0)[t=2Tlogpθ(xt1|xt)q(xt1|xt,x0)]=Eq(x1|x0)[logpθ(x0|x1)]DKL(q(xT|x0)||p(xT))t>1Eq(xt|x0)[DKL(q(xt1|xt,x0)||pθ(xt1|xt))]

The objective of maximizing this lower bound is equivalent to minimizing a loss function that is its negation, (8)Lvlb=DKL(q(xT|x0)||p(xT))LT+t>1Eq(xt|x0)[DKL(q(xt1|xt,x0)||pθ(xt1|xt))Lt1]Eq(x1|x0)[logpθ(x0|x1)]L0

The term LT has no trainable parameters so it's ignored during training, furthermore, as we have assumed a large enough T such that the final distribution is Gaussian, this term effectively becomes zero.

L0 can be interpreted as a reconstruction term (similar to VAE).

The term Lt1 formulates the difference between the predicted denoising steps pθ(xt1|xt) and the reverse diffusion step q(xt1|xt,x0) (which is given as a target to the model). It is explicitly conditioned on the original sample x0 in the loss function so that the distribution q(xt1|xt,x0) takes the form of Gaussian. q(xt1|xt,x0)=N(xt1;μ~t(xt,x0)βt~)

But why do we need it to be Gaussian?

Since the model output, pθ(xt1|xt) is already parameterized as a Gaussian, every KL term compares two Gaussian distributions and therefore they can be computed in closed form. This makes the loss function tractable.

Intuitively, a painter (our generative model) needs a reference image (x0) to slowly draw (reverse diffusion step q(xt1|xt)) an image. Thus, we can take a small step backward, meaning from noise to generate an image, if and only if we have x0 as a reference.


Using Bayes' rule we can write, q(xt1|xt,x0)=q(xt|xt1,x0)q(xt1|x0)q(xt|x0)Using equation 2 and 4,=N(xt;1βtxt1,βtI)N(xt1;α¯t1x0,(1α¯t1)I)N(xt;α¯tx0,(1α¯t)I)Replacing 1βt=αt,=N(xt;αtxt1,βtI)N(xt1;α¯t1x0,(1α¯t1)I)N(xt;α¯tx0,(1α¯t)I)exp[12((xtαtxt1)2βt+(xt1α¯t1x0)2(1α¯t1)(xtα¯tx0)2(1α¯t))]=exp[12(xt2+αtxt122xtαtxt1βt+xt12+α¯t1x022xt1α¯t1x0(1α¯t1)(xtα¯tx0)2(1α¯t))]=exp[12(xt12(αtβt+11α¯t1)+xt1(2xtαtβt+2α¯t1x01α¯t1)+C(xt,x0)))]where C is a function whose details are omitted for readability.=exp[12(xt12(αtβt+11α¯t1)2xt1(xtαtβt+α¯t1x01α¯t1)+C(xt,x0)))]which can be written in the form,=exp[12((xt1μ~t(xt,x0))2β~t)]

Such that the variance is, βt~=(αtβt+11α¯t1)1=(1α¯t1)βtαtαtα¯t1+βt=(1α¯t1)βt1αtα¯t1(9)=1α¯t11α¯tβt

and the mean is, μ~t(xt,x0)=(xtαtβt+α¯t1x01α¯t1)(αtβt+11α¯t1)1=αt(1α¯t1)xt+βtα¯t1x0αtαtα¯t1+βt=αt(1α¯t1)1α¯txt+βtα¯t11α¯tx0

Equation 4 is xt=αt¯x0+1αt¯ϵt, which can be rewritten as, x0=xt1αt¯ϵtαt¯. Substituting, μ~t=αt(1α¯t1)1α¯txt+βtα¯t11α¯txt1αt¯ϵtαt¯=αt(1α¯t1)1α¯txt+βt1α¯txt1αt¯ϵtαt=(αt(1α¯t1)+βt)xtβt1αt¯ϵtαt(1αt¯)=(1αt¯)xtβt1αt¯ϵtαt(1αt¯)(10)=1αt(xtβt1αt¯ϵt)


Therefore, the term Lt1 estimates the KL-divergence between, Predicted: pθ(xt1|xt)=N(xt1;μθ(xt,t),Σθ(xt,t))Target: q(xt1|xt,x0)=N(xt1;μ~t(xt),βt~)

Recall that we learn a neural network that predicts the mean and diagonal variance of the Gaussian distribution of the reverse process. Ho et al. decided to keep the predicted variances fixed to time-dependent constants because they found that learning them leads to unstable training and poorer sample quality. They set Σθ(xt,t)=σt2I, where σt2=βt or β~t (both gave same results).

Because xt is available as input at training time, instead of predicting the mean (equation 10), we make it predict the noise term ϵt using ϵθ. We can then write the predicted mean as, (11)μθ(xt,t)=1αt(xtβt1αt¯ϵθ(xt,t))

and the predicted de-noised sample can be written using the reparameterization trick as, xt1=μθ(xt,t)+Σθ(xt,t)zt(12)=1αt(xtβt1αt¯ϵθ(xt,t))+σtzt where zN(0,1) at each time step.


Thus, the network predicts only the noise term at each time step t. Let's simplify the term Lt1, given that the variances are equal. KL-divergence b/w two Gaussians is given as:, DKL(p||q)=12[log|Σq||Σp|k+(μpμq)TΣq1(μpμq)+tr{Σq1Σp}] where k is number of dimensions.

Lt1=DKL(q(xt1|xt,x0)||pθ(xt1|xt))=12[k+(μ~tμθ(xt,t))(μ~tμθ(xt,t))Tσt2+tr{I}]=12[k+(μ~tμθ(xt,t))(μ~tμθ(xt,t))Tσt2+k]=12σt2||μ~tμθ(xt,t)||2=12σt21αt(xtβt1αt¯ϵt)1αt(xtβt1αt¯ϵθ(xt,t))2(13)=βt22σt2αt(1αt¯)||ϵtϵθ(xt,t)||2 The objective reduces to a weighted L2-loss between the noises and second term of the loss function becomes, Eq[Lt1]=Ext,ϵt[wt||ϵtϵθ(xt,t)||2]


Empirically, Ho et al. found that training the diffusion model works better with a simplified objective that ignores the weighting term in Lt1. They also got rid of the term L0 by altering the sampling method, such that at the end of sampling (t = 1), we obtain x0=μθ(x0,t=1).

The simplified loss function for DDPM is given as, (14.1)Lsimple=Ext,ϵt[||ϵtϵθ(xt,t)||2](14.2)=Ex0,ϵt[||ϵtϵθ(αt¯x0+1αt¯ϵt,t)||2]


The training and sampling algorithms in the DDPM paper (Ho et al.)


Credits: While writing this blog, I have referred to Lillian Weng's blog post, AI Summer's post and Luo at al. (2022)

Comments

Popular posts from this blog

Three Wheeled Omnidirectional Robot : Motion Analysis

The move_base ROS node

Overview of ATmega328P