← Back to homepage

Building a Decision Tree for classification with Python and Scikit-learn

January 23, 2022 by Chris

Although we hear a lot about deep learning these days, there is a wide variety of other machine learning techniques that can still be very useful. Decision tree learning is one of them. By recursively partitioning your feature space into segments that group common elements yielding a class outcome together, it becomes possible to build predictive models for both classification and regression.

In today's tutorial, you will learn to build a decision tree for classification. You will do so using Python and one of the key machine learning libraries for the Python ecosystem, Scikit-learn. After reading it, you will understand...

Are you ready? Let's take a look! 😎

What are decision trees?

Suppose that you have a dataset that describes wines in many columns, and the wine variety in the last column.

These independent variables can be used to build a predictive model that, given some new inputs, tells us whether a specific measurement comes from a wine of variety one, two or three, ...

As you already know, there are many techniques for building predictive models. Deep neural networks are very popular these days, but there are also approaches that are a bit more classic - but not necessarily wrong.

Decision trees are one such technique. They essentially work by breaking down the decision-making process into many smaller questions. In the wine scenario, as an example, you know that wines can be separated by color. This distinguishes between wine varieties that make white wine and varieties that make red wine. There are more such questions that can be asked: what is the alcohol content? What is the magnesium content? And so forth.

An example of a decision tree. Each variety (there are three) represents a different color - orange, green and purple. Both color and color intensity point towards an estimated class given a sub question stage. For example, the first question points towards class 2, the path of which gets stronger over time. Still, it is possible to end up with both class 1 and class 3 - by simply taking the other path or diverting down the road.

By structuring these questions in a smart way, you can separate the classes (in this case, the varieties) by simply providing answers that point you to a specific variety. And precisely that is what decision trees are: they are tree-like structures that break your classification problem into many smaller sub questions given the inputs you have.

Decision trees can be constructed manually. More relevant however is the automated construction of decision trees. And that is precisely what you will be looking at today, by building one with Scikit-learn.

How are decision tree classifiers learned in Scikit-learn?

In today's tutorial, you will be building a decision tree for classification with the DecisionTreeClassifier class in Scikit-learn. When learning a decision tree, it follows the Classification And Regression Trees or CART algorithm - at least, an optimized version of it. Let's first take a look at how this algorithm works, before we build a classification decision tree.

Learning a CART tree

At a high level, a CART tree is built in the following way, using some split evaluation criterion (we will cover that in a few moments):

  1. Compute all splits that can be made (often, this is a selection over the entire feature space). In other words, do this for each of the independent variables, and a target value. For example, in the tree above, "Proline <= 755.0" in the root node is one such split at the first level. It's the proline variable, with 755.0 as the target value.
  2. For each split, compute the value of the split evaluation criterion.
  3. Pick the one with the best value as the split.
  4. Repeat this process for the next level, until split generation is exhausted (by either a lack of further independent variables or a user-constrained depth of the decision tree).

In other words, the decision tree learning process is a recursive process that picks the best split at each level for building the tree until the tree is exhausted or a user-defined criterion (such as maximum tree depth) is reached).

Now, regarding the split evaluation criterion, Scikit-learn based CART trees use two types of criterions: the Gini impurity and the entropy metrics.

Gini impurity

The first - and default - split evaluation metric available in Scikit's decision tree learner is Gini impurity:

The metric is defined in the following way:

Gini impurity (named after Italian mathematician Corrado Gini) is a measure of how often a randomly chosen element from the set would be incorrectly labeled if it was randomly labeled according to the distribution of labels in the subset.

Wikipedia (2004)

Suppose that we...

What is the probability that we classify it wrongly? That's the Gini impurity for the specific sample.

Random classification

For example, if we have 100 samples, where 25 belong to class A and 75 to class B, these are our probabilities:

So, what's the probability of classifying it wrongly?

That's 18.75 + 18.75 = 37.5%. In other words, the Gini impurity of this data scenario with random classification is 0.375.

By minimizing the Gini impurity of the scenario, we get the best classification for our selection.

Adding a split

Suppose that instead of randomly classifying our samples, we add a decision boundary. In other words, we split our sample space in two, or in other words, we add a split.

We can simply compute the Gini impurity of this split by computing a weighted average of the Gini impurities of both sides of the split.

Suppose that we add the following split to the very simple two-dimensional dataset below, generated by the OPTICS clustering algorithm:

Now, for both sides of the split, we repeat the same:

On the left, you can clearly see that Gini impurity is 0: if we pick a sample, it can be classified as blue only, because the only class available in that side is blue.

On the right, impurity is very low, but not zero: there are some blue samples available, and Gini impurity is approximately 0.00398.

Clearly, a better split is available at X[0] ~ 5, where Gini impurity would be 0... ;-) But this is just for demonstrative purposes!

Now, how good is a split?

Now that you understand how Gini impurity can be computed given a split, we can look at the final aspect of computing the goodness-of-split using Gini impurity...how to decide about the contribution of a split?

At each level of your decision tree, you know the following:

Picking the best split now involves picking the split with the greatest reduction in total Gini impurity. This can be computed by the weighted average mentioned before. In the case above...

If this is the greatest reduction of Gini impurity (by computing the difference between existing impurity and resulting impurity), then it's the split to choose! :)

Entropy

A similar but slightly different metric that can be used is that of entropy:

For using entropy, you'll have to repeat all the steps executed above. Then, it simply boils down to adding the probabilities computed above into the formula... and you pick the split that yields lowest entropy.

Choosing between Gini impurity and entropy

Model performance-wise, there is little reason to choose between Gini impurity and entropy. In an analysis work, Raileanu and Stoffel (2004) identified that...

In other words, I would go with Gini impurity - and assume that's why it's the default option in Scikit-learn, too! :)

Building a Decision Tree for classification with Scikit-learn

Now that you understand some of the theory behind CART trees, it's time to build one such tree for classification. You will use one of the default machine learning libraries for this purpose, being Scikit-learn. It's a three-step process:

Ensure that you have installed the dependencies

Before writing any code, it's important that you have installed all the dependencies on your machine:

Today's dataset

If you have been a frequent reader of MachineCurve tutorials, you know that I favor out-of-the-box datasets that come preinstalled with machine learning libraries used during tutorials.

That's very simple - although in the real world data is key to success, these tutorials are meant to tell you something about the models you're building and hence lengthy sections on datasets can be distracting.

For that reason, today, you will be using one of the datasets that comes with Scikit-learn out of the box: the wine dataset.

The wine dataset is a classic and very easy multi-class classification dataset.

Scikit-learn

It is a dataset with 178 samples and 13 attributes that assigns each sample to a wine variety (indeed, we're using a dataset similar to what you have read about before!). The dataset has 3 wine varieties. These are the attributes that are part of the wine dataset:

In other words, in the various dimensions of the independent variables, many splits can be made using which many Gini impurity/entropy values can be computed... after which we can choose the best split every time.

Specifying the Python imports

Now that you understand something about decision tree learning and today's dataset, let's start writing some code. Open up a Python file in your favorite IDE or create a Jupyter Notebook, and let's add some imports:

from sklearn.datasets import load_wine
from sklearn import tree
import matplotlib.pyplot as plt

These imports speak pretty much for themselves. The first is related to the dataset that you will be using. The second is the representation of decision trees within Scikit-learn, and the latter one is the PyPlot functionality from Matplotlib.

Loading our dataset

In Python, it's good practice to work with definitions. They make code reusable and allow you to structure your code into logical segments. In today's model, you will apply these definitions too.

The first one that you will create is one for loading your dataset. It simply calls load_wine(...) and passes the return_X_y attribute set to True. This way, your dataset will be returned in two separate lists - X and y.

def load_dataset():
    """ Load today's dataset. """
    return load_wine(return_X_y=True)

Defining feature and class names

Next up, you will specify a definition that returns names of the features (the independent variables) and the eventual class names.

def feature_and_class_names():
    """ Define feature and class names. """
    feature_names = ["Alcohol","Malic acid","Ash","Alcalinity of ash","Magnesium","Total phenols","Flavanoids","Nonflavanoid phenols","Proanthocyanins","Color intensity","Hue","OD280/OD315 of diluted wines","Proline",]
    class_names = ["Class 1", "Class 2", "Class 3"]

    return feature_names, class_names

Initializing the classifier and fitting the data

Per the Scikit-learn documentation of the DecisionTreeClassifier model type that you will use, there are some options that you must include in your model design. These are the options that are configurable.

Let's now create a definition for initializing your decision tree. We choose Gini impurity, best splitting, and letting maximum depth be guided by the minimum of samples necessary for generating a split. In other words, we risk overfitting to avoid adding a lot of complexity to the tree. In practice, that wouldn't be a good

def init_tree():
    """ Initialize the DecisionTreeClassifier. """
    return tree.DecisionTreeClassifier()

Then, we can add a definition for training the tree:

def train_tree(empty_tree, X, Y):
    """ Train the DecisionTreeClassifier. """
    return empty_tree.fit(X, Y)

Plotting the decision tree

Finally, what's left is a definition for plotting the decision tree:

def plot_tree(trained_tree):
    """ Plot the DecisionTreeClassifier. """

    # Load feature and class names
    feature_names, class_names = feature_and_class_names()

    # Plot tree
    tree.plot_tree(trained_tree, feature_names=feature_names, class_names=class_names, fontsize=12, rounded=True, filled=True)
    plt.show()

Merging everything together

Then, you merge everything together ...

def decision_tree_classifier():
    """ End-to-end training of decision tree classifier. """

    # Load dataset
    X, Y = load_dataset()

    # Train the decision tree
    tree = init_tree()
    trained_tree = train_tree(tree, X, Y)

    # Plot the trained decision tree
    plot_tree(trained_tree)


if __name__ == '__main__':
    decision_tree_classifier()

Full model code

If you want to get started immediately, here is the full code example for creating a classification decision tree with Scikit-learn.

from sklearn.datasets import load_wine
from sklearn import tree
import matplotlib.pyplot as plt

def load_dataset():
    """ Load today's dataset. """
    return load_wine(return_X_y=True)


def feature_and_class_names():
    """ Define feature and class names. """
    feature_names = ["Alcohol","Malic acid","Ash","Alcalinity of ash","Magnesium","Total phenols","Flavanoids","Nonflavanoid phenols","Proanthocyanins","Color intensity","Hue","OD280/OD315 of diluted wines","Proline",]
    class_names = ["Class 1", "Class 2", "Class 3"]

    return feature_names, class_names


def init_tree():
    """ Initialize the DecisionTreeClassifier. """
    return tree.DecisionTreeClassifier()


def train_tree(empty_tree, X, Y):
    """ Train the DecisionTreeClassifier. """
    return empty_tree.fit(X, Y)


def plot_tree(trained_tree):
    """ Plot the DecisionTreeClassifier. """

    # Load feature and class names
    feature_names, class_names = feature_and_class_names()

    # Plot tree
    tree.plot_tree(trained_tree, feature_names=feature_names, class_names=class_names, fontsize=12, rounded=True, filled=True)
    plt.show()


def decision_tree_classifier():
    """ End-to-end training of decision tree classifier. """

    # Load dataset
    X, Y = load_dataset()

    # Train the decision tree
    tree = init_tree()
    trained_tree = train_tree(tree, X, Y)

    # Plot the trained decision tree
    plot_tree(trained_tree)


if __name__ == '__main__':
    decision_tree_classifier()

References

Scikit-learn. (n.d.). 1.10. Decision trees — scikit-learn 0.24.0 documentation. scikit-learn: machine learning in Python — scikit-learn 0.16.1 documentation. Retrieved January 21, 2022, from https://scikit-learn.org/stable/modules/tree.html

Scikit-learn. (n.d.). Sklearn.tree.DecisionTreeClassifier. scikit-learn. Retrieved January 21, 2022, from https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html#sklearn.tree.DecisionTreeClassifier

Scikit-learn. (n.d.). Sklearn.datasets.load_wine. scikit-learn. Retrieved January 21, 2022, from https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_wine.html#sklearn.datasets.load_wine

Wikipedia. (2004, April 5). Decision tree learning. Wikipedia, the free encyclopedia. Retrieved January 22, 2022, from https://en.wikipedia.org/wiki/Decision_tree_learning

Raileanu, L. E., & Stoffel, K. (2004). Theoretical comparison between the Gini index and information gain criteria. Annals of Mathematics and Artificial Intelligence, 41(1), 77-93. https://doi.org/10.1023/b:amai.0000018580.96245.c6

Hi, I'm Chris!

I know a thing or two about AI and machine learning. Welcome to MachineCurve.com, where machine learning is explained in gentle terms.