Conditional GAN (cGAN and pix2pix)
Introduction
Generative Adversarial Network is a network with two game players: a generator that generates fake images and a discriminator that spots the real from the fake images. The network is trained and ultimately they reach an equilibrium in which the generator generates images close to real art, and the discriminator cannot distinguish the forged version. This network results in sophisticated generator and the generator is usually used for generating real images, even art. There are many variants of GANs since the original paper.
The normal GAN doesn’t support generating images of a certain category, it only generates a general beautiful images, of any categories it was trained on. Conditional GAN is a variant of GANs that has an additional embeded vector y as input for both generator and discriminator. Trained in this way, a cGAN generator can generate images of a given category.
In cGAN, the value function becomes dependent on the input y (hence the word conditional):
\[min_G max_D V(D,G) = E_x {[log D(x \mid y)]} + E_z {[log(1-D(G(z \mid y)))]}\]Apart from generating images of a certain kind, cGAN paper also introduces an application in multi-modal learning, in which it generates tags (text) for image captioning. This is multimodal since the number of tags for one image can be more than one, and those words can also be synonym.
There are several ways we can add the label y to the generator:
-
As an embedding layer
-
Add as an additional channel to the images
-
keep embedding dimension low then upsample to match image size
Code example
Follows is an example of the discriminator and generator. When you run for MNIST for Fashion MNIST dataset or CIFAR-10, you need to adapt the input and some layers to fit the dimensions of the images (28x28x1 for MNISTs and 32x32x3 color images for CIFAR-10). MNISTs are black and white images (one channel, 28 pixel width and 28 pixel height) and CIFAR-10 are color images with 3 channels (RGB, and 32 pixel width and 32 pixel height). The following discriminator and generator are adapted for CIFAR-10. To adapt among different length of tensors, we can run the model.summary()
and fix the numbers accordingly.
# define the standalone discriminator model
def define_discriminator(in_shape=(32,32,3), n_classes=10):
# label input
in_label = Input(shape=(1,))
# embedding for categorical input
li = Embedding(n_classes, 50)(in_label)
# scale up to image dimensions with linear activation
n_nodes = in_shape[0] * in_shape[1]
li = Dense(n_nodes)(li)
# reshape to additional channel
li = Reshape((in_shape[0], in_shape[1], 1))(li)
# image input
in_image = Input(shape=in_shape)
# concat label as a channel
merge = Concatenate()([in_image, li])
# downsample
fe = Conv2D(64, (3,3), strides=(2,2), padding='same')(merge)
fe = LeakyReLU(alpha=0.2)(fe)
# downsample
fe = Conv2D(128, (3,3), strides=(2,2), padding='same')(fe)
fe = LeakyReLU(alpha=0.2)(fe)
# flatten feature maps
fe = Flatten()(fe)
# dropout
fe = Dropout(0.4)(fe)
# output
out_layer = Dense(1, activation='sigmoid')(fe)
# define model
model = Model([in_image, in_label], out_layer)
# compile model
opt = Adam(lr=0.0002, beta_1=0.5)
model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])
return model
# define the standalone generator model
def define_generator(latent_dim, n_classes=10):
# label input
in_label = Input(shape=(1,))
# embedding for categorical input
li = Embedding(n_classes, 50)(in_label)
# linear multiplication
n_nodes = 8 * 8
li = Dense(n_nodes)(li)
# reshape to additional channel
li = Reshape((8, 8, 1))(li)
# image generator input
in_lat = Input(shape=(latent_dim,))
# foundation for 7x7 image
n_nodes = 128 * 8 * 8
gen = Dense(n_nodes)(in_lat)
gen = LeakyReLU(alpha=0.2)(gen)
gen = Reshape((8, 8, 128))(gen)
# merge image gen and label input
merge = Concatenate()([gen, li])
# upsample to 14x14
gen = Conv2DTranspose(64, (4,4), strides=(2,2), padding='same')(merge)
gen = LeakyReLU(alpha=0.2)(gen)
# upsample to 28x28
gen = Conv2DTranspose(32, (4,4), strides=(2,2), padding='same')(gen)
gen = LeakyReLU(alpha=0.2)(gen)
# output
out_layer = Conv2D(3, (8,8), activation='tanh', padding='same')(gen)
# define model
model = Model([in_lat, in_label], out_layer)
return model
latent_dim = 100
d_model = define_discriminator()
d_model.summary()
g_model = define_generator(latent_dim)
g_model.summary()
Model: “model_10” (discriminator)
Layer (type) | Output Shape | Param # | Connected to |
---|---|---|---|
input_17 (InputLayer) | [(None, 1)] | 0 | [] |
embedding_8 (Embedding) | (None, 1, 50) | 500 | [‘input_17[0][0]’] |
dense_16 (Dense) | (None, 1, 1024) | 52224 | [‘embedding_8[0][0]’] |
input_18 (InputLayer) | [(None, 32, 32, 3)] | 0 | [] |
reshape_11 (Reshape) | (None, 32, 32, 1) | 0 | [‘dense_16[0][0]’] |
concatenate_8 (Concatenate) | (None, 32, 32, 4) | 0 | [‘input_18[0][0]’,’reshape_11[0][0]’] |
conv2d_13 (Conv2D) | (None, 16, 16, 64) | 2368 | [‘concatenate_8[0][0]’] |
leaky_re_lu_19 (LeakyReLU) | (None, 16, 16, 64) | 0 | [‘conv2d_13[0][0]’] |
conv2d_14 (Conv2D) | (None, 8, 8, 128) | 73856 | [‘leaky_re_lu_19[0][0]’] |
leaky_re_lu_20 (LeakyReLU) | (None, 8, 8, 128) | 0 | [‘conv2d_14[0][0]’] |
flatten_5 (Flatten) | (None, 8192) | 0 | [‘leaky_re_lu_20[0][0]’] |
dropout_5 (Dropout) | (None, 8192) | 0 | [‘flatten_5[0][0]’] |
dense_17 (Dense) | (None, 1) | 8193 | [‘dropout_5[0][0]’] |
Total params: 137,141
Trainable params: 137,141
Non-trainable params: 0
Model: “model_11” (generator)
Layer (type) | Output Shape | Param # | Connected to |
---|---|---|---|
input_20 (InputLayer) | [(None, 100)] | 0 | [] |
input_19 (InputLayer) | [(None, 1)] | 0 | [] |
dense_19 (Dense) | (None, 8192) | 827392 | [‘input_20[0][0]’] |
embedding_9 (Embedding) | (None, 1, 50) | 500 | [‘input_19[0][0]’] |
leaky_re_lu_21 (LeakyReLU) | (None, 8192) | 0 | [‘dense_19[0][0]’] |
dense_18 (Dense) | (None, 1, 64) | 3264 | [‘embedding_9[0][0]’] |
reshape_13 (Reshape) | (None, 8, 8, 128) | 0 | [‘leaky_re_lu_21[0][0]’] |
reshape_12 (Reshape) | (None, 8, 8, 1) | 0 | [‘dense_18[0][0]’] |
concatenate_9 (Concatenate) | (None, 8, 8, 129) | 0 | [‘reshape_13[0][0]’,’reshape_12[0][0]’] |
conv2d_transpose_6 (Conv2DTranspose) | (None, 16, 16, 64) | 132160 | [‘concatenate_9[0][0]’] |
leaky_re_lu_22 (LeakyReLU) | (None, 16, 16, 64) | 0 | [‘conv2d_transpose_6[0][0]’] |
conv2d_transpose_7 (Conv2DTranspose) | (None, 32, 32, 32) | 32800 | [‘leaky_re_lu_22[0][0]’] |
leaky_re_lu_23 (LeakyReLU) | (None, 32, 32, 32) | 0 | [‘conv2d_transpose_7[0][0]’] |
conv2d_15 (Conv2D) | (None, 32, 32, 3) | 6147 | [‘leaky_re_lu_23[0][0]’] |
Total params: 1,002,263
Trainable params: 1,002,263
Non-trainable params: 0
Note that the followings are after only 10 epochs:
MNIST:
FASHION MNIST:
CIFAR:
Pix2Pix GAN
Pix2Pix is a conditional GAN architecture, used to translate image to image, for example: using a certain style (Monet, Van Gogh, ..), or from summer to winter time.
When a GAN goes from noise \(G: z \rightarrow y\), a conditional GAN goes from noise and input \(G: \{x,y\} \rightarrow y\).
In Pix2Pix paper, the objective function is modified. They add a regularization term: the L1 distance for the generator (how far the generated image is from the original one). They use L1 since L2 (euclidean distance) tends to tell to average all the pixel values to minimize it, hence encourage blurring of the images.
\[Loss_{L_1}(G) = E{[\mid \mid y-G(x,z)\mid \mid_1]}\]This term is added into the final objective with a scaling factor \(\alpha\):
\[G^* = arg min_G max_D V(G,D) + \alpha Loss_{L_1} (G)\]Architecture
In the original model, they don’t use noise, only dropout. The reason is that in their experiments, it doesn’t depend on the initial random point of the latent space. For both the generator and discriminator, they use modules of Convolution-BatchNorm-ReLU. They also utilize skip connection in their model of the generator. A skip connection simply means that we concanate everything at layer i to those at layer j. Like in the UNET architecture with encoder and decoder to reduce the information and then expand it, they use symetric layers to skip connections.
Since L1 is to force low-frequency correctness (pixels that locate on the corresponding positions should look like each other), this blurs the image. The authors use a technique to enforce high frequencies (to return a crisp image), that only look at the structure in local patches. This is called PatchGAN. This term only penalizes structure at level of a patch. So the discriminator only try to classify if averagely speaking, all the patches in the images are real. This idea is like a Markov random field: assuming that pixels outside their patch are independent from them. Small patching still contribute to high quality results, and reduces the computation.
Some other minor notes, they try to maximize the rate of discriminator instead of generator. But they also divide by two that rate so that the discriminator learns slowly. They use minibatch SGD with Adam solver, learning rate of 0.0002. One difficulty with evaluating synthetic images is that our usual euclidean distance doesn’t really work. Since the mean square only measure the total distance, it doesn’t capture the spatial concept (images are in 2 dimensions and each pixel location has its imporance).
Here is the discriminator:
Source: https://machinelearningmastery.com/how-to-implement-pix2pix-gan-models-from-scratch-with-keras/
The generator, with encoder and decoder blocks:
Source: https://machinelearningmastery.com/how-to-implement-pix2pix-gan-models-from-scratch-with-keras/
# example of defining a composite model for training the generator model
from keras.optimizers import Adam
from keras.initializers import RandomNormal
from keras.models import Model
from keras.models import Input
from keras.layers import Conv2D
from keras.layers import Conv2DTranspose
from keras.layers import LeakyReLU
from keras.layers import Activation
from keras.layers import Concatenate
from keras.layers import Dropout
from keras.layers import BatchNormalization
from keras.layers import LeakyReLU
from keras.utils.vis_utils import plot_model
# define the discriminator model
def define_discriminator(image_shape):
# weight initialization
init = RandomNormal(stddev=0.02)
# source image input
in_src_image = Input(shape=image_shape)
# target image input
in_target_image = Input(shape=image_shape)
# concatenate images channel-wise
merged = Concatenate()([in_src_image, in_target_image])
# C64
d = Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(merged)
d = LeakyReLU(alpha=0.2)(d)
# C128
d = Conv2D(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
d = BatchNormalization()(d)
d = LeakyReLU(alpha=0.2)(d)
# C256
d = Conv2D(256, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
d = BatchNormalization()(d)
d = LeakyReLU(alpha=0.2)(d)
# C512
d = Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
d = BatchNormalization()(d)
d = LeakyReLU(alpha=0.2)(d)
# second last output layer
d = Conv2D(512, (4,4), padding='same', kernel_initializer=init)(d)
d = BatchNormalization()(d)
d = LeakyReLU(alpha=0.2)(d)
# patch output
d = Conv2D(1, (4,4), padding='same', kernel_initializer=init)(d)
patch_out = Activation('sigmoid')(d)
# define model
model = Model([in_src_image, in_target_image], patch_out)
# compile model
opt = Adam(lr=0.0002, beta_1=0.5)
model.compile(loss='binary_crossentropy', optimizer=opt, loss_weights=[0.5])
return model
# define an encoder block
def define_encoder_block(layer_in, n_filters, batchnorm=True):
# weight initialization
init = RandomNormal(stddev=0.02)
# add downsampling layer
g = Conv2D(n_filters, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(layer_in)
# conditionally add batch normalization
if batchnorm:
g = BatchNormalization()(g, training=True)
# leaky relu activation
g = LeakyReLU(alpha=0.2)(g)
return g
# define a decoder block
def decoder_block(layer_in, skip_in, n_filters, dropout=True):
# weight initialization
init = RandomNormal(stddev=0.02)
# add upsampling layer
g = Conv2DTranspose(n_filters, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(layer_in)
# add batch normalization
g = BatchNormalization()(g, training=True)
# conditionally add dropout
if dropout:
g = Dropout(0.5)(g, training=True)
# merge with skip connection
g = Concatenate()([g, skip_in])
# relu activation
g = Activation('relu')(g)
return g
# define the standalone generator model
def define_generator(image_shape=(256,256,3)):
# weight initialization
init = RandomNormal(stddev=0.02)
# image input
in_image = Input(shape=image_shape)
# encoder model: C64-C128-C256-C512-C512-C512-C512-C512
e1 = define_encoder_block(in_image, 64, batchnorm=False)
e2 = define_encoder_block(e1, 128)
e3 = define_encoder_block(e2, 256)
e4 = define_encoder_block(e3, 512)
e5 = define_encoder_block(e4, 512)
e6 = define_encoder_block(e5, 512)
e7 = define_encoder_block(e6, 512)
# bottleneck, no batch norm and relu
b = Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(e7)
b = Activation('relu')(b)
# decoder model: CD512-CD1024-CD1024-C1024-C1024-C512-C256-C128
d1 = decoder_block(b, e7, 512)
d2 = decoder_block(d1, e6, 512)
d3 = decoder_block(d2, e5, 512)
d4 = decoder_block(d3, e4, 512, dropout=False)
d5 = decoder_block(d4, e3, 256, dropout=False)
d6 = decoder_block(d5, e2, 128, dropout=False)
d7 = decoder_block(d6, e1, 64, dropout=False)
# output
g = Conv2DTranspose(3, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d7)
out_image = Activation('tanh')(g)
# define model
model = Model(in_image, out_image)
return model
# define the combined generator and discriminator model, for updating the generator
def define_gan(g_model, d_model, image_shape):
# make weights in the discriminator not trainable
for layer in d_model.layers:
if not isinstance(layer, BatchNormalization):
layer.trainable = False
# define the source image
in_src = Input(shape=image_shape)
# connect the source image to the generator input
gen_out = g_model(in_src)
# connect the source input and generator output to the discriminator input
dis_out = d_model([in_src, gen_out])
# src image as input, generated image and classification output
model = Model(in_src, [dis_out, gen_out])
# compile model
opt = Adam(lr=0.0002, beta_1=0.5)
model.compile(loss=['binary_crossentropy', 'mae'], optimizer=opt, loss_weights=[1,100])
return model