Every now and then, you might need to demonstrate your Keras model structure. There's one or two things that you may do when this need arises. First, you may send the person who needs this overview your code, requiring them to derive the model architecture themselves. If you're nicer, you send them a model of your architecture.
...but creating such models is often a hassle when you have to do it manually. Solutions like www.draw.io are used quite often in those cases, because they are (relatively) quick and dirty, allowing you to create models fast.
However, there's a better solution: the built-in plot_model
facility within Keras. It allows you to create a visualization of your model architecture. In this blog, I'll show you how to create such a visualization. Specifically, I focus on the model itself, discussing its architecture so that you fully understand what happens. Subsquently, I'll list some software dependencies that you'll need - including a highlight about a bug in Keras that results in a weird error related to pydot
and GraphViz, which are used for visualization. Finally, I present you the code used for visualization and the end result.
After reading this tutorial, you will...
plot_model()
util in TensorFlow 2.0/Keras does.Note that model code is also available on GitHub.
Update 22/Jan/2021: ensured that the tutorial is up-to-date and reflects code for TensorFlow 2.0. It can now be used with recent versions of the library. Also performed some header changes and textual improvements based on the switch from Keras 1.0 to TensorFlow 2.0. Also added an exampl of horizontal plotting.
If you want to get started straight away, here is the code that you can use for visualizing your TensorFlow 2.0/Keras model with plot_model
:
from tensorflow.keras.utils import plot_model
plot_model(model, to_file='model.png')
Make sure to read the rest of this tutorial if you want to understand everything in more detail!
To show you how to visualize a Keras model, I think it's best if we discussed one first.
Today, we will visualize the Convolutional Neural Network that we created earlier to demonstrate the benefits of using CNNs over densely-connected ones.
This is the code of that model:
import tensorflow
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Flatten
from tensorflow.keras.layers import Conv2D, MaxPooling2D
# Model configuration
img_width, img_height = 28, 28
batch_size = 250
no_epochs = 25
no_classes = 10
validation_split = 0.2
verbosity = 1
# Load MNIST dataset
(input_train, target_train), (input_test, target_test) = mnist.load_data()
# Set input shape
sample_shape = input_train[0].shape
img_width, img_height = sample_shape[0], sample_shape[1]
input_shape = (img_width, img_height, 1)
# Reshape data
input_train = input_train.reshape(len(input_train), input_shape[0], input_shape[1], input_shape[2])
input_test = input_test.reshape(len(input_test), input_shape[0], input_shape[1], input_shape[2])
# Parse numbers as floats
input_train = input_train.astype('float32')
input_test = input_test.astype('float32')
# Convert them into black or white: [0, 1].
input_train = input_train / 255
input_test = input_test / 255
# Convert target vectors to categorical targets
target_train = tensorflow.keras.utils.to_categorical(target_train, no_classes)
target_test = tensorflow.keras.utils.to_categorical(target_test, no_classes)
# Create the model
model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=input_shape))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Conv2D(64, kernel_size=(3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(256, activation='relu'))
model.add(Dense(no_classes, activation='softmax'))
# Compile the model
model.compile(loss=tensorflow.keras.losses.categorical_crossentropy,
optimizer=tensorflow.keras.optimizers.Adam(),
metrics=['accuracy'])
# Fit data to model
model.fit(input_train, target_train,
batch_size=batch_size,
epochs=no_epochs,
verbose=verbosity,
validation_split=validation_split)
# Generate generalization metrics
score = model.evaluate(input_test, target_test, verbose=0)
print(f'Test loss: {score[0]} / Test accuracy: {score[1]}')
What does it do?
I'd suggest that you read the post if you wish to understand it very deeply, but I'll briefly cover it here.
It simply classifies the MNIST dataset. This dataset contains 28 x 28 pixel images of digits, or numbers between 0 and 9, and our CNN classifies them with a staggering 99% accuracy. It does so by combining two convolutional blocks (which consist of a two-dimensional convolutional layer, two-dimensional max pooling and dropout) with densely-conneted layers. It's the best of both worlds in terms of interpreting the image and generating final predictions.
But how to visualize this model's architecture? Let's find out.
plot_model
utilUtilities. I love them, because they make my life easier. They're often relatively simple functions that can be called upon to perform some relatively simple actions. Don't be fooled, however, because these actions often benefit one's efficiently greatly - in this case, not having to visualize a model architecture yourself in tools like draw.io
I'm talking about the plot_model
util, which comes delivered with Keras.
It allows you to create a visualization of your Keras neural network.
More specifically, the Keras docs define it as follows:
from tensorflow.keras.utils import plot_model
plot_model(model, to_file='model.png')
From the Keras utilities, one needs to import the function, after which it can be used with very minimal parameters:
to_file
parameter, which essentially specifies a location on disk where the model visualization is stored.If you wish, you can supply some additional parameters as well:
False
by default) which controls whether the shape of the layer outputs are shown in the graph. This would be beneficial if besides the architecture you also need to understand how it transforms data.False
by default) you can indicate whether to show layer data types on the plot.True
by default) which determines whether the names of the layers are displayed.TB
by default) can be used to indicate whether you want a vertical or horizontal plot. TB
is vertical, LR
is horizontal.False
by default) controls how nested models are displayed.However, likely, for a simple visualization, you don't need them. Let's now take a look what we would need if we were to create such a visualization.
If you wish to run the code presented in this blog successfully, you need to install certain software dependencies. You'll need those to run it:
Preferably, you'll run this from an Anaconda environment, which allows you to run these packages in an isolated fashion. Note that many people report that a pip
based installation of Graphviz doesn't work; rather, you'll have to install it separately into your host OS from their website. Bummer!
pydot
failed to call GraphVizWhen trying to visualize my Keras neural network with plot_model
, I ran into this error:
'`pydot` failed to call GraphViz.'
OSError: `pydot` failed to call GraphViz.Please install GraphViz (https://www.graphviz.org/) and ensure that its executables are in the $PATH.
...which essentially made sense at first, because I didn't have Graphviz installed.
...but which didn't after I installed it, because the error kept reappearing, even after restarting the Anaconda terminal.
Sometimes, it helps to install pydotplus
as well with pip install pydotplus
. Another solution, although not preferred, is to downgrade your pydot
version.
When adapting the code from my original CNN, scrapping away the elements I don't need for visualizing the model architecture, I end up with this:
import tensorflow
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Flatten
from tensorflow.keras.layers import Conv2D, MaxPooling2D
from tensorflow.keras.utils import plot_model
# Load MNIST dataset
(input_train, target_train), (input_test, target_test) = mnist.load_data()
# Set input shape
sample_shape = input_train[0].shape
img_width, img_height = sample_shape[0], sample_shape[1]
input_shape = (img_width, img_height, 1)
# Number of classes
no_classes = 10
# Reshape data
input_train = input_train.reshape(len(input_train), input_shape[0], input_shape[1], input_shape[2])
input_test = input_test.reshape(len(input_test), input_shape[0], input_shape[1], input_shape[2])
# Create the model
model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=input_shape))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Conv2D(64, kernel_size=(3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(256, activation='relu'))
model.add(Dense(no_classes, activation='softmax'))
plot_model(model, to_file='model.png')
You'll first perform the imports that you still need in order to successfully run the Python code. Specifically, you'll import the Keras library, the Sequential API and certain layers - this is obviously dependent on what you want. Do you want to use the Functional API? That's perfectly fine. Other layers? Fine too. I just used them since the CNN is exemplary.
Note that I also imported plot_model
with from tensorflow.keras.utils import plot_model
and reshaped the data to accomodate for the Conv2D layer.
Speaking about architecture: that's what I kept in. Based on the Keras Sequential API, I apply the two convolutional blocks as discussed previously, before flattening their output and feeding it to the densely-connected layers generating the final prediction. And, of course, we need no_classes = 10
to ensure that our final Dense
layer works as well.
However, in this case, no such prediction is generated. Rather, the model
instance is used by plot_model
to generate a model visualization stored at disk as model.png
. Likely, you'll add hyperparameter tuning and data fitting later on - but hey, that's not the purpose of this blog.
And your final end result looks like this:
Indeed, above we saw that we can use the rankdir
attribute (which is set to TB
i.e. vertical by default) to generate a horizontal plot! This is new, and highly preferred, as we sometimes don't want these massive vertical plots.
Making a horizontal plot of your TensorFlow/Keras model simply involves adding the rankdir='LR'
a.k.a. horizontal attribute:
plot_model(model, to_file='model.png', rankdir='LR')
Which gets you this:
Awesome!
In this blog, you've seen how to create a Keras model visualization based on the plot_model
util provided by the library. I hope you found it useful - let me know in the comments section, I'd appreciate it! ๐ If not, let me know as well, so I can improve. For now: happy engineering! ๐ฉโ๐ป
Note that model code is also available on GitHub.
How to create a CNN classifier with Keras? โ MachineCurve. (2019, September 24). Retrieved from https://www.machinecurve.com/index.php/2019/09/17/how-to-create-a-cnn-classifier-with-keras/
Keras. (n.d.). Visualization. Retrieved from https://keras.io/visualization/
Avoid wasting resources with EarlyStopping and ModelCheckpoint in Keras โ MachineCurve. (2019, June 3). Retrieved from https://www.machinecurve.com/index.php/2019/05/30/avoid-wasting-resources-with-earlystopping-and-modelcheckpoint-in-keras/
pydot issue ยท Issue #7 ยท XifengGuo/CapsNet-Keras. (n.d.). Retrieved from https://github.com/XifengGuo/CapsNet-Keras/issues/7#issuecomment-536100376
TensorFlow. (2021).ย Tf.keras.utils.plot_model.ย https://www.tensorflow.org/api_docs/python/tf/keras/utils/plot_model
Learn how large language models and other foundation models are working and how you can train open source ones yourself.
Keras is a high-level API for TensorFlow. It is one of the most popular deep learning frameworks.
Read about the fundamentals of machine learning, deep learning and artificial intelligence.
To get in touch with me, please connect with me on LinkedIn. Make sure to write me a message saying hi!
The content on this website is written for educational purposes. In writing the articles, I have attempted to be as correct and precise as possible. Should you find any errors, please let me know by creating an issue or pull request in this GitHub repository.
All text on this website written by me is copyrighted and may not be used without prior permission. Creating citations using content from this website is allowed if a reference is added, including an URL reference to the referenced article.
If you have any questions or remarks, feel free to get in touch.
TensorFlow, the TensorFlow logo and any related marks are trademarks of Google Inc.
PyTorch, the PyTorch logo and any related marks are trademarks of The Linux Foundation.
Montserrat and Source Sans are fonts licensed under the SIL Open Font License version 1.1.
Mathjax is licensed under the Apache License, Version 2.0.