CycleGAN Tensorflow Example For Style Translation
In this post, we’re going to take a look at a CycleGAN example for image-to-image translation using Tensorflow. But, first of all, let’s talk about what CycleGAN even is.
Basically, it’s a variation of a generative adversarial network (GAN), which takes an additional input to generate resulting data. In other words, we can give it a photograph as input, and it will transform it to look in a certain style of painting.
Furthermore, creating such a model is exactly, what we’re aiming to do here. Additionally, we’ll be working with a set of images of Monet paintings and a set of photographs.
CycleGAN, unlike a regular GAN, requires 2 generator and 2 discriminator models. Reason behind this is so it’ll be able to transform from a photo to monet style and back to original photo. Hence the “Cycle” in the name.
Coding CycleGAN using Tensorflow
Firstly, like in all the other ML examples, we need to import all the necessary libraries, which includes Tensorflow, Keras, Numpy, and more.
import tensorflow as tf
import tensorflow_addons as tfa
import keras
from keras import layers
import numpy as np
from glob import glob
import os
import PIL
Next thing we need to do is declare a few parameters, which we’ll use throughout the whole script. Furthermore, these include paths to our locally stored dataset, image side size, number of output color channels and others.
# constants
root = os.path.dirname(__file__)
MONET_FILES = glob(os.path.join(root, 'data', 'monet_tfrec', '*.tfrec'))
PHOTO_FILES = glob(os.path.join(root, 'data', 'photo_tfrec', '*.tfrec'))
AUTOTUNE = tf.data.AUTOTUNE
IMAGE_SIZE = 256
OUTPUT_CHANNELS = 3
LAMBDA_CYCLE = 10
Next step is to define functions that will preprocess and load the datasets. Furthermore, in this example, we’re working with TFRecords.
# decode individual image
def load_image(example):
tfrecord_format = {
'image_name': tf.io.FixedLenFeature([], tf.string),
'image': tf.io.FixedLenFeature([], tf.string),
'target': tf.io.FixedLenFeature([], tf.string)
}
example = tf.io.parse_single_example(example, tfrecord_format)
image = tf.image.decode_jpeg(example['image'], channels=3)
image = (tf.cast(image, tf.float32) / 127.5) - 1
image = tf.reshape(image, [IMAGE_SIZE, IMAGE_SIZE, 3])
return image
# load dataset from TFRecord files
def load_dataset(files):
dataset = tf.data.TFRecordDataset(files)
dataset = dataset.map(load_image, num_parallel_calls=AUTOTUNE)
return dataset
Okay, now we’re ready for building the models for generators and discriminators. Therefore, we’re going to define functions that will output each type of model architecture. However, we’re going to define functions that will output downsampling and upsampling groups of layers first.
Furthermore, generators in CycleGAN are basically autoencoders, which have a hourglass kind of architecture. This also allows them to accept images as input, whereas with traditional GANs which generate input with random noise.
# group of layers for downsampling
def downsample(filters, size, apply_instancenorm=True):
initializer = tf.random_normal_initializer(0., 0.02)
gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)
result = keras.Sequential()
result.add(layers.Conv2D(filters, size, strides=2, padding='same', kernel_initializer=initializer, use_bias=False))
if apply_instancenorm:
result.add(tfa.layers.InstanceNormalization(gamma_initializer=gamma_init))
result.add(layers.LeakyReLU())
return result
# group of layers for upsampling
def upsample(filters, size, apply_dropout=False):
initializer = tf.random_normal_initializer(0., 0.02)
gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)
result = keras.Sequential()
result.add(layers.Conv2DTranspose(filters, size, strides=2, padding='same', kernel_initializer=initializer, use_bias=False))
result.add(tfa.layers.InstanceNormalization(gamma_initializer=gamma_init))
if apply_dropout:
result.add(layers.Dropout(0.5))
result.add(layers.ReLU())
return result
# function for creating a generator model
def create_generator_model():
inputs = layers.Input(shape=[IMAGE_SIZE, IMAGE_SIZE, 3])
down_stack = [
downsample(64, 4, apply_instancenorm=False),
downsample(128, 4),
downsample(256, 4),
downsample(512, 4),
downsample(512, 4),
downsample(512, 4),
downsample(512, 4),
downsample(512, 4),
]
up_stack = [
upsample(512, 4, apply_dropout=True),
upsample(512, 4, apply_dropout=True),
upsample(512, 4, apply_dropout=True),
upsample(512, 4),
upsample(256, 4),
upsample(128, 4),
upsample(64, 4),
]
initializer = tf.random_normal_initializer(0., 0.02)
last = layers.Conv2DTranspose(OUTPUT_CHANNELS, 4, strides=2, padding='same', kernel_initializer=initializer, activation='tanh')
x = inputs
skips = []
for down in down_stack:
x = down(x)
skips.append(x)
skips = reversed(skips[:-1])
for up, skip in zip(up_stack, skips):
x = up(x)
x = layers.Concatenate()([x, skip])
x = last(x)
return keras.Model(inputs=inputs, outputs=x)
# function for creating discriminator model
def create_discriminator_model():
initializer = tf.random_normal_initializer(0., 0.02)
gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)
inputs = layers.Input(shape=[IMAGE_SIZE, IMAGE_SIZE, 3], name='input_image')
x = inputs
x = downsample(64, 4, False)(x)
x = downsample(128, 4)(x)
x = downsample(256, 4)(x)
x = layers.ZeroPadding2D()(x)
x = layers.Conv2D(512, 4, strides=1, kernel_initializer=initializer, use_bias=False)(x)
x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_init)(x)
x = layers.LeakyReLU()(x)
x = layers.ZeroPadding2D()(x)
x = layers.Conv2D(1, 4, strides=1, kernel_initializer=initializer)(x)
return keras.Model(inputs=inputs, outputs=x)
Loss functions
This type of model requires a bit of finessing when it comes to applying loss functions. Unlike ordinary GANs, where we only need generator and discriminator loss, this one also needs cycle consistency loss.
In the following section, we’ll define all the necessary loss functions, we’ll be using for training our model.
# define necessary loss functions
def generator_loss(generated):
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)
loss = cross_entropy(tf.ones_like(generated), generated)
return loss
def discriminator_loss(real, generated):
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)
real_loss = cross_entropy(tf.ones_like(real), real)
generated_loss = cross_entropy(tf.zeros_like(generated), generated)
loss = real_loss + generated_loss
return loss * 0.5
def cycle_loss(real_image, cycled_image):
loss = tf.reduce_mean(tf.abs(real_image - cycled_image))
return loss * LAMBDA_CYCLE
def identity_loss(real_image, same_image):
loss = tf.reduce_mean(tf.abs(real_image - same_image))
return loss * LAMBDA_CYCLE * 0.5
Tensorflow CycleGAN class
Now, we’re finally ready to define the CycleGAN model class, where we’re also going to override train_step and compile functions.
# define cycle GAN model class
class CycleGAN(keras.Model):
def __init__(self, monet_generator, photo_generator, monet_discriminator, photo_discriminator):
super(CycleGAN, self).__init__()
self.monet_generator = monet_generator
self.photo_generator = photo_generator
self.monet_discriminator = monet_discriminator
self.photo_discriminator = photo_discriminator
def compile(self, monet_gen_optimizer, photo_gen_optimizer, monet_disc_optimizer, photo_disc_optimizer,
gen_loss_fn, disc_loss_fn, cycle_loss_fn, identity_loss_fn):
super(CycleGAN, self).compile()
self.monet_gen_optimizer = monet_gen_optimizer
self.photo_gen_optimizer = photo_gen_optimizer
self.monet_disc_optimizer = monet_disc_optimizer
self.photo_disc_optimizer = photo_disc_optimizer
self.gen_loss_fn = gen_loss_fn
self.disc_loss_fn = disc_loss_fn
self.cycle_loss_fn = cycle_loss_fn
self.identity_loss_fn = identity_loss_fn
def train_step(self, batch_data):
real_monet, real_photo = batch_data
with tf.GradientTape(persistent=True) as tape:
fake_monet = self.monet_generator(real_photo, training=True)
cycled_photo = self.photo_generator(fake_monet, training=True)
fake_photo = self.photo_generator(real_monet, training=True)
cycled_monet = self.monet_generator(fake_photo, training=True)
same_monet = self.monet_generator(real_monet, training=True)
same_photo = self.photo_generator(real_photo, training=True)
disc_real_monet = self.monet_discriminator(real_monet, training=True)
disc_real_photo = self.photo_discriminator(real_photo, training=True)
disc_fake_monet = self.monet_discriminator(fake_monet, training=True)
disc_fake_photo = self.photo_discriminator(fake_photo, training=True)
monet_gen_loss = self.gen_loss_fn(disc_fake_monet)
photo_gen_loss = self.gen_loss_fn(disc_fake_photo)
monet_cycle_loss = self.cycle_loss_fn(real_monet, cycled_monet)
photo_cycle_loss = self.cycle_loss_fn(real_photo, cycled_photo)
total_cycle_loss = monet_cycle_loss + photo_cycle_loss
monet_identity_loss = self.identity_loss_fn(real_monet, same_monet)
photo_identity_loss = self.identity_loss_fn(real_photo, same_photo)
total_monet_gen_loss = monet_gen_loss + total_cycle_loss + monet_identity_loss
total_photo_gen_loss = photo_gen_loss + total_cycle_loss + photo_identity_loss
monet_disc_loss = self.disc_loss_fn(disc_real_monet, disc_fake_monet)
photo_disc_loss = self.disc_loss_fn(disc_real_photo, disc_fake_photo)
monet_gen_gradients = tape.gradient(total_monet_gen_loss, self.monet_generator.trainable_variables)
photo_gen_gradients = tape.gradient(total_photo_gen_loss, self.photo_generator.trainable_variables)
monet_disc_gradients = tape.gradient(monet_disc_loss, self.monet_discriminator.trainable_variables)
photo_disc_gradients = tape.gradient(photo_disc_loss, self.photo_discriminator.trainable_variables)
self.monet_gen_optimizer.apply_gradients(zip(monet_gen_gradients, self.monet_generator.trainable_variables))
self.photo_gen_optimizer.apply_gradients(zip(photo_gen_gradients, self.photo_generator.trainable_variables))
self.monet_disc_optimizer.apply_gradients(zip(monet_disc_gradients, self.monet_discriminator.trainable_variables))
self.photo_disc_optimizer.apply_gradients(zip(photo_disc_gradients, self.photo_discriminator.trainable_variables))
return {
"monet_gen_loss": total_monet_gen_loss,
"photo_gen_loss": total_photo_gen_loss,
"monet_disc_loss": monet_disc_loss,
"photo_disc_loss": photo_disc_loss
}
In the next step, we put all this together, where we load the data, declare models, and compile and train the CycleGAN.
# load and preprocess data
monet_ds = load_dataset(MONET_FILES).batch(1)
photo_ds = load_dataset(PHOTO_FILES).batch(1)
# define models
monet_generator = create_generator_model()
photo_generator = create_generator_model()
monet_discriminator = create_discriminator_model()
photo_discriminator = create_discriminator_model()
# define optimizers
monet_gen_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
photo_gen_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
monet_disc_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
photo_disc_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
# create cycle GAN model, compile, and train it
model = CycleGAN(
monet_generator=monet_generator,
photo_generator=photo_generator,
monet_discriminator=monet_discriminator,
photo_discriminator=photo_discriminator
)
model.compile(
monet_gen_optimizer=monet_gen_optimizer,
photo_gen_optimizer=photo_gen_optimizer,
monet_disc_optimizer=monet_disc_optimizer,
photo_disc_optimizer=photo_disc_optimizer,
gen_loss_fn=generator_loss,
disc_loss_fn=discriminator_loss,
cycle_loss_fn=cycle_loss,
identity_loss_fn=identity_loss,
)
model.fit(
tf.data.Dataset.zip((monet_ds, photo_ds)),
epochs=25
)
In order to see the results, we’re going to use the entire dataset of photographs to transform them and save them locally.
# transform all photographs and save them locally
i = 1
for img in photo_ds:
prediction = monet_generator(img, training=False)[0].numpy()
prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
im = PIL.Image.fromarray(prediction)
im.save(os.path.join('transformed', str(i) + '.jpg'))
i += 1
Conclusion
To conclude, we demonstrated how to create a CycleGAN model with Tensorflow and trained it on Monet painting images and real photographs.
I hope this article helped you gain a better understanding of how it all works, and perhaps even lead you to create a model that generates amazing results.