Flatten-T Swish is a new (2018) activation function that attempts to find the best of both worlds between traditional ReLU and traditional Sigmoid.
However, it's not readily available within the Keras deep learning framework, which only covers the standard activation functions like ReLU and Sigmoid.
Therefore, in today's blog, we'll implement it ourselves. First, we'll take a look at FTSwish - by providing a recap - and then implement it using Keras. This blog also includes an example model which uses FTSwish, and evaluates the model after training.
Are you ready?
Let's go! 😊
In our blog post "What is the FTSwish activation function?" we looked at what the Flatten-T Swish or FTSwish activation function is like. Here, we'll recap the essentials, so that you can understand with ease what we're going to build next.
We can define FTSwish as follows:
\begin{equation} FTSwish: f(x) = \begin{cases} T, & \text{if}\ x < 0 \\ \frac{x}{1 + e^{-x}} + T, & \text{otherwise} \\ \end{cases} \end{equation}
It's essentially a combination of the ReLU and Sigmoid activation functions, with some threshold T
which ensures that negative inputs always yield nonzero outputs.
It looks as follows:
And indeed, it does resemble Swish in a way:
Keras has a range of activation functions available, but FTSwish is not one of them. Fortunately, it's possible to define your own activations, so yes: we can still use FTSwish with Keras :) Let's now find out how.
In any Keras model, you'll first have to import the backend you're working with, in order to provide tensor-specific operations such as maximum
:
from keras import backend as K
We can then define the FTSwish activation function as follows:
# Define
t = -1.0
def ftswish(x):
return K.maximum(t, K.relu(x)*K.sigmoid(x) + t)
Let's break the definition down into understandable steps:
t
is the threshold value \(T\), which in our case is -1.0. It ensures that negative inputs saturate to this value. Its value can be different, but take a look at the derivative plot to ensure that you'll have a smooth one.def
(definition) ensures that we can use ftswish
as some kind of function - mapping some input to an output. It also means that we can simply feed it to Keras later, to be used in processing.K.relu
is the ReLU part.K.sigmoid
is the Sigmoid part.t
to the outcome of the multiplication.ReLU
, which is \(0\) for negative inputs and \(x\) for others, can be rewritten to \(max(0, x)\) (indeed: \(x = 4\) yields outputs of 4, while \(x = -2\) yields 0. This is in line with the ReLU definition). Hence, given the formula for FTSwish above, we can rewrite it to a max
between t
(the negative output) and the ReLU/Sigmoid combination (the positive output).K
instead of np
because we're performing these operations on multidimensional tensors.Let's now create an example with Keras :) Open up your Explorer or Finder, navigate to some folder, and create a Python file, e.g. model_ftswish.py
.
Now, open up model_ftswish.py
in a code editor and start coding :) First, we'll add the imports:
'''
Keras model using Flatten-T Swish (FTSwish) activation function
Source for FTSwish activation function:
Chieng, H. H., Wahid, N., Ong, P., & Perla, S. R. K. (2018). Flatten-T Swish: a thresholded ReLU-Swish-like activation function for deep learning. arXiv preprint arXiv:1812.06247.
https://arxiv.org/abs/1812.06247
'''
import keras
from keras.datasets import cifar10
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras import backend as K
import matplotlib.pyplot as plt
import numpy as np
As expected, we'll imporrt keras
and a lot of sub parts of it: the cifar10
dataset (which we'll use today), the Sequential
API for easy stacking of our layers, all the layers that are common in a ConvNet, and the Keras backend (which, in our case, maps to Tensorflow). Finally, we also import pyplot
from Matplotlib and numpy
.
Next, it's time to set some configuration values:
# Model configuration
img_width, img_height = 32, 32
batch_size = 250
no_epochs = 100
no_classes = 10
validation_split = 0.2
verbosity = 1
The CIFAR-10 dataset which we're using today contains 32 x 32 pixels images across 10 different classes. Hence, img_width = img_height = 32
, and no_classes = 10
. The batch_size
is 250 which is a fairly OK setting based on experience (click here to find out why to balance between high batch sizes and memory requirements). We train for 100 epochs
, and use 20% of our training data for validation purposes. We output everything on screen by setting verbosity
to True.
We next load the CIFAR-10 data:
# Load CIFAR-10 dataset
(input_train, target_train), (input_test, target_test) = cifar10.load_data()
Which easily loads the CIFAR-10 samples into our training and testing variables:
A few CIFAR-10 samples.
After loading, we reshape the data based on the channels first/channels last approach used by our backend (to ensure that we can use a fixed input_shape
):
# Reshape data based on channels first / channels last strategy.
# This is dependent on whether you use TF, Theano or CNTK as backend.
# Source: https://github.com/keras-team/keras/blob/master/examples/mnist_cnn.py
if K.image_data_format() == 'channels_first':
input_train = input_train.reshape(input_train.shape[0], 3, img_width, img_height)
input_test = input_test.reshape(input_test.shape[0], 3, img_width, img_height)
input_shape = (3, img_width, img_height)
else:
input_train = input_train.reshape(input_train.shape[0], img_width, img_height, 3)
input_test = input_test.reshape(input_test.shape[0], img_width, img_height, 3)
input_shape = (img_width, img_height, 3)
Then, we parse our numbers into float32
format, which presumably speeds up our training process:
# Parse numbers as floats
input_train = input_train.astype('float32')
input_test = input_test.astype('float32')
This is followed by normalizing our data to be in the \([-1, 1]\) range, which is appreciated by the neural network during optimization:
# Normalize data
input_train = input_train / 255
input_test = input_test / 255
Finally, we convert our targets into categorical format, which allows us to use categorical crossentropy loss later:
# Convert target vectors to categorical targets
target_train = keras.utils.to_categorical(target_train, no_classes)
target_test = keras.utils.to_categorical(target_test, no_classes)
We can next add the definition of the FTSwish activation function we created earlier:
# Define
t = -1.0
def ftswish(x):
return K.maximum(t, K.relu(x)*K.sigmoid(x) + t)
Then, we can create the architecture of our model.
# Create the model
model = Sequential()
model.add(Conv2D(64, kernel_size=(3, 3), activation=ftswish, input_shape=input_shape, kernel_initializer='he_normal'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.5))
model.add(Conv2D(128, kernel_size=(3, 3), activation=ftswish, kernel_initializer='he_normal'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.5))
model.add(Flatten())
model.add(Dense(512, kernel_initializer='he_normal', activation=ftswish))
model.add(Dense(256, kernel_initializer='he_normal', activation=ftswish))
model.add(Dense(no_classes, activation='softmax', kernel_initializer='he_normal'))
It's a relatively simple ConvNet, with two Conv2D layers, max pooling, Dropout and finally Dense layers for classification. We use He init because our activation function resembles ReLU.
Next, we can compile the model:
# Compile the model
model.compile(loss=keras.losses.categorical_crossentropy,
optimizer=keras.optimizers.Adam(),
metrics=['accuracy'])
Because we are facing a multiclass classification problem with one-hot encoded vectors (by virtue of calling to_categorical
), we'll be using categorical crossentropy. If you wish to skip the conversion to categorical targets, you might want to replace this with sparse categorical crossentropy, which supports integer targets.
For optimization, we use the Adam optimizer - the default choice for today's neural networks. Finally, we specify accuracy
as an additional metric, which is more intuitive than crossentropy loss.
Then, we fit the training data, configuring the model in line with how we specified our model configuration before:
# Fit data to model
history_FTSwish = model.fit(input_train, target_train,
batch_size=batch_size,
epochs=no_epochs,
verbose=verbosity,
validation_split=validation_split)
The final thing we do is adding code for evaluation (using our testing data) and visualizing the training process:
# Generate evaluation metrics
score = model.evaluate(input_test, target_test, verbose=0)
print(f'Test loss for Keras FTSwish CNN: {score[0]} / Test accuracy: {score[1]}')
# Visualize model history
plt.plot(history_FTSwish.history['accuracy'], label='Training accuracy')
plt.plot(history_FTSwish.history['val_accuracy'], label='Validation accuracy')
plt.title('FTSwish training / validation accuracies')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(loc="upper left")
plt.show()
plt.plot(history_FTSwish.history['loss'], label='Training loss')
plt.plot(history_FTSwish.history['val_loss'], label='Validation loss')
plt.title('FTSwish training / validation loss values')
plt.ylabel('Loss value')
plt.xlabel('Epoch')
plt.legend(loc="upper left")
plt.show()
It's also possible to get the full model code at once, should you wish to start playing around with it. In that case, here you go:
'''
Keras model using Flatten-T Swish (FTSwish) activation function
Source for FTSwish activation function:
Chieng, H. H., Wahid, N., Ong, P., & Perla, S. R. K. (2018). Flatten-T Swish: a thresholded ReLU-Swish-like activation function for deep learning. arXiv preprint arXiv:1812.06247.
https://arxiv.org/abs/1812.06247
'''
import keras
from keras.datasets import cifar10
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras import backend as K
import matplotlib.pyplot as plt
import numpy as np
# Model configuration
img_width, img_height = 32, 32
batch_size = 250
no_epochs = 100
no_classes = 10
validation_split = 0.2
verbosity = 1
# Load CIFAR-10 dataset
(input_train, target_train), (input_test, target_test) = cifar10.load_data()
# Reshape data based on channels first / channels last strategy.
# This is dependent on whether you use TF, Theano or CNTK as backend.
# Source: https://github.com/keras-team/keras/blob/master/examples/mnist_cnn.py
if K.image_data_format() == 'channels_first':
input_train = input_train.reshape(input_train.shape[0], 3, img_width, img_height)
input_test = input_test.reshape(input_test.shape[0], 3, img_width, img_height)
input_shape = (3, img_width, img_height)
else:
input_train = input_train.reshape(input_train.shape[0], img_width, img_height, 3)
input_test = input_test.reshape(input_test.shape[0], img_width, img_height, 3)
input_shape = (img_width, img_height, 3)
# Parse numbers as floats
input_train = input_train.astype('float32')
input_test = input_test.astype('float32')
# Normalize data
input_train = input_train / 255
input_test = input_test / 255
# Convert target vectors to categorical targets
target_train = keras.utils.to_categorical(target_train, no_classes)
target_test = keras.utils.to_categorical(target_test, no_classes)
# Define
t = -1.0
def ftswish(x):
return K.maximum(t, K.relu(x)*K.sigmoid(x) + t)
# Create the model
model = Sequential()
model.add(Conv2D(64, kernel_size=(3, 3), activation=ftswish, input_shape=input_shape, kernel_initializer='he_normal'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.5))
model.add(Conv2D(128, kernel_size=(3, 3), activation=ftswish, kernel_initializer='he_normal'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.5))
model.add(Flatten())
model.add(Dense(256, kernel_initializer='he_normal', activation=ftswish))
model.add(Dense(no_classes, activation='softmax', kernel_initializer='he_normal'))
# Compile the model
model.compile(loss=keras.losses.categorical_crossentropy,
optimizer=keras.optimizers.Adam(),
metrics=['accuracy'])
# Fit data to model
history_FTSwish = model.fit(input_train, target_train,
batch_size=batch_size,
epochs=no_epochs,
verbose=verbosity,
validation_split=validation_split)
# Generate evaluation metrics
score = model.evaluate(input_test, target_test, verbose=0)
print(f'Test loss for Keras FTSwish CNN: {score[0]} / Test accuracy: {score[1]}')
# Visualize model history
plt.plot(history_FTSwish.history['accuracy'], label='Training accuracy')
plt.plot(history_FTSwish.history['val_accuracy'], label='Validation accuracy')
plt.title('FTSwish training / validation accuracies')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(loc="upper left")
plt.show()
plt.plot(history_FTSwish.history['loss'], label='Training loss')
plt.plot(history_FTSwish.history['val_loss'], label='Validation loss')
plt.title('FTSwish training / validation loss values')
plt.ylabel('Loss value')
plt.xlabel('Epoch')
plt.legend(loc="upper left")
plt.show()
Now that you have finished creating the model, it's time to train it - and to see the results :)
Open up a terminal that supports the dependencies listed above, cd
into the folder where your Python file is located, and run python model_ftswish.py
. The training process should begin.
Once it finishes, you should also be able to see the results of the evaluation & visualization steps:
Test loss for Keras FTSwish CNN: 2.3128050004959104 / Test accuracy: 0.6650999784469604
As you can see, loss is still quite high, and accuracy relatively low - it's only correct in 2/3 of cases. This likely occurs because the CIFAR-10 dataset is relatively complex (with various objects in various shapes), which means that it's likely overfitting. Additional techniques such as data augmentation may help here.
But is it overfitting? Let's take a look at the visualizations.
Visually, the training process looks as follows.
Indeed: overfitting starts once the model seems to hit the 66% accuracy mark. Ever since, performance in terms of loss gets worse and worse. To fix this is not within the scope of this post, which was about FTSwish. However, what may be worthwhile is adding extra Conv2D layers, using Batch Normalization, or using data augmentation.
In this blog post, we've seen how to create and use the Flatten-T Swish (FTSwish) activation function with Keras. It included a recap of the FTSwish activation function, which was followed by an example implementation of the activation function.
I hope you've learnt something from today's blog post! Thanks for reading MachineCurve and happy engineering 😎
Chieng, H. H., Wahid, N., Ong, P., & Perla, S. R. K. (2018). Flatten-T Swish: a thresholded ReLU-Swish-like activation function for deep learning. arXiv preprint arXiv:1812.06247.
Learn how large language models and other foundation models are working and how you can train open source ones yourself.
Keras is a high-level API for TensorFlow. It is one of the most popular deep learning frameworks.
Read about the fundamentals of machine learning, deep learning and artificial intelligence.
To get in touch with me, please connect with me on LinkedIn. Make sure to write me a message saying hi!
The content on this website is written for educational purposes. In writing the articles, I have attempted to be as correct and precise as possible. Should you find any errors, please let me know by creating an issue or pull request in this GitHub repository.
All text on this website written by me is copyrighted and may not be used without prior permission. Creating citations using content from this website is allowed if a reference is added, including an URL reference to the referenced article.
If you have any questions or remarks, feel free to get in touch.
TensorFlow, the TensorFlow logo and any related marks are trademarks of Google Inc.
PyTorch, the PyTorch logo and any related marks are trademarks of The Linux Foundation.
Montserrat and Source Sans are fonts licensed under the SIL Open Font License version 1.1.
Mathjax is licensed under the Apache License, Version 2.0.