In 2014, Ian Goodfellow, who is currently working as Director of Machine Learning for Apple, published a paper called "Generative Adversarial Networks" or GAN for short. Which basically talked about a system of two neural networks, Generator and Discriminator, that can generate images or any data that is similar to provided datasets from essentially random noise.
The diagram below shows the basic architecture of a GAN.
Source: https://developers.google.com/
Generative Network takes some random noise and outputs some random noise. This output noise is passed to a discriminator along with the real image or data as a ground truth based on that both Discriminator and Generator is trained.
As you can see, the concept of GAN is very simple. When I started learning about GANs I thought I can easily implement one of those. But boy I was wrong. In reality training a GAN is extremely hard both Generator and Discriminator must be trained side by side if one overpowers the other it won't work, we will talk about all the problems you might face while training a GAN. But now let's look at a simple GAN in Pytorch. We will be using the MNIST Dataset1 for this post.
Understanding the Dataset
The MNIST dataset is a huge database of handwritten numbers from 0 to 9 used for Optical Character Recognition or reading numbers from an image.
This dataset consists of 28x28 images of handwritten numbers where each pixel contains either a zero or a one.
The computer vision extension of PyTorch, Torchvision, provides this dataset which we can download using the following code snippet.
mnist = datasets.MNIST(root='datasets', train=True, transform=transformations, download=True)
But before we can run this code, we need to import some libraries.
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
You can install the torchvision library using this pip command.
pip install torchvision
The DataLoader class is used to load the data to memory in batches, this prevents your system from running out of memory while training.
Transforms class is used to make random augmentations to the image such as random rotation, resize, crop, etc. but for now we will only normalise the images to range from -1 to 1.
The entire function that you can copy paste is this.
def load_dataset():
transformations = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, ), (0.5, ))
])
mnist = datasets.MNIST(root='datasets', train=True, transform=transformations, download=True)
return DataLoader(mnist, batch_size=32, shuffle=True)
Implementing the Generator
In this section, we will implement a generator network which will take a random noise vector of size 100 and convert it into a vector of size 784 which we will then convert to 28x28
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.gen = nn.Sequential(
nn.Linear(100, 256),
nn.LeakyReLU(0.01),
nn.Linear(256, 784),
nn.Tanh(), # make outputs [-1, 1]
)
def forward(self, x):
return self.gen(x)
We will be implementing the original GAN created by Ian Goodfellow in this paper 2
Implemeting the Discriminator
Next, we will implement a discriminator that will take the vector of size 784 that may be generated or from a real image.
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.disc = nn.Sequential(
nn.Linear(784, 128),
nn.LeakyReLU(0.01),
nn.Linear(128, 1),
nn.Sigmoid(), # make outputs [0, 1]
)
def forward(self, x):
return self.disc(x)
Hyperparameters
As I mentioned earlier, GANs are extremely hard to train, one of the reasons is that a GAN is very sensitive to the initial values or hyperparameters. You have to follow the original papers to get the training right. Otherwise, you might have to spend many days optimizing the hyperparameters.
config = {
'device': "cuda" if torch.cuda.is_available() else "cpu",
'lr': 3e-4,
'epochs': 50,
'batch_size': 32
}
In our case, we will be using the same parameters as it was said in the paper, that is, a learning rate of 0.0003 and 50 epochs.
What is a learning rate?
Simple answer, it's the rate at which a machine learning model learns. Smaller the number, the slower it learns and higher the number the faster it learns. For more info, check out our post about Perceptrons.
Training Time
Finally, it's time to actually generate some handwritten numbers, that is, train our GAN. First let's create the objects for Generator, Discriminator, and their optimizers, we will be using the Adam optimizers.
ADAM or ADAptive Moment optimiser is an algorithm3 used to update the weights such that the overall error goes down.
disc = Discriminator().to(config['device'])
gen = Generator().to(config['device'])
optimiser_g = optim.Adam(params=gen.parameters(), lr=config['lr'])
optimiser_d = optim.Adam(params=disc.parameters(), lr=config['lr'])
We also need to define a loss function before we train, we will be using the Binary Cross Entropy loss4 function to calculate the error of our model.
Binary Cross Entropy (BCE) Loss
This loss formula is used to calculate the distance between two probability distributions.
loss_fn = nn.BCELoss()
Now we will write our training step which is one cycle of generating and discriminating a handwritten number compare with an original number and update the models.
for epoch in range(config['epoch']):
for batch_idx, (real, label) in enumerate(train_data):
noise = torch.randn(config['batch_size'], 100).to(config['device']) # Create a random probability distribution
fake = gen(noise) # Generate a fake number
disc_real = disc(real).view(-1) # pass the real number through the discriminator
lossD_real = loss_fn(disc_real, torch.ones_like(disc_real)) # calculate the loss for real image
disc_fake = disc(fake).view(-1) # pass the fake number through the discriminator
lossD_fake = loss_fn(disc_fake, torch.zeros_like(disc_fake)) # calculate the loss for fake image
lossD = (lossD_real + lossD_fake) / 2 # calculate the average loss
# update the weights for the discriminator
disc.zero_grad()
lossD.backward(retain_graph=True)
optimiser_d.step()
output = disc(fake).view(-1)
lossG = loss_fn(output, torch.ones_like(output)) # calculate the error between the fake image and the true image
# update the weights for the generator
gen.zero_grad()
lossG.backward()
optimiser_g.step()
This code is heavily inspired by a youtube video by Aladdin Persson.
Find the complete code at our Github Repo
If you have any questions, feel free to comment below.
Footnotes
- MNIST Database - “THE MNIST DATABASE.” MNIST Handwritten Digit Database, Yann LeCun, Corinna Cortes and Chris Burges, yann.lecun.com/exdb/mnist/.↩
- Paper by Ian Goodfellow - Goodfellow , Ian J., et al. “Generative Adversarial Networks.” ArXiv.org, 10 June 2014, arxiv.org/abs/1406.2661v1.↩
- Adam: A Method for Stochastic Optimization - Kingma, Diederik P., and Jimmy Ba. “Adam: A Method for Stochastic Optimization.” ArXiv.org, 30 Jan. 2017, arxiv.org/abs/1412.6980.↩
- Binary Cross-Entropy Loss - Understanding Categorical Cross-Entropy Loss, Binary Cross-Entropy Loss, Softmax Loss, Logistic Loss, Focal Loss and All Those Confusing Names, gombru.github.io/2018/05/23/cross_entropy_loss/↩