A friendly introduction to Generative Adversarial Networks

So far, we have been talking about discriminative models which map input features x to labels y and approximate P(y/x) - Baye's law.

Generative models do the opposite, they try to predict input features given the labels. Assuming given label is y how likely are we to see certain features x. They approximate the joint probability of P(x and y).


Source: Medium / CycleGAN


Generative Adversarial Networks (GANs)

source: O'Reilly


Components of a GAN:

1. Generator - This is an inverse CNN, instead of compressing information as we go along a CNN chain and extracting features at the output, this network takes in random noise as input features and generates an image at it's output.
2. Discriminator - The discriminator is a CNN which looks at images from both the training set and the Generator output and classifies them as real (1) or fake (0). 

In general, both the generator and discriminator have inverse functions to optimize. Where the discriminator is trying to minimize Cross Entropy Loss of correctly predicting real v/s fake image. The generator is trying the maximize this loss (trying to fool the discriminator). 

Let's get into some simple math of how the Generator and Discriminator are competing against each other:

Notes the loss function of the Discriminator -

1. The first term is the output of the discriminator for a training set image (x) - you want the output of the discriminator D(x) to be 1. Let's examine this with how log loss looks like - 

log(D(x) or 1) is 0. So for the correct classification of the training set image, the log loss is going to be 0. However, for the inverse case - D(x) = 0 for a training set image, the loss will be negative infinity. 

2. The second term - log(1-D(G(z)), you want the discriminator to be able to identify the generated image G(z) and output D(G(z)) = 0. In this case, the loss will be 0. However, if it gets that wrong and considers the generated image as real - log(0) is -ve infinity. 

3. The network is heavily penalized for incorrect classification in both cases.

The discriminator is trying to minimize it's misclassification loss at the same time, the generator is trying to maximize the discriminator's loss for it's generated output. Therefore, the concept is closely related to a min-max game.

What is the discriminator trying to do here?


Source - Ian Goodfellow '16 OpenAI presentation

  • The objective of the discriminator is defined in the slide above (the ratio), the discriminator is essentially trying to tell the difference between Model generate input and actual training set input. 
  • So, if pdata(x) = pmodel(x), then D(x) is 0.5. The discriminator cannot differentiate between fake and real. This is the goal of the optimization here. 
  • Note: At the same time, the generator is trying to move the model and data distribution as close to each other as possible.




Stay tuned for more..


References:
https://skymind.ai/wiki/generative-adversarial-network-gan
https://medium.com/@jonathan_hui/gan-whats-generative-adversarial-networks-and-its-application-f39ed278ef09




Comments