← Back to homepage

TensorFlow model optimization: an introduction to Pruning

September 23, 2020 by Chris

Enjoying the benefits of machine learning models means that they are deployed in the field after training has finished. However, if you're counting on great speed with which predictions for new data - called model inference - are generated, then it's possible that you're getting a bit intimidated. If you really want your models to run with speed, it's likely that you'll have to buy powerful equipment - like massive GPUs - which come at significant cost.

If you don't, your models will run slower; sometimes, really slow - especially when your models are big. And big models are very common in today's state-of-the-art in machine learning.

Fortunately, modern machine learning frameworks such as TensorFlow attempt to help machine learning engineers. Through extensions such as TF Lite, methods such as quantization can be used to optimize your model. While with quantization the number representation of your machine learning model is adapted to benefit size and speed (often at the cost of precision), we'll take a look at model pruning in this article. Firstly, we'll take a look at why model optimization is necessary. Subsequently, we'll introduce pruning - by taking a look at how neural networks work as well as questioning why we should keep weights that don't contribute to model performance.

Following the theoretical part of this article, we'll build a Keras model and subsequently apply pruning to optimize it. This shows you how to apply pruning to your TensorFlow/Keras model with a real example. Finally, when we know how to do is, we'll continue by combining pruning with quantization for compound optimization. Obviously, this also includes adding quantization to the Keras example that we created before.

Are you ready? Let's go! 😎

Update 02/Oct/2020: added reference to article about pruning schedules as a suggestion.

The need for model optimization

Machine learning models can be used for a wide variety of use cases, for example the detection of objects:


If you're into object detection, it's likely that you have heard about machine learning architectures like RCNN, Faster-RCNN, YOLO (recently, version 5 was released!) and others. Those are increasingly state-of-the-art architectures that can be used to detect objects very efficiently based on a training dataset.

The architectures are composed of a pipeline that includes a feature extraction model, region proposal network, and subsequently a classification model (Data Science Stack Exchange, n.d.). By consequence, this pipeline is capable of extracting interesting features from your input data, detecting regions of interest for classification, and finally classifying those regions - resulting in videos like the one above.

Now, while they are very performant in terms of object detection, the neural networks used for classifying (and sometimes also for feature extraction/region selection) also come at a downside: they are very big.

For example, the neural nets, which can include VGG-16, RESNET-50, and others, have the following size when used as a tf.keras application (for example, as a convolutional base):

| Model | Size | Top-1 Accuracy | Top-5 Accuracy | Parameters | Depth | |

Hi, I'm Chris!

I know a thing or two about AI and machine learning. Welcome to MachineCurve.com, where machine learning is explained in gentle terms.