In a different blog post, we studied the concept of a Variational Autoencoder (or VAE) in detail. The models, which are generative, can be used to manipulate datasets by learning the distribution of this input data.
But there's a difference between theory and practice. While it's always nice to understand neural networks in theory, it's always even more fun to actually create them with a particular framework. It makes them really usable.
Today, we'll use the Keras deep learning framework to create a convolutional variational autoencoder. We subsequently train it on the MNIST dataset, and also show you what our latent space looks like as well as new samples generated from the latent space.
But first, let's take a look at what VAEs are.
Are you ready?
Let's go! π
Update 17/08/2020: added a fix for an issue with vae.fit().
If you are already familiar with variational autoencoders or wish to find the implementation straight away, I'd suggest to skip this section. In any other case, it may be worth the read.
Contrary to a normal autoencoder, which learns to encode some input into a point in latent space, Variational Autoencoders (VAEs) learn to encode multivariate probability distributions into latent space, given their configuration usually Gaussian ones:
Sampling from the distribution gives a point in latent space that, given the distribution, is oriented around some mean value \(\mu\) and standard deviation \(\sigma\), like the points in this two-dimensional distribution:
Combining this with a Kullback-Leibler divergence segment in the loss function leads to a latent space that is both continuous and complete: for every point sampled close to the distribution's mean and standard deviation (which is in our case the standard normal distribution) the output should be both similar to samples around that sample and should make sense.
Continuity and completeness in the latent space.
Besides the regular stuff one can do with an autoencoder (like denoising and dimensionality reduction), the principles of a VAE outlined above allow us to use variational autoencoders for generative purposes.
Samples generated with a VAE trained on the Fashion MNIST dataset.
I would really recommend my blog "What is a Variational Autoencoder (VAE)?" if you are interested in understanding VAEs in more detail. However, based on the high-level recap above, I hope that you now both understand (1) how VAEs work at a high level and (2) what this allows you to do with them: using them for generative purposes.
Let's now take a look at how we will use VAEs today π
Today, we'll use the Keras deep learning framework for creating a VAE. It consists of three individual parts: the encoder, the decoder and the VAE as a whole. We do so using the Keras Functional API, which allows us to combine layers very easily.
The MNIST dataset will be used for training the autoencoder. This dataset contains thousands of 28 x 28 pixel images of handwritten digits, as we can see below. As such, our autoencoder will learn the distribution of handwritten digits across (two)dimensional latent space, which we can then use to manipulate samples into a format we like.
Samples from the MNIST dataset
This is the structure of the encoder:
Model: "encoder"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
encoder_input (InputLayer) (None, 28, 28, 1) 0
__________________________________________________________________________________________________
conv2d_1 (Conv2D) (None, 14, 14, 8) 80 encoder_input[0][0]
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 14, 14, 8) 32 conv2d_1[0][0]
__________________________________________________________________________________________________
conv2d_2 (Conv2D) (None, 7, 7, 16) 1168 batch_normalization_1[0][0]
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, 7, 7, 16) 64 conv2d_2[0][0]
__________________________________________________________________________________________________
flatten_1 (Flatten) (None, 784) 0 batch_normalization_2[0][0]
__________________________________________________________________________________________________
dense_1 (Dense) (None, 20) 15700 flatten_1[0][0]
__________________________________________________________________________________________________
batch_normalization_3 (BatchNor (None, 20) 80 dense_1[0][0]
__________________________________________________________________________________________________
latent_mu (Dense) (None, 2) 42 batch_normalization_3[0][0]
__________________________________________________________________________________________________
latent_sigma (Dense) (None, 2) 42 batch_normalization_3[0][0]
__________________________________________________________________________________________________
z (Lambda) (None, 2) 0 latent_mu[0][0]
latent_sigma[0][0]
==================================================================================================
Total params: 17,208
Trainable params: 17,120
Non-trainable params: 88
And the decoder:
__________________________________________________________________________________________________
Model: "decoder"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
decoder_input (InputLayer) (None, 2) 0
_________________________________________________________________
dense_2 (Dense) (None, 784) 2352
_________________________________________________________________
batch_normalization_4 (Batch (None, 784) 3136
_________________________________________________________________
reshape_1 (Reshape) (None, 7, 7, 16) 0
_________________________________________________________________
conv2d_transpose_1 (Conv2DTr (None, 14, 14, 16) 2320
_________________________________________________________________
batch_normalization_5 (Batch (None, 14, 14, 16) 64
_________________________________________________________________
conv2d_transpose_2 (Conv2DTr (None, 28, 28, 8) 1160
_________________________________________________________________
batch_normalization_6 (Batch (None, 28, 28, 8) 32
_________________________________________________________________
decoder_output (Conv2DTransp (None, 28, 28, 1) 73
=================================================================
Total params: 9,137
Trainable params: 7,521
Non-trainable params: 1,616
And, finally, the VAE as a whole:
_________________________________________________________________
Model: "vae"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
encoder_input (InputLayer) (None, 28, 28, 1) 0
_________________________________________________________________
encoder (Model) [(None, 2), (None, 2), (N 17208
_________________________________________________________________
decoder (Model) (None, 28, 28, 1) 9137
=================================================================
Total params: 26,345
Trainable params: 24,641
Non-trainable params: 1,704
From the final summary, we can see that indeed, the VAE takes in samples of shape \((28, 28, 1)\) and returns samples in the same format. Great! π
Let's now start working on our model. Open up your Explorer / Finder, navigate to some folder, and create a new Python file, e.g. variational_autoencoder.py
. Now, open this file in your code editor, and let's start coding! π
Before we begin, it's important that you ensure that you have all the required dependencies installed on your system:
Let's now declare everything that we will import:
Model
container from keras.models
. This allows us to instantiate the models eventually.mnist
dataset is the dataset we'll be training our VAE on.binary_crossentropy
, we can compute reconstruction loss.K
) contains calls for tensor manipulations, which we'll use.This is the code that includes our imports:
'''
Variational Autoencoder (VAE) with the Keras Functional API.
'''
import keras
from keras.layers import Conv2D, Conv2DTranspose, Input, Flatten, Dense, Lambda, Reshape
from keras.layers import BatchNormalization
from keras.models import Model
from keras.datasets import mnist
from keras.losses import binary_crossentropy
from keras import backend as K
import numpy as np
import matplotlib.pyplot as plt
Next thing: importing the MNIST dataset. Since MNIST is part of the Keras Datasets, we can import it easily - by calling mnist.load_data()
. Love the Keras simplicity!
# Load MNIST dataset
(input_train, target_train), (input_test, target_test) = mnist.load_data()
Importing the data is followed by setting config parameters for data and model.
# Data & model configuration
img_width, img_height = input_train.shape[1], input_train.shape[2]
batch_size = 128
no_epochs = 100
validation_split = 0.2
verbosity = 1
latent_dim = 2
num_channels = 1
The width and height of our configuration settings is determined by the training data. In our case, they will be img_width = img_height = 28
, as the MNIST dataset contains samples that are 28 x 28 pixels.
Batch size is set to 128 samples per (mini)batch, which is quite normal. The same is true for the number of epochs, which was set to 100. 20% of the training data is used for validation purposes. This is also quite normal. Nothing special here.
Verbosity mode is set to True (by means of 1
), which means that all the output is shown on screen.
The final two configuration settings are of relatively more interest. First, the latent space will be two-dimensional - which means that a significant information bottleneck will be created which should yield good results with autoencoders on relatively simple datasets. Finally, the num_channels
parameter can be configured to equal the number of image channels: for RGB data, it's 3 (red - green - blue), and for grayscale data (such as MNIST), it's 1.
The next thing is data preprocessing:
# Reshape data
input_train = input_train.reshape(input_train.shape[0], img_height, img_width, num_channels)
input_test = input_test.reshape(input_test.shape[0], img_height, img_width, num_channels)
input_shape = (img_height, img_width, num_channels)
# Parse numbers as floats
input_train = input_train.astype('float32')
input_test = input_test.astype('float32')
# Normalize data
input_train = input_train / 255
input_test = input_test / 255
First, we reshape the data so that it takes the shape (X, 28, 28, 1), where X is the number of samples in either the training or testing dataset. We also set (28, 28, 1) as input_shape
.
Next, we parse the numbers as floats, which presumably speeds up the training process, and normalize it, which the neural network appreciates. And that's it already for data preprocessing :-)
Now, it's time to create the encoder. This is a three-step process: firstly, we define it. Secondly, we perform something that is known as the reparameterization trick in order to allow us to link the encoder to the decoder later, to instantiate the VAE as a whole. But before that, we instantiate the encoder first, as our third and final step.
The first step in the three-step process is the definition of our encoder. Following the connection process of the Keras Functional API, we link the layers together:
# # =================
# # Encoder
# # =================
# Definition
i = Input(shape=input_shape, name='encoder_input')
cx = Conv2D(filters=8, kernel_size=3, strides=2, padding='same', activation='relu')(i)
cx = BatchNormalization()(cx)
cx = Conv2D(filters=16, kernel_size=3, strides=2, padding='same', activation='relu')(cx)
cx = BatchNormalization()(cx)
x = Flatten()(cx)
x = Dense(20, activation='relu')(x)
x = BatchNormalization()(x)
mu = Dense(latent_dim, name='latent_mu')(x)
sigma = Dense(latent_dim, name='latent_sigma')(x)
Let's now take a look at the individual lines of code in more detail.
Input
layer. It accepts data with input_shape = (28, 28, 1)
and is named encoder_input. It's actually a pretty dumb layer, haha πFlatten
layer. It's a relatively dumb layer too, and only serves to flatten the multidimensional data from the convolutional layers into one-dimensional shape. This has to be done because the densely-connected layers that we use next require data to have this shape.mu
and sigma
, are actually not separate from each other - look at the previous layer they are linked to (both x
, i.e. the Dense(20) layer). The first outputs the mean values \(\mu\) of the encoded input and the second one outputs the stddevs \(\sigma\). With these, we can sample the random variables that constitute the point in latent space onto which some input is mapped.That's for the layers of our encoder π The next step is to retrieve the shape of the final Conv2D output:
# Get Conv2D shape for Conv2DTranspose operation in decoder
conv_shape = K.int_shape(cx)
We'll need it when defining the layers of our decoder. I won't bother you with the details yet, as they are best explained when we're a bit further down the road. However, just remember to come back here if you wonder why we need some conv_shape
value in the decoder, okay? π
Let's now take a look at the second part of our encoder segment: the reparameterization trick.
While for a mathematically sound explanation of the so-called "reparameterization trick" introduced to VAEs by Kingma & Welling (2013) I must refer you to Gregory Gunderson's "The Reparameterization Trick", I'll try to explain the need for reparameritization briefly.
If you use neural networks (or, to be more precise, gradient descent) for optimizing the variational autoencoder, you effectively minimize some expected loss value, which can be estimated with Monte-Carlo techniques (Huang, n.d.). However, this requires that the loss function is differentiable, which is not necessarily the case, because it is dependent on the parameter of some probability distribution that we don't know about. In this case, it's possible to rewrite the equation, but then it no longer has the form of an expectation, making it impossible to use the Monte-Carlo techniques usable before.
However, if we can reparameterize the sample fed to the function into the shape \(\mu + \sigma^2 \times \epsilon\), it now becomes possible to use gradient descent for estimating the gradients accurately (Gunderson, n.d.; Huang, n.d.).
And that's precisely what we'll do in our code. We "sample" the value for \(z\) from the computed \(\mu\) and \(\sigma\) values by resampling into mu + K.exp(sigma / 2) * eps
.
# Define sampling with reparameterization trick
def sample_z(args):
mu, sigma = args
batch = K.shape(mu)[0]
dim = K.int_shape(mu)[1]
eps = K.random_normal(shape=(batch, dim))
return mu + K.exp(sigma / 2) * eps
We then use this with a Lambda
to ensure that correct gradients are computed during the backwards pass based on our values for mu
and sigma
:
# Use reparameterization trick to ensure correct gradient
z = Lambda(sample_z, output_shape=(latent_dim, ), name='z')([mu, sigma])
Now, it's time to instantiate the encoder - taking inputs through input layer i
, and outputting the values generated by the mu
, sigma
and z
layers (i.e., the individual means and standard deviations, and the point sampled from the random variable represented by them):
# Instantiate encoder
encoder = Model(i, [mu, sigma, z], name='encoder')
encoder.summary()
Now that we've got the encoder, it's time to start working on the decoder :)
Creating the decoder is a bit simpler and boils down to a two-step process: defining it, and instantiating it.
Firstly, we'll define the layers of our decoder - just as we've done when defining the structure of our encoder.
# =================
# Decoder
# =================
# Definition
d_i = Input(shape=(latent_dim, ), name='decoder_input')
x = Dense(conv_shape[1] * conv_shape[2] * conv_shape[3], activation='relu')(d_i)
x = BatchNormalization()(x)
x = Reshape((conv_shape[1], conv_shape[2], conv_shape[3]))(x)
cx = Conv2DTranspose(filters=16, kernel_size=3, strides=2, padding='same', activation='relu')(x)
cx = BatchNormalization()(cx)
cx = Conv2DTranspose(filters=8, kernel_size=3, strides=2, padding='same', activation='relu')(cx)
cx = BatchNormalization()(cx)
o = Conv2DTranspose(filters=num_channels, kernel_size=3, activation='sigmoid', padding='same', name='decoder_output')(cx)
Input
layer, the decoder_input
layer. It takes input with the shape (latent_dim, )
, which as we will see is the vector we sampled for z
with our encoder.(latent_dim, )
into some shape that can be reshaped into the output shape of the last convolutional layer of our encoder.conv_shape
variable. We'll thus now add a Dense
layer which has conv_shape[1] * conv_shape[2] * conv_shape[3]
output, and converts the latent space into many outputs.Reshape
layer to convert the output of the Dense layer into the output shape of the last convolutional layer: (conv_shape[1], conv_shape[2], conv_shape[3] = (7, 7, 16)
. Sixteen filters learnt with 7 x 7 pixels per filter.Conv2DTranspose
and BatchNormalization
in the exact opposite order as with our encoder to upsample our data into 28 x 28 pixels (which is equal to the width and height of our inputs). However, we still have 8 filters, so the shape so far is (28, 28, 8)
.Conv2DTranspose
layer which does nothing to the width and height of the data, but ensures that the number of filters learns equals num_channels
. For MNIST data, where num_channels = 1
, this means that the shape of our output will be (28, 28, 1
), just as it has to be :) This last layer also uses Sigmoid activation, which allows us to use binary crossentropy loss when computing the reconstruction loss part of our loss function.The next thing we do is instantiate the decoder:
# Instantiate decoder
decoder = Model(d_i, o, name='decoder')
decoder.summary()
It takes the inputs from the decoder input layer d_i
and outputs whatever is output by the output layer o
. Simple :)
Now that the encoder and decoder are complete, we can create the VAE as a whole:
# =================
# VAE as a whole
# =================
# Instantiate VAE
vae_outputs = decoder(encoder(i)[2])
vae = Model(i, vae_outputs, name='vae')
vae.summary()
If you think about it, the outputs of the entire VAE are the original inputs, encoded by the encoder, and decoded by the decoder.
That's how we arrive at vae_outputs = decoder(encoder(i)[2])
: inputs i
are encoded by the encoder
into [mu, sigma, z]
(the individual means and standard deviations with the sampled z
as well). We then take the sampled z
values (hence the [2]
) and feed it to the decoder
, which ensures that we arrive at correct VAE output.
We the instantiate the model: i
are our inputs indeed, and vae_outputs
are the outputs. We call the model vae
, because it simply is.
Now that we have defined our model, we can proceed with model configuration. Usually, with neural networks, this is done with model.compile
, where a loss function is specified such as binary crossentropy. However, when we look at how VAEs are optimized, we see that it's not a simple loss function that is used: we use reconstruction loss (in our case, binary crossentropy loss) together with KL divergence loss to ensure that our latent space is both continuous and complete.
We define it as follows:
# Define loss
def kl_reconstruction_loss(true, pred):
# Reconstruction loss
reconstruction_loss = binary_crossentropy(K.flatten(true), K.flatten(pred)) * img_width * img_height
# KL divergence loss
kl_loss = 1 + sigma - K.square(mu) - K.exp(sigma)
kl_loss = K.sum(kl_loss, axis=-1)
kl_loss *= -0.5
# Total loss = 50% rec + 50% KL divergence loss
return K.mean(reconstruction_loss + kl_loss)
reconstruction_loss
is the binary crossentropy value computed for the flattened true
values (representing our targets, i.e. our ground truth) and the pred
prediction values generated by our VAE. It's multiplied with img_width
and img_height
to reduce the impact of flattening.Now that we have defined our custom loss function, we can compile our model. We do so using the Adam optimizer and our kl_reconstruction_loss
custom loss function.
# Compile VAE
vae.compile(optimizer='adam', loss=kl_reconstruction_loss)
# Train autoencoder
vae.fit(input_train, input_train, epochs = no_epochs, batch_size = batch_size, validation_split = validation_split)
Once compiled, we can call vae.fit
to start the training process. Note that we set input_train
both as our features and targets, as is usual with autoencoders. For the rest, we configure the training process as defined previously, in the model configuration step.
Even though you can now actually train your VAE, it's best to wait just a bit more - because we'll add some code for visualization purposes:
Some credits first, though: the code for the two visualizers was originally created (and found by me) in the Keras Docs, at the link here, as well as in François Chollet's blog post, here. All credits for the original ideas go to the authors of these articles. I made some adaptations to the code to accomodate for this blog post:
Visualizing inputs mapped onto the latent space is simply taking some input data, feeding it to the encoder, taking the mean values \(\mu\) for the predictions, and plotting them in a scatter plot:
# =================
# Results visualization
# Credits for original visualization code: https://keras.io/examples/variational_autoencoder_deconv/
# (François Chollet).
# Adapted to accomodate this VAE.
# =================
def viz_latent_space(encoder, data):
input_data, target_data = data
mu, _, _ = encoder.predict(input_data)
plt.figure(figsize=(8, 10))
plt.scatter(mu[:, 0], mu[:, 1], c=target_data)
plt.xlabel('z - dim 1')
plt.ylabel('z - dim 2')
plt.colorbar()
plt.show()
Visualizing samples from the latent space entails a bit more work. First, we'll have to create a figure filled with zeros, as well as a linear space around \((\mu = 0, \sigma = 1)\) we can iterate over (from \(domain = range = [-4, +4]\)). We take a sample from the grid (determined by our current \(x\) and \(y\) positions) and feed it to the decoder. We then replace the zeros in our figure
with the output, and finally plot the entire figure on screen. This includes reshaping one-dimensional (i.e., grayscale) input if necessary.
def viz_decoded(encoder, decoder, data):
num_samples = 15
figure = np.zeros((img_width * num_samples, img_height * num_samples, num_channels))
grid_x = np.linspace(-4, 4, num_samples)
grid_y = np.linspace(-4, 4, num_samples)[::-1]
for i, yi in enumerate(grid_y):
for j, xi in enumerate(grid_x):
z_sample = np.array([[xi, yi]])
x_decoded = decoder.predict(z_sample)
digit = x_decoded[0].reshape(img_width, img_height, num_channels)
figure[i * img_width: (i + 1) * img_width,
j * img_height: (j + 1) * img_height] = digit
plt.figure(figsize=(10, 10))
start_range = img_width // 2
end_range = num_samples * img_width + start_range + 1
pixel_range = np.arange(start_range, end_range, img_width)
sample_range_x = np.round(grid_x, 1)
sample_range_y = np.round(grid_y, 1)
plt.xticks(pixel_range, sample_range_x)
plt.yticks(pixel_range, sample_range_y)
plt.xlabel('z - dim 1')
plt.ylabel('z - dim 2')
# matplotlib.pyplot.imshow() needs a 2D array, or a 3D array with the third dimension being of shape 3 or 4!
# So reshape if necessary
fig_shape = np.shape(figure)
if fig_shape[2] == 1:
figure = figure.reshape((fig_shape[0], fig_shape[1]))
# Show image
plt.imshow(figure)
plt.show()
Using the visualizers is however much easier:
# Plot results
data = (input_test, target_test)
viz_latent_space(encoder, data)
viz_decoded(encoder, decoder, data)
Let's now run our model. Open up a terminal which has access to all the required dependencies, cd
to the folder where your Python file is located, and run it, e.g. python variational_autoencoder.py
.
The training process should now begin with some visualizations being output after it finishes! :)
Marc, one of our readers, reported an issue with the model when running the VAE with TensorFlow 2.3.0 (and possibly also newer versions): https://github.com/tensorflow/probability/issues/519
By adding the following line of code, this issue can be resolved:
tf.config.experimental_run_functions_eagerly(True)
Even though I would recommend to read the entire post first before you start playing with the code (because the structures are intrinsically linked), it may be that you wish to take the full code and start fiddling right away. In this case, having the full code at once may be worthwhile to you, so here you go π
'''
Variational Autoencoder (VAE) with the Keras Functional API.
'''
import keras
from keras.layers import Conv2D, Conv2DTranspose, Input, Flatten, Dense, Lambda, Reshape
from keras.layers import BatchNormalization
from keras.models import Model
from keras.datasets import mnist
from keras.losses import binary_crossentropy
from keras import backend as K
import numpy as np
import matplotlib.pyplot as plt
# Load MNIST dataset
(input_train, target_train), (input_test, target_test) = mnist.load_data()
# Data & model configuration
img_width, img_height = input_train.shape[1], input_train.shape[2]
batch_size = 128
no_epochs = 100
validation_split = 0.2
verbosity = 1
latent_dim = 2
num_channels = 1
# Reshape data
input_train = input_train.reshape(input_train.shape[0], img_height, img_width, num_channels)
input_test = input_test.reshape(input_test.shape[0], img_height, img_width, num_channels)
input_shape = (img_height, img_width, num_channels)
# Parse numbers as floats
input_train = input_train.astype('float32')
input_test = input_test.astype('float32')
# Normalize data
input_train = input_train / 255
input_test = input_test / 255
# # =================
# # Encoder
# # =================
# Definition
i = Input(shape=input_shape, name='encoder_input')
cx = Conv2D(filters=8, kernel_size=3, strides=2, padding='same', activation='relu')(i)
cx = BatchNormalization()(cx)
cx = Conv2D(filters=16, kernel_size=3, strides=2, padding='same', activation='relu')(cx)
cx = BatchNormalization()(cx)
x = Flatten()(cx)
x = Dense(20, activation='relu')(x)
x = BatchNormalization()(x)
mu = Dense(latent_dim, name='latent_mu')(x)
sigma = Dense(latent_dim, name='latent_sigma')(x)
# Get Conv2D shape for Conv2DTranspose operation in decoder
conv_shape = K.int_shape(cx)
# Define sampling with reparameterization trick
def sample_z(args):
mu, sigma = args
batch = K.shape(mu)[0]
dim = K.int_shape(mu)[1]
eps = K.random_normal(shape=(batch, dim))
return mu + K.exp(sigma / 2) * eps
# Use reparameterization trick to ....??
z = Lambda(sample_z, output_shape=(latent_dim, ), name='z')([mu, sigma])
# Instantiate encoder
encoder = Model(i, [mu, sigma, z], name='encoder')
encoder.summary()
# =================
# Decoder
# =================
# Definition
d_i = Input(shape=(latent_dim, ), name='decoder_input')
x = Dense(conv_shape[1] * conv_shape[2] * conv_shape[3], activation='relu')(d_i)
x = BatchNormalization()(x)
x = Reshape((conv_shape[1], conv_shape[2], conv_shape[3]))(x)
cx = Conv2DTranspose(filters=16, kernel_size=3, strides=2, padding='same', activation='relu')(x)
cx = BatchNormalization()(cx)
cx = Conv2DTranspose(filters=8, kernel_size=3, strides=2, padding='same', activation='relu')(cx)
cx = BatchNormalization()(cx)
o = Conv2DTranspose(filters=num_channels, kernel_size=3, activation='sigmoid', padding='same', name='decoder_output')(cx)
# Instantiate decoder
decoder = Model(d_i, o, name='decoder')
decoder.summary()
# =================
# VAE as a whole
# =================
# Instantiate VAE
vae_outputs = decoder(encoder(i)[2])
vae = Model(i, vae_outputs, name='vae')
vae.summary()
# Define loss
def kl_reconstruction_loss(true, pred):
# Reconstruction loss
reconstruction_loss = binary_crossentropy(K.flatten(true), K.flatten(pred)) * img_width * img_height
# KL divergence loss
kl_loss = 1 + sigma - K.square(mu) - K.exp(sigma)
kl_loss = K.sum(kl_loss, axis=-1)
kl_loss *= -0.5
# Total loss = 50% rec + 50% KL divergence loss
return K.mean(reconstruction_loss + kl_loss)
# Compile VAE
vae.compile(optimizer='adam', loss=kl_reconstruction_loss)
# Train autoencoder
vae.fit(input_train, input_train, epochs = no_epochs, batch_size = batch_size, validation_split = validation_split)
# =================
# Results visualization
# Credits for original visualization code: https://keras.io/examples/variational_autoencoder_deconv/
# (François Chollet).
# Adapted to accomodate this VAE.
# =================
def viz_latent_space(encoder, data):
input_data, target_data = data
mu, _, _ = encoder.predict(input_data)
plt.figure(figsize=(8, 10))
plt.scatter(mu[:, 0], mu[:, 1], c=target_data)
plt.xlabel('z - dim 1')
plt.ylabel('z - dim 2')
plt.colorbar()
plt.show()
def viz_decoded(encoder, decoder, data):
num_samples = 15
figure = np.zeros((img_width * num_samples, img_height * num_samples, num_channels))
grid_x = np.linspace(-4, 4, num_samples)
grid_y = np.linspace(-4, 4, num_samples)[::-1]
for i, yi in enumerate(grid_y):
for j, xi in enumerate(grid_x):
z_sample = np.array([[xi, yi]])
x_decoded = decoder.predict(z_sample)
digit = x_decoded[0].reshape(img_width, img_height, num_channels)
figure[i * img_width: (i + 1) * img_width,
j * img_height: (j + 1) * img_height] = digit
plt.figure(figsize=(10, 10))
start_range = img_width // 2
end_range = num_samples * img_width + start_range + 1
pixel_range = np.arange(start_range, end_range, img_width)
sample_range_x = np.round(grid_x, 1)
sample_range_y = np.round(grid_y, 1)
plt.xticks(pixel_range, sample_range_x)
plt.yticks(pixel_range, sample_range_y)
plt.xlabel('z - dim 1')
plt.ylabel('z - dim 2')
# matplotlib.pyplot.imshow() needs a 2D array, or a 3D array with the third dimension being of shape 3 or 4!
# So reshape if necessary
fig_shape = np.shape(figure)
if fig_shape[2] == 1:
figure = figure.reshape((fig_shape[0], fig_shape[1]))
# Show image
plt.imshow(figure)
plt.show()
# Plot results
data = (input_test, target_test)
viz_latent_space(encoder, data)
viz_decoded(encoder, decoder, data)
Now, time for the results :)
Training the model for 100 epochs yields this visualization of the latent space:
As we can see, around \((0, 0)\) our latent space is pretty continuous as well as complete. Somewhere around \((0, -1.5)\) we see some holes, as well as near the edges (e.g. \((3, -3)\)). We can see these issues in the actual sampling too:
Especially in the right corners, we see the issue with completeness, which yield outputs that do not make sense. Some issues with continuity are visible wherever the samples are blurred. However, generally speaking, I'm quite happy with the results! π
However, let's see if we can make them even better :)
In their paper "Unsupervised representation learning with deep convolutional generative adversarial networks", Radford et al. (2015) introduce the concept of a deep convolutional generative adversarial network, or DCGAN. While a GAN represents the other branch of generative models, results have suggested that deep convolutional architectures for generative models may produce better results with VAEs as well.
So, as an extension of our original post, we've changed the architecture of our model into deeper and wider convolutional layers, in line with Radford et al. (2015). I changed the encoder
into:
i = Input(shape=input_shape, name='encoder_input')
cx = Conv2D(filters=128, kernel_size=5, strides=2, padding='same', activation='relu')(i)
cx = BatchNormalization()(cx)
cx = Conv2D(filters=256, kernel_size=5, strides=2, padding='same', activation='relu')(cx)
cx = BatchNormalization()(cx)
cx = Conv2D(filters=512, kernel_size=5, strides=2, padding='same', activation='relu')(cx)
cx = BatchNormalization()(cx)
cx = Conv2D(filters=1024, kernel_size=5, strides=2, padding='same', activation='relu')(cx)
cx = BatchNormalization()(cx)
x = Flatten()(cx)
x = Dense(20, activation='relu')(x)
x = BatchNormalization()(x)
mu = Dense(latent_dim, name='latent_mu')(x)
sigma = Dense(latent_dim, name='latent_sigma')(x)
And the decoder
into:
# Definition
d_i = Input(shape=(latent_dim, ), name='decoder_input')
x = Dense(conv_shape[1] * conv_shape[2] * conv_shape[3], activation='relu')(d_i)
x = BatchNormalization()(x)
x = Reshape((conv_shape[1], conv_shape[2], conv_shape[3]))(x)
cx = Conv2DTranspose(filters=1024, kernel_size=5, strides=2, padding='same', activation='relu')(x)
cx = BatchNormalization()(cx)
cx = Conv2DTranspose(filters=512, kernel_size=5, strides=2, padding='same', activation='relu')(cx)
cx = BatchNormalization()(cx)
cx = Conv2DTranspose(filters=256, kernel_size=5, strides=2, padding='same', activation='relu')(cx)
cx = BatchNormalization()(cx)
cx = Conv2DTranspose(filters=128, kernel_size=5, strides=2, padding='same', activation='relu')(cx)
cx = BatchNormalization()(cx)
o = Conv2DTranspose(filters=num_channels, kernel_size=3, activation='sigmoid', padding='same', name='decoder_output')(cx)
While our original VAE had approximately 26.000 trainable parameters, this one has approximately 9M:
_________________________________________________________________
Model: "vae"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
encoder_input (InputLayer) (None, 28, 28, 1) 0
_________________________________________________________________
encoder (Model) [(None, 2), (None, 2), (N 4044984
_________________________________________________________________
decoder (Model) (None, 28, 28, 1) 4683521
=================================================================
Total params: 8,728,505
Trainable params: 8,324,753
Non-trainable params: 403,752
However, even after training it for only 5 epochs, results have become considerably better:
Latent space (left) also looks better compared to our initial VAE (right):
However, what is interesting, is that the left one is a zoom, actually, as we also have some outliers now:
Interesting result :)
In this blog post, we've seen how to create a variational autoencoder with Keras. We first looked at what VAEs are, and why they are different from regular autoencoders. We then created a neural network implementation with Keras and explained it step by step, so that you can easily reproduce it yourself while understanding what happens.
In order to compare our initial 26K-parameter VAE, we expanded the architecture to resemble a DCGAN-like architecture of approx. 9M parameters, for both the encoder and the decoder. This yielded better results, but also increased the number of outliers.
I hope you've learnt something from this article :) If you did, please let me know by leaving a comment in the comments section below! π If you have questions or remarks, please do the same!
Thank you for reading MachineCurve today and happy engineering π
Keras. (n.d.). Variational autoencoder deconv. Retrieved from https://keras.io/examples/variational_autoencoder_deconv/
Gundersen,Β G. (2018, April 29). The Reparameterization Trick. Retrieved from http://gregorygundersen.com/blog/2018/04/29/reparameterization/
Kingma, D. P., & Welling, M. (2013). Auto-encoding variational bayes.Β arXiv preprint arXiv:1312.6114.
Huang,Β G. (n.d.). Reparametrization Trick Β· Machine Learning. Retrieved from https://gabrielhuang.gitbooks.io/machine-learning/content/reparametrization-trick.html
Wiseodd. (2016, December 10). Variational Autoencoder: Intuition and Implementation. Retrieved from http://wiseodd.github.io/techblog/2016/12/10/variational-autoencoder/
Keras Blog. (n.d.). Building Autoencoders in Keras. Retrieved from https://blog.keras.io/building-autoencoders-in-keras.html
Radford, A., Metz, L., & Chintala, S. (2015). Unsupervised representation learning with deep convolutional generative adversarial networks.Β arXiv preprint arXiv:1511.06434.
Learn how large language models and other foundation models are working and how you can train open source ones yourself.
Keras is a high-level API for TensorFlow. It is one of the most popular deep learning frameworks.
Read about the fundamentals of machine learning, deep learning and artificial intelligence.
To get in touch with me, please connect with me on LinkedIn. Make sure to write me a message saying hi!
The content on this website is written for educational purposes. In writing the articles, I have attempted to be as correct and precise as possible. Should you find any errors, please let me know by creating an issue or pull request in this GitHub repository.
All text on this website written by me is copyrighted and may not be used without prior permission. Creating citations using content from this website is allowed if a reference is added, including an URL reference to the referenced article.
If you have any questions or remarks, feel free to get in touch.
TensorFlow, the TensorFlow logo and any related marks are trademarks of Google Inc.
PyTorch, the PyTorch logo and any related marks are trademarks of The Linux Foundation.
Montserrat and Source Sans are fonts licensed under the SIL Open Font License version 1.1.
Mathjax is licensed under the Apache License, Version 2.0.