Enjoying the benefits of machine learning models means that they are deployed in the field after training has finished. However, if you're counting on great speed with which predictions for new data - called model inference - are generated, then it's possible that you're getting a bit intimidated. If you really want your models to run with speed, it's likely that you'll have to buy powerful equipment - like massive GPUs - which come at significant cost.
If you don't, your models will run slower; sometimes, really slow - especially when your models are big. And big models are very common in today's state-of-the-art in machine learning.
Fortunately, modern machine learning frameworks such as TensorFlow attempt to help machine learning engineers. Through extensions such as TF Lite, methods such as quantization can be used to optimize your model. While with quantization the number representation of your machine learning model is adapted to benefit size and speed (often at the cost of precision), we'll take a look at model pruning in this article. Firstly, we'll take a look at why model optimization is necessary. Subsequently, we'll introduce pruning - by taking a look at how neural networks work as well as questioning why we should keep weights that don't contribute to model performance.
Following the theoretical part of this article, we'll build a Keras model and subsequently apply pruning to optimize it. This shows you how to apply pruning to your TensorFlow/Keras model with a real example. Finally, when we know how to do is, we'll continue by combining pruning with quantization for compound optimization. Obviously, this also includes adding quantization to the Keras example that we created before.
Are you ready? Let's go! 😎
Update 02/Oct/2020: added reference to article about pruning schedules as a suggestion.
Machine learning models can be used for a wide variety of use cases, for example the detection of objects:
If you're into object detection, it's likely that you have heard about machine learning architectures like RCNN, Faster-RCNN, YOLO (recently, version 5 was released!) and others. Those are increasingly state-of-the-art architectures that can be used to detect objects very efficiently based on a training dataset.
The architectures are composed of a pipeline that includes a feature extraction model, region proposal network, and subsequently a classification model (Data Science Stack Exchange, n.d.). By consequence, this pipeline is capable of extracting interesting features from your input data, detecting regions of interest for classification, and finally classifying those regions - resulting in videos like the one above.
Now, while they are very performant in terms of object detection, the neural networks used for classifying (and sometimes also for feature extraction/region selection) also come at a downside: they are very big.
| Model | Size | Top-1 Accuracy | Top-5 Accuracy | Parameters | Depth | |
Learn how large language models are working and how you can train open source ones yourself.
Keras is a high-level API for TensorFlow. It is one of the most popular deep learning frameworks.
Read about the fundamentals of machine learning, deep learning and artificial intelligence.
To get in touch with me, please connect with me on LinkedIn. Make sure to write me a message saying hi!
The content on this website is written for educational purposes. In writing the articles, I have attempted to be as correct and precise as possible. Should you find any errors, please let me know by creating an issue or pull request in this GitHub repository.
All text on this website written by me is copyrighted and may not be used without prior permission. Creating citations using content from this website is allowed if a reference is added, including an URL reference to the referenced article.
If you have any questions or remarks, feel free to get in touch.
TensorFlow, the TensorFlow logo and any related marks are trademarks of Google Inc.
PyTorch, the PyTorch logo and any related marks are trademarks of The Linux Foundation.
Mathjax is licensed under the Apache License, Version 2.0.