Maybe-Ray

All GANS No Brakes

This is part 8 and the last entry of my series. In the previous post Generating Human Faces with Variational Autoencoders, we took a deep dive into Variational Autoencoders (VAE) and trained one to generate Human faces.

GANS are conceptually easier to get your head around compared to VAEs and have been used to generate some of the most realistic images that can fool human beings. Please check out This People does not exist to see how good the images generated by GANS are.

In this post, we will first go over the basic concepts needed to understand GANs. Secondly, we will go through the training algorithm for GANs and investigate the problems associated with training them. Lastly, we will train two GAN models, one that generates Mnists digits and another that generates human faces.

GAN Overview

GANs can be seen as one big game between two models which are trying to outcompete each other. A mental model to use is to think of a criminal who is counterfeiting money and a police officer who is trying to distinguish between real and fake money. The criminal is trying to re-create money that could fool the police officer, while the police officer is trying to get good at recognising real and fake money.

The Criminal in GAN terms would be the generator trying to recreate realistic-looking data or images, in our case. While the police officer is the discriminator, trying to figure out what data is real or fake. GANs solve a big problem in the image processing field where loss functions, such as Mean Squared Error (MSE), do not take into account human perception of the quality of images. In the previous post, we saw that VAEs generated blurry images, and this was partly due to MSE being used as a loss function. The Discriminator acts as a pixel-wise loss function that can generate good-quality images.

The Training Loop of a GAN

Step 1: Train the Discriminator

  1. Generate a random noise vector z (usually Gaussian)
  2. Generate some fake images by passing z into the Generator
  3. Pass fake images into Discriminator
  4. Use binary cross-entropy with the output of the discriminator and zero labels to calculate the loss
  5. Get a batch of images from the dataset x
  6. Pass real images into Discriminator
  7. Use binary cross-entropy with the output of the discriminator and one labels to calculate the loss
  8. Add real and fake loss
  9. Calculate gradients and perform backpropagation

Step 2: Train the Generator

  1. Generate a random noise vector z
  2. Generate some fake images by passing z into the Generator
  3. Pass fake images into Discriminator
  4. Use binary cross-entropy with the output of the discriminator and one labels to calculate the loss (since you are trying to fool the discriminator)
  5. Calculate gradients and perform backpropagation

Why are GANs hard to train?

GANs are notoriously hard to train; any slight misconfigurations of the hyperparameters and model structure can lead to model collapse. This happens for several reasons.

  1. Training is a balancing act: both models need to learn at relatively the same pace. If one of the models becomes too powerful, it can throw the other model off and lead to high gradients. A state of equilibrium is needed for both models in order for them to maintain stability while training.

  2. Model Collapse: This happens when the Generator starts outputting a few similar images that are good enough to fool the Discriminator. In this case, the Generator does not learn the diverse distribution of the images, but just generates output that minimises its loss.

Generating MNIST Digits

For this part, we will train a simple GAN that can generate digits from the MNIST dataset. All the code for this section can be found in this Kaggle notebook. The Generator and Discriminator are both Linear models with a latent space of 100 vectors.

latent_space = 100

dicrimintor = nn.Sequential(
    nn.Flatten(),
    nn.Linear(28*28, 128),
    nn.LeakyReLU(0.2),
    nn.Linear(128, 64),
    nn.LeakyReLU(0.2),
    nn.Linear(64, 32),
    nn.LeakyReLU(0.2),
    nn.Linear(32, 1),
).to(device)

Generator =  nn.Sequential(
    nn.Linear(latent_space, 32),
    nn.BatchNorm1d(32),
    nn.LeakyReLU(0.2),
    nn.Linear(32, 64),
    nn.BatchNorm1d(64),
    nn.LeakyReLU(0.2),
    nn.Linear(64, 128),
    nn.BatchNorm1d(128),
    nn.LeakyReLU(0.2),
    nn.Linear(128, 28*28),
    nn.Tanh(),
).to(device)

The discriminator model returns raw logits because I used nn.BCEWithLogitsLoss() as a loss function. You could add the Sigmoid activation function at the end of the model if need be, and use the normal BCELoss function.

The training functions for both models are given below. Remember, in the training loop, the generator is trained first and then the discriminator.

def train_Generator(Generator, discriminator, loss_fn, optim, latent_space,batch_shape):
  Generator.train()
  discriminator.eval()

  rand_noise = torch.normal(0, 1, size=(batch_shape, latent_space)).to(device)

  fake_images = Generator(rand_noise).view(-1, 1, 28, 28).to(device)

  pred = discriminator(fake_images)
  y_values = torch.ones(batch_shape, 1).to(device)

  loss = loss_fn(pred, y_values)

  optim.zero_grad()
  loss.backward()
  optim.step()

  return loss
def train_Discriminator(Generator, discriminator, loss_fn, x, optim, latent_space, batch_shape):
    Generator.eval()
    discriminator.train()
    
    # Generating fake images 
    rand_noise = torch.normal(0, 1, size=(batch_shape, latent_space)).to(device)
    fake_images = Generator(rand_noise).view(-1, 1, 28, 28).detach().to(device)
    y_fake = torch.zeros(batch_shape, 1).to(device)
    
    # Train fake
    fake_pred = discriminator(fake_images)
    fake_loss = loss_fn(fake_pred, y_fake)


    #Train real
    y_true = torch.ones(batch_shape, 1).to(device)
    pred = discriminator(x)
    real_loss = loss_fn(pred, y_true)


    # combined loss
    loss = real_loss + fake_loss


    optim.zero_grad()
    loss.backward()
    optim.step()
    
    return loss

Sample Digits

Below are some example images generated from the generator. As you can see, images generated by the GAN are higher quality compared to the ones generated with VAEs from the previous post.

GAN Generated Digits

Human Faces with DCGANS

In this section, we will generate some human faces using GANs. We will be using Deep Convolutional GANS (DCGAN), which was introduced in the paper Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks. All the code for this section can be found in the Kaggle Notebook.

Our Models

This is what our discriminator looks like:

class Discrimninator(nn.Module):

  def __init__(self):
    super().__init__()
    self.features = nn.Sequential(
        nn.Conv2d(3, 128, (4, 4), stride=2, padding=1),
        nn.BatchNorm2d(128),
        nn.LeakyReLU(0.2,  inplace=True),
        nn.Conv2d(128, 256,(4, 4), stride=2, padding=1),
        nn.BatchNorm2d(256),
        nn.LeakyReLU(0.2,  inplace=True),
        nn.Conv2d(256, 512,(4, 4), stride=2, padding=1),
        nn.BatchNorm2d(512),
        nn.LeakyReLU(0.2,  inplace=True),
        nn.Conv2d(512, 1, (4, 4)),
    )

  def forward(self, x):
    return self.features(x)

This is what our generator will look like:

class Generator(nn.Module):

  def __init__(self):
    super().__init__()
    self.features = nn.Sequential(
        nn.ConvTranspose2d(100, 1024, (4, 4), stride=2),
        nn.ReLU(inplace=True),
        nn.ConvTranspose2d(1024, 512,(4, 4), stride=2, padding=1),
        nn.BatchNorm2d(512),
        nn.ReLU(inplace=True),
        nn.ConvTranspose2d(512, 256, (4, 4), stride=2, padding=1),
        nn.BatchNorm2d(256),
        nn.ReLU(inplace=True),
        nn.ConvTranspose2d(256, 128,(4, 4), stride=2, padding=1),
        nn.BatchNorm2d(128),
        nn.ReLU(inplace=True),
        nn.ConvTranspose2d(128, 3, (4, 4), stride=2, padding=1),
        nn.Tanh()
    )

  def forward(self, x):
    return self.features(x)

Our training loops will be structured as follows

def train_Generator(Generator, discriminator, loss_fn, optim, batch_shape):
    x = torch.randn(batch_shape, 100, 1, 1).to(device)
    
    fake_images = Generator(x).to(device)
    pred = discriminator(fake_images)
    
    y_values = torch.ones_like(pred).to(device)
    
    gen_loss = loss_fn(pred, y_values)

    
    optim.zero_grad()
    gen_loss.backward()
    optim.step()
    return gen_loss
def train_Discriminator(Generator, discriminator, real, loss_fn, optim, batch_shape):
  
    
    # Generating fake images
    with torch.no_grad():
        x = torch.randn(batch_shape, 100, 1, 1).to(device)
        fake_images = Generator(x).detach().to(device)

    # Train fake
    fake_pred = discriminator(fake_images.detach())
    fake_labels = torch.zeros_like(fake_pred)
    fake_loss = loss_fn(fake_pred, fake_labels)

    # Train real
    pred = discriminator(real)
    real_labels = torch.ones_like(pred)
    real_loss = loss_fn(pred, real_labels)

    
    # combined loss
    loss = real_loss + fake_loss
    optim.zero_grad()
    loss.backward()
    optim.step()
    
    return loss

Sample of Faces

Here are some example faces generated by the generator. As you can see, we generated higher-quality images compared to what we got with our VAE implementation in our previous post. The images from GANS are not blurry and look somewhat realistic.

DCGAN Generated Faces

Playing around with latent space

In this section, we will generate two faces and play around with them. We will do this by interpolating the vector representations of the faces.

So our first step is to generate two vectors and feed them into our Generator. Below are our two faces, generated, that we will be using to play around with:

DCGAN Face samples

For now, let's interpolate the values between the two vectors and generate images that lie between them. As you can see below, our generator is slowly morphing Face 1 into Face 2.

DCGAN Linespace Faces

Now, let's see what happens when we add both the vectors together to see what image the model will generate.

DCGAN Faces Combined

Conclusion

GANS are amazing at generating realistic-looking images. They often fall short when it comes to how unstable they are to train. Researchers have proposed other variants, such as LSGANS and WG-GANS, that stabilise the training process, but instability is still a major concern. Regardless, GANS are still a great model architecture and recommend that people read up on pix2pix and CycleGANs to see what these types of models can do.

This wraps up my series in which I explored Vision models. I have learned a lot, but there is still a lot left to learn. In my time researching these topics, I discovered other types of models, such as Vision Transformers, Diffusion models, and Energy-based models, as well as improvements to the models I have implemented in this series. I plan to continue looking into them, but just not in the same structured way. So in the coming months or years, this series might expand.

Anyways, thank you, guys, for being on this journey with me. You're the best :)

Series

  1. Demystifying LeNet-5: A Deep Dive into CNN Foundations
  2. Exploring pre-trained Convolutional layers and Kernels
  3. AlexNet: My introduction to Deep Computer Vision models
  4. VGG vs Inception: Just how deep can they go?
  5. ResNet and Skip Connections
  6. Denosing Images of Cats and Dogs with Autoencoders
  7. Generating Human Faces with Variational Autoencoders
  8. Generative Adversarial Networks (you are here right now)