Generative Adversarial Networks
This tutorial describes how to implement a vanilla Generative Adversarial Network using Flux and how train it on the MNIST dataset. It is based on this Pytorch tutorial. The original GAN paper by Goodfellow et al. is a great resource that describes the motivation and theory behind GANs:
In the proposed adversarial nets framework, the generative model is pitted against an adversary: a discriminative model that learns to determine whether a sample is from the model distribution or the data distribution. The generative model can be thought of as analogous to a team of counterfeiters, trying to produce fake currency and use it without detection, while the discriminative model is analogous to the police, trying to detect the counterfeit currency. Competition in this game drives both teams to improve their methods until the counterfeits are indistinguishable from the genuine articles.
Let’s implement a GAN in Flux. To get started we first import a few useful packages:
using MLDatasets: MNIST using Flux.Data: DataLoader using Flux using CUDA using Zygote using UnicodePlots
To download a package in the Julia REPL, type
] to enter package mode and then
add MLDatasets or perform this operation with the Pkg module like this
> import Pkg > Pkg.add(MLDatasets)
While UnicodePlots is not necessary, it can be used to plot generated samples into the terminal during training. Having direct feedback, instead of looking at plots in a separate window, use fantastic for debugging.
Next, let us define values for learning rate, batch size, epochs, and other hyper-parameters. While we are at it, we also define optimizers for the generator and discriminator network. More on what these are later.
lr_g = 2e-4 # Learning rate of the generator network lr_d = 2e-4 # Learning rate of the discriminator network batch_size = 128 # batch size num_epochs = 1000 # Number of epochs to train for output_period = 100 # Period length for plots of generator samples n_features = 28 * 28# Number of pixels in each sample of the MNIST dataset latent_dim = 100 # Dimension of latent space opt_dscr = ADAM(lr_d)# Optimizer for the discriminator opt_gen = ADAM(lr_g) # Optimizer for the generator
In this tutorial I’m assuming that a CUDA-enabled GPU is available on the
system where the script is running. If this is not the case, simply remove
|>gpu decorators: piping.
The MNIST data set is available from MLDatasets. The first time you instantiate it you will be prompted if you want to download it. You should agree to this.
GANs can be trained unsupervised. Therefore only keep the images from the training set and discard the labels.
After we load the training data we re-scale the data from values in [0:1] to values in [-1:1]. GANs are notoriously tricky to train and this re-scaling is a recommended GAN hack. The re-scaled data is used to define a data loader which handles batching and shuffling the data.
# Load the dataset train_x, _ = MNIST.traindata(Float32); # This dataset has pixel values ∈ [0:1]. Map these to [-1:1] train_x = 2f0 * reshape(train_x, 28, 28, 1, :) .- 1f0 |>gpu; # DataLoader allows to access data batch-wise and handles shuffling. train_loader = DataLoader(train_x, batchsize=batch_size, shuffle=true);
Defining the Networks
Here, the coefficient
α (in the
leakyrelu below), is set to 0.2. Empirically,
this value allows for good training of the network (based on prior experiments). It has also been found that Dropout ensures a good generalization of the learned network, so we will use that below. As a final non-linearity, we use the
discriminator = Chain(Dense(n_features, 1024, x -> leakyrelu(x, 0.2f0)), Dropout(0.3), Dense(1024, 512, x -> leakyrelu(x, 0.2f0)), Dropout(0.3), Dense(512, 256, x -> leakyrelu(x, 0.2f0)), Dropout(0.3), Dense(256, 1, sigmoid)) |> gpu
Let’s define the generator in a similar fashion. This network maps a latent
variable (a variable that is not directly observed but instead inferred) to the
image space and we set the input and output dimension accordingly. A
the output of the final layer to values in [-1:1], the same range that we squashed
the training data onto.
generator = Chain(Dense(latent_dim, 256, x -> leakyrelu(x, 0.2f0)), Dense(256, 512, x -> leakyrelu(x, 0.2f0)), Dense(512, 1024, x -> leakyrelu(x, 0.2f0)), Dense(1024, n_features, tanh)) |> gpu
Training functions for the networks
To train the discriminator, we present it with real data from the MNIST
data set and with fake data and reward it by predicting the correct labels for
each sample. The correct labels are of course 1 for in-distribution data
and 0 for out-of-distribution data coming from the generator.
Binary cross entropy
is the loss function of choice. While the Flux documentation suggests to use
Logit binary cross entropy,
the GAN seems to be difficult to train with this loss function.
This function returns the discriminator loss for logging purposes. We can
calculate the loss in the same call as evaluating the pullback and resort
to getting the pullback directly from Zygote instead of calling
Flux.train! on the model. To calculate the gradients of the loss
function with respect to the parameters of the discriminator we then only have to
evaluate the pullback with a seed gradient of 1.0. These gradients are used
to update the model parameters
function train_dscr!(discriminator, real_data, fake_data) this_batch = size(real_data)[end] # Number of samples in the batch # Concatenate real and fake data into one big vector all_data = hcat(real_data, fake_data) # Target vector for predictions: 1 for real data, 0 for fake data. all_target = [ones(eltype(real_data), 1, this_batch) zeros(eltype(fake_data), 1, this_batch)] |> gpu; ps = Flux.params(discriminator) loss, pullback = Zygote.pullback(ps) do preds = discriminator(all_data) loss = Flux.Losses.binarycrossentropy(preds, all_target) end # To get the gradients we evaluate the pullback with 1.0 as a seed gradient. grads = pullback(1f0) # Update the parameters of the discriminator with the gradients we calculated above Flux.update!(opt_dscr, Flux.params(discriminator), grads) return loss end
Now we need to define a function to train the generator network. The job of the generator is to fool the discriminator so we reward the generator when the discriminator predicts a high probability for its samples to be real data. In the training function we first need to sample some noise, i.e. normally distributed data. This has to be done outside the pullback since we don’t want to get the gradients with respect to the noise, but to the generator parameters. Then we evaluate the pullback, call it with a seed gradient of 1.0 as above, update the parameters of the generator network and return the loss.
function train_gen!(discriminator, generator) # Sample noise noise = randn(latent_dim, batch_size) |> gpu; # Define parameters and get the pullback ps = Flux.params(generator) # Evaluate the loss function while calculating the pullback. We get the loss for free loss, back = Zygote.pullback(ps) do preds = discriminator(generator(noise)); loss = Flux.Losses.binarycrossentropy(preds, 1.) end # Evaluate the pullback with a seed-gradient of 1.0 to get the gradients for # the parameters of the generator grads = back(1.0f0) Flux.update!(opt_gen, Flux.params(generator), grads) return loss end
Now we are ready to train the GAN. In the training loop we keep track of the per-sample loss of the generator and the discriminator, where we use the batch loss returned by the two training functions defined above. In each epoch we iterate over the mini-batches given by the data loader. Only minimal data processing needs to be done before the training functions can be called.
lossvec_gen = zeros(num_epochs) lossvec_dscr = zeros(num_epochs) for n in 1:num_epochs loss_sum_gen = 0.0f0 loss_sum_dscr = 0.0f0 for x in train_loader # - Flatten the images from 28x28xbatchsize to 784xbatchsize real_data = flatten(x); # Train the discriminator noise = randn(latent_dim, size(x)[end]) |> gpu fake_data = generator(noise) loss_dscr = train_dscr!(discriminator, real_data, fake_data) loss_sum_dscr += loss_dscr # Train the generator loss_gen = train_gen!(discriminator, generator) loss_sum_gen += loss_gen end # Add the per-sample loss of the generator and discriminator lossvec_gen[n] = loss_sum_gen / size(train_x)[end] lossvec_dscr[n] = loss_sum_dscr / size(train_x)[end] if n % output_period == 0 @show n noise = randn(latent_dim, 4) |> gpu; fake_data = reshape(generator(noise), 28, 4*28); p = heatmap(fake_data, colormap=:inferno) print(p) end end
For the hyper-parameters shown in this example, the generator produces useful
images after about 1000 epochs. And after about 5000 epochs the result look
indistinguishable from real MNIST data. Using a Nvidia V100 GPU on a 2.7
GHz Power9 CPU with 32 hardware threads, training 100 epochs takes about 80
seconds when using the GPU. The GPU utilization is between 30 and 40%.
To observe the network more frequently during training you can for example set
output_period=20. Training the GAN using the CPU takes about 10 minutes
per epoch and is not recommended.
Below you can see what some of the images output may look like after different numbers of epochs.