CycleGAN and Star(Cycle)GAN
Introduction
The first difference between CycleGAN and other GANs is that it is unpaired image translation. Unpaired means that the input images don’t look like each other, the translation part only takes care of transfering the style. So it is a style transfering GAN, but the mechanism is different from StyleGAN.
The second difference is in the translating mechanism. It has two discriminators
The adversarial loss becomes 3 losses:
- For the function G that converts X to Y:
with G: X
- For the function F that converts Y to X:
with F: Y
- The cycle consistency loss encourages forward and backward cycle consistency:
The full objective function is:
The authors use a similar to style transferring architecture for the generator: three convolutions, some residual blocks, two convolutions with stride
During training, though, a least squares loss is replacement for the negative log likelihood. In particular, G is trained to minimize
Code example
We use the provided dataset of horses and zebras in the tensorflow datasets. It is a mix of real images of horses in zebras in different settings. And the generator use the structure of a ResNet. ResNet is short for Residual Network in which they concatenate input directly to a later layer to provide the extra and original feature maps to later layer. This is done because the information is warped / transformed after many convoluted layers, to send the original structuring as additional information to the later layer, they simply let the input skip several connections and concanate directly to the later layer they wish.
Input
class ReflectionPadding2D(layers.Layer):
"""Implements Reflection Padding as a layer.
Args:
padding(tuple): Amount of padding for the
spatial dimensions.
Returns:
A padded tensor with the same type as the input tensor.
"""
def __init__(self, padding=(1, 1), **kwargs):
self.padding = tuple(padding)
super().__init__(**kwargs)
def call(self, input_tensor, mask=None):
padding_width, padding_height = self.padding
padding_tensor = [
[0, 0],
[padding_height, padding_height],
[padding_width, padding_width],
[0, 0],
]
return tf.pad(input_tensor, padding_tensor, mode="REFLECT")
def residual_block(
x,
activation,
kernel_initializer=kernel_init,
kernel_size=(3, 3),
strides=(1, 1),
padding="valid",
gamma_initializer=gamma_init,
use_bias=False,
):
dim = x.shape[-1]
input_tensor = x
x = ReflectionPadding2D()(input_tensor)
x = layers.Conv2D(
dim,
kernel_size,
strides=strides,
kernel_initializer=kernel_initializer,
padding=padding,
use_bias=use_bias,
)(x)
x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
x = activation(x)
x = ReflectionPadding2D()(x)
x = layers.Conv2D(
dim,
kernel_size,
strides=strides,
kernel_initializer=kernel_initializer,
padding=padding,
use_bias=use_bias,
)(x)
x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
x = layers.add([input_tensor, x])
return x
def downsample(
x,
filters,
activation,
kernel_initializer=kernel_init,
kernel_size=(3, 3),
strides=(2, 2),
padding="same",
gamma_initializer=gamma_init,
use_bias=False,
):
x = layers.Conv2D(
filters,
kernel_size,
strides=strides,
kernel_initializer=kernel_initializer,
padding=padding,
use_bias=use_bias,
)(x)
x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
if activation:
x = activation(x)
return x
def upsample(
x,
filters,
activation,
kernel_size=(3, 3),
strides=(2, 2),
padding="same",
kernel_initializer=kernel_init,
gamma_initializer=gamma_init,
use_bias=False,
):
x = layers.Conv2DTranspose(
filters,
kernel_size,
strides=strides,
padding=padding,
kernel_initializer=kernel_initializer,
use_bias=use_bias,
)(x)
x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
if activation:
x = activation(x)
return x
def get_resnet_generator(
filters=64,
num_downsampling_blocks=2,
num_residual_blocks=9,
num_upsample_blocks=2,
gamma_initializer=gamma_init,
name=None,
):
img_input = layers.Input(shape=input_img_size, name=name + "_img_input")
x = ReflectionPadding2D(padding=(3, 3))(img_input)
x = layers.Conv2D(filters, (7, 7), kernel_initializer=kernel_init, use_bias=False)(
x
)
x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
x = layers.Activation("relu")(x)
# Downsampling
for _ in range(num_downsampling_blocks):
filters *= 2
x = downsample(x, filters=filters, activation=layers.Activation("relu"))
# Residual blocks
for _ in range(num_residual_blocks):
x = residual_block(x, activation=layers.Activation("relu"))
# Upsampling
for _ in range(num_upsample_blocks):
filters //= 2
x = upsample(x, filters, activation=layers.Activation("relu"))
# Final block
x = ReflectionPadding2D(padding=(3, 3))(x)
x = layers.Conv2D(3, (7, 7), padding="valid")(x)
x = layers.Activation("tanh")(x)
model = keras.models.Model(img_input, x, name=name)
return model
def get_discriminator(
filters=64, kernel_initializer=kernel_init, num_downsampling=3, name=None
):
img_input = layers.Input(shape=input_img_size, name=name + "_img_input")
x = layers.Conv2D(
filters,
(4, 4),
strides=(2, 2),
padding="same",
kernel_initializer=kernel_initializer,
)(img_input)
x = layers.LeakyReLU(0.2)(x)
num_filters = filters
for num_downsample_block in range(3):
num_filters *= 2
if num_downsample_block < 2:
x = downsample(
x,
filters=num_filters,
activation=layers.LeakyReLU(0.2),
kernel_size=(4, 4),
strides=(2, 2),
)
else:
x = downsample(
x,
filters=num_filters,
activation=layers.LeakyReLU(0.2),
kernel_size=(4, 4),
strides=(1, 1),
)
x = layers.Conv2D(
1, (4, 4), strides=(1, 1), padding="same", kernel_initializer=kernel_initializer
)(x)
model = keras.models.Model(inputs=img_input, outputs=x, name=name)
return model
# Get the generators
gen_G = get_resnet_generator(name="generator_G")
gen_F = get_resnet_generator(name="generator_F")
# Get the discriminators
disc_X = get_discriminator(name="discriminator_X")
disc_Y = get_discriminator(name="discriminator_Y")
Generator F
Generator G
Discriminator X
Discriminator Y
Results
StarGAN
For a normal CycleGAN, we need to build an independent model for each pair of image domains. StarGAN is a CycleGAN that can handle multidomain image translation. An image can have multidomain such as: white hair, blonde, with hat, wearing glasses, happy, sad, etc. For usual cycleGAN, to train for each pair of domains, we need one model, for k attributes, we need k(k-1) models. To achive multidomain translation, when training generator G to translate input image x to output image y a domain label target c is added:
The adversarial loss becomes:
The domain classification loss for real images is:
The domain classification loss for fake images is:
When we minimize the adversarial and classification loss, G will generate realistic images with correct classification. To enforce consistency, the authors use a cycle consistency loss for the generator:
All of these losses combine force and give us objective functions for G and D:
In their experiments, they set
For training, the generator has two convolutional layers with stride of two (for downsampling), six residual blocks and two transposed convolutional layers with stride of two (for upsampling). PatchGANs were used for the discriminator. StarGAN is trained on multiple datasets, with different features, so the author creates a mask vector that specify when to focus on which dataset.
The following example shows the result on a pretrained model. Notice that the middle image input is a drawing and it has a hand next to the face, making it a bit more difficult to emulate.
Code example
We clone this repo and here is the results: