Generative Adversarial Networks: A Two-player game

Introduced in 2014 by Goodfellow. et al, Generative Adversarial Networks (GANs) revolutionized the field of Generative modeling. They proposed a new framework that generated very realistic synthetic data trained through a minimax two-player game.

With GANs, we don't explicitly learn the distribution of the data pdata, but we can still sample from it. Like VAEs, GANs also have two networks: a Generator and a Discriminator that are trained simultaneously.

A latent variable is sampled from a prior zp(z) and passed through the Generator to obtain a fake sample x=G(z). Then the discriminator performs an image classification task that takes in all samples and classifies it as real (1) or fake (0). These are trained together such that while competing with each other, they make each other stronger at the same time.


To summarize,

  • A discriminator D estimates the probability of a given sample coming from the real dataset. It works as a critic and is optimized to tell the fake samples from the real ones.

  • A generator G outputs synthetic samples given a noise variable input z. It is trained to capture the real data distribution (want pGpdata) so that its generative samples can be as real as possible, or in other words, can trick the discriminator to output higher probabilities.



Training Objective

We want to make sure that the Discriminator's decision over the real data is accurate, i.e. D(x)=1. This can be achieved by maximizing the log-likelihood over the data. (1)maxDExpdata[logD(x)]

The Discriminator should also assign low probabilities to the fake samples generated by the Generator, i.e. D(G(z))=0. In other words, minimize the log-likelihood of the fake data. (2.1)minDEzp(z)[log(D(G(z)))] which can be re-written as, (2.2)maxDEzp(z)[log(1D(G(z)))]


On the other hand, the Generator wants to fool the Discriminator such that it classifies fake samples as real, i.e. D(G(z))=1. This can be obtained by maximizing the log-likelihood over the fake data. (3.1)maxGEzp(z)[log(D(G(z)))] which can be re-written as, (3.2)minGEzp(z)[log(1D(G(z)))]


The full objective is given by,

(4)minGmaxDExpdata[logD(x)]+Ezp(z)[log(1D(G(z)))]

The objective of training GANs is therefore a minimax game: the generator G is trying hard to trick the discriminator, while the critic model D is trying hard not to be cheated. In this process, both networks force each other to improve their functionalities. And both are trained together with alternating gradient updates (gradient ascent on discriminator and gradient descent on generator).


However, at the start of training, the generator is really bad due to which D(G(z)) is close to zero, creating a vanishing gradients problem for the Generator (as loss = 0). Therefore, instead of minGEzp(z)[log(1D(G(z)))], we use minGEzp(z)[log(D(G(z)))] or just maxGEzp(z)[log(D(G(z)))] (Equation 3.2) as an objective for the Generator.


Training Algorithm


Optimality

Let's write Equation 4 with a change of variables (ignoring the improvement of Generator for now), minGmaxDExpdata[logD(x)]+Ezp(z)[log(1D(G(z)))]=minGmaxDExpdata[logD(x)]+ExpG[log(1D(x))]=minGxmaxD(pdata(x)logD(x)+pG(x)log(1D(x)))For the max function, we take the derivative wrt D and assign it to zero.pdata(x)D(x)pG(x)(1D(x))=0Dmax(x)=pdata(x)pdata(x)+pG(x) [Optimal Discriminator]=minGx(pdata(x)logpdata(x)pdata(x)+pG(x)+pG(x)logpG(x)pdata(x)+pG(x))=minG(Expdata[logpdata(x)pdata(x)+pG(x)]+ExpG[logpG(x)pdata(x)+pG(x)])Multiply and divide by 2=minG(Expdata[log2pdata(x)pdata(x)+pG(x)]+ExpG[log2pG(x)pdata(x)+pG(x)]log4)From the definition of KL-Divergence: DKL(p||q)=Exp[logp(x)q(x)]=minG(DKL[pdata||pdata+pG2]+DKL[pG||pdata+pG2]log4)From Jensen-Shannon divergence: DJS(p||q)=12DKL(p||p+q2)+12DKL(q||p+q2)=minG(2DJS(pdata||pG)log4)

JS divergence is always non-negative and is zero only when the two distributions are equal. Hence, training a GAN with this objective, minimizes the JS divergence to zero, obtaining a global minimum when pdata=pG.

The global minimum of the minimax game happens when, Optimal Generator: pdata=pGOptimal Discriminator: DG(x)=pdata(x)pdata(x)+pG(x)=12



Some prominent papers,

  1. Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks (DCGAN): Architectures of Generator and Discriminator created with CNNs.

  2. Wasserstein GAN (WGAN): Improvement over traditional GAN training.

  3. Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks (CycleGAN): For image-to-image translation

Comments

Popular posts from this blog

Three Wheeled Omnidirectional Robot : Motion Analysis

The move_base ROS node

Overview of ATmega328P