Although we hear a lot about deep learning these days, there is a wide variety of other machine learning techniques that can still be very useful. Decision tree learning is one of them. By recursively partitioning your feature space into segments that group common elements yielding a class outcome together, it becomes possible to build predictive models for both classification and regression.
In today's tutorial, you will learn to build a decision tree for classification. You will do so using Python and one of the key machine learning libraries for the Python ecosystem, Scikit-learn. After reading it, you will understand...
Are you ready? Let's take a look! 😎
Suppose that you have a dataset that describes wines in many columns, and the wine variety in the last column.
These independent variables can be used to build a predictive model that, given some new inputs, tells us whether a specific measurement comes from a wine of variety one, two or three, ...
As you already know, there are many techniques for building predictive models. Deep neural networks are very popular these days, but there are also approaches that are a bit more classic - but not necessarily wrong.
Decision trees are one such technique. They essentially work by breaking down the decision-making process into many smaller questions. In the wine scenario, as an example, you know that wines can be separated by color. This distinguishes between wine varieties that make white wine and varieties that make red wine. There are more such questions that can be asked: what is the alcohol content? What is the magnesium content? And so forth.
An example of a decision tree. Each variety (there are three) represents a different color - orange, green and purple. Both color and color intensity point towards an estimated class given a sub question stage. For example, the first question points towards class 2, the path of which gets stronger over time. Still, it is possible to end up with both class 1 and class 3 - by simply taking the other path or diverting down the road.
By structuring these questions in a smart way, you can separate the classes (in this case, the varieties) by simply providing answers that point you to a specific variety. And precisely that is what decision trees are: they are tree-like structures that break your classification problem into many smaller sub questions given the inputs you have.
Decision trees can be constructed manually. More relevant however is the automated construction of decision trees. And that is precisely what you will be looking at today, by building one with Scikit-learn.
In today's tutorial, you will be building a decision tree for classification with the DecisionTreeClassifier
class in Scikit-learn. When learning a decision tree, it follows the Classification And Regression Trees or CART algorithm - at least, an optimized version of it. Let's first take a look at how this algorithm works, before we build a classification decision tree.
At a high level, a CART tree is built in the following way, using some split evaluation criterion (we will cover that in a few moments):
In other words, the decision tree learning process is a recursive process that picks the best split at each level for building the tree until the tree is exhausted or a user-defined criterion (such as maximum tree depth) is reached).
Now, regarding the split evaluation criterion, Scikit-learn based CART trees use two types of criterions: the Gini impurity and the entropy metrics.
The first - and default - split evaluation metric available in Scikit's decision tree learner is Gini impurity:
The metric is defined in the following way:
Gini impurity (named after Italian mathematician Corrado Gini) is a measure of how often a randomly chosen element from the set would be incorrectly labeled if it was randomly labeled according to the distribution of labels in the subset.
Wikipedia (2004)
Suppose that we...
What is the probability that we classify it wrongly? That's the Gini impurity for the specific sample.
For example, if we have 100 samples, where 25 belong to class A and 75 to class B, these are our probabilities:
So, what's the probability of classifying it wrongly?
That's 18.75 + 18.75 = 37.5%. In other words, the Gini impurity of this data scenario with random classification is 0.375.
By minimizing the Gini impurity of the scenario, we get the best classification for our selection.
Suppose that instead of randomly classifying our samples, we add a decision boundary. In other words, we split our sample space in two, or in other words, we add a split.
We can simply compute the Gini impurity of this split by computing a weighted average of the Gini impurities of both sides of the split.
Suppose that we add the following split to the very simple two-dimensional dataset below, generated by the OPTICS clustering algorithm:
Now, for both sides of the split, we repeat the same:
On the left, you can clearly see that Gini impurity is 0: if we pick a sample, it can be classified as blue only, because the only class available in that side is blue.
On the right, impurity is very low, but not zero: there are some blue samples available, and Gini impurity is approximately 0.00398.
Clearly, a better split is available at X[0] ~ 5
, where Gini impurity would be 0... ;-) But this is just for demonstrative purposes!
Now that you understand how Gini impurity can be computed given a split, we can look at the final aspect of computing the goodness-of-split using Gini impurity...how to decide about the contribution of a split?
At each level of your decision tree, you know the following:
Picking the best split now involves picking the split with the greatest reduction in total Gini impurity. This can be computed by the weighted average mentioned before. In the case above...
If this is the greatest reduction of Gini impurity (by computing the difference between existing impurity and resulting impurity), then it's the split to choose! :)
A similar but slightly different metric that can be used is that of entropy:
For using entropy, you'll have to repeat all the steps executed above. Then, it simply boils down to adding the probabilities computed above into the formula... and you pick the split that yields lowest entropy.
Model performance-wise, there is little reason to choose between Gini impurity and entropy. In an analysis work, Raileanu and Stoffel (2004) identified that...
In other words, I would go with Gini impurity - and assume that's why it's the default option in Scikit-learn, too! :)
Now that you understand some of the theory behind CART trees, it's time to build one such tree for classification. You will use one of the default machine learning libraries for this purpose, being Scikit-learn. It's a three-step process:
Before writing any code, it's important that you have installed all the dependencies on your machine:
pip install -U scikit-learn
.pip install matplotlib
.If you have been a frequent reader of MachineCurve tutorials, you know that I favor out-of-the-box datasets that come preinstalled with machine learning libraries used during tutorials.
That's very simple - although in the real world data is key to success, these tutorials are meant to tell you something about the models you're building and hence lengthy sections on datasets can be distracting.
For that reason, today, you will be using one of the datasets that comes with Scikit-learn out of the box: the wine dataset.
The wine dataset is a classic and very easy multi-class classification dataset.
Scikit-learn
It is a dataset with 178 samples and 13 attributes that assigns each sample to a wine variety (indeed, we're using a dataset similar to what you have read about before!). The dataset has 3 wine varieties. These are the attributes that are part of the wine dataset:
In other words, in the various dimensions of the independent variables, many splits can be made using which many Gini impurity/entropy values can be computed... after which we can choose the best split every time.
Now that you understand something about decision tree learning and today's dataset, let's start writing some code. Open up a Python file in your favorite IDE or create a Jupyter Notebook, and let's add some imports:
from sklearn.datasets import load_wine
from sklearn import tree
import matplotlib.pyplot as plt
These imports speak pretty much for themselves. The first is related to the dataset that you will be using. The second is the representation of decision trees within Scikit-learn, and the latter one is the PyPlot functionality from Matplotlib.
In Python, it's good practice to work with definitions. They make code reusable and allow you to structure your code into logical segments. In today's model, you will apply these definitions too.
The first one that you will create is one for loading your dataset. It simply calls load_wine(...)
and passes the return_X_y
attribute set to True
. This way, your dataset will be returned in two separate lists - X
and y
.
def load_dataset():
""" Load today's dataset. """
return load_wine(return_X_y=True)
Next up, you will specify a definition that returns names of the features (the independent variables) and the eventual class names.
def feature_and_class_names():
""" Define feature and class names. """
feature_names = ["Alcohol","Malic acid","Ash","Alcalinity of ash","Magnesium","Total phenols","Flavanoids","Nonflavanoid phenols","Proanthocyanins","Color intensity","Hue","OD280/OD315 of diluted wines","Proline",]
class_names = ["Class 1", "Class 2", "Class 3"]
return feature_names, class_names
Per the Scikit-learn documentation of the DecisionTreeClassifier
model type that you will use, there are some options that you must include in your model design. These are the options that are configurable.
gini
.Let's now create a definition for initializing your decision tree. We choose Gini impurity, best splitting, and letting maximum depth be guided by the minimum of samples necessary for generating a split. In other words, we risk overfitting to avoid adding a lot of complexity to the tree. In practice, that wouldn't be a good
def init_tree():
""" Initialize the DecisionTreeClassifier. """
return tree.DecisionTreeClassifier()
Then, we can add a definition for training the tree:
def train_tree(empty_tree, X, Y):
""" Train the DecisionTreeClassifier. """
return empty_tree.fit(X, Y)
Finally, what's left is a definition for plotting the decision tree:
def plot_tree(trained_tree):
""" Plot the DecisionTreeClassifier. """
# Load feature and class names
feature_names, class_names = feature_and_class_names()
# Plot tree
tree.plot_tree(trained_tree, feature_names=feature_names, class_names=class_names, fontsize=12, rounded=True, filled=True)
plt.show()
Then, you merge everything together ...
def decision_tree_classifier():
""" End-to-end training of decision tree classifier. """
# Load dataset
X, Y = load_dataset()
# Train the decision tree
tree = init_tree()
trained_tree = train_tree(tree, X, Y)
# Plot the trained decision tree
plot_tree(trained_tree)
if __name__ == '__main__':
decision_tree_classifier()
If you want to get started immediately, here is the full code example for creating a classification decision tree with Scikit-learn.
from sklearn.datasets import load_wine
from sklearn import tree
import matplotlib.pyplot as plt
def load_dataset():
""" Load today's dataset. """
return load_wine(return_X_y=True)
def feature_and_class_names():
""" Define feature and class names. """
feature_names = ["Alcohol","Malic acid","Ash","Alcalinity of ash","Magnesium","Total phenols","Flavanoids","Nonflavanoid phenols","Proanthocyanins","Color intensity","Hue","OD280/OD315 of diluted wines","Proline",]
class_names = ["Class 1", "Class 2", "Class 3"]
return feature_names, class_names
def init_tree():
""" Initialize the DecisionTreeClassifier. """
return tree.DecisionTreeClassifier()
def train_tree(empty_tree, X, Y):
""" Train the DecisionTreeClassifier. """
return empty_tree.fit(X, Y)
def plot_tree(trained_tree):
""" Plot the DecisionTreeClassifier. """
# Load feature and class names
feature_names, class_names = feature_and_class_names()
# Plot tree
tree.plot_tree(trained_tree, feature_names=feature_names, class_names=class_names, fontsize=12, rounded=True, filled=True)
plt.show()
def decision_tree_classifier():
""" End-to-end training of decision tree classifier. """
# Load dataset
X, Y = load_dataset()
# Train the decision tree
tree = init_tree()
trained_tree = train_tree(tree, X, Y)
# Plot the trained decision tree
plot_tree(trained_tree)
if __name__ == '__main__':
decision_tree_classifier()
Scikit-learn. (n.d.). 1.10. Decision trees — scikit-learn 0.24.0 documentation. scikit-learn: machine learning in Python — scikit-learn 0.16.1 documentation. Retrieved January 21, 2022, from https://scikit-learn.org/stable/modules/tree.html
Scikit-learn. (n.d.). Sklearn.tree.DecisionTreeClassifier. scikit-learn. Retrieved January 21, 2022, from https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html#sklearn.tree.DecisionTreeClassifier
Scikit-learn. (n.d.). Sklearn.datasets.load_wine. scikit-learn. Retrieved January 21, 2022, from https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_wine.html#sklearn.datasets.load_wine
Wikipedia. (2004, April 5). Decision tree learning. Wikipedia, the free encyclopedia. Retrieved January 22, 2022, from https://en.wikipedia.org/wiki/Decision_tree_learning
Raileanu, L. E., & Stoffel, K. (2004). Theoretical comparison between the Gini index and information gain criteria. Annals of Mathematics and Artificial Intelligence, 41(1), 77-93. https://doi.org/10.1023/b:amai.0000018580.96245.c6
Learn how large language models and other foundation models are working and how you can train open source ones yourself.
Keras is a high-level API for TensorFlow. It is one of the most popular deep learning frameworks.
Read about the fundamentals of machine learning, deep learning and artificial intelligence.
To get in touch with me, please connect with me on LinkedIn. Make sure to write me a message saying hi!
The content on this website is written for educational purposes. In writing the articles, I have attempted to be as correct and precise as possible. Should you find any errors, please let me know by creating an issue or pull request in this GitHub repository.
All text on this website written by me is copyrighted and may not be used without prior permission. Creating citations using content from this website is allowed if a reference is added, including an URL reference to the referenced article.
If you have any questions or remarks, feel free to get in touch.
TensorFlow, the TensorFlow logo and any related marks are trademarks of Google Inc.
PyTorch, the PyTorch logo and any related marks are trademarks of The Linux Foundation.
Montserrat and Source Sans are fonts licensed under the SIL Open Font License version 1.1.
Mathjax is licensed under the Apache License, Version 2.0.