Data scientists find some aspects of their job really frustrating. Data preprocessing is one of them, but the same is true for generating visualizations and other kind of reports. They're boring, nobody reads them, and creating them takes a lot of time.
What if there is an alternative, allowing you to create interactive visualizations of your data science results within minutes?
That's what we're going to find out today. You're going to explore Streamlit, an open source and free package for creating data driven web apps. More specifically, you will generate visualizations of the tensorflow.keras.datasets
datasets related to images: the MNIST dataset, the Fashion MNIST dataset, and the CIFAR-10 and CIFAR-100 datasets. It allows you to easily walk through the datasets, generating plots on the fly.
After reading this tutorial, you will...
tensorflow.keras
datasets.Are you ready? Let's take a look! š
While the job of data scientists can be cool, it can also be really frustrating - especially when it comes to visualizing your datasets.
Creating an application for showing what you have built or what you want to built can be really frustrating.
No more. Say hello to Streamlit. Streamlit is an open source and free package with which you can create data driven web apps in minutes.
Really, it takes almost no time to build your data dashboard - and we're going to see how to use it today.
Let's now write some code! š
You will need to install these dependencies, if not already installed, to run the code in this tutorial
pip install streamlit
pip install tensorflow
pip install matplotlib
Let's now take a look at writing our tool. Creating an interactive visualization for the Keras datasets involves the following steps:
get_dataset_mappings()
def.load_dataset()
def.draw_images()
def.do_streamlit()
def.__main__
part.However, let's begin with creating a file where our code can be written - say, keras-image-datasets.py
.
A screenshot from the visualization generated by our tool.
The first thing we do - as always - is writing the specification of the dependencies that we need:
import streamlit as st
import tensorflow
from tensorflow.keras import datasets
import matplotlib.pyplot as plt
We will need streamlit
because it is the runtime for our interactive visualization. With tensorflow
and tensorflow.keras
, we can load the datasets. Finally, we're using Matplotlib's pyplot
API for visualizing the images.
We can then write the dataset to dataset mappings:
def get_dataset_mappings():
"""
Get mappings for dataset key
to dataset and name.
"""
mappings = {
'CIFAR-10': datasets.cifar10,
'CIFAR-100': datasets.cifar100,
'Fashion MNIST': datasets.fashion_mnist,
'MNIST': datasets.mnist
}
return mappings
This definition provides a string -> dataset
mapping by defining a dictionary that can be used for converting some input String to the corresponding tensorflow.keras.datasets
dataset. For example, if we take its MNIST
attribute, it returns the MNIST dataset. We can use this dictionary for emulating switch
-like behavior, which is not present in Python by default.
Subsequently, we can define load_dataset()
. It takes a name
argument. First, it retrieves the dataset mappings that we discussed above. Subsequently, it loads the corresponding Keras dataset (also as discussed above) and performs load_data()
. As you can see, we're only using the training inputs, which we return as the output of this def.
def load_dataset(name):
"""
Load a dataset
"""
# Define name mapping
name_mapping = get_dataset_mappings()
# Get train data
(X, _), (_, _) = name_mapping[name].load_data()
# Return data
return X
Now that we have a dataset, we can draw some images!
With draw_images()
, we will be able to generate a multiplot with the samples that we selected.
For this, we have to specify a dataset (data
), a position/index of our starting image (start_index
), and the number of rows (num_rows
) and columns (num_cols)
that we want to show.
First of all, we generate Matplotlib subplots - as many as num_rows
and num_cols
allow.
Then, usign the columns and rows, we can compute the total number of images, in show_items
. We then specify an iterator index and iterate over each row
and col
, filling the specific frame with the image at that index.
Finally, we return the figure - but do so using Streamlit's pyplot
wrapper, to make it work.
def draw_images(data, start_index, num_rows, num_cols):
"""
Generate multiplot with selected samples.
"""
# Get figure and axes
fig, axs = plt.subplots(num_rows, num_cols)
# Show number of items
show_items = num_rows * num_cols
# Iterate over items from start index
iterator = 0
for row in range(0, num_rows):
for col in range(0, num_cols):
index = iterator + start_index
axs[row, col].imshow(data[index])
axs[row, col].axis('off')
iterator += 1
# Return figure
return st.pyplot(fig)
It is good practice in Python to keep as much of your code in definitions. That's why we finally define do_streamlit()
, which does nothing more than setting up the Streamlit dashboard and processing user interactions.
It involves the following steps:
load_dataset()
def.maximum_length
here in order to not exceed the input shape by too much.try/except
statement because invalid combinations, although minimized, remain possible. For example, by setting the picture_id
to a value that less than no_rows * no_cols
below the maximum_length
, image generation crashes. We can fix this with some additional code, but chose to keep things simple. Who needs the final images if you can visualize many in between?def do_streamlit():
"""
Set up the Streamlit dashboard and capture
interactions.
"""
# Styling
plt.style.use('dark_background')
# Set title
st.title('Interactive visualization of Keras image datasets')
# Define select box
dataset_selection = st.selectbox('Dataset', ('CIFAR-10', 'CIFAR-100', 'Fashion MNIST', 'MNIST'))
# Dataset
dataset = load_dataset(dataset_selection)
# Number of images in dataset
maximum_length = dataset.shape[0]
# Define sliders
picture_id = st.slider('Start at picture', 0, maximum_length, 0)
no_rows = st.slider('Number of rows', 2, 30, 5)
no_cols = st.slider('Number of columns', 2, 30, 5)
# Show image
try:
st.image(draw_images(dataset, picture_id, no_rows, no_cols))
except:
print()
Finally, we write the runtime if
statement, which checks if we are running the Python interpreter. If so, we're invoking everything with do_streamlit()
.
if __name__ == '__main__':
do_streamlit()
I can understand if you don't want to follow all the individual steps above and rather want to play with the full code. That's why you can also retrieve the full code below. Make sure to rest of the article in order to understand everything that is going on! :)
import streamlit as st
import tensorflow
from tensorflow.keras import datasets
import matplotlib.pyplot as plt
def get_dataset_mappings():
"""
Get mappings for dataset key
to dataset and name.
"""
mappings = {
'CIFAR-10': datasets.cifar10,
'CIFAR-100': datasets.cifar100,
'Fashion MNIST': datasets.fashion_mnist,
'MNIST': datasets.mnist
}
return mappings
def load_dataset(name):
"""
Load a dataset
"""
# Define name mapping
name_mapping = get_dataset_mappings()
# Get train data
(X, _), (_, _) = name_mapping[name].load_data()
# Return data
return X
def draw_images(data, start_index, num_rows, num_cols):
"""
Generate multiplot with selected samples.
"""
# Get figure and axes
fig, axs = plt.subplots(num_rows, num_cols)
# Show number of items
show_items = num_rows * num_cols
# Iterate over items from start index
iterator = 0
for row in range(0, num_rows):
for col in range(0, num_cols):
index = iterator + start_index
axs[row, col].imshow(data[index])
axs[row, col].axis('off')
iterator += 1
# Return figure
return st.pyplot(fig)
def do_streamlit():
"""
Set up the Streamlit dashboard and capture
interactions.
"""
# Styling
plt.style.use('dark_background')
# Set title
st.title('Interactive visualization of Keras image datasets')
# Define select box
dataset_selection = st.selectbox('Dataset', ('CIFAR-10', 'CIFAR-100', 'Fashion MNIST', 'MNIST'))
# Dataset
dataset = load_dataset(dataset_selection)
# Number of images in dataset
maximum_length = dataset.shape[0]
# Define sliders
picture_id = st.slider('Start at picture', 0, maximum_length, 0)
no_rows = st.slider('Number of rows', 2, 30, 5)
no_cols = st.slider('Number of columns', 2, 30, 5)
# Show image
try:
st.image(draw_images(dataset, picture_id, no_rows, no_cols))
except:
print()
if __name__ == '__main__':
do_streamlit()
Let's now take a look what happens when we run the code.
We can do so by opening up a terminal and making sure that it runs in the environment where our dependencies are installed. If not, make sure that it does - by enabling it.
Then run streamlit run keras-image-datasets.py
. It should open up your browser relatively quickly and this is what you should see:
You can use the selectors on top for customing the output image. With Dataset, you can pick one of the image-based TensorFlow Keras datasets. With number of rows and number of columns, you can configure the output dimensions of your image. Finally, using start at picture, you can choose the index of the picture in the top left corner. All other images are the subsequent indices.
For example, by switching to the Fashion MNIST dataset:
This is what we get:
Then, we also tune the start position, the number of rows and the number of columns:
And see, we have created ourselves a tool that allows us to quickly explore the Keras datasets!
With some adaptation, it should even be possible to explore your own dataset with this tool, but that's for another tutorial :)
Now that you have read this tutorial, you...
tensorflow.keras
datasets.I hope that it was useful for your learning process! Please feel free to share what you have learned in the comments section š¬ Iād love to hear from you. Please do the same if you have any questions or other remarks.
Thank you for reading MachineCurve today and happy engineering! š
Streamlit. (n.d.). The fastest way to build and share data apps.Ā https://streamlit.io/
GitHub. (n.d.).Ā Streamlit/streamlit.Ā https://github.com/streamlit/streamlit
Learn how large language models and other foundation models are working and how you can train open source ones yourself.
Keras is a high-level API for TensorFlow. It is one of the most popular deep learning frameworks.
Read about the fundamentals of machine learning, deep learning and artificial intelligence.
To get in touch with me, please connect with me on LinkedIn. Make sure to write me a message saying hi!
The content on this website is written for educational purposes. In writing the articles, I have attempted to be as correct and precise as possible. Should you find any errors, please let me know by creating an issue or pull request in this GitHub repository.
All text on this website written by me is copyrighted and may not be used without prior permission. Creating citations using content from this website is allowed if a reference is added, including an URL reference to the referenced article.
If you have any questions or remarks, feel free to get in touch.
TensorFlow, the TensorFlow logo and any related marks are trademarks of Google Inc.
PyTorch, the PyTorch logo and any related marks are trademarks of The Linux Foundation.
Montserrat and Source Sans are fonts licensed under the SIL Open Font License version 1.1.
Mathjax is licensed under the Apache License, Version 2.0.