Blog by Jason Brownlee. Jason started this blog because he is passionate about helping professional developers to get started and confidently apply machine learning to address complex problems.
The Wasserstein Generative Adversarial Network, or Wasserstein GAN, is an extension to the generative adversarial network that both improves the stability when training the model and provides a loss function that correlates with the quality of generated images.
The development of the WGAN has a dense mathematical motivation, although in practice requires only a few minor modifications to the established standard deep convolutional generative adversarial network, or DCGAN.
In this tutorial, you will discover how to implement the Wasserstein generative adversarial network from scratch.
After completing this tutorial, you will know:
The differences between the standard deep convolutional GAN and the new Wasserstein GAN.
How to implement the specific details of the Wasserstein GAN from scratch.
How to develop a WGAN for image generation and interpret the dynamic behavior of the model.
Discover how to develop DCGANs, conditional GANs, Pix2Pix, CycleGANs, and more with Keras in my new GANs book, with 29 step-by-step tutorials and full source code.
Let’s get started.
How to Code a Wasserstein Generative Adversarial Network (WGAN) From Scratch Photo by Feliciano Guimarães, some rights reserved.
Tutorial Overview
This tutorial is divided into three parts; they are:
Wasserstein Generative Adversarial Network
Wasserstein GAN Implementation Details
How to Train a Wasserstein GAN Model
Wasserstein Generative Adversarial Network
The Wasserstein GAN, or WGAN for short, was introduced by Martin Arjovsky, et al. in their 2017 paper titled “Wasserstein GAN.”
It is an extension of the GAN that seeks an alternate way of training the generator model to better approximate the distribution of data observed in a given training dataset.
Instead of using a discriminator to classify or predict the probability of generated images as being real or fake, the WGAN changes or replaces the discriminator model with a critic that scores the realness or fakeness of a given image.
This change is motivated by a theoretical argument that training the generator should seek a minimization of the distance between the distribution of the data observed in the training dataset and the distribution observed in generated examples.
The benefit of the WGAN is that the training process is more stable and less sensitive to model architecture and choice of hyperparameter configurations. Perhaps most importantly, the loss of the discriminator appears to relate to the quality of images created by the generator.
Wasserstein GAN Implementation Details
Although the theoretical grounding for the WGAN is dense, the implementation of a WGAN requires a few minor changes to the standard Deep Convolutional GAN, or DCGAN.
The image below provides a summary of the main training loop for training a WGAN, taken from the paper. Note the listing of recommended hyperparameters used in the model.
Algorithm for the Wasserstein Generative Adversarial Networks. Taken from: Wasserstein GAN.
The differences in implementation for the WGAN are as follows:
Use a linear activation function in the output layer of the critic model (instead of sigmoid).
Use -1 labels for real images and 1 labels for fake images (instead of 1 and 0).
Use Wasserstein loss to train the critic and generator models.
Constrain critic model weights to a limited range after each mini batch update (e.g. [-0.01,0.01]).
Update the critic model more times than the generator each iteration (e.g. 5).
Use the RMSProp version of gradient descent with a small learning rate and no momentum (e.g. 0.00005).
Using the standard DCGAN model as a starting point, let’s take a look at each of these implementation details in turn.
Want to Develop GANs from Scratch?
Take my free 7-day email crash course now (with sample code).
Click to sign-up and also get a free PDF Ebook version of the course.
The DCGAN uses the sigmoid activation function in the output layer of the discriminator to predict the likelihood of a given image being real.
In the WGAN, the critic model requires a linear activation to predict the score of “realness” for a given image.
This can be achieved by setting the ‘activation‘ argument to ‘linear‘ in the output layer of the critic model.
# define output layer of the critic model
...
model.add(Dense(1, activation='linear'))
The linear activation is the default activation for a layer, so we can, in fact, leave the activation unspecified to achieve the same result.
# define output layer of the critic model
...
model.add(Dense(1))
2. Class Labels for Real and Fake Images
The DCGAN uses the class 0 for fake images and class 1 for real images, and these class labels are used to train the GAN.
In the DCGAN, these are precise labels that the discriminator is expected to achieve. The WGAN does not have precise labels for the critic. Instead, it encourages the critic to output scores that are different for real and fake images.
This is achieved via the Wasserstein function that cleverly makes use of positive and negative class labels.
The WGAN can be implemented where -1 class labels are used for real images and +1 class labels are used for fake or generated images.
...
# generate class labels, -1 for 'real'
y = -ones((n_samples, 1))
...
# create class labels with 1.0 for 'fake'
y = ones((n_samples, 1))
3. Wasserstein Loss Function
The DCGAN trains the discriminator as a binary classification model to predict the probability that a given image is real.
To train this model, the discriminator is optimized using the binary cross entropy loss function. The same loss function is used to update the generator model.
The primary contribution of the WGAN model is the use of a new loss function that encourages the discriminator to predict a score of how real or fake a given input looks. This transforms the role of the discriminator from a classifier into a critic for scoring the realness or fakeness of images, where the difference between the scores is as large as possible.
We can implement the Wasserstein loss as a custom function in Keras that calculates the average score for real or fake images.
The score is maximizing for real examples and minimizing for fake examples. Given that stochastic gradient descent is a minimization algorithm, we can multiply the class label by the mean score (e.g. -1 for real and 1 for fake which as no effect), which ensures that the loss for real and fake images is minimizing to the network.
An efficient implementation of this loss function for Keras is listed below.
from keras import backend
# implementation of wasserstein loss
def wasserstein_loss(y_true, y_pred):
return backend.mean(y_true * y_pred)
This loss function can be used to train a Keras model by specifying the function name when compiling the model.
For example:
...
# compile the model
model.compile(loss=wasserstein_loss, ...)
4. Critic Weight Clipping
The DCGAN does not use any gradient clipping, although the WGAN requires gradient clipping for the critic model.
We can implement weight clipping as a Keras constraint.
This is a class that must extend the Constraint class and define an implementation of the __call__() function for applying the operation and the get_config() function for returning any configuration.
We can also define an __init__() function to set the configuration, in this case, the symmetrical size of the bounding box for the weight hypercube, e.g. 0.01.
The ClipConstraint class is defined below.
# clip model weights to a given hypercube
class ClipConstraint(Constraint):
# set clip value when initialized
def __init__(self, clip_value):
self.clip_value = clip_value
# clip model weights to hypercube
def __call__(self, weights):
return backend.clip(weights, -self.clip_value, self.clip_value)
# get the config
def get_config(self):
return {'clip_value': self.clip_value}
To use the constraint, the class can be constructed, then used in a layer by setting the kernel_constraint argument; for example:
...
# define the constraint
const = ClipConstraint(0.01)
...
# use the constraint in a layer
model.add(Conv2D(..., kernel_constraint=const))
The constraint is only required when updating the critic model.
5. Update Critic More Than Generator
In the DCGAN, the generator and the discriminator model must be updated in equal amounts.
Specifically, the discriminator is updated with a half batch of real and a half batch of fake samples each iteration, whereas the generator is updated with a single batch of generated samples.
For example:
...
# main gan training loop
for i in range(n_steps):
# update the discriminator
# get randomly selected 'real' samples
X_real, y_real = generate_real_samples(dataset, half_batch)
# update critic model weights
c_loss1 = c_model.train_on_batch(X_real, y_real)
# generate 'fake' examples
X_fake, y_fake = generate_fake_samples(g_model, latent_dim, half_batch)
# update critic model weights
c_loss2 = c_model.train_on_batch(X_fake, y_fake)
# update generator
# prepare points in latent space as input for the generator
X_gan = generate_latent_points(latent_dim, n_batch)
# create inverted labels for the fake samples
y_gan = ones((n_batch, 1))
# update the generator via the critic's error
g_loss = gan_model.train_on_batch(X_gan, y_gan)
In the WGAN model, the critic model must be updated more than the generator model.
Specifically, a new hyperparameter is defined to control the number of times that the critic is updated for each update to the generator model, called n_critic, and is set to 5.
This can be implemented as a new loop within the main GAN update loop; for example:
...
# main gan training loop
for i in range(n_steps):
# update the critic
for _ in range(n_critic):
# get randomly selected 'real' samples
X_real, y_real = generate_real_samples(dataset, half_batch)
# update critic model weights
c_loss1 = c_model.train_on_batch(X_real, y_real)
# generate 'fake' examples
X_fake, y_fake = generate_fake_samples(g_model, latent_dim, half_batch)
# update critic model weights
c_loss2 = c_model.train_on_batch(X_fake, y_fake)
# update generator
# prepare points in latent space as input for the generator
X_gan = generate_latent_points(latent_dim, n_batch)
# create inverted labels for the fake samples
y_gan = ones((n_batch, 1))
# update the generator via the critic's error
g_loss = gan_model.train_on_batch(X_gan, y_gan)
The WGAN recommends the use of RMSProp instead, with a small learning rate of 0.00005.
This can be implemented in Keras when the model is compiled. For example:
...
# compile model
opt = RMSprop(lr=0.00005)
model.compile(loss=wasserstein_loss, optimizer=opt)
How to Train a Wasserstein GAN Model
Now that we know the specific implementation details for the WGAN, we can implement the model for image generation.
In this section, we will develop a WGAN to generate a single handwritten digit (‘7’) from the MNIST dataset. This is a good test problem for the WGAN as it is a small dataset requiring a modest mode that is quick to train.
The first step is to define the models.
The critic model takes as input one 28×28 grayscale image and outputs a score for the realness or fakeness of the image. It is implemented as a modest convolutional neural network using best practices for DCGAN design such as using the LeakyReLU activation function with a slope of 0.2, batch normalization, and using a 2×2 stride to downsample.
The critic model makes use of the new ClipConstraint weight constraint to clip model weights after mini-batch updates and is optimized using the custom wasserstein_loss() function, the RMSProp version of stochastic gradient descent with a learning rate of 0.00005.
The define_critic() function below implements this, defining and compiling the critic model and returning it. The input shape of the image is parameterized as a default function argument to make it clear.
# define the standalone critic model
def define_critic(in_shape=(28,28,1)):
# weight initialization
init = RandomNormal(stddev=0.02)
# weight constraint
const = ClipConstraint(0.01)
# define model
model = Sequential()
# downsample to 14x14
model.add(Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init, kernel_constraint=const, input_shape=in_shape))
model.add(BatchNormalization())
model.add(LeakyReLU(alpha=0.2))
# downsample to 7x7
model.add(Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init, kernel_constraint=const))
model.add(BatchNormalization())
model.add(LeakyReLU(alpha=0.2))
# scoring, linear activation
model.add(Flatten())
model.add(Dense(1))
# compile model
opt = RMSprop(lr=0.00005)
model.compile(loss=wasserstein_loss, optimizer=opt)
return model
The generator model takes as input a point in the latent space and outputs a single 28×28 grayscale image.
This is achieved by using a fully connected layer to interpret the point in the latent space and provide sufficient activations that can be reshaped into many copies (in this case, 128) of a low-resolution version of the output image (e.g. 7×7). This is then upsampled two times, doubling the size and quadrupling the area of the activations each time using transpose convolutional layers.
The model uses best practices such as the LeakyReLU activation, a kernel size that is a factor of the stride size, and a hyperbolic tangent (tanh) activation function in the output layer.
The define_generator() function below defines the generator model but intentionally does not compile it as it is not trained directly, then returns the model. The size of the latent space is parameterized as a function argument.
# define the standalone generator model
def define_generator(latent_dim):
# weight initialization
init = RandomNormal(stddev=0.02)
# define model
model = Sequential()
# foundation for 7x7 image
n_nodes = 128 * 7 * 7
model.add(Dense(n_nodes, kernel_initializer=init, input_dim=latent_dim))
model.add(LeakyReLU(alpha=0.2))
model.add(Reshape((7, 7, 128)))
# upsample to 14x14
model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init))
model.add(BatchNormalization())
model.add(LeakyReLU(alpha=0.2))
# upsample to 28x28
model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init))
model.add(BatchNormalization())
model.add(LeakyReLU(alpha=0.2))
# output 28x28x1
model.add(Conv2D(1, (7,7), activation='tanh', padding='same', kernel_initializer=init))
return model
Next, a GAN model can be defined that combines both the generator model and the critic model into one larger model.
This larger model will be used to train the model weights in the generator, using the output and error calculated by the critic model. The critic model is trained separately, and as such, the model weights are marked as not trainable in this larger GAN model to ensure that only the weights of the generator model are updated. This change to the trainability of the critic weights only has an effect when training the combined GAN model, not when training the critic standalone.
This larger GAN model takes as input a point in the latent space, uses the generator model to generate an image, which is fed as input to the critic model, then output scored as real or fake. The model is fit using RMSProp with the custom wasserstein_loss() function.
The define_gan() function below implements this, taking the already defined generator and critic models as input.
# define the combined generator and critic model, for updating the generator
def define_gan(generator, critic):
# make weights in the critic not trainable
critic.trainable = False
# connect them
model = Sequential()
# add generator
model.add(generator)
# add the critic
model.add(critic)
# compile model
opt = RMSprop(lr=0.00005)
model.compile(loss=wasserstein_loss, optimizer=opt)
return model
Now that we have defined the GAN model, we need to train it. But, before we can train the model, we require input data.
The first step is to load and scale the MNIST dataset. The whole dataset is loaded via a call to the load_data() Keras function, then a subset of the images is selected (about 5,000) that belongs to class 7, e.g. are a handwritten depiction of the number seven. Then the pixel values must be scaled to the range [-1,1] to match the output of the generator model.
The load_real_samples() function below implements this, returning the loaded and scaled subset of the MNIST training dataset ready for modeling.
# load images
def load_real_samples():
# load dataset
(trainX, trainy), (_, _) = load_data()
# select all of the examples for a given class
selected_ix = trainy == 7
X = trainX[selected_ix]
# expand to 3d, e.g. add channels
X = expand_dims(X, axis=-1)
# convert from ints to floats
X = X.astype('float32')
# scale from [0,255] to [-1,1]
X = (X - 127.5) / 127.5
return X
We will require one batch (or a half) batch of real images from the dataset each update to the GAN model. A simple way to achieve this is to select a random sample of images from the dataset each time.
The generate_real_samples() function below implements this, taking the prepared dataset as an argument, selecting and returning a random sample of images and their corresponding label for the critic, specifically target=-1 indicating that they are real images.
# select real samples
def generate_real_samples(dataset, n_samples):
# choose random instances
ix = randint(0, dataset.shape[0], n_samples)
# select images
X = dataset[ix]
# generate class labels, -1 for 'real'
y = -ones((n_samples, 1))
return X, y
The generate_latent_points() function implements this, taking the size of the latent space as an argument and the number of points required, and returning them as a batch of input samples for the generator model.
# generate points in latent space as input for the generator
def generate_latent_points(latent_dim, n_samples):
# generate points in the latent space
x_input = randn(latent_dim * n_samples)
# reshape into a batch of inputs for the network
x_input = x_input.reshape(n_samples, latent_dim)
return x_input
Next, we need to use the points in the latent space as input to the generator in order to generate new images.
The generate_fake_samples() function below implements this, taking the generator model and size of the latent space as arguments, then generating points in the latent space and using them as input to the generator model.
The function returns the generated images and their corresponding label for the critic model, specifically target=1 to indicate they are fake or generated.
# use the generator to generate n fake examples, with class labels
def generate_fake_samples(generator, latent_dim, n_samples):
# generate points in latent space
x_input = generate_latent_points(latent_dim, n_samples)
# predict outputs
X = generator.predict(x_input)
# create class labels with 1.0 for 'fake'
y = ones((n_samples, 1))
return X, y
We need to record the performance of the model. Perhaps the most reliable way to evaluate the performance of a GAN is to use the generator to generate images, and then review and subjectively evaluate them.
The summarize_performance() function below takes the generator model at a given point during training and uses it to generate 100 images in a 10×10 grid, that are then plotted and saved to file. The model is also saved to file at this time, in case we would like to use it later to generate more images.
# generate samples and save as a plot and save the model
def summarize_performance(step, g_model, latent_dim, n_samples=100):
# prepare fake examples
X, _ = generate_fake_samples(g_model, latent_dim, n_samples)
# scale from [-1,1] to [0,1]
X = (X + 1) / 2.0
# plot images
for i in range(10 * 10):
# define subplot
pyplot.subplot(10, 10, 1 + i)
# turn off axis
pyplot.axis('off')
# plot raw pixel data
pyplot.imshow(X[i, :, :, 0], cmap='gray_r')
# save plot to file
filename1 = 'generated_plot_%04d.png' % (step+1)
pyplot.savefig(filename1)
pyplot.close()
# save the generator model
filename2 = 'model_%04d.h5' % (step+1)
g_model.save(filename2)
print('>Saved: %s and %s' % (filename1, filename2))
In addition to image quality, it is a good idea to keep track of the loss and accuracy of the model over time.
The loss for the critic for real and fake samples can be tracked for each model update, as can the loss for the generator for each update. These can then be used to create line plots of loss at the end of the training run. The plot_history() function..
The Wasserstein Generative Adversarial Network, or Wasserstein GAN, is an extension to the generative adversarial network that both improves the stability when training the model and provides a loss function that correlates with the quality of generated images.
It is an important extension to the GAN model and requires a conceptual shift away from a discriminator that predicts the probability of a generated image being “real” and toward the idea of a critic model that scores the “realness” of a given image.
This conceptual shift is motivated mathematically using the earth mover distance, or Wasserstein distance, to train the GAN that measures the distance between the data distribution observed in the training dataset and the distribution observed in the generated examples.
In this post, you will discover how to implement Wasserstein loss for Generative Adversarial Networks.
After reading this post, you will know:
The conceptual shift in the WGAN from discriminator predicting a probability to a critic predicting a score.
The implementation details for the WGAN as minor changes to the standard deep convolutional GAN.
The intuition behind the Wasserstein loss function and how implement it from scratch.
Discover how to develop DCGANs, conditional GANs, Pix2Pix, CycleGANs, and more with Keras in my new GANs book, with 29 step-by-step tutorials and full source code.
Let’s get started.
How to Implement Wasserstein Loss for Generative Adversarial Networks Photo by Brandon Levinger, some rights reserved.
Overview
This tutorial is divided into five parts; they are:
GAN Stability and the Discriminator
What Is a Wasserstein GAN?
Implementation Details of the Wasserstein GAN
How to Implement Wasserstein Loss
Common Point of Confusion With Expected Labels
GAN Stability and the Discriminator
Generative Adversarial Networks, or GANs, are challenging to train.
The discriminator model must classify a given input image as real (from the dataset) or fake (generated), and the generator model must generate new and plausible images.
The reason GANs are difficult to train is that the architecture involves the simultaneous training of a generator and a discriminator model in a zero-sum game. Stable training requires finding and maintaining an equilibrium between the capabilities of the two models.
The discriminator model is a neural network that learns a binary classification problem, using a sigmoid activation function in the output layer, and is fit using a binary cross entropy loss function. As such, the model predicts a probability that a given input is real (or fake as 1 minus the predicted) as a value between 0 and 1.
The loss function has the effect of penalizing the model proportionally to how far the predicted probability distribution differs from the expected probability distribution for a given image. This provides the basis for the error that is back propagated through the discriminator and the generator in order to perform better on the next batch.
The WGAN relaxes the role of the discriminator when training a GAN and proposes the alternative of a critic.
What Is a Wasserstein GAN?
The Wasserstein GAN, or WGAN for short, was introduced by Martin Arjovsky, et al. in their 2017 paper titled “Wasserstein GAN.”
It is an extension of the GAN that seeks an alternate way of training the generator model to better approximate the distribution of data observed in a given training dataset.
Instead of using a discriminator to classify or predict the probability of generated images as being real or fake, the WGAN changes or replaces the discriminator model with a critic that scores the realness or fakeness of a given image.
This change is motivated by a mathematical argument that training the generator should seek a minimization of the distance between the distribution of the data observed in the training dataset and the distribution observed in generated examples. The argument contrasts different distribution distance measures, such as Kullback-Leibler (KL) divergence, Jensen-Shannon (JS) divergence, and the Earth-Mover (EM) distance, referred to as Wasserstein distance.
The most fundamental difference between such distances is their impact on the convergence of sequences of probability distributions.
They demonstrate that a critic neural network can be trained to approximate the Wasserstein distance, and, in turn, used to effectively train a generator model.
… we define a form of GAN called Wasserstein-GAN that minimizes a reasonable and efficient approximation of the EM distance, and we theoretically show that the corresponding optimization problem is sound.
Importantly, the Wasserstein distance has the properties that it is continuous and differentiable and continues to provide a linear gradient, even after the critic is well trained.
The fact that the EM distance is continuous and differentiable a.e. means that we can (and should) train the critic till optimality. […] the more we train the critic, the more reliable gradient of the Wasserstein we get, which is actually useful by the fact that Wasserstein is differentiable almost everywhere.
This is unlike the discriminator model that, once trained, may fail to provide useful gradient information for updating the generator model.
The discriminator learns very quickly to distinguish between fake and real, and as expected provides no reliable gradient information. The critic, however, can’t saturate, and converges to a linear function that gives remarkably clean gradients everywhere.
The benefit of the WGAN is that the training process is more stable and less sensitive to model architecture and choice of hyperparameter configurations.
… training WGANs does not require maintaining a careful balance in training of the discriminator and the generator, and does not require a careful design of the network architecture either. The mode dropping phenomenon that is typical in GANs is also drastically reduced.
Perhaps most importantly, the loss of the discriminator appears to relate to the quality of images created by the generator.
Specifically, the lower the loss of the critic when evaluating generated images, the higher the expected quality of the generated images. This is important as unlike other GANs that seek stability in terms of finding an equilibrium between two models, the WGAN seeks convergence, lowering generator loss.
To our knowledge, this is the first time in GAN literature that such a property is shown, where the loss of the GAN shows properties of convergence. This property is extremely useful when doing research in adversarial networks as one does not need to stare at the generated samples to figure out failure modes and to gain information on which models are doing better over others.
Although the theoretical grounding for the WGAN is dense, the implementation of a WGAN requires a few minor changes to the standard deep convolutional GAN, or DCGAN.
Those changes are as follows:
Use a linear activation function in the output layer of the critic model (instead of sigmoid).
Use Wasserstein loss to train the critic and generator models that promote larger difference between scores for real and generated images.
Constrain critic model weights to a limited range after each mini batch update (e.g. [-0.01,0.01]).
In order to have parameters w lie in a compact space, something simple we can do is clamp the weights to a fixed box (say W = [−0.01, 0.01]l ) after each gradient update.
The image below provides a summary of the main training loop for training a WGAN, taken from the paper. Note the listing of recommended hyperparameters used in the model.
Algorithm for the Wasserstein Generative Adversarial Networks. Taken from: Wasserstein GAN.
How to Implement Wasserstein Loss
The Wasserstein loss function seeks to increase the gap between the scores for real and generated images.
We can summarize the function as it is described in the paper as follows:
Critic Loss = [average critic score on real images] – [average critic score on fake images]
Generator Loss = -[average critic score on fake images]
Where the average scores are calculated across a mini-batch of samples.
This is precisely how the loss is implemented for graph-based deep learning frameworks such as PyTorch and TensorFlow.
The calculations are straightforward to interpret once we recall that stochastic gradient descent seeks to minimize loss.
In the case of the generator, a larger score from the critic will result in a smaller loss for the generator, encouraging the critic to output larger scores for fake images. For example, an average score of 10 becomes -10, an average score of 50 becomes -50, which is smaller, and so on.
In the case of the critic, a larger score for real images results in a larger resulting loss for the critic, penalizing the model. This encourages the critic to output smaller scores for real images. For example, an average score of 20 for real images and 50 for fake images results in a loss of -30; an average score of 10 for real images and 50 for fake images results in a loss of -40, which is better, and so on.
The sign of the loss does not matter in this case, as long as loss for real images is a small number and the loss for fake images is a large number. The Wasserstein loss encourages the critic to separate these numbers.
We can also reverse the situation and encourage the critic to output a large score for real images and a small score for fake images and achieve the same result. Some implementations make this change.
In the Keras deep learning library (and some others), we cannot implement the Wasserstein loss function directly as described in the paper and as implemented in PyTorch and TensorFlow. Instead, we can achieve the same effect without having the calculation of the loss for the critic dependent upon the loss calculated for real and fake images.
A good way to think about this is a negative score for real images and a positive score for fake images, although this negative/positive split of scores learned during training is not required; just larger and smaller is sufficient.
Small Critic Score (e.g.< 0): Real – Large Critic Score (e.g. >0): Fake
We can multiply the average predicted score by -1 in the case of fake images so that larger averages become smaller averages and the gradient is in the correct direction, i.e. minimizing loss. For example, average scores on fake images of [0.5, 0.8, and 1.0] across three batches of fake images would become [-0.5, -0.8, and -1.0] when calculating weight updates.
Loss For Fake Images = -1 * Average Critic Score
No change is needed for the case of real scores, as we want to encourage smaller average scores for real images.
Loss For Real Images = Average Critic Score
This can be implemented consistently by assigning an expected outcome target of -1 for fake images and 1 for real images and implementing the loss function as the expected label multiplied by the average score. The -1 label will be multiplied by the average score for fake images and encourage a larger predicted average, and the +1 label will be multiplied by the average score for real images and have no effect, encouraging a smaller predicted average.
Wasserstein Loss = Label * Average Critic Score
Or
Wasserstein Loss(Real Images) = 1 * Average Predicted Score
Wasserstein Loss(Fake Images) = -1 * Average Predicted Score
We can implement this in Keras by assigning the expected labels of -1 and 1 for fake and real images respectively. The inverse labels could be used to the same effect, e.g. -1 for real and +1 for fake to encourage small scores for fake images and large scores for real images. Some developers do implement the WGAN in this alternate way, which is just as correct.
The loss function can be implemented by multiplying the expected label for each sample by the predicted score (element wise), then calculating the mean.
The above function is the elegant way to implement the loss function; an alternative, less-elegant implementation that might be more intuitive is as follows:
In Keras, the mean function can be implemented using the Keras backend API to ensure the mean is calculated across samples in the provided tensors; for example:
from keras import backend
# implementation of wasserstein loss
def wasserstein_loss(y_true, y_pred):
return backend.mean(y_true * y_pred)
Now that we know how to implement the Wasserstein loss function in Keras, let’s clarify one common point of misunderstanding.
Common Point of Confusion With Expected Labels
Recall we are using the expected labels of -1 for fake images and +1 for real images.
A common point of confusion is that a perfect critic model will output -1 for every fake image and +1 for every real image.
This is incorrect.
Again, recall we are using stochastic gradient descent to find the set of weights in the critic (and generator) models that minimize the loss function.
We have established that we want the critic model to output larger scores on average for fake images and smaller scores on average for real images. We then designed a loss function to encourage this outcome.
This is the key point about loss functions used to train neural network models. They encourage a desired model behavior, and they do not have to achieve this by providing the expected outcomes. In this case, we defined our Wasserstein loss function to interpret the average score predicted by the critic model and used labels for the real and fake cases to help with this interpretation.
So what is a good loss for real and fake images under Wasserstein loss?
Wasserstein is not an absolute and comparable loss for comparing across GAN models. Instead, it is relative and depends on your model configuration and dataset. What is important is that it is consistent for a given critic model and convergence of the generator (better loss) does correlate with better generated image quality.
It could be negative scores for real images and positive scores for fake images, but this is not required. All scores could be positive or all scores could be negative.
The loss function only encourages a separation between scores for fake and real images as larger and smaller, not necessarily positive and negative.
Further Reading
This section provides more resources on the topic if you are looking to go deeper.
The Generative Adversarial Network, or GAN for short, is an architecture for training a generative model.
The architecture is comprised of two models. The generator that we are interested in, and a discriminator model that is used to assist in the training of the generator. Initially, both of the generator and discriminator models were implemented as Multilayer Perceptrons (MLP), although more recently, the models are implemented as deep convolutional neural networks.
It can be challenging to understand how a GAN is trained and exactly how to understand and implement the loss function for the generator and discriminator models.
In this tutorial, you will discover how to implement the generative adversarial network training algorithm and loss functions.
After completing this tutorial, you will know:
How to implement the training algorithm for a generative adversarial network.
How the loss function for the discriminator and generator work.
How to implement weight updates for the discriminator and generator models in practice.
Discover how to develop DCGANs, conditional GANs, Pix2Pix, CycleGANs, and more with Keras in my new GANs book, with 29 step-by-step tutorials and full source code.
Let’s get started.
How to Code the Generative Adversarial Network Training Algorithm and Loss Functions Photo by Hilary Charlotte, some rights reserved.
Tutorial Overview
This tutorial is divided into three parts; they are:
How to Implement the GAN Training Algorithm
Understanding the GAN Loss Function
How to Train GAN Models in Practice
Note: The code examples in this tutorial are snippets only, not standalone runnable examples. They are designed to help you develop an intuition for the algorithm and they can be used as the starting point for implementing the GAN training algorithm on your own project.
How to Implement the GAN Training Algorithm
The GAN training algorithm involves training both the discriminator and the generator model in parallel.
The algorithm is summarized in the figure below, taken from the original 2014 paper by Goodfellow, et al. titled “Generative Adversarial Networks.”
Summary of the Generative Adversarial Network Training Algorithm.Taken from: Generative Adversarial Networks.
Let’s take some time to unpack and get comfortable with this algorithm.
The outer loop of the algorithm involves iterating over steps to train the models in the architecture. One cycle through this loop is not an epoch: it is a single update comprised of specific batch updates to the discriminator and generator models.
An epoch is defined as one cycle through a training dataset, where the samples in a training dataset are used to update the model weights in mini-batches. For example, a training dataset of 100 samples used to train a model with a mini-batch size of 10 samples would involve 10 mini batch updates per epoch. The model would be fit for a given number of epochs, such as 500.
This is often hidden from you via the automated training of a model via a call to the fit() function and specifying the number of epochs and the size of each mini-batch.
In the case of the GAN, the number of training iterations must be defined based on the size of your training dataset and batch size. In the case of a dataset with 100 samples, a batch size of 10, and 500 training epochs, we would first calculate the number of batches per epoch and use this to calculate the total number of training iterations using the number of epochs.
In the case of a dataset of 500 samples, a batch size of 10, and 500 epochs, the GAN would be trained for floor(100 / 10) * 500 or 5,000 total iterations.
Next, we can see that one iteration of training results in possibly multiple updates to the discriminator and one update to the generator, where the number of updates to the discriminator is a hyperparameter that is set to 1.
The training process consists of simultaneous SGD. On each step, two minibatches are sampled: a minibatch of x values from the dataset and a minibatch of z values drawn from the model’s prior over latent variables. Then two gradient steps are made simultaneously …
We can therefore summarize the training algorithm with Python pseudocode as follows:
# gan training algorithm
def train_gan(dataset, n_epochs, n_batch):
# calculate the number of batches per epoch
batches_per_epoch = int(len(dataset) / n_batch)
# calculate the number of training iterations
n_steps = batches_per_epoch * n_epochs
# gan training algorithm
for i in range(n_steps):
# update the discriminator model
# ...
# update the generator model
# ...
An alternative approach may involve enumerating the number of training epochs and splitting the training dataset into batches for each epoch.
Updating the discriminator model involves a few steps.
First, a batch of random points from the latent space must be selected for use as input to the generator model to provide the basis for the generated or ‘fake‘ samples. Then a batch of samples from the training dataset must be selected for input to the discriminator as the ‘real‘ samples.
Next, the discriminator model must make predictions for the real and fake samples and the weights of the discriminator must be updated proportional to how correct or incorrect those predictions were. The predictions are probabilities and we will get into the nature of the predictions and the loss function that is minimized in the next section. For now, we can outline what these steps actually look like in practice.
We need a generator and a discriminator model, e.g. such as a Keras model. These can be provided as arguments to the training function.
Next, we must generate points from the latent space and then use the generator model in its current form to generate some fake images. For example:
...
# generate points in the latent space
z = randn(latent_dim * n_batch)
# reshape into a batch of inputs for the network
z = x_input.reshape(n_batch, latent_dim)
# generate fake images
fake = generator.predict(x_input)
Note that the size of the latent dimension is also provided as a hyperparameter to the training algorithm.
We then must select a batch of real samples, and this too will be wrapped into a function.
...
# select a batch of random real images
ix = randint(0, len(dataset), n_batch)
# retrieve real images
real = dataset[ix]
The discriminator model must then make a prediction for each of the generated and real images and the weights must be updated.
# gan training algorithm
def train_gan(generator, discriminator, dataset, latent_dim, n_epochs, n_batch):
# calculate the number of batches per epoch
batches_per_epoch = int(len(dataset) / n_batch)
# calculate the number of training iterations
n_steps = batches_per_epoch * n_epochs
# gan training algorithm
for i in range(n_steps):
# generate points in the latent space
z = randn(latent_dim * n_batch)
# reshape into a batch of inputs for the network
z = z.reshape(n_batch, latent_dim)
# generate fake images
fake = generator.predict(z)
# select a batch of random real images
ix = randint(0, len(dataset), n_batch)
# retrieve real images
real = dataset[ix]
# update weights of the discriminator model
# ...
# update the generator model
# ...
Next, the generator model must be updated.
Again, a batch of random points from the latent space must be selected and passed to the generator to generate fake images, and then passed to the discriminator to classify.
...
# generate points in the latent space
z = randn(latent_dim * n_batch)
# reshape into a batch of inputs for the network
z = z.reshape(n_batch, latent_dim)
# generate fake images
fake = generator.predict(z)
# classify as real or fake
result = discriminator.predict(fake)
The response can then be used to update the weights of the generator model.
# gan training algorithm
def train_gan(generator, discriminator, dataset, latent_dim, n_epochs, n_batch):
# calculate the number of batches per epoch
batches_per_epoch = int(len(dataset) / n_batch)
# calculate the number of training iterations
n_steps = batches_per_epoch * n_epochs
# gan training algorithm
for i in range(n_steps):
# generate points in the latent space
z = randn(latent_dim * n_batch)
# reshape into a batch of inputs for the network
z = z.reshape(n_batch, latent_dim)
# generate fake images
fake = generator.predict(z)
# select a batch of random real images
ix = randint(0, len(dataset), n_batch)
# retrieve real images
real = dataset[ix]
# update weights of the discriminator model
# ...
# generate points in the latent space
z = randn(latent_dim * n_batch)
# reshape into a batch of inputs for the network
z = z.reshape(n_batch, latent_dim)
# generate fake images
fake = generator.predict(z)
# classify as real or fake
result = discriminator.predict(fake)
# update weights of the generator model
# ...
It is interesting that the discriminator is updated with two batches of samples each training iteration whereas the generator is only updated with a single batch of samples per training iteration.
Now that we have defined the training algorithm for the GAN, we need to understand how the model weights are updated. This requires understanding the loss function used to train the GAN.
Understanding the GAN Loss Function
The discriminator is trained to correctly classify real and fake images.
This is achieved by maximizing the log of predicted probability of real images and the log of the inverted probability of fake images, averaged over each mini-batch of examples.
Recall that we add log probabilities, which is the same as multiplying probabilities, although without vanishing into small numbers. Therefore, we can understand this loss function as seeking probabilities close to 1.0 for real images and probabilities close to 0.0 for fake images, inverted to become larger numbers. The addition of these values means that lower average values of this loss function result in better performance of the discriminator.
Inverting this to a minimization problem, it should not be surprising if you are familiar with developing neural networks for binary classification, as this is exactly the approach used.
This is just the standard cross-entropy cost that is minimized when training a standard binary classifier with a sigmoid output. The only difference is that the classifier is trained on two minibatches of data; one coming from the dataset, where the label is 1 for all examples, and one coming from the generator, where the label is 0 for all examples.
The GAN algorithm defines the generator model’s loss as minimizing the log of the inverted probability of the discriminator’s prediction of fake images, averaged over a mini-batch.
This is straightforward, but according to the authors, it is not effective in practice when the generator is poor and the discriminator is good at rejecting fake images with high confidence. The loss function no longer gives good gradient information that the generator can use to adjust weights and instead saturates.
In this case, log(1 − D(G(z))) saturates. Rather than training G to minimize log(1 − D(G(z))) we can train G to maximize log D(G(z)). This objective function results in the same fixed point of the dynamics of G and D but provides much stronger gradients early in learning.
Instead, the authors recommend maximizing the log of the discriminator’s predicted probability for fake images.
The change is subtle.
In the first case, the generator is trained to minimize the probability of the discriminator being correct. With this change to the loss function, the generator is trained to maximize the probability of the discriminator being incorrect.
In the minimax game, the generator minimizes the log-probability of the discriminator being correct. In this game, the generator maximizes the log probability of the discriminator being mistaken.
The sign of this loss function can then be inverted to give a familiar minimizing loss function for training the generator. As such, this is sometimes referred to as the -log D trick for training GANs.
Our baseline comparison is DCGAN, a GAN with a convolutional architecture trained with the standard GAN procedure using the −log D trick.
Now that we understand the GAN loss function, we can look at how the discriminator and the generator model can be updated in practice.
How to Train GAN Models in Practice
The practical implementation of the GAN loss function and model updates is straightforward.
We will look at examples using the Keras library.
We can implement the discriminator directly by configuring the discriminator model to predict a probability of 1 for real images and 0 for fake images and minimizing the cross-entropy loss, specifically the binary cross-entropy loss.
For example, a snippet of our model definition with Keras for the discriminator might look as follows for the output layer and the compilation of the model with the appropriate loss function.
The defined model can be trained for each batch of real and fake samples providing arrays of 1s and 0s for the expected outcome.
The ones() and zeros() NumPy functions can be used to create these target labels, and the Keras function train_on_batch() can be used to update the model for each batch of samples.
...
X_fake = ...
X_real = ...
# define target labels for fake images
y_fake = zeros((n_batch, 1))
# update the discriminator for fake images
discriminator.train_on_batch(X_fake, y_fake)
# define target labels for real images
y_real = ones((n_batch, 1))
# update the discriminator for real images
discriminator.train_on_batch(X_real, y_real)
The discriminator model will be trained to predict the probability of “realness” of a given input image that can be interpreted as a class label of class=0 for fake and class=1 for real.
The generator is trained to maximize the discriminator predicting a high probability of “realness” for generated images.
This is achieved by updating the generator via the discriminator with the class label of 1 for the generated images. The discriminator is not updated in this operation but provides the gradient information required to update the weights of the generator model.
For example, if the discriminator predicts a low average probability for the batch of generated images, then this will result in a large error signal propagated backward into the generator given the “expected probability” for the samples was 1.0 for real. This large error signal, in turn, results in relatively large changes to the generator to hopefully improve its ability at generating fake samples on the next batch.
This can be implemented in Keras by creating a composite model that combines the generator and discriminator models, allowing the output images from the generator to flow into discriminator directly, and in turn, allow the error signals from the predicted probabilities of the discriminator to flow back through the weights of the generator model.
For example:
# define a composite gan model for the generator and discriminator
def define_gan(generator, discriminator):
# make weights in the discriminator not trainable
discriminator.trainable = False
# connect them
model = Sequential()
# add generator
model.add(generator)
# add the discriminator
model.add(discriminator)
# compile model
model.compile(loss='binary_crossentropy', optimizer='adam')
return model
The composite model can then be updated using fake images and real class labels.
...
# generate points in the latent space
z = randn(latent_dim * n_batch)
# reshape into a batch of inputs for the network
z = z.reshape(n_batch, latent_dim)
# define target labels for real images
y_real = ones((n_batch, 1))
# update generator model
gan_model.train_on_batch(z, y_real)
That completes out tour of the GAN training algorithm, loss function and weight update details for the discriminator and generator models.
Further Reading
This section provides more resources on the topic if you are looking to go deeper.
The study and application of GANs are only a few years old, yet the results achieved have been nothing short of remarkable. Because the field is so young, it can be challenging to know how to get started, what to focus on, and how to best use the available techniques.
In this crash course, you will discover how you can get started and confidently develop deep learning Generative Adversarial Networks using Python in seven days.
Note: This is a big and important post. You might want to bookmark it.
Discover how to develop DCGANs, conditional GANs, Pix2Pix, CycleGANs, and more with Keras in my new GANs book, with 29 step-by-step tutorials and full source code.
Let’s get started.
How to Get Started With Generative Adversarial Networks (7-Day Mini-Course) Photo by Matthias Ripp, some rights reserved.
Who Is This Crash-Course For?
Before we get started, let’s make sure you are in the right place.
The list below provides some general guidelines as to who this course was designed for.
Don’t panic if you don’t match these points exactly; you might just need to brush up in one area or another to keep up.
You need to know:
Your way around basic Python, NumPy, and Keras for deep learning.
You do NOT need to be:
A math wiz!
A deep learning expert!
A computer vision researcher!
This crash course will take you from a developer that knows a little machine learning to a developer who can bring GANs to your own computer vision project.
Note: This crash course assumes you have a working Python 2 or 3 SciPy environment with at least NumPy, Pandas, scikit-learn, and Keras 2 installed. If you need help with your environment, you can follow the step-by-step tutorial here:
This crash course is broken down into seven lessons.
You could complete one lesson per day (recommended) or complete all of the lessons in one day (hardcore). It really depends on the time you have available and your level of enthusiasm.
Below are the seven lessons that will get you started and productive with Generative Adversarial Networks in Python:
Lesson 01: What Are Generative Adversarial Networks?
Lesson 02: GAN Tips, Tricks and Hacks
Lesson 03: Discriminator and Generator Models
Lesson 04: GAN Loss Functions
Lesson 05: GAN Training Algorithm
Lesson 06: GANs for Image Translation
Lesson 07: Advanced GANs
Each lesson could take you anywhere from 60 seconds up to 30 minutes. Take your time and complete the lessons at your own pace. Ask questions and even post results in the comments below.
The lessons might expect you to go off and find out how to do things. I will give you hints, but part of the point of each lesson is to force you to learn where to go to look for help on and about deep learning and GANs (hint: I have all of the answers on this blog; just use the search box).
Post your results in the comments; I’ll cheer you on!
Lesson 01: What Are Generative Adversarial Networks?
In this lesson, you will discover what GANs are and the basic model architecture.
Generative Adversarial Networks, or GANs for short, are an approach to generative modeling using deep learning methods, such as convolutional neural networks.
GANs are a clever way of training a generative model by framing the problem as a supervised learning problem with two sub-models: the generator model that we train to generate new examples, and the discriminator model that tries to classify examples as either real (from the domain) or fake (generated).
Generator. Model that is used to generate new plausible examples from the problem domain.
Discriminator. Model that is used to classify examples as real (from the domain) or fake (generated).
The two models are trained together in a zero-sum game, adversarial, until the discriminator model is fooled about half the time, meaning the generator model is generating plausible examples.
The Generator
The generator model takes a fixed-length random vector as input and generates an image in the domain.
The vector is drawn randomly from a Gaussian distribution (called the latent space), and the vector is used to seed the generative process.
After training, the generator model is kept and used to generate new samples.
The Discriminator
The discriminator model takes an example from the domain as input (real or generated) and predicts a binary class label of real or fake (generated).
The real example comes from the training dataset. The generated examples are output by the generator model.
The discriminator is a normal (and well understood) classification model.
After the training process, the discriminator model is discarded as we are interested in the generator.
GAN Training
The two models, the generator and discriminator, are trained together.
A single training cycle involves first selecting a batch of real images from the problem domain. A batch of latent points is generated and fed to the generator model to synthesize a batch of images.
The discriminator is then updated using the batch of real and generated images, minimizing binary cross-entropy loss used in any binary classification problem.
The generator is then updated via the discriminator model. This means that generated images are presented to the discriminator as though they are real (not generated) and the error is propagated back through the generator model. This has the effect of updating the generator model toward generating images that are more likely to fool the discriminator.
This process is then repeated for a given number of training iterations.
Your Task
Your task in this lesson is to list three possible applications for Generative Adversarial Networks. You may get ideas from looking at recently published research papers.
Post your findings in the comments below. I would love to see what you discover.
In the next lesson, you will discover tips and tricks for the successful training of GAN models.
Lesson 02: GAN Tips, Tricks, and Hacks
In this lesson, you will discover the tips, tricks, and hacks that you need to know to successfully train GAN models.
Generative Adversarial Networks are challenging to train.
This is because the architecture involves both a generator and a discriminator model that compete in a zero-sum game. Improvements to one model come at the cost of a degrading of performance in the other model. The result is a very unstable training process that can often lead to failure, e.g. a generator that generates the same image all the time or generates nonsense.
As such, there are a number of heuristics or best practices (called “GAN hacks“) that can be used when configuring and training your GAN models.
Perhaps one of the most important steps forward in the design and training of stable GAN models is the approach that became known as the Deep Convolutional GAN, or DCGAN.
This architecture involves seven best practices to consider when implementing your GAN model:
Downsample Using Strided Convolutions (e.g. don’t use pooling layers).
Upsample Using Strided Convolutions (e.g. use the transpose convolutional layer).
Scale Images to the Range [-1,1] (e.g. use tanh in the output of the generator).
These heuristics have been hard won by practitioners testing and evaluating hundreds or thousands of combinations of configuration operations on a range of problems.
Your Task
Your task in this lesson is to list three additional GAN tips or hacks that can be used during training.
Post your findings in the comments below. I would love to see what you discover.
In the next lesson, you will discover how to implement simple discriminator and generator models.
Lesson 03: Discriminator and Generator Models
In this lesson, you will discover how to implement a simple discriminator and generator model using the Keras deep learning library.
We will assume the images in our domain are 28×28 pixels in size and color, meaning they have three color channels.
Discriminator Model
The discriminator model accepts an image with the with size 28x28x3 pixels and must classify it as real (1) or fake (0) via the sigmoid activation function.
Our model has two convolutional layers with 64 filters each and uses same padding. Each convolutional layer will downsample the input using a 2×2 stride, which is a best practice for GANs, instead of using a pooling layer.
Also following best practice, the convolutional layers are followed by a LeakyReLU activation with a slope of 0.2 and a batch normalization layer.
...
# define the discriminator model
model = Sequential()
# downsample to 14x14
model.add(Conv2D(64, (3,3), strides=(2, 2), padding='same', input_shape=(28,28,3)))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization())
# downsample to 7x7
model.add(Conv2D(64, (3,3), strides=(2, 2), padding='same'))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization())
# classify
model.add(Flatten())
model.add(Dense(1, activation='sigmoid'))
Generator Model
The generator model takes a 100-dimensional point in the latent space as input and generates a 28x28x3.
The point in latent space is a vector of Gaussian random numbers. This is projected using a Dense layer to the basis of 64 tiny 7×7 images.
The small images are then upsampled twice using two transpose convolutional layers with a 2×2 stride and followed by a LeakyReLU and BatchNormalization layers, which are a best practice for GANs.
The output is a three channel image with pixel values in the range [-1,1] via the tanh activation function.
...
# define the generator model
model = Sequential()
# foundation for 7x7 image
n_nodes = 64 * 7 * 7
model.add(Dense(n_nodes, input_dim=100))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization())
model.add(Reshape((7, 7, 64)))
# upsample to 14x14
model.add(Conv2DTranspose(64, (3,3), strides=(2,2), padding='same'))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization())
# upsample to 28x28
model.add(Conv2DTranspose(64, (3,3), strides=(2,2), padding='same'))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization())
model.add(Conv2D(3, (3,3), activation='tanh', padding='same'))
Your Task
Your task in this lesson is to implement both the discriminator models and summarize their structure.
For bonus points, update the models to support an image with the size 64×64 pixels.
Post your findings in the comments below. I would love to see what you discover.
In the next lesson, you will discover how to configure the loss functions for training the GAN models.
Lesson 04: GAN Loss Functions
In this lesson, you will discover how to configure the loss functions used for training the GAN model weights.
Discriminator Loss
The discriminator model is optimized to maximize the probability of correctly identifying real images from the dataset and fake or synthetic images output by the generator.
This can be implemented as a binary classification problem where the discriminator outputs a probability for a given image between 0 and 1 for fake and real respectively.
The model can then be trained on batches of real and fake images directly and minimize the negative log likelihood, most commonly implemented as the binary cross-entropy loss function.
As is the best practice, the model can be optimized using the Adam version of stochastic gradient descent with a small learning rate and conservative momentum.
...
# compile model
model.compile(loss='binary_crossentropy', optimizer=Adam(lr=0.0002, beta_1=0.5))
Generator Loss
The generator is not updated directly and there is no loss for this model.
Instead, the discriminator is used to provide a learned or indirect loss function for the generator.
This is achieved by creating a composite model where the generator outputs an image that feeds directly into the discriminator for classification.
The composite model can then be trained by providing random points in latent space as input and indicating to the discriminator that the generated images are, in fact, real. This has the effect of updating the weights of the generator to output images that are more likely to be classified as real by the discriminator.
Importantly, the discriminator weights are not updated during this process and are marked as not trainable.
The composite model uses the same categorical cross entropy loss as the standalone discriminator model and the same Adam version of stochastic gradient descent to perform the optimization.
# create the composite model for training the generator
generator = ...
discriminator = ...
...
# make weights in the discriminator not trainable
d_model.trainable = False
# connect them
model = Sequential()
# add generator
model.add(generator)
# add the discriminator
model.add(discriminator)
# compile model
model.compile(loss='binary_crossentropy', optimizer=Adam(lr=0.0002, beta_1=0.5))
Your Task
Your task in this lesson is to research and summarize three additional types of loss function that can be used to train the GAN models.
Post your findings in the comments below. I would love to see what you discover.
In the next lesson, you will discover the training algorithm used to update the model weights for the GAN.
Lesson 05: GAN Training Algorithm
In this lesson, you will discover the GAN training algorithm.
Defining the GAN models is the hard part. The GAN training algorithm is relatively straightforward.
One cycle of the algorithm involves first selecting a batch of real images and using the current generator model to generate a batch of fake images. You can develop small functions to perform these two operations.
These real and fake images are then used to update the discriminator model directly via a call to the train_on_batch() Keras function.
Next, points in latent space can be generated as input for the composite generator-discriminator model and labels of “real” (class=1) can be provided to update the weights of the generator model.
The training process is then repeated thousands of times.
The generator model can be saved periodically and later loaded to check the quality of the generated images.
The example below demonstrates the GAN training algorithm.
...
# gan training algorithm
discriminator = ...
generator = ...
gan_model = ...
n_batch = 16
latent_dim = 100
for i in range(10000)
# get randomly selected 'real' samples
X_real, y_real = select_real_samples(dataset, n_batch)
# generate 'fake' examples
X_fake, y_fake = generate_fake_samples(generator, latent_dim, n_batch)
# create training set for the discriminator
X, y = vstack((X_real, X_fake)), vstack((y_real, y_fake))
# update discriminator model weights
d_loss = discriminator.train_on_batch(X, y)
# prepare points in latent space as input for the generator
X_gan = generate_latent_points(latent_dim, n_batch)
# create inverted labels for the fake samples
y_gan = ones((n_batch, 1))
# update the generator via the discriminator's error
g_loss = gan_model.train_on_batch(X_gan, y_gan)
Your Task
Your task in this lesson is to tie together the elements from this and the prior lessons and train a GAN on a small image dataset such as MNIST or CIFAR-10.
Post your findings in the comments below. I would love to see what you discover.
In the next lesson, you will discover the application of GANs for image translation.
Lesson 06: GANs for Image Translation
In this lesson, you will discover GANs used for image translation.
Image-to-image translation is the controlled conversion of a given source image to a target image. An example might be the conversion of black and white photographs to color photographs.
Image-to-image translation is a challenging problem and often requires specialized models and loss functions for a given translation task or dataset.
GANs can be trained to perform image-to-image translation and two examples include the Pix2Pix and the CycleGAN.
Pix2Pix
The Pix2Pix GAN is a general approach for image-to-image translation.
The model is trained on a dataset of paired examples, where each pair involves an example of the image before and after the desired translation.
The Pix2Pix model is based on the conditional generative adversarial network, where a target image is generated, conditional on a given input image.
The discriminator model is given an input image and a real or generated paired image and must determine whether the paired image is real or fake.
The generator model is provided with a given image as input and generates a translated version of the image. The generator model is trained to both fool the discriminator model and to minimize the loss between the generated image and the expected target image.
More sophisticated deep convolutional neural network models are used in the Pix2Pix. Specifically, a U-Net model is used for the generator model and a PatchGAN is used for the discriminator model.
The loss for the generator is comprised of a composite of both the adversarial loss of a normal GAN model and the L1 loss between the generated and expected translated image.
CycleGAN
A limitation of the Pix2Pix model is that it requires a dataset of paired examples before and after the desired translation.
There are many image-to-image translation tasks where we may not have examples of the translation, such as translating photos of zebra to horses. There are other image translation tasks where such paired examples do not exist, such as translating art of landscapes to photographs.
The CycleGAN is a technique that involves the automatic training of image-to-image translation models without paired examples. The models are trained in an unsupervised manner using a collection of images from the source and target domain that do not need to be related in any way.
The CycleGAN is an extension of the GAN architecture that involves the simultaneous training of two generator models and two discriminator models.
One generator takes images from the first domain as input and outputs images for the second domain, and the other generator takes images from the second domain as input and generates images from the first domain. Discriminator models are then used to determine how plausible the generated images are and update the generator models accordingly.
The CycleGAN uses an additional extension to the architecture called cycle consistency. This is the idea that an image output by the first generator could be used as input to the second generator and the output of the second generator should match the original image. The reverse is also true: that an output from the second generator can be fed as input to the first generator and the result should match the input to the second generator.
Your Task
Your task in this lesson is to list five examples of image-to-image translation you might like to explore with GAN models.
Post your findings in the comments below. I would love to see what you discover.
In the next lesson, you will discover some of the recent advancements in GAN models.
Lesson 07: Advanced GANs
In this lesson, you will discover some of the more advanced GAN that are demonstrating remarkable results.
BigGAN
The BigGAN is an approach to pull together a suite of recent best practices in training GANs and scaling up the batch size and number of model parameters.
As its name suggests, the BigGAN is focused on scaling up the GAN models. This includes GAN models with:
More model parameters (e.g. many more feature maps).
Larger Batch Sizes (e.g. hundreds or thousands of images).
The resulting BigGAN generator model is capable of generating high-quality 256×256 and 512×512 images across a wide range of image classes.
Progressive Growing GAN
Progressive Growing GAN is an extension to the GAN training process that allows for the stable training of generator models that can output large high-quality images.
It involves starting with a very small image and incrementally adding blocks of layers that increase the output size of the generator model and the input size of the discriminator model until the desired image size is..
Generative Adversarial Networks, or GANs, are deep learning architecture generative models that have seen wide success.
There are thousands of papers on GANs and many hundreds of named-GANs, that is, models with a defined name that often includes “GAN“, such as DCGAN, as opposed to a minor extension to the method. Given the vast size of the GAN literature and number of models, it can be, at the very least, confusing and frustrating as to know what GAN models to focus on.
In this post, you will discover the Generative Adversarial Network models that you need to know to establish a useful and productive foundation in the field.
After reading this post, you will know:
The foundation GAN models that provide the basis for the field of study.
The extension GAN models that build upon what works and lead the way for more advanced models.
The advanced GAN models that push the limits of the architecture and achieve impressive results.
Let’s get started.
A Tour of Generative Adversarial Network Models and Extensions Photo by Tomek Niedzwiedz, some rights reserved.
Overview
This tutorial is divided into three parts; they are:
Foundation
Generative Adversarial Network (GAN)
Deep Convolutional Generative Adversarial Network (DCGAN)
Extensions
Conditional Generative Adversarial Network (cGAN)
Information Maximizing Generative Adversarial Network (InfoGAN)
This section summarizes the foundational GAN models from which most, if not all, other GANs build upon.
Generative Adversarial Network (GAN)
The Generative Adversarial Network architecture and first empirical demonstration of the approach was described in the 2014 paper by Ian Goodfellow, et al. titled “Generative Adversarial Networks.”
The paper describes the architecture succinctly involving a generator model that takes as input points from a latent space and generates an image, and a discriminator model that classifies images as either real (from the dataset) or fake (output by the generator).
We propose a new framework for estimating generative models via an adversarial process, in which we simultaneously train two models: a generative model G that captures the data distribution, and a discriminative model D that estimates the probability that a sample came from the training data rather than G. The training procedure for G is to maximize the probability of D making a mistake.
The models are comprised of fully connected layers (MLPs) with ReLU activations in the generator and maxout activations in the discriminator and was applied to standard image datasets such as MNIST and CIFAR-10.
We trained adversarial nets as a range of datasets including MNIST, the Toronto Face Database (TFD), and CIFAR-10. The generator nets used a mixture of rectifier linear activations and sigmoid activations, while the discriminator net used maxout activations. Dropout was applied in training the discriminator net.
Deep Convolutional Generative Adversarial Network (DCGAN)
The deep convolutional generative adversarial network, or DCGAN for short, is an extension of the GAN architecture for using deep convolutional neural networks for both the generator and discriminator models and configurations for the models and training that result in the stable training of a generator model.
We introduce a class of CNNs called deep convolutional generative adversarial networks (DCGANs), that have certain architectural constraints, and demonstrate that they are a strong candidate for unsupervised learning.
The DCGAN is important because it suggested the constraints on the model required to effectively develop high-quality generator models in practice. This architecture, in turn, provided the basis for the rapid development of a large number of GAN extensions and applications.
We propose and evaluate a set of constraints on the architectural topology of Convolutional GANs that make them stable to train in most settings.
This section summarizes named GAN models that provide some of the more common or widely used discrete extensions to the GAN model architecture or training process.
Conditional Generative Adversarial Network (cGAN)
The conditional generative adversarial network, or cGAN for short, is an extension to the GAN architecture that makes use of information in addition to the image as input both to the generator and the discriminator models. For example, if class labels are available, they can be used as input.
Generative adversarial nets can be extended to a conditional model if both the generator and discriminator are conditioned on some extra information y. y could be any kind of auxiliary information, such as class labels or data from other modalities. We can perform the conditioning by feeding y into the both the discriminator and generator as additional input layer.
Example of the Model Architecture for a Conditional Generative Adversarial Network (cGAN). Taken from: Conditional Generative Adversarial Nets.
Information Maximzing Generative Adversarial Network (InfoGAN)
The info generative adversarial network, or InfoGAN for short, is an extension to the GAN that attempts to structure the input or latent space for the generator. Specifically, the goal is to add specific semantic meaning to the variables in the latent space.
… , when generating images from the MNIST dataset, it would be ideal if the model automatically chose to allocate a discrete random variable to represent the numerical identity of the digit (0-9), and chose to have two additional continuous variables that represent the digit’s angle and thickness of the digit’s stroke.
This is achieved by separating points in the latent space into both noise and latent codes. The latent codes are then used to condition or control specific semantic properties in the generated image.
… rather than using a single unstructured noise vector, we propose to decompose the input noise vector into two parts: (i) z, which is treated as source of incompressible noise; (ii) c, which we will call the latent code and will target the salient structured semantic features of the data distribution
Example of Using Latent Codes to vary Features in Generated Handwritten Digits With an InfoGAN. Taken from: InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets.
The auxiliary classifier generative adversarial network, or AC-GAN, is an extension to the GAN that both changes the generator to be class conditional as with the cGAN, and adds an additional or auxiliary model to the discriminator that is trained to reconstruct the class label.
… we introduce a model that combines both strategies for leveraging side information. That is, the model proposed below is class conditional, but with an auxiliary decoder that is tasked with reconstructing class labels.
The stacked generative adversarial network, or StackGAN, is an extension to the GAN to generate images from text using a hierarchical stack of conditional GAN models.
… we propose Stacked Generative Adversarial Networks (StackGAN) to generate 256×256 photo-realistic images conditioned on text descriptions.
The architecture is comprised of a series of text- and image-conditional GAN models. The first level generator (Stage-I GAN) is conditioned on text and generates a low-resolution image. The second level generator (Stage-II GaN) is conditioned both on the text and on the low-resolution image output by the first level and outputs a high-resolution image.
Low-resolution images are first generated by our Stage-I GAN. On the top of our Stage-I GAN, we stack Stage-II GAN to generate realistic high-resolution (e.g., 256×256) images conditioned on Stage-I results and text descriptions. By conditioning on the Stage-I result and the text again, Stage-II GAN learns to capture the text information that is omitted by Stage-I GAN and draws more details for the object
Example of the Architecture for the Stacked Generative Adversarial Network for Text to Image Generation. Taken from: StackGAN: Text to Photo-realistic Image Synthesis with Stacked Generative Adversarial Networks.
Context Encoders
The Context Encoders model is an encoder-decoder model for conditional image generation trained using the adversarial approach devised for GANs. Although it is not referred to in the paper as a GAN model, it has many GAN features.
By analogy with auto-encoders, we propose Context Encoders – a convolutional neural network trained to generate the contents of an arbitrary image region conditioned on its surroundings.
Example of the Context Encoders Encoder-Decoder Model Architecture. Taken from: Context Encoders: Feature Learning by Inpainting
The model is trained with a joint-loss that combines both the adversarial loss of generator and discriminator models and the reconstruction loss that calculates the vector norm distance between the predicted and expected output image.
When training context encoders, we have experimented with both a standard pixel-wise reconstruction loss, as well as a reconstruction plus an adversarial loss. The latter produces much sharper results because it can better handle multiple modes in the output.
The pix2pix model is an extension of the GAN for image-conditional image generation, referred to as the task image-to-image translation. A U-Net model architecture is used in the generator model, and a PatchGAN model architecture is used as the discriminator model.
Our method also differs from the prior works in several architectural choices for the generator and discriminator. Unlike past work, for our generator we use a “U-Net”-based architecture, and for our discriminator we use a convolutional “PatchGAN” classifier, which only penalizes structure at the scale of image patches.
The loss for the generator model is updated to also include the vector distance from the target output image.
The discriminator’s job remains unchanged, but the generator is tasked to not only fool the discriminator but also to be near the ground truth output in an L2 sense. We also explore this option, using L1 distance rather than L2 as L1 encourages less blurring.
This section lists those GAN models that have recently led to surprising or impressive results, building upon prior GAN extensions.
These models mostly focus on developments that allow for the generation of large photorealistic images.
Wasserstein Generative Adversarial Network (WGAN)
The Wasserstein generative adversarial network, or WGAN for short, is an extension to the GAN that changes the training procedure to update the discriminator model, now called a critic, many more times than the generator model for each iteration.
Algorithm for the Wasserstein Generative Adversarial Network (WGAN). Taken from: Wasserstein GAN.
The critic is updated to output a real-value (linear activation) instead of a binary prediction with a sigmoid activation, and the critic and generator models are both trained using “Wasserstein loss,” which is the average of the product of real and predicted values from the critic, designed to provide linear gradients that are useful for updating the model.
The discriminator learns very quickly to distinguish between fake and real, and as expected provides no reliable gradient information. The critic, however, can’t saturate, and converges to a linear function that gives remarkably clean gradients everywhere. The fact that we constrain the weights limits the possible growth of the function to be at most linear in different parts of the space, forcing the optimal critic to have this behaviour.
In addition, the weights of the critic model are clipped to keep them small, e.g. a bounding box of [-0.01. 0.01].
In order to have parameters w lie in a compact space, something simple we can do is clamp the weights to a fixed box (say W = [−0.01, 0.01]l ) after each gradient update.
The cycle-consistent generative adversarial network, or CycleGAN for short, is an extension to the GAN for image-to-image translation without paired image data. That means that examples of the target image are not required as is the case with conditional GANs, such as Pix2Pix.
… for many tasks, paired training data will not be available. We present an approach for learning to translate an image from a source domain X to a target domain Y in the absence of paired examples.
Their approach seeks “cycle consistency” such that image translation from one domain to another is reversible, meaning it forms a consistent cycle of translation.
… we exploit the property that translation should be “cycle consistent”, in the sense that if we translate, e.g., a sentence from English to French, and then translate it back from French to English, we should arrive back at the original sentence
This is achieved by having two generator models: one for translation X to Y and another for reconstructing X given Y. In turn, the architecture has two discriminator models.
… our model includes two mappings G : X -> Y and F : Y -> X. In addition, we introduce two adversarial discriminators DX and DY , where DX aims to distinguish between images {x} and translated images {F(y)}; in the same way, DY aims to discriminate between {y} and {G(x)}.
The progressive growing generative adversarial network, or Progressive GAN for short, is a change to the architecture and training of GAN models that involves progressively increasing the model depth during the training process.
The key idea is to grow both the generator and discriminator progressively: starting from a low resolution, we add new layers that model increasingly fine details as training progresses. This both speeds the training up and greatly stabilizes it, allowing us to produce images of unprecedented quality …
How to Identify Unstable Models When Training Generative Adversarial Networks.
GANs are difficult to train.
The reason they are difficult to train is that both the generator model and the discriminator model are trained simultaneously in a zero sum game. This means that improvements to one model come at the expense of the other model.
The goal of training two models involves finding a point of equilibrium between the two competing concerns.
It also means that every time the parameters of one of the models are updated, the nature of the optimization problem that is being solved is changed. This has the effect of creating a dynamic system. In neural network terms, the technical challenge of training two competing neural networks at the same time is that they can fail to converge.
It is important to develop an intuition for both the normal convergence of a GAN model and unusual convergence of GAN models, sometimes called failure modes.
In this tutorial, we will first develop a stable GAN model for a simple image generation task in order to establish what normal convergence looks like and what to expect more generally.
We will then impair the GAN models in different ways and explore a range of failure modes that you may encounter when training GAN models. These scenarios will help you to develop an intuition for what to look for or expect when a GAN model is failing to train, and ideas for what you could do about it.
After completing this tutorial, you will know:
How to identify a stable GAN training process from the generator and discriminator loss over time.
How to identify a mode collapse by reviewing both learning curves and generated images.
How to identify a convergence failure by reviewing learning curves of generator and discriminator loss over time.
Let’s get started.
A Practical Guide to Generative Adversarial Network Failure Modes Photo by Jason Heavner, some rights reserved.
Tutorial Overview
This tutorial is divided into three parts; they are:
How To Identify a Stable Generative Adversarial Network
How To Identify a Mode Collapse in a Generative Adversarial Network
How To Identify Convergence Failure in a Generative Adversarial Network
How To Train a Stable Generative Adversarial Network
In this section, we will train a stable GAN to generate images of a handwritten digit.
Specifically, we will use the digit ‘8’ from the MNIST handwritten digit dataset.
The results of this model will establish both a stable GAN that can be used for later experimentation and a profile for what generated images and learning curves look like for a stable GAN training process.
The first step is to define the models.
The discriminator model takes as input one 28×28 grayscale image and outputs a binary prediction as to whether the image is real (class=1) or fake (class=0). It is implemented as a modest convolutional neural network using best practices for GAN design such as using the LeakyReLU activation function with a slope of 0.2, batch normalization, using a 2×2 stride to downsample, and the adam version of stochastic gradient descent with a learning rate of 0.0002 and a momentum of 0.5
The define_discriminator() function below implements this, defining and compiling the discriminator model and returning it. The input shape of the image is parameterized as a default function argument to make it clear.
# define the standalone discriminator model
def define_discriminator(in_shape=(28,28,1)):
# weight initialization
init = RandomNormal(stddev=0.02)
# define model
model = Sequential()
# downsample to 14x14
model.add(Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init, input_shape=in_shape))
model.add(BatchNormalization())
model.add(LeakyReLU(alpha=0.2))
# downsample to 7x7
model.add(Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init))
model.add(BatchNormalization())
model.add(LeakyReLU(alpha=0.2))
# classifier
model.add(Flatten())
model.add(Dense(1, activation='sigmoid'))
# compile model
opt = Adam(lr=0.0002, beta_1=0.5)
model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])
return model
The generator model takes as input a point in the latent space and outputs a single 28×28 grayscale image. This is achieved by using a fully connected layer to interpret the point in the latent space and provide sufficient activations that can be reshaped into many copies (in this case, 128) of a low-resolution version of the output image (e.g. 7×7). This is then upsampled two times, doubling the size and quadrupling the area of the activations each time using transpose convolutional layers. The model uses best practices such as the LeakyReLU activation, a kernel size that is a factor of the stride size, and a hyperbolic tangent (tanh) activation function in the output layer.
The define_generator() function below defines the generator model, but intentionally does not compile it as it is not trained directly, then returns the model. The size of the latent space is parameterized as a function argument.
# define the standalone generator model
def define_generator(latent_dim):
# weight initialization
init = RandomNormal(stddev=0.02)
# define model
model = Sequential()
# foundation for 7x7 image
n_nodes = 128 * 7 * 7
model.add(Dense(n_nodes, kernel_initializer=init, input_dim=latent_dim))
model.add(LeakyReLU(alpha=0.2))
model.add(Reshape((7, 7, 128)))
# upsample to 14x14
model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init))
model.add(BatchNormalization())
model.add(LeakyReLU(alpha=0.2))
# upsample to 28x28
model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init))
model.add(BatchNormalization())
model.add(LeakyReLU(alpha=0.2))
# output 28x28x1
model.add(Conv2D(1, (7,7), activation='tanh', padding='same', kernel_initializer=init))
return model
Next, a GAN model can be defined that combines both the generator model and the discriminator model into one larger model. This larger model will be used to train the model weights in the generator, using the output and error calculated by the discriminator model. The discriminator model is trained separately, and as such, the model weights are marked as not trainable in this larger GAN model to ensure that only the weights of the generator model are updated. This change to the trainability of the discriminator weights only has an effect when training the combined GAN model, not when training the discriminator standalone.
This larger GAN model takes as input a point in the latent space, uses the generator model to generate an image, which is fed as input to the discriminator model, then output or classified as real or fake.
The define_gan() function below implements this, taking the already defined generator and discriminator models as input.
# define the combined generator and discriminator model, for updating the generator
def define_gan(generator, discriminator):
# make weights in the discriminator not trainable
discriminator.trainable = False
# connect them
model = Sequential()
# add generator
model.add(generator)
# add the discriminator
model.add(discriminator)
# compile model
opt = Adam(lr=0.0002, beta_1=0.5)
model.compile(loss='binary_crossentropy', optimizer=opt)
return model
Now that we have defined the GAN model, we need to train it. But, before we can train the model, we require input data.
The first step is to load and scale the MNIST dataset. The whole dataset is loaded via a call to the load_data() Keras function, then a subset of the images are selected (about 5,000) that belong to class 8, e.g. are a handwritten depiction of the number eight. Then the pixel values must be scaled to the range [-1,1] to match the output of the generator model.
The load_real_samples() function below implements this, returning the loaded and scaled subset of the MNIST training dataset ready for modeling.
# load mnist images
def load_real_samples():
# load dataset
(trainX, trainy), (_, _) = load_data()
# expand to 3d, e.g. add channels
X = expand_dims(trainX, axis=-1)
# select all of the examples for a given class
selected_ix = trainy == 8
X = X[selected_ix]
# convert from ints to floats
X = X.astype('float32')
# scale from [0,255] to [-1,1]
X = (X - 127.5) / 127.5
return X
We will require one (or a half) batch of real images from the dataset each update to the GAN model. A simple way to achieve this is to select a random sample of images from the dataset each time.
The generate_real_samples() function below implements this, taking the prepared dataset as an argument, selecting and returning a random sample of face images, and their corresponding class label for the discriminator, specifically class=1 indicating that they are real images.
# select real samples
def generate_real_samples(dataset, n_samples):
# choose random instances
ix = randint(0, dataset.shape[0], n_samples)
# select images
X = dataset[ix]
# generate class labels
y = ones((n_samples, 1))
return X, y
The generate_latent_points() function implements this, taking the size of the latent space as an argument and the number of points required, and returning them as a batch of input samples for the generator model.
# generate points in latent space as input for the generator
def generate_latent_points(latent_dim, n_samples):
# generate points in the latent space
x_input = randn(latent_dim * n_samples)
# reshape into a batch of inputs for the network
x_input = x_input.reshape(n_samples, latent_dim)
return x_input
Next, we need to use the points in the latent space as input to the generator in order to generate new images.
The generate_fake_samples() function below implements this, taking the generator model and size of the latent space as arguments, then generating points in the latent space and using them as input to the generator model. The function returns the generated images and their corresponding class label for the discriminator model, specifically class=0 to indicate they are fake or generated.
# use the generator to generate n fake examples, with class labels
def generate_fake_samples(generator, latent_dim, n_samples):
# generate points in latent space
x_input = generate_latent_points(latent_dim, n_samples)
# predict outputs
X = generator.predict(x_input)
# create class labels
y = zeros((n_samples, 1))
return X, y
We need to record the performance of the model. Perhaps the most reliable way to evaluate the performance of a GAN is to use the generator to generate images, and then review and subjectively evaluate them.
The summarize_performance() function below takes the generator model at a given point during training and uses it to generate 100 images in a 10×10 grid that are then plotted and saved to file. The model is also saved to file at this time, in case we would like to use it later to generate more images.
# generate samples and save as a plot and save the model
def summarize_performance(step, g_model, latent_dim, n_samples=100):
# prepare fake examples
X, _ = generate_fake_samples(g_model, latent_dim, n_samples)
# scale from [-1,1] to [0,1]
X = (X + 1) / 2.0
# plot images
for i in range(10 * 10):
# define subplot
pyplot.subplot(10, 10, 1 + i)
# turn off axis
pyplot.axis('off')
# plot raw pixel data
pyplot.imshow(X[i, :, :, 0], cmap='gray_r')
# save plot to file
pyplot.savefig('results_baseline/generated_plot_%03d.png' % (step+1))
pyplot.close()
# save the generator model
g_model.save('results_baseline/model_%03d.h5' % (step+1))
In addition to image quality, it is a good idea to keep track of the loss and accuracy of the model over time.
The loss and classification accuracy for the discriminator for real and fake samples can be tracked for each model update, as can the loss for the generator for each update. These can then be used to create line plots of loss and accuracy at the end of the training run.
The plot_history() function below implements this and saves the results to file.
# create a line plot of loss for the gan and save to file
def plot_history(d1_hist, d2_hist, g_hist, a1_hist, a2_hist):
# plot loss
pyplot.subplot(2, 1, 1)
pyplot.plot(d1_hist, label='d-real')
pyplot.plot(d2_hist, label='d-fake')
pyplot.plot(g_hist, label='gen')
pyplot.legend()
# plot discriminator accuracy
pyplot.subplot(2, 1, 2)
pyplot.plot(a1_hist, label='acc-real')
pyplot.plot(a2_hist, label='acc-fake')
pyplot.legend()
# save plot to file
pyplot.savefig('results_baseline/plot_line_plot_loss.png')
pyplot.close()
We are now ready to fit the GAN model.
The model is fit for 10 training epochs, which is arbitrary, as the model begins generating plausible number-8 digits after perhaps the first few epochs. A batch size of 128 samples is used, and each training epoch involves 5,851/128 or about 45 batches of real and fake samples and updates to the model. The model is therefore trained for 10 epochs of 45 batches, or 450 iterations.
First, the discriminator model is updated for a half batch of real samples, then a half batch of fake samples, together forming one batch of weight updates. The generator is then updated via the composite GAN model. Importantly, the class label is set to 1, or real, for the fake samples. This has the effect of updating the generator toward getting better at generating real samples on the next batch.
The train() function below implements this, taking the defined models, dataset, and size of the latent dimension as arguments and parameterizing the number of epochs and batch size with default arguments. The generator model is saved at the end of training.
The performance of the discriminator and generator models is reported each iteration. Sample images are generated and saved every epoch, and line plots of model performance are created and saved at the end of the run.
# train the generator and discriminator
def train(g_model, d_model, gan_model, dataset, latent_dim, n_epochs=10, n_batch=128):
# calculate the number of batches per epoch
bat_per_epo = int(dataset.shape[0] / n_batch)
# calculate the total iterations based on batch and epoch
n_steps = bat_per_epo * n_epochs
# calculate the number of samples in half a batch
half_batch = int(n_batch / 2)
# prepare lists for storing stats each iteration
d1_hist, d2_hist, g_hist, a1_hist, a2_hist = list(), list(), list(), list(), list()
# manually enumerate epochs
for i in range(n_steps):
# get randomly selected 'real' samples
X_real, y_real = generate_real_samples(dataset, half_batch)
# update discriminator model weights
d_loss1, d_acc1 = d_model.train_on_batch(X_real, y_real)
# generate 'fake' examples
X_fake, y_fake = generate_fake_samples(g_model, latent_dim, half_batch)
# update discriminator model weights
d_loss2, d_acc2 = d_model.train_on_batch(X_fake, y_fake)
# prepare points in latent space as input for the generator
X_gan = generate_latent_points(latent_dim, n_batch)
# create inverted labels for the fake samples
y_gan = ones((n_batch, 1))
# update the generator via the discriminator's error
g_loss = gan_model.train_on_batch(X_gan, y_gan)
# summarize loss on this batch
print('>%d, d1=%.3f, d2=%.3f g=%.3f, a1=%d, a2=%d' %
(i+1, d_loss1, d_loss2, g_loss, int(100*d_acc1), int(100*d_acc2)))
# record history
d1_hist.append(d_loss1)
d2_hist.append(d_loss2)
g_hist.append(g_loss)
a1_hist.append(d_acc1)
a2_hist.append(d_acc2)
# evaluate the model performance every 'epoch'
if (i+1) % bat_per_epo == 0:
summarize_performance(i, g_model, latent_dim)
plot_history(d1_hist, d2_hist, g_hist, a1_hist, a2_hist)
Now that all of the functions have been defined, we can create the directory where images and models will be stored (in this case ‘results_baseline‘), create the models, load the dataset, and begin the training process.
# make folder for results
makedirs('results_baseline', exist_ok=True)
# size of the latent space
latent_dim = 50
# create the discriminator
discriminator = define_discriminator()
# create the generator
generator = define_generator(latent_dim)
# create the gan
gan_model = define_gan(generator, discriminator)
# load image data
dataset = load_real_samples()
print(dataset.shape)
# train model
train(generator, discriminator, gan_model, dataset, latent_dim)
Tying all of this together, the complete example is listed below.
# example of training a stable gan for generating a handwritten digit
from os import makedirs
from numpy import expand_dims
from numpy import zeros
from numpy import ones
from numpy.random import randn
from numpy.random import randint
from keras.datasets.mnist import load_data
from keras.optimizers import Adam
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Reshape
from keras.layers import Flatten
from keras.layers import Conv2D
from keras.layers import Conv2DTranspose
from keras.layers import LeakyReLU
from keras.layers import BatchNormalization
from keras.initializers import RandomNormal
from matplotlib import pyplot
# define the standalone discriminator model
def define_discriminator(in_shape=(28,28,1)):
# weight initialization
init = RandomNormal(stddev=0.02)
# define model
model = Sequential()
# downsample to 14x14
model.add(Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init, input_shape=in_shape))
model.add(BatchNormalization())
model.add(LeakyReLU(alpha=0.2))
# downsample to 7x7
model.add(Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init))
model.add(BatchNormalization())
model.add(LeakyReLU(alpha=0.2))
# classifier
model.add(Flatten())
model.add(Dense(1, activation='sigmoid'))
# compile model
opt = Adam(lr=0.0002, beta_1=0.5)
model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])
return model
# define the standalone generator model
def define_generator(latent_dim):
# weight initialization
init = RandomNormal(stddev=0.02)
# define model
model = Sequential()
# foundation for 7x7 image
n_nodes = 128 * 7 * 7
model.add(Dense(n_nodes, kernel_initializer=init, input_dim=latent_dim))
model.add(LeakyReLU(alpha=0.2))
model.add(Reshape((7, 7, 128)))
# upsample to 14x14
model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init))
model.add(BatchNormalization())
model.add(LeakyReLU(alpha=0.2))
# upsample to 28x28
model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init))
model.add(BatchNormalization())
model.add(LeakyReLU(alpha=0.2))
# output 28x28x1
model.add(Conv2D(1, (7,7), activation='tanh', padding='same', kernel_initializer=init))
return model
# define the combined generator and discriminator model, for updating the generator
def define_gan(generator, discriminator):
# make weights in the discriminator not trainable
discriminator.trainable = False
# connect them
model = Sequential()
# add generator
model.add(generator)
# add the discriminator
model.add(discriminator)
# compile model
opt = Adam(lr=0.0002, beta_1=0.5)
model.compile(loss='binary_crossentropy', optimizer=opt)
return model
# load mnist images
def load_real_samples():
# load dataset
(trainX, trainy), (_, _) = load_data()
# expand to 3d, e.g. add channels
X = expand_dims(trainX, axis=-1)
# select all of the examples for a given class
selected_ix = trainy == 8
X = X[selected_ix]
# convert from ints to floats
X = X.astype('float32')
# scale from [0,255] to [-1,1]
X = (X - 127.5) / 127.5
return X
# select real samples
def generate_real_samples(dataset, n_samples):
# choose random instances
ix = randint(0, dataset.shape[0], n_samples)
# select images
X = dataset[ix]
# generate class labels
y = ones((n_samples, 1))
return X, y
# generate points in latent space as input for the generator
def generate_latent_points(latent_dim, n_samples):
# generate points in the latent space
x_input = randn(latent_dim * n_samples)
# reshape into a batch of inputs for the network
x_input = x_input.reshape(n_samples, latent_dim)
return x_input
# use the generator to generate n fake examples, with class labels
def generate_fake_samples(generator, latent_dim, n_samples):
# generate points in latent space
x_input = generate_latent_points(latent_dim, n_samples)
# predict outputs
X = generator.predict(x_input)
# create class labels
y = zeros((n_samples, 1))
return X, y
# generate samples and save as a plot and save the model
def summarize_performance(step, g_model, latent_dim, n_samples=100):
# prepare fake examples
X, _ = generate_fake_samples(g_model, latent_dim, n_samples)
# scale from [-1,1] to [0,1]
X = (X + 1) / 2.0
# plot images
for i in range(10 * 10):
# define subplot
pyplot.subplot(10, 10, 1 + i)
# turn off axis
pyplot.axis('off')
# plot raw pixel data
pyplot.imshow(X[i, :, :, 0], cmap='gray_r')
# save plot to file
pyplot.savefig('results_baseline/generated_plot_%03d.png' % (step+1))
pyplot.close()
# save the generator model
g_model.save('results_baseline/model_%03d.h5' % (step+1))
# create a line plot of loss for the gan and save to file
def plot_history(d1_hist, d2_hist, g_hist, a1_hist, a2_hist):
# plot loss
pyplot.subplot(2, 1, 1)
pyplot.plot(d1_hist, label='d-real')
pyplot.plot(d2_hist, label='d-fake')
pyplot.plot(g_hist, label='gen')
pyplot.legend()
# plot discriminator accuracy
pyplot.subplot(2, 1, 2)
pyplot.plot(a1_hist, label='acc-real')
pyplot.plot(a2_hist, label='acc-fake')
pyplot.legend()
# save plot to file
pyplot.savefig('results_baseline/plot_line_plot_loss.png')
pyplot.close()
# train the generator and discriminator
def train(g_model, d_model, gan_model, dataset, latent_dim, n_epochs=10, n_batch=128):
# calculate the number of batches per epoch
bat_per_epo = int(dataset.shape[0] / n_batch)
# calculate the total iterations based on batch and epoch
n_steps = bat_per_epo * n_epochs
# calculate the number of samples in half a batch
half_batch = int(n_batch / 2)
# prepare lists for storing stats each iteration
d1_hist, d2_hist, g_hist, a1_hist, a2_hist = list(), list(), list(), list(), list()
# manually enumerate epochs
for i in range(n_steps):
# get randomly selected 'real' samples
X_real, y_real = generate_real_samples(dataset, half_batch)
# update discriminator model weights
d_loss1, d_acc1 = d_model.train_on_batch(X_real, y_real)
# generate 'fake' examples
X_fake, y_fake = generate_fake_samples(g_model, latent_dim, half_batch)
# update discriminator model weights
d_loss2, d_acc2 = d_model.train_on_batch(X_fake, y_fake)
# prepare points in latent space as input for the..
Generative Adversarial Networks, or GANs, are an architecture for training generative models, such as deep convolutional neural networks for generating images.
Although GAN models are capable of generating new random plausible examples for a given dataset, there is no way to control the types of images that are generated other than trying to figure out the complex relationship between the latent space input to the generator and the generated images.
The conditional generative adversarial network, or cGAN for short, is a type of GAN that involves the conditional generation of images by a generator model. Image generation can be conditional on a class label, if available, allowing the targeted generated of images of a given type.
In this tutorial, you will discover how to develop a conditional generative adversarial network for the targeted generation of items of clothing.
After completing this tutorial, you will know:
The limitations of generating random samples with a GAN that can be overcome with a conditional generative adversarial network.
How to develop and evaluate an unconditional generative adversarial network for generating photos of items of clothing.
How to develop and evaluate a conditional generative adversarial network for generating photos of items of clothing.
Let’s get started.
How to Develop a Conditional Generative Adversarial Network From Scratch Photo by Big Cypress National Preserve, some rights reserved
Tutorial Overview
This tutorial is divided into five parts; they are:
Conditional Generative Adversarial Networks
Fashion-MNIST Clothing Photograph Dataset
Unconditional GAN for Fashion-MNIST
Conditional GAN for Fashion-MNIST
Conditional Clothing Generation
Conditional Generative Adversarial Networks
A generative adversarial network, or GAN for short, is an architecture for training deep learning-based generative models.
The architecture is comprised of a generator and a discriminator model. The generator model is responsible for generating new plausible examples that ideally are indistinguishable from real examples in the dataset. The discriminator model is responsible for classifying a given image as either real (drawn from the dataset) or fake (generated).
The models are trained together in a zero-sum or adversarial manner, such that improvements in the discriminator come at the cost of a reduced capability of the generator, and vice versa.
GANs are effective at image synthesis, that is, generating new examples of images for a target dataset. Some datasets have additional information, such as a class label, and it is desirable to make use of this information.
For example, the MNIST handwritten digit dataset has class labels of the corresponding integers, the CIFAR-10 small object photograph dataset has class labels for the corresponding objects in the photographs, and the Fashion-MNIST clothing dataset has class labels for the corresponding items of clothing.
There are two motivations for making use of the class label information in a GAN model.
Improve the GAN.
Targeted Image Generation.
Additional information that is correlated with the input images, such as class labels, can be used to improve the GAN. This improvement may come in the form of more stable training, faster training, and/or generated images that have better quality.
Class labels can also be used for the deliberate or targeted generation of images of a given type.
A limitation of a GAN model is that it may generate a random image from the domain. There is a relationship between points in the latent space to the generated images, but this relationship is complex and hard to map.
Alternately, a GAN can be trained in such a way that both the generator and the discriminator models are conditioned on the class label. This means that when the trained generator model is used as a standalone model to generate images in the domain, images of a given type, or class label, can be generated.
Generative adversarial nets can be extended to a conditional model if both the generator and discriminator are conditioned on some extra information y. […] We can perform the conditioning by feeding y into the both the discriminator and generator as additional input layer.
For example, in the case of MNIST, specific handwritten digits can be generated, such as the number 9; in the case of CIFAR-10, specific object photographs can be generated such as ‘frogs‘; and in the case of the Fashion MNIST dataset, specific items of clothing can be generated, such as ‘dress.’
This type of model is called a Conditional Generative Adversarial Network, CGAN or cGAN for short.
The cGAN was first described by Mehdi Mirza and Simon Osindero in their 2014 paper titled “Conditional Generative Adversarial Nets.” In the paper, the authors motivate the approach based on the desire to direct the image generation process of the generator model.
… by conditioning the model on additional information it is possible to direct the data generation process. Such conditioning could be based on class labels
Their approach is demonstrated in the MNIST handwritten digit dataset where the class labels are one hot encoded and concatenated with the input to both the generator and discriminator models.
The image below provides a summary of the model architecture.
Example of a Conditional Generator and a Conditional Discriminator in a Conditional Generative Adversarial Network. Taken from Conditional Generative Adversarial Nets, 2014.
There have been many advancements in the design and training of GAN models, most notably the deep convolutional GAN, or DCGAN for short, that outlines the model configuration and training procedures that reliably result in the stable training of GAN models for a wide variety of problems. The conditional training of the DCGAN-based models may be referred to as CDCGAN or cDCGAN for short.
There are many ways to encode and incorporate the class labels into the discriminator and generator models. A best practice involves using an embedding layer followed by a fully connected layer with a linear activation that scales the embedding to the size of the image before concatenating it in the model as an additional channel or feature map.
… we also explore a class conditional version of the model, where a vector c encodes the label. This is integrated into Gk & Dk by passing it through a linear layer whose output is reshaped into a single plane feature map which is then concatenated with the 1st layer maps.
This recommendation was later added to the ‘GAN Hacks‘ list of heuristic recommendations when designing and training GAN models, summarized as:
16: Discrete variables in Conditional GANs
– Use an Embedding layer
– Add as additional channels to images
– Keep embedding dimensionality low and upsample to match image channel size
Although GANs can be conditioned on the class label, so-called class-conditional GANs, they can also be conditioned on other inputs, such as an image, in the case where a GAN is used for image-to-image translation tasks.
In this tutorial, we will develop a GAN, specifically a DCGAN, then update it to use class labels in a cGAN, specifically a cDCGAN model architecture.
Fashion-MNIST Clothing Photograph Dataset
The Fashion-MNIST dataset is proposed as a more challenging replacement dataset for the MNIST dataset.
It is a dataset comprised of 60,000 small square 28×28 pixel grayscale images of items of 10 types of clothing, such as shoes, t-shirts, dresses, and more.
Keras provides access to the Fashion-MNIST dataset via the fashion_mnist.load_dataset() function. It returns two tuples, one with the input and output elements for the standard training dataset, and another with the input and output elements for the standard test dataset.
The example below loads the dataset and summarizes the shape of the loaded dataset.
Note: the first time you load the dataset, Keras will automatically download a compressed version of the images and save them under your home directory in ~/.keras/datasets/. The download is fast as the dataset is only about 25 megabytes in its compressed form.
# example of loading the fashion_mnist dataset
from keras.datasets.fashion_mnist import load_data
# load the images into memory
(trainX, trainy), (testX, testy) = load_data()
# summarize the shape of the dataset
print('Train', trainX.shape, trainy.shape)
print('Test', testX.shape, testy.shape)
Running the example loads the dataset and prints the shape of the input and output components of the train and test splits of images.
We can see that there are 60K examples in the training set and 10K in the test set and that each image is a square of 28 by 28 pixels.
Train (60000, 28, 28) (60000,)
Test (10000, 28, 28) (10000,)
The images are grayscale with a black background (0 pixel value) and the items of clothing are in white ( pixel values near 255). This means if the images were plotted, they would be mostly black with a white item of clothing in the middle.
We can plot some of the images from the training dataset using the matplotlib library with the imshow() function and specify the color map via the ‘cmap‘ argument as ‘gray‘ to show the pixel values correctly.
# plot raw pixel data
pyplot.imshow(trainX[i], cmap='gray')
Alternately, the images are easier to review when we reverse the colors and plot the background as white and the clothing in black.
They are easier to view as most of the image is now white with the area of interest in black. This can be achieved using a reverse grayscale color map, as follows:
# plot raw pixel data
pyplot.imshow(trainX[i], cmap='gray_r')
The example below plots the first 100 images from the training dataset in a 10 by 10 square.
# example of loading the fashion_mnist dataset
from keras.datasets.fashion_mnist import load_data
from matplotlib import pyplot
# load the images into memory
(trainX, trainy), (testX, testy) = load_data()
# plot images from the training dataset
for i in range(100):
# define subplot
pyplot.subplot(10, 10, 1 + i)
# turn off axis
pyplot.axis('off')
# plot raw pixel data
pyplot.imshow(trainX[i], cmap='gray_r')
pyplot.show()
Running the example creates a figure with a plot of 100 images from the MNIST training dataset, arranged in a 10×10 square.
Plot of the First 100 Items of Clothing From the Fashion MNIST Dataset.
We will use the images in the training dataset as the basis for training a Generative Adversarial Network.
Specifically, the generator model will learn how to generate new plausible items of clothing using a discriminator that will try to distinguish between real images from the Fashion MNIST training dataset and new images output by the generator model.
This is a relatively simple problem that does not require a sophisticated generator or discriminator model, although it does require the generation of a grayscale output image.
Unconditional GAN for Fashion-MNIST
In this section, we will develop an unconditional GAN for the Fashion-MNIST dataset.
The first step is to define the models.
The discriminator model takes as input one 28×28 grayscale image and outputs a binary prediction as to whether the image is real (class=1) or fake (class=0). It is implemented as a modest convolutional neural network using best practices for GAN design such as using the LeakyReLU activation function with a slope of 0.2, using a 2×2 stride to downsample, and the adam version of stochastic gradient descent with a learning rate of 0.0002 and a momentum of 0.5
The define_discriminator() function below implements this, defining and compiling the discriminator model and returning it. The input shape of the image is parameterized as a default function argument in case you want to re-use the function for your own image data later.
# define the standalone discriminator model
def define_discriminator(in_shape=(28,28,1)):
model = Sequential()
# downsample
model.add(Conv2D(128, (3,3), strides=(2,2), padding='same', input_shape=in_shape))
model.add(LeakyReLU(alpha=0.2))
# downsample
model.add(Conv2D(128, (3,3), strides=(2,2), padding='same'))
model.add(LeakyReLU(alpha=0.2))
# classifier
model.add(Flatten())
model.add(Dropout(0.4))
model.add(Dense(1, activation='sigmoid'))
# compile model
opt = Adam(lr=0.0002, beta_1=0.5)
model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])
return model
The generator model takes as input a point in the latent space and outputs a single 28×28 grayscale image. This is achieved by using a fully connected layer to interpret the point in the latent space and provide sufficient activations that can be reshaped into many copies (in this case 128) of a low-resolution version of the output image (e.g. 7×7). This is then upsampled twice, doubling the size and quadrupling the area of the activations each time using transpose convolutional layers. The model uses best practices such as the LeakyReLU activation, a kernel size that is a factor of the stride size, and a hyperbolic tangent (tanh) activation function in the output layer.
The define_generator() function below defines the generator model, but intentionally does not compile it as it is not trained directly, then returns the model. The size of the latent space is parameterized as a function argument.
# define the standalone generator model
def define_generator(latent_dim):
model = Sequential()
# foundation for 7x7 image
n_nodes = 128 * 7 * 7
model.add(Dense(n_nodes, input_dim=latent_dim))
model.add(LeakyReLU(alpha=0.2))
model.add(Reshape((7, 7, 128)))
# upsample to 14x14
model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same'))
model.add(LeakyReLU(alpha=0.2))
# upsample to 28x28
model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same'))
model.add(LeakyReLU(alpha=0.2))
# generate
model.add(Conv2D(1, (7,7), activation='tanh', padding='same'))
return model
Next, a GAN model can be defined that combines both the generator model and the discriminator model into one larger model. This larger model will be used to train the model weights in the generator, using the output and error calculated by the discriminator model. The discriminator model is trained separately, and as such, the model weights are marked as not trainable in this larger GAN model to ensure that only the weights of the generator model are updated. This change to the trainability of the discriminator weights only has an effect when training the combined GAN model, not when training the discriminator standalone.
This larger GAN model takes as input a point in the latent space, uses the generator model to generate an image which is fed as input to the discriminator model, then is output or classified as real or fake.
The define_gan() function below implements this, taking the already-defined generator and discriminator models as input.
# define the combined generator and discriminator model, for updating the generator
def define_gan(generator, discriminator):
# make weights in the discriminator not trainable
discriminator.trainable = False
# connect them
model = Sequential()
# add generator
model.add(generator)
# add the discriminator
model.add(discriminator)
# compile model
opt = Adam(lr=0.0002, beta_1=0.5)
model.compile(loss='binary_crossentropy', optimizer=opt)
return model
Now that we have defined the GAN model, we need to train it. But, before we can train the model, we require input data.
The first step is to load and prepare the Fashion MNIST dataset. We only require the images in the training dataset. The images are black and white, therefore we must add an additional channel dimension to transform them to be three dimensional, as expected by the convolutional layers of our models. Finally, the pixel values must be scaled to the range [-1,1] to match the output of the generator model.
The load_real_samples() function below implements this, returning the loaded and scaled Fashion MNIST training dataset ready for modeling.
# load fashion mnist images
def load_real_samples():
# load dataset
(trainX, _), (_, _) = load_data()
# expand to 3d, e.g. add channels
X = expand_dims(trainX, axis=-1)
# convert from ints to floats
X = X.astype('float32')
# scale from [0,255] to [-1,1]
X = (X - 127.5) / 127.5
return X
We will require one batch (or a half) batch of real images from the dataset each update to the GAN model. A simple way to achieve this is to select a random sample of images from the dataset each time.
The generate_real_samples() function below implements this, taking the prepared dataset as an argument, selecting and returning a random sample of Fashion MNIST images and their corresponding class label for the discriminator, specifically class=1, indicating that they are real images.
# select real samples
def generate_real_samples(dataset, n_samples):
# choose random instances
ix = randint(0, dataset.shape[0], n_samples)
# select images
X = dataset[ix]
# generate class labels
y = ones((n_samples, 1))
return X, y
Next, we need inputs for the generator model. These are random points from the latent space, specifically Gaussian distributed random variables.
The generate_latent_points() function implements this, taking the size of the latent space as an argument and the number of points required and returning them as a batch of input samples for the generator model.
# generate points in latent space as input for the generator
def generate_latent_points(latent_dim, n_samples):
# generate points in the latent space
x_input = randn(latent_dim * n_samples)
# reshape into a batch of inputs for the network
x_input = x_input.reshape(n_samples, latent_dim)
return x_input
Next, we need to use the points in the latent space as input to the generator in order to generate new images.
The generate_fake_samples() function below implements this, taking the generator model and size of the latent space as arguments, then generating points in the latent space and using them as input to the generator model. The function returns the generated images and their corresponding class label for the discriminator model, specifically class=0 to indicate they are fake or generated.
# use the generator to generate n fake examples, with class labels
def generate_fake_samples(generator, latent_dim, n_samples):
# generate points in latent space
x_input = generate_latent_points(latent_dim, n_samples)
# predict outputs
X = generator.predict(x_input)
# create class labels
y = zeros((n_samples, 1))
return X, y
We are now ready to fit the GAN models.
The model is fit for 100 training epochs, which is arbitrary, as the model begins generating plausible items of clothing after perhaps 20 epochs. A batch size of 128 samples is used, and each training epoch involves 60,000/128, or about 468 batches of real and fake samples and updates to the model.
First, the discriminator model is updated for a half batch of real samples, then a half batch of fake samples, together forming one batch of weight updates. The generator is then updated via the composite gan model. Importantly, the class label is set to 1 or real for the fake samples. This has the effect of updating the generator toward getting better at generating real samples on the next batch.
The train() function below implements this, taking the defined models, dataset, and size of the latent dimension as..
Generative Adversarial Networks, or GANs, are an architecture for training generative models, such as deep convolutional neural networks for generating images.
The generative model in the GAN architecture learns to map points in the latent space to generated images. The latent space has no meaning other than the meaning applied to it via the generative model. Yet, the latent space has structure that can be explored, such as by interpolating between points and performing vector arithmetic between points in latent space which have meaningful and targeted effects on the generated images.
In this tutorial, you will discover how to develop a generative adversarial network for face generation and explore the structure of latent space and the effect on generated faces.
After completing this tutorial, you will know:
How to develop a generative adversarial network for generating faces.
How to interpolate between points in latent space and generate images that morph from one face to another.
How to perform vector arithmetic in latent space and achieve targeted results in the resulting generated faces.
Let’s get started.
How to Interpolate and Perform Vector Arithmetic With Faces Using a Generative Adversarial Network. Photo by Intermountain Forest Service, some rights reserved.
Tutorial Overview
This tutorial is divided into five parts; they are:
Vector Arithmetic in Latent Space
Large-Scale CelebFaces Dataset (CelebA)
How to Prepare CelebA Faces Dataset
How to Develop a Generative Adversarial Network
How to Explore the Latent Space for Generated Faces
Vector Arithmetic in Latent Space
The generator model in the GAN architecture takes a point from the latent space as input and generates a new image.
The latent space itself has no meaning. Typically it is a 100-dimensional hypersphere with each variable drawn from a Gaussian distribution with a mean of zero and a standard deviation of one. Through training, the generator learns to map points into the latent space with specific output images and this mapping will be different each time the model is trained.
The latent space has structure when interpreted by the generator model, and this structure can be queried and navigated for a given model.
Typically, new images are generated using random points in the latent space. Taken a step further, points in the latent space can be constructed (e.g. all 0s, all 0.5s, or all 1s) and used as input or a query to generate a specific image.
A series of points can be created on a linear path between two points in the latent space, such as two generated images. These points can be used to generate a series of images that show a transition between the two generated images.
Finally, the points in the latent space can be kept and used in simple vector arithmetic to create new points in the latent space that, in turn, can be used to generate images. This is an interesting idea, as it allows for the intuitive and targeted generation of images.
In the paper, the authors explored the latent space for GANs fit on a number of different training datasets, most notably a dataset of celebrity faces. They demonstrated two interesting aspects.
The first was the vector arithmetic with faces. For example, a face of a smiling woman minus the face of a neutral woman plus the face of a neutral man resulted in the face of a smiling man.
smiling woman - neutral woman + neutral man = smiling man
Specifically, the arithmetic was performed on the points in the latent space for the resulting faces. Actually on the average of multiple faces with a given characteristic, to provide a more robust result.
Example of Vector Arithmetic on Points in the Latent Space for Generating Faces With a GAN. Taken from Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks.
The second demonstration was the transition between two generated faces, specifically by creating a linear path through the latent dimension between the points that generated two faces and then generating all of the faces for the points along the path.
Example of Faces on a Path Between Two GAN Generated Faces. Taken from Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks.
Exploring the structure of the latent space for a GAN model is both interesting for the problem domain and helps to develop an intuition for what has been learned by the generator model.
In this tutorial, we will develop a GAN for generating photos of faces, then explore the latent space for the model with vector arithmetic.
The dataset provides about 200,000 photographs of celebrity faces along with annotations for what appears in given photos, such as glasses, face shape, hats, hair type, etc. As part of the dataset, the authors provide a version of each photo centered on the face and cropped to the portrait with varying sizes around 150 pixels wide and 200 pixels tall. We will use this as the basis for developing our GAN model.
The dataset can be easily downloaded from the Kaggle webpage. Note: this may require an account with Kaggle.
Specifically, download the file “img_align_celeba.zip” which is about 1.3 gigabytes. To do this, click on the filename on the Kaggle website and then click the download icon.
The download might take a while depending on the speed of your internet connection.
After downloading, unzip the archive.
This will create a new directory named “img_align_celeba” that contains all of the images with filenames like 202599.jpg and 202598.jpg.
Next, we can look at preparing the raw images for modeling.
How to Prepare CelebA Faces Dataset
The first step is to develop code to load the images.
We can use the Pillow library to load a given image file, convert it to RGB format (if needed) and return an array of pixel data. The load_image() function below implements this.
# load an image as an rgb numpy array
def load_image(filename):
# load image from file
image = Image.open(filename)
# convert to RGB, if needed
image = image.convert('RGB')
# convert to array
pixels = asarray(image)
return pixels
Next, we can enumerate the directory of images, load each as an array of pixels in turn, and return an array with all of the images.
There are 200K images in the dataset, which is probably more than we need so we can also limit the number of images to load with an argument. The load_faces() function below implements this.
# load images and extract faces for all images in a directory
def load_faces(directory, n_faces):
faces = list()
# enumerate files
for filename in listdir(directory):
# load the image
pixels = load_image(directory + filename)
# store
faces.append(pixels)
# stop once we have enough
if len(faces) >= n_faces:
break
return asarray(faces)
Finally, once the images are loaded, we can plot them using the imshow() function from the matplotlib library.
The plot_faces() function below does this, plotting images arranged into in a square.
# plot a list of loaded faces
def plot_faces(faces, n):
for i in range(n * n):
# define subplot
pyplot.subplot(n, n, 1 + i)
# turn off axis
pyplot.axis('off')
# plot raw pixel data
pyplot.imshow(faces[i])
pyplot.show()
Tying this together, the complete example is listed below.
# load and plot faces
from os import listdir
from numpy import asarray
from PIL import Image
from matplotlib import pyplot
# load an image as an rgb numpy array
def load_image(filename):
# load image from file
image = Image.open(filename)
# convert to RGB, if needed
image = image.convert('RGB')
# convert to array
pixels = asarray(image)
return pixels
# load images and extract faces for all images in a directory
def load_faces(directory, n_faces):
faces = list()
# enumerate files
for filename in listdir(directory):
# load the image
pixels = load_image(directory + filename)
# store
faces.append(pixels)
# stop once we have enough
if len(faces) >= n_faces:
break
return asarray(faces)
# plot a list of loaded faces
def plot_faces(faces, n):
for i in range(n * n):
# define subplot
pyplot.subplot(n, n, 1 + i)
# turn off axis
pyplot.axis('off')
# plot raw pixel data
pyplot.imshow(faces[i])
pyplot.show()
# directory that contains all images
directory = 'img_align_celeba/'
# load and extract all faces
faces = load_faces(directory, 25)
print('Loaded: ', faces.shape)
# plot faces
plot_faces(faces, 5)
Running the example loads a total of 25 images from the directory, then summarizes the size of the returned array.
Loaded: (25, 218, 178, 3)
Finally, the 25 images are plotted in a 5×5 square.
Plot of a Sample of 25 Faces from the Celebrity Faces Dataset
When working with a GAN, it is easier to model a dataset if all of the images are small and square in shape.
Further, as we are only interested in the face in each photo, and not the background, we can perform face detection and extract only the face before resizing the result to a fixed size.
There are many ways to perform face detection. In this case, we will use a pre-trained Multi-Task Cascaded Convolutional Neural Network, or MTCNN. This is a state-of-the-art deep learning model for face detection, described in the 2016 paper titled “Joint Face Detection and Alignment Using Multitask Cascaded Convolutional Networks.”
We will use the implementation provided by Iván de Paz Centeno in the ipazc/mtcnn project. This library can be installed via pip as follows:
sudo pip install mtcnn
We can confirm that the library was installed correctly by importing the library and printing the version; for example:
# confirm mtcnn was installed correctly
import mtcnn
# print version
print(mtcnn.__version__)
Running the example prints the current version of the library.
0.0.8
The MTCNN model is very easy to use.
First, an instance of the MTCNN model is created, then the detect_faces() function can be called passing in the pixel data for one image. The result is a list of detected faces, with a bounding box defined in pixel offset values.
...
# prepare model
model = MTCNN()
# detect face in the image
faces = model.detect_faces(pixels)
# extract details of the face
x1, y1, width, height = faces[0]['box']
We can update our example to extract the face from each loaded photo and resize the extracted face pixels to a fixed size. In this case, we will use the square shape of 80×80 pixels.
The extract_face() function below implements this, taking the MTCNN model and pixel values for a single photograph as arguments and returning an 80x80x3 array of pixel values with just the face, or None if no face was detected (which can happen rarely).
# extract the face from a loaded image and resize
def extract_face(model, pixels, required_size=(80, 80)):
# detect face in the image
faces = model.detect_faces(pixels)
# skip cases where we could not detect a face
if len(faces) == 0:
return None
# extract details of the face
x1, y1, width, height = faces[0]['box']
# force detected pixel values to be positive (bug fix)
x1, y1 = abs(x1), abs(y1)
# convert into coordinates
x2, y2 = x1 + width, y1 + height
# retrieve face pixels
face_pixels = pixels[y1:y2, x1:x2]
# resize pixels to the model size
image = Image.fromarray(face_pixels)
image = image.resize(required_size)
face_array = asarray(image)
return face_array
We can now update the load_faces() function to extract the face from the loaded photo and store that in the list of faces returned.
# load images and extract faces for all images in a directory
def load_faces(directory, n_faces):
# prepare model
model = MTCNN()
faces = list()
# enumerate files
for filename in listdir(directory):
# load the image
pixels = load_image(directory + filename)
# get face
face = extract_face(model, pixels)
if face is None:
continue
# store
faces.append(face)
print(len(faces), face.shape)
# stop once we have enough
if len(faces) >= n_faces:
break
return asarray(faces)
Tying this together, the complete example is listed below.
In this case, we increase the total number of loaded faces to 50,000 to provide a good training dataset for our GAN model.
# example of extracting and resizing faces into a new dataset
from os import listdir
from numpy import asarray
from numpy import savez_compressed
from PIL import Image
from mtcnn.mtcnn import MTCNN
from matplotlib import pyplot
# load an image as an rgb numpy array
def load_image(filename):
# load image from file
image = Image.open(filename)
# convert to RGB, if needed
image = image.convert('RGB')
# convert to array
pixels = asarray(image)
return pixels
# extract the face from a loaded image and resize
def extract_face(model, pixels, required_size=(80, 80)):
# detect face in the image
faces = model.detect_faces(pixels)
# skip cases where we could not detect a face
if len(faces) == 0:
return None
# extract details of the face
x1, y1, width, height = faces[0]['box']
# force detected pixel values to be positive (bug fix)
x1, y1 = abs(x1), abs(y1)
# convert into coordinates
x2, y2 = x1 + width, y1 + height
# retrieve face pixels
face_pixels = pixels[y1:y2, x1:x2]
# resize pixels to the model size
image = Image.fromarray(face_pixels)
image = image.resize(required_size)
face_array = asarray(image)
return face_array
# load images and extract faces for all images in a directory
def load_faces(directory, n_faces):
# prepare model
model = MTCNN()
faces = list()
# enumerate files
for filename in listdir(directory):
# load the image
pixels = load_image(directory + filename)
# get face
face = extract_face(model, pixels)
if face is None:
continue
# store
faces.append(face)
print(len(faces), face.shape)
# stop once we have enough
if len(faces) >= n_faces:
break
return asarray(faces)
# directory that contains all images
directory = 'img_align_celeba/'
# load and extract all faces
all_faces = load_faces(directory, 50000)
print('Loaded: ', all_faces.shape)
# save in compressed format
savez_compressed('img_align_celeba.npz', all_faces)
Running the example may take a few minutes given the larger number of faces to be loaded.
At the end of the run, the array of extracted and resized faces is saved as a compressed NumPy array with the filename ‘img_align_celeba.npz‘.
The prepared dataset can then be loaded any time, as follows.
# load the prepared dataset
from numpy import load
# load the face dataset
data = load('img_align_celeba.npz')
faces = data['arr_0']
print('Loaded: ', faces.shape)
Loading the dataset summarizes the shape of the array, showing 50K images with the size of 80×80 pixels and three color channels.
Loaded: (50000, 80, 80, 3)
We are now ready to develop a GAN model to generate faces using this dataset.
How to Develop a Generative Adversarial Network
In this section, we will develop a GAN for the faces dataset that we have prepared.
The first step is to define the models.
The discriminator model takes as input one 80×80 color image an outputs a binary prediction as to whether the image is real (class=1) or fake (class=0). It is implemented as a modest convolutional neural network using best practices for GAN design such as using the LeakyReLU activation function with a slope of 0.2, using a 2×2 stride to downsample, and the adam version of stochastic gradient descent with a learning rate of 0.0002 and a momentum of 0.5
The define_discriminator() function below implements this, defining and compiling the discriminator model and returning it. The input shape of the image is parameterized as a default function argument in case you want to re-use the function for your own image data later.
# define the standalone discriminator model
def define_discriminator(in_shape=(80,80,3)):
model = Sequential()
# normal
model.add(Conv2D(128, (5,5), padding='same', input_shape=in_shape))
model.add(LeakyReLU(alpha=0.2))
# downsample to 40x40
model.add(Conv2D(128, (5,5), strides=(2,2), padding='same'))
model.add(LeakyReLU(alpha=0.2))
# downsample to 20x30
model.add(Conv2D(128, (5,5), strides=(2,2), padding='same'))
model.add(LeakyReLU(alpha=0.2))
# downsample to 10x10
model.add(Conv2D(128, (5,5), strides=(2,2), padding='same'))
model.add(LeakyReLU(alpha=0.2))
# downsample to 5x5
model.add(Conv2D(128, (5,5), strides=(2,2), padding='same'))
model.add(LeakyReLU(alpha=0.2))
# classifier
model.add(Flatten())
model.add(Dropout(0.4))
model.add(Dense(1, activation='sigmoid'))
# compile model
opt = Adam(lr=0.0002, beta_1=0.5)
model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])
return model
The generator model takes as input a point in the latent space and outputs a single 80×80 color image.
This is achieved by using a fully connected layer to interpret the point in the latent space and provide sufficient activations that can be reshaped into many copies (in this case 128) of a low-resolution version of the output image (e.g. 5×5). This is then upsampled four times, doubling the size and quadrupling the area of the activations each time using transpose convolutional layers. The model uses best practices such as the LeakyReLU activation, a kernel size that is a factor of the stride size, and a hyperbolic tangent (tanh) activation function in the output layer.
The define_generator() function below defines the generator model but intentionally does not compile it as it is not trained directly, then returns the model. The size of the latent space is parameterized as a function argument.
# define the standalone generator model
def define_generator(latent_dim):
model = Sequential()
# foundation for 5x5 feature maps
n_nodes = 128 * 5 * 5
model.add(Dense(n_nodes, input_dim=latent_dim))
model.add(LeakyReLU(alpha=0.2))
model.add(Reshape((5, 5, 128)))
# upsample to 10x10
model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same'))
model.add(LeakyReLU(alpha=0.2))
# upsample to 20x20
model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same'))
model.add(LeakyReLU(alpha=0.2))
# upsample to 40x40
model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same'))
model.add(LeakyReLU(alpha=0.2))
# upsample to 80x80
model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same'))
model.add(LeakyReLU(alpha=0.2))
# output layer 80x80x3
model.add(Conv2D(3, (5,5), activation='tanh',..
Generative Adversarial Networks, or GANs, are an architecture for training generative models, such as deep convolutional neural networks for generating images.
Developing a GAN for generating images requires both a discriminator convolutional neural network model for classifying whether a given image is real or generated and a generator model that uses inverse convolutional layers to transform an input to a full two-dimensional image of pixel values.
It can be challenging to understand both how GANs work and how deep convolutional neural network models can be trained in a GAN architecture for image generation. A good starting point for beginners is to practice developing and using GANs on standard image datasets used in the field of computer vision, such as the CIFAR small object photograph dataset. Using small and well-understood datasets means that smaller models can be developed and trained quickly, allowing focus to be put on the model architecture and image generation process itself.
In this tutorial, you will discover how to develop a generative adversarial network with deep convolutional networks for generating small photographs of objects.
After completing this tutorial, you will know:
How to define and train the standalone discriminator model for learning the difference between real and fake images.
How to define the standalone generator model and train the composite generator and discriminator model.
How to evaluate the performance of the GAN and use the final standalone generator model to generate new images.
Let’s get started.
How to Develop a Generative Adversarial Network for a CIFAR-10 Small Object Photographs From Scratch Photo by hiGorgeous, some rights reserved.
Tutorial Overview
This tutorial is divided into seven parts; they are:
CIFAR-10 Small Object Photograph Dataset
How to Define and Train the Discriminator Model
How to Define and Use the Generator Model
How to Train the Generator Model
How to Evaluate GAN Model Performance
Complete Example of GAN for CIFAR-10
How to Use the Final Generator Model to Generate Images
CIFAR-10 Small Object Photograph Dataset
CIFAR is an acronym that stands for the Canadian Institute For Advanced Research and the CIFAR-10 dataset was developed along with the CIFAR-100 dataset (covered in the next section) by researchers at the CIFAR institute.
The dataset is comprised of 60,000 32×32 pixel color photographs of objects from 10 classes, such as frogs, birds, cats, ships, airplanes, etc.
These are very small images, much smaller than a typical photograph, and the dataset is intended for computer vision research.
Keras provides access to the CIFAR10 dataset via the cifar10.load_dataset() function. It returns two tuples, one with the input and output elements for the standard training dataset, and another with the input and output elements for the standard test dataset.
The example below loads the dataset and summarizes the shape of the loaded dataset.
Note: the first time you load the dataset, Keras will automatically download a compressed version of the images and save them under your home directory in ~/.keras/datasets/. The download is fast as the dataset is only about 163 megabytes in its compressed form.
# example of loading the cifar10 dataset
from keras.datasets.cifar10 import load_data
# load the images into memory
(trainX, trainy), (testX, testy) = load_data()
# summarize the shape of the dataset
print('Train', trainX.shape, trainy.shape)
print('Test', testX.shape, testy.shape)
Running the example loads the dataset and prints the shape of the input and output components of the train and test splits of images.
We can see that there are 50K examples in the training set and 10K in the test set and that each image is a square of 32 by 32 pixels.
The images are color with the object centered in the middle of the frame.
We can plot some of the images from the training dataset with the matplotlib library using the imshow() function.
# plot raw pixel data
pyplot.imshow(trainX[i])
The example below plots the first 49 images from the training dataset in a 7 by 7 square.
# example of loading and plotting the cifar10 dataset
from keras.datasets.cifar10 import load_data
from matplotlib import pyplot
# load the images into memory
(trainX, trainy), (testX, testy) = load_data()
# plot images from the training dataset
for i in range(49):
# define subplot
pyplot.subplot(7, 7, 1 + i)
# turn off axis
pyplot.axis('off')
# plot raw pixel data
pyplot.imshow(trainX[i])
pyplot.show()
Running the example creates a figure with a plot of 49 images from the CIFAR10 training dataset, arranged in a 7×7 square.
In the plot, you can see small photographs of planes, trucks, horses, cars, frogs, and so on.
Plot of the First 49 Small Object Photographs From the CIFAR10 Dataset.
We will use the images in the training dataset as the basis for training a Generative Adversarial Network.
Specifically, the generator model will learn how to generate new plausible photographs of objects using a discriminator that will try and distinguish between real images from the CIFAR10 training dataset and new images output by the generator model.
This is a non-trivial problem that requires modest generator and discriminator models that are probably most effectively trained on GPU hardware.
For help using cheap Amazon EC2 instances to train deep learning models, see the post:
The first step is to define the discriminator model.
The model must take a sample image from our dataset as input and output a classification prediction as to whether the sample is real or fake. This is a binary classification problem.
Inputs: Image with three color channel and 32×32 pixels in size.
Outputs: Binary classification, likelihood the sample is real (or fake).
The discriminator model has a normal convolutional layer followed by three convolutional layers using a stride of 2×2 to downsample the input image. The model has no pooling layers and a single node in the output layer with the sigmoid activation function to predict whether the input sample is real or fake. The model is trained to minimize the binary cross entropy loss function, appropriate for binary classification.
The define_discriminator() function below defines the discriminator model and parametrizes the size of the input image.
# define the standalone discriminator model
def define_discriminator(in_shape=(32,32,3)):
model = Sequential()
# normal
model.add(Conv2D(64, (3,3), padding='same', input_shape=in_shape))
model.add(LeakyReLU(alpha=0.2))
# downsample
model.add(Conv2D(128, (3,3), strides=(2,2), padding='same'))
model.add(LeakyReLU(alpha=0.2))
# downsample
model.add(Conv2D(128, (3,3), strides=(2,2), padding='same'))
model.add(LeakyReLU(alpha=0.2))
# downsample
model.add(Conv2D(256, (3,3), strides=(2,2), padding='same'))
model.add(LeakyReLU(alpha=0.2))
# classifier
model.add(Flatten())
model.add(Dropout(0.4))
model.add(Dense(1, activation='sigmoid'))
# compile model
opt = Adam(lr=0.0002, beta_1=0.5)
model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])
return model
We can use this function to define the discriminator model and summarize it.
The complete example is listed below.
# example of defining the discriminator model
from keras.models import Sequential
from keras.optimizers import Adam
from keras.layers import Dense
from keras.layers import Conv2D
from keras.layers import Flatten
from keras.layers import Dropout
from keras.layers import LeakyReLU
from keras.utils.vis_utils import plot_model
# define the standalone discriminator model
def define_discriminator(in_shape=(32,32,3)):
model = Sequential()
# normal
model.add(Conv2D(64, (3,3), padding='same', input_shape=in_shape))
model.add(LeakyReLU(alpha=0.2))
# downsample
model.add(Conv2D(128, (3,3), strides=(2,2), padding='same'))
model.add(LeakyReLU(alpha=0.2))
# downsample
model.add(Conv2D(128, (3,3), strides=(2,2), padding='same'))
model.add(LeakyReLU(alpha=0.2))
# downsample
model.add(Conv2D(256, (3,3), strides=(2,2), padding='same'))
model.add(LeakyReLU(alpha=0.2))
# classifier
model.add(Flatten())
model.add(Dropout(0.4))
model.add(Dense(1, activation='sigmoid'))
# compile model
opt = Adam(lr=0.0002, beta_1=0.5)
model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])
return model
# define model
model = define_discriminator()
# summarize the model
model.summary()
# plot the model
plot_model(model, to_file='discriminator_plot.png', show_shapes=True, show_layer_names=True)
Running the example first summarizes the model architecture, showing the input and output from each layer.
We can see that the aggressive 2×2 stride acts to down-sample the input image, first from 32×32 to 16×16, then to 8×8 and more before the model makes an output prediction.
This pattern is by design as we do not use pooling layers and use the large stride to achieve a similar downsampling effect. We will see a similar pattern, but in reverse in the generator model in the next section.
A plot of the model is also created and we can see that the model expects two inputs and will predict a single output.
Note: creating this plot assumes that the pydot and graphviz libraries are installed. If this is a problem, you can comment out the import statement and the call to the plot_model() function.
Plot of the Discriminator Model in the CIFAR10 Generative Adversarial Network
We could start training this model now with real examples with a class label of one and randomly generate samples with a class label of zero.
The development of these elements will be useful later, and it helps to see that the discriminator is just a normal neural network model for binary classification.
First, we need a function to load and prepare the dataset of real images.
We will use the cifar.load_data() function to load the CIFAR-10 dataset and just use the input part of the training dataset as the real images.
We must scale the pixel values from the range of unsigned integers in [0,255] to the normalized range of [-1,1].
The generator model will generate images with pixel values in the range [-1,1] as it will use the tanh activation function, a best practice.
It is also a good practice for the real images to be scaled to the same range.
...
# convert from unsigned ints to floats
X = trainX.astype('float32')
# scale from [0,255] to [-1,1]
X = (X - 127.5) / 127.5
The load_real_samples() function below implements the loading and scaling of real CIFAR-10 photographs.
# load and prepare cifar10 training images
def load_real_samples():
# load cifar10 dataset
(trainX, _), (_, _) = load_data()
# convert from unsigned ints to floats
X = trainX.astype('float32')
# scale from [0,255] to [-1,1]
X = (X - 127.5) / 127.5
return X
The model will be updated in batches, specifically with a collection of real samples and a collection of generated samples. On training, an epoch is defined as one pass through the entire training dataset.
We could systematically enumerate all samples in the training dataset, and that is a good approach, but good training via stochastic gradient descent requires that the training dataset be shuffled prior to each epoch. A simpler approach is to select random samples of images from the training dataset.
The generate_real_samples() function below will take the training dataset as an argument and will select a random subsample of images; it will also return class labels for the sample, specifically a class label of 1, to indicate real images.
# select real samples
def generate_real_samples(dataset, n_samples):
# choose random instances
ix = randint(0, dataset.shape[0], n_samples)
# retrieve selected images
X = dataset[ix]
# generate 'real' class labels (1)
y = ones((n_samples, 1))
return X, y
Now, we need a source of fake images.
We don’t have a generator model yet, so instead, we can generate images comprised of random pixel values, specifically random pixel values in the range [0,1], then scaled to the range [-1, 1] like our scaled real images.
The generate_fake_samples() function below implements this behavior and generates images of random pixel values and their associated class label of 0, for fake.
# generate n fake samples with class labels
def generate_fake_samples(n_samples):
# generate uniform random numbers in [0,1]
X = rand(32 * 32 * 3 * n_samples)
# update to have the range [-1, 1]
X = -1 + X * 2
# reshape into a batch of color images
X = X.reshape((n_samples, 32, 32, 3))
# generate 'fake' class labels (0)
y = zeros((n_samples, 1))
return X, y
Finally, we need to train the discriminator model.
This involves repeatedly retrieving samples of real images and samples of generated images and updating the model for a fixed number of iterations.
We will ignore the idea of epochs for now (e.g. complete passes through the training dataset) and fit the discriminator model for a fixed number of batches. The model will learn to discriminate between real and fake (randomly generated) images rapidly, therefore not many batches will be required before it learns to discriminate perfectly.
The train_discriminator() function implements this, using a batch size of 128 images, where 64 are real and 64 are fake each iteration.
We update the discriminator separately for real and fake examples so that we can calculate the accuracy of the model on each sample prior to the update. This gives insight into how the discriminator model is performing over time.
# train the discriminator model
def train_discriminator(model, dataset, n_iter=20, n_batch=128):
half_batch = int(n_batch / 2)
# manually enumerate epochs
for i in range(n_iter):
# get randomly selected 'real' samples
X_real, y_real = generate_real_samples(dataset, half_batch)
# update discriminator on real samples
_, real_acc = model.train_on_batch(X_real, y_real)
# generate 'fake' examples
X_fake, y_fake = generate_fake_samples(half_batch)
# update discriminator on fake samples
_, fake_acc = model.train_on_batch(X_fake, y_fake)
# summarize performance
print('>%d real=%.0f%% fake=%.0f%%' % (i+1, real_acc*100, fake_acc*100))
Tying all of this together, the complete example of training an instance of the discriminator model on real and randomly generated (fake) images is listed below.
# example of training the discriminator model on real and random cifar10 images
from numpy import expand_dims
from numpy import ones
from numpy import zeros
from numpy.random import rand
from numpy.random import randint
from keras.datasets.cifar10 import load_data
from keras.optimizers import Adam
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Conv2D
from keras.layers import Flatten
from keras.layers import Dropout
from keras.layers import LeakyReLU
# define the standalone discriminator model
def define_discriminator(in_shape=(32,32,3)):
model = Sequential()
# normal
model.add(Conv2D(64, (3,3), padding='same', input_shape=in_shape))
model.add(LeakyReLU(alpha=0.2))
# downsample
model.add(Conv2D(128, (3,3), strides=(2,2), padding='same'))
model.add(LeakyReLU(alpha=0.2))
# downsample
model.add(Conv2D(128, (3,3), strides=(2,2), padding='same'))
model.add(LeakyReLU(alpha=0.2))
# downsample
model.add(Conv2D(256, (3,3), strides=(2,2), padding='same'))
model.add(LeakyReLU(alpha=0.2))
# classifier
model.add(Flatten())
model.add(Dropout(0.4))
model.add(Dense(1, activation='sigmoid'))
# compile model
opt = Adam(lr=0.0002, beta_1=0.5)
model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])
return model
# load and prepare cifar10 training images
def load_real_samples():
# load cifar10 dataset
(trainX, _), (_, _) = load_data()
# convert from unsigned ints to floats
X = trainX.astype('float32')
# scale from [0,255] to [-1,1]
X = (X - 127.5) / 127.5
return X
# select real samples
def generate_real_samples(dataset, n_samples):
# choose random instances
ix = randint(0, dataset.shape[0], n_samples)
# retrieve selected images
X = dataset[ix]
# generate 'real' class labels (1)
y = ones((n_samples, 1))
return X, y
# generate n fake samples with class labels
def generate_fake_samples(n_samples):
# generate uniform random numbers in [0,1]
X = rand(32 * 32 * 3 * n_samples)
# update to have the range [-1, 1]
X = -1 + X * 2
# reshape into a batch of color images
X = X.reshape((n_samples, 32, 32, 3))
# generate 'fake' class labels (0)
y = zeros((n_samples, 1))
return X, y
# train the discriminator model
def train_discriminator(model, dataset, n_iter=20, n_batch=128):
half_batch = int(n_batch / 2)
# manually enumerate epochs
for i in range(n_iter):
# get randomly selected 'real' samples
X_real, y_real = generate_real_samples(dataset, half_batch)
# update discriminator on real samples
_, real_acc = model.train_on_batch(X_real, y_real)
# generate 'fake' examples
X_fake, y_fake = generate_fake_samples(half_batch)
# update discriminator on fake samples
_, fake_acc = model.train_on_batch(X_fake, y_fake)
# summarize performance
print('>%d real=%.0f%% fake=%.0f%%' % (i+1, real_acc*100, fake_acc*100))
# define the discriminator model
model = define_discriminator()
# load image data
dataset = load_real_samples()
# fit the model
train_discriminator(model, dataset)
Running the example first defines the model, loads the CIFAR-10 dataset, then trains the discriminator model.
Note: your specific results may vary given the stochastic nature of the learning algorithm. Consider running the example a few..
Generative Adversarial Networks, or GANs, are an architecture for training generative models, such as deep convolutional neural networks for generating images.
Developing a GAN for generating images requires both a discriminator convolutional neural network model for classifying whether a given image is real or generated and a generator model that uses inverse convolutional layers to transform an input to a full two-dimensional image of pixel values.
It can be challenging to understand both how GANs work and how deep convolutional neural network models can be trained in a GAN architecture for image generation. A good starting point for beginners is to practice developing and using GANs on standard image datasets used in the field of computer vision, such as the MNIST handwritten digit dataset. Using small and well-understood datasets means that smaller models can be developed and trained quickly, allowing the focus to be put on the model architecture and image generation process itself.
In this tutorial, you will discover how to develop a generative adversarial network with deep convolutional networks for generating handwritten digits.
After completing this tutorial, you will know:
How to define and train the standalone discriminator model for learning the difference between real and fake images.
How to define the standalone generator model and train the composite generator and discriminator model.
How to evaluate the performance of the GAN and use the final standalone generator model to generate new images.
Let’s get started.
How to Develop a Generative Adversarial Network for an MNIST Handwritten Digits From Scratch in Keras Photo by jcookfisher, some rights reserved.
Tutorial Overview
This tutorial is divided into seven parts; they are:
MNIST Handwritten Digit Dataset
How to Define and Train the Discriminator Model
How to Define and Use the Generator Model
How to Train the Generator Model
How to Evaluate GAN Model Performance
Complete Example of GAN for MNIST
How to Use the Final Generator Model to Generate Images
MNIST Handwritten Digit Dataset
The MNIST dataset is an acronym that stands for the Modified National Institute of Standards and Technology dataset.
It is a dataset of 70,000 small square 28×28 pixel grayscale images of handwritten single digits between 0 and 9.
The task is to classify a given image of a handwritten digit into one of 10 classes representing integer values from 0 to 9, inclusively.
Keras provides access to the MNIST dataset via the mnist.load_dataset() function. It returns two tuples, one with the input and output elements for the standard training dataset, and another with the input and output elements for the standard test dataset.
The example below loads the dataset and summarizes the shape of the loaded dataset.
Note: the first time you load the dataset, Keras will automatically download a compressed version of the images and save them under your home directory in ~/.keras/datasets/. The download is fast as the dataset is only about eleven megabytes in its compressed form.
# example of loading the mnist dataset
from keras.datasets.mnist import load_data
# load the images into memory
(trainX, trainy), (testX, testy) = load_data()
# summarize the shape of the dataset
print('Train', trainX.shape, trainy.shape)
print('Test', testX.shape, testy.shape)
Running the example loads the dataset and prints the shape of the input and output components of the train and test splits of images.
We can see that there are 60K examples in the training set and 10K in the test set and that each image is a square of 28 by 28 pixels.
Train (60000, 28, 28) (60000,)
Test (10000, 28, 28) (10000,)
The images are grayscale with a black background (0 pixel value) and the handwritten digits in white (pixel values near 255). This means if the images were plotted, they would be mostly black with a white digit in the middle.
We can plot some of the images from the training dataset using the matplotlib library using the imshow() function and specify the color map via the ‘cmap‘ argument as ‘gray‘ to show the pixel values correctly.
# plot raw pixel data
pyplot.imshow(trainX[i], cmap='gray')
Alternately, the images are easier to review when we reverse the colors and plot the background as white and the handwritten digits in black.
They are easier to view as most of the image is now white with the area of interest in black. This can be achieved using a reverse grayscale color map, as follows:
# plot raw pixel data
pyplot.imshow(trainX[i], cmap='gray_r')
The example below plots the first 25 images from the training dataset in a 5 by 5 square.
# example of loading the mnist dataset
from keras.datasets.mnist import load_data
from matplotlib import pyplot
# load the images into memory
(trainX, trainy), (testX, testy) = load_data()
# plot images from the training dataset
for i in range(25):
# define subplot
pyplot.subplot(5, 5, 1 + i)
# turn off axis
pyplot.axis('off')
# plot raw pixel data
pyplot.imshow(trainX[i], cmap='gray_r')
pyplot.show()
Running the example creates a plot of 25 images from the MNIST training dataset, arranged in a 5×5 square.
Plot of the First 25 Handwritten Digits From the MNIST Dataset.
We will use the images in the training dataset as the basis for training a Generative Adversarial Network.
Specifically, the generator model will learn how to generate new plausible handwritten digits between 0 and 9, using a discriminator that will try to distinguish between real images from the MNIST training dataset and new images output by the generator model.
This is a relatively simple problem that does not require a sophisticated generator or discriminator model, although it does require the generation of a grayscale output image.
How to Define and Train the Discriminator Model
The first step is to define the discriminator model.
The model must take a sample image from our dataset as input and output a classification prediction as to whether the sample is real or fake.
This is a binary classification problem:
Inputs: Image with one channel and 28×28 pixels in size.
Outputs: Binary classification, likelihood the sample is real (or fake).
The discriminator model has two convolutional layers with 64 filters each, a small kernel size of 3, and larger than normal stride of 2. The model has no pooling layers and a single node in the output layer with the sigmoid activation function to predict whether the input sample is real or fake. The model is trained to minimize the binary cross entropy loss function, appropriate for binary classification.
We will use some best practices in defining the discriminator model, such as the use of LeakyReLU instead of ReLU, using Dropout, and using the Adam version of stochastic gradient descent with a learning rate of 0.0002 and a momentum of 0.5.
The function define_discriminator() below defines the discriminator model and parametrizes the size of the input image.
# define the standalone discriminator model
def define_discriminator(in_shape=(28,28,1)):
model = Sequential()
model.add(Conv2D(64, (3,3), strides=(2, 2), padding='same', input_shape=in_shape))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.4))
model.add(Conv2D(64, (3,3), strides=(2, 2), padding='same'))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.4))
model.add(Flatten())
model.add(Dense(1, activation='sigmoid'))
# compile model
opt = Adam(lr=0.0002, beta_1=0.5)
model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])
return model
We can use this function to define the discriminator model and summarize it.
The complete example is listed below.
# example of defining the discriminator model
from keras.models import Sequential
from keras.optimizers import Adam
from keras.layers import Dense
from keras.layers import Conv2D
from keras.layers import Flatten
from keras.layers import Dropout
from keras.layers import LeakyReLU
from keras.utils.vis_utils import plot_model
# define the standalone discriminator model
def define_discriminator(in_shape=(28,28,1)):
model = Sequential()
model.add(Conv2D(64, (3,3), strides=(2, 2), padding='same', input_shape=in_shape))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.4))
model.add(Conv2D(64, (3,3), strides=(2, 2), padding='same'))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.4))
model.add(Flatten())
model.add(Dense(1, activation='sigmoid'))
# compile model
opt = Adam(lr=0.0002, beta_1=0.5)
model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])
return model
# define model
model = define_discriminator()
# summarize the model
model.summary()
# plot the model
plot_model(model, to_file='discriminator_plot.png', show_shapes=True, show_layer_names=True)
Running the example first summarizes the model architecture, showing the input and output from each layer.
This pattern is by design as we do not use pooling layers and use the large stride as achieve a similar downsampling effect. We will see a similar pattern, but in reverse, in the generator model in the next section.
A plot of the model is also created and we can see that the model expects two inputs and will predict a single output.
Note: creating this plot assumes that the pydot and graphviz libraries are installed. If this is a problem, you can comment out the import statement for the plot_model function and the call to the plot_model() function.
Plot of the Discriminator Model in the MNIST GAN
We could start training this model now with real examples with a class label of one, and randomly generated samples with a class label of zero.
The development of these elements will be useful later, and it helps to see that the discriminator is just a normal neural network model for binary classification.
First, we need a function to load and prepare the dataset of real images.
We will use the mnist.load_data() function to load the MNIST dataset and just use the input part of the training dataset as the real images.
The images are 2D arrays of pixels and convolutional neural networks expect 3D arrays of images as input, where each image has one or more channels.
We must update the images to have an additional dimension for the grayscale channel. We can do this using the expand_dims() NumPy function and specify the final dimension for the channels-last image format.
# expand to 3d, e.g. add channels dimension
X = expand_dims(trainX, axis=-1)
Finally, we must scale the pixel values from the range of unsigned integers in [0,255] to the normalized range of [0,1].
# convert from unsigned ints to floats
X = X.astype('float32')
# scale from [0,255] to [0,1]
X = X / 255.0
The load_real_samples() function below implements this.
# load and prepare mnist training images
def load_real_samples():
# load mnist dataset
(trainX, _), (_, _) = load_data()
# expand to 3d, e.g. add channels dimension
X = expand_dims(trainX, axis=-1)
# convert from unsigned ints to floats
X = X.astype('float32')
# scale from [0,255] to [0,1]
X = X / 255.0
return X
The model will be updated in batches, specifically with a collection of real samples and a collection of generated samples. On training, epoch is defined as one pass through the entire training dataset.
We could systematically enumerate all samples in the training dataset, and that is a good approach, but good training via stochastic gradient descent requires that the training dataset be shuffled prior to each epoch. A simpler approach is to select random samples of images from the training dataset.
The generate_real_samples() function below will take the training dataset as an argument and will select a random subsample of images; it will also return class labels for the sample, specifically a class label of 1, to indicate real images.
# select real samples
def generate_real_samples(dataset, n_samples):
# choose random instances
ix = randint(0, dataset.shape[0], n_samples)
# retrieve selected images
X = dataset[ix]
# generate 'real' class labels (1)
y = ones((n_samples, 1))
return X, y
Now, we need a source of fake images.
We don’t have a generator model yet, so instead, we can generate images comprised of random pixel values, specifically random pixel values in the range [0,1] like our scaled real images.
The generate_fake_samples() function below implements this behavior and generates images of random pixel values and their associated class label of 0, for fake.
# generate n fake samples with class labels
def generate_fake_samples(n_samples):
# generate uniform random numbers in [0,1]
X = rand(28 * 28 * n_samples)
# reshape into a batch of grayscale images
X = X.reshape((n_samples, 28, 28, 1))
# generate 'fake' class labels (0)
y = zeros((n_samples, 1))
return X, y
Finally, we need to train the discriminator model.
This involves repeatedly retrieving samples of real images and samples of generated images and updating the model for a fixed number of iterations.
We will ignore the idea of epochs for now (e.g. complete passes through the training dataset) and fit the discriminator model for a fixed number of batches. The model will learn to discriminate between real and fake (randomly generated) images rapidly, therefore, not many batches will be required before it learns to discriminate perfectly.
The train_discriminator() function implements this, using a batch size of 256 images where 128 are real and 128 are fake each iteration.
We update the discriminator separately for real and fake examples so that we can calculate the accuracy of the model on each sample prior to the update. This gives insight into how the discriminator model is performing over time.
# train the discriminator model
def train_discriminator(model, dataset, n_iter=100, n_batch=256):
half_batch = int(n_batch / 2)
# manually enumerate epochs
for i in range(n_iter):
# get randomly selected 'real' samples
X_real, y_real = generate_real_samples(dataset, half_batch)
# update discriminator on real samples
_, real_acc = model.train_on_batch(X_real, y_real)
# generate 'fake' examples
X_fake, y_fake = generate_fake_samples(half_batch)
# update discriminator on fake samples
_, fake_acc = model.train_on_batch(X_fake, y_fake)
# summarize performance
print('>%d real=%.0f%% fake=%.0f%%' % (i+1, real_acc*100, fake_acc*100))
Tying all of this together, the complete example of training an instance of the discriminator model on real and randomly generated (fake) images is listed below.
# example of training the discriminator model on real and random mnist images
from numpy import expand_dims
from numpy import ones
from numpy import zeros
from numpy.random import rand
from numpy.random import randint
from keras.datasets.mnist import load_data
from keras.optimizers import Adam
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Conv2D
from keras.layers import Flatten
from keras.layers import Dropout
from keras.layers import LeakyReLU
# define the standalone discriminator model
def define_discriminator(in_shape=(28,28,1)):
model = Sequential()
model.add(Conv2D(64, (3,3), strides=(2, 2), padding='same', input_shape=in_shape))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.4))
model.add(Conv2D(64, (3,3), strides=(2, 2), padding='same'))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.4))
model.add(Flatten())
model.add(Dense(1, activation='sigmoid'))
# compile model
opt = Adam(lr=0.0002, beta_1=0.5)
model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])
return model
# load and prepare mnist training images
def load_real_samples():
# load mnist dataset
(trainX, _), (_, _) = load_data()
# expand to 3d, e.g. add channels dimension
X = expand_dims(trainX, axis=-1)
# convert from unsigned ints to floats
X = X.astype('float32')
# scale from [0,255] to [0,1]
X = X / 255.0
return X
# select real samples
def generate_real_samples(dataset, n_samples):
# choose random instances
ix = randint(0, dataset.shape[0], n_samples)
# retrieve selected images
X = dataset[ix]
# generate 'real' class labels (1)
y = ones((n_samples, 1))
return X, y
# generate n fake samples with class labels
def generate_fake_samples(n_samples):
# generate uniform random numbers in [0,1]
X = rand(28 * 28 * n_samples)
# reshape into a batch of grayscale images
X = X.reshape((n_samples, 28, 28, 1))
# generate 'fake' class labels (0)
y = zeros((n_samples, 1))
return X, y
# train the discriminator model
def train_discriminator(model, dataset, n_iter=100, n_batch=256):
half_batch = int(n_batch / 2)
# manually enumerate epochs
for i in range(n_iter):
# get randomly selected 'real' samples
X_real, y_real = generate_real_samples(dataset, half_batch)
# update discriminator on real samples
_, real_acc = model.train_on_batch(X_real, y_real)
# generate 'fake' examples
X_fake, y_fake = generate_fake_samples(half_batch)
# update discriminator on fake samples
_, fake_acc = model.train_on_batch(X_fake, y_fake)
# summarize performance
print('>%d real=%.0f%% fake=%.0f%%' % (i+1, real_acc*100, fake_acc*100))
# define the discriminator model
model = define_discriminator()
# load image data
dataset = load_real_samples()
# fit the model
train_discriminator(model, dataset)
Running the example first defines the model, loads the MNIST dataset, then trains the discriminator model.
Note: your specific results may vary given the stochastic nature of the learning algorithm. Consider running the example a few times.
In this case, the discriminator model learns to tell the difference between real and randomly generated MNIST images very quickly, in about 50 batches.
Now that we know how to define and train the discriminator model, we need to look at developing the generator model.
How to Define and Use the Generator Model
The generator model is responsible for creating new, fake but plausible images of handwritten digits.
It does this by taking a point from the latent space as input and outputting a square grayscale image.
The latent space is an arbitrarily defined vector space of Gaussian-distributed values, e.g. 100 dimensions. It has no meaning, but by drawing points from this space randomly and providing them to the generator model during training, the generator model will assign..