Machine learning models must be evaluated with a test set after they have been trained. We do this to ensure that models have not overfit and to ensure that they work with real-life datasets, which may have slightly deviating distributions compared to the training set.
But in order to make your model really robust, simply evaluating with a train/test split may not be enough.
For example, take the situation where you have a dataset composed of samples from two classes. Most of the samples in the first 80% of your dataset belong to class A, whereas most of the samples in the other 20% belong to class B. If you would take a simple 80/20 hold-out split, then your datasets would have vastly different distributions - and evaluation might result in wrong conclusions.
That's something what you want to avoid. In this article, you'll therefore learn about another technique that can be applied - K-fold Cross Validation. By generating train/test splits across multiple folds, you can perform multiple training and testing sessions, with different splits. You'll also see how you can use K-fold Cross Validation with PyTorch, one of the leading libraries for neural networks these days.
After reading this tutorial, you will...
Update 29/Mar/2021: fixed possible issue with weight leaks.
Update 15/Feb/2021: fixed small textual error.
Model evaluation is often performed with a hold-out split, where an often 80/20 split is made and where 80% of your dataset is used for training the model. and 20% for evaluating the model. While this is a simple approach, it is also very naïve, since it assumes that data is representative across the splits, that it's not a time series dataset and that there are no redundant samples within the datasets.
K-fold Cross Validation is a more robust evaluation technique. It splits the dataset in \(k-1\) training batches and 1 testing batch across \(k\) folds, or situations. Using the training batches, you can then train your model, and subsequently evaluate it with the testing batch. This allows you to train the model for multiple times with different dataset configurations. Even better, it allows you to be more confident in your model evaluation results.
Below, you will see a full example of using K-fold Cross Validation with PyTorch, using Scikit-learn's KFold
functionality. It can be used on the go. If you want to understand things in more detail, however, it's best to continue reading the rest of the tutorial as well! 🚀
``` import os import torch from torch import nn from torchvision.datasets import MNIST from torch.utils.data import DataLoader, ConcatDataset from torchvision import transforms from sklearn.model_selection import KFold
def reset_weights(m): ''' Try resetting model weights to avoid weight leakage. ''' for layer in m.children(): if hasattr(layer, 'reset_parameters'): print(f'Reset trainable parameters of layer = {layer}') layer.reset_parameters()
class SimpleConvNet(nn.Module): ''' Simple Convolutional Neural Network ''' def init(self): super().init() self.layers = nn.Sequential( nn.Conv2d(1, 10, kernel_size=3), nn.ReLU(), nn.Flatten(), nn.Linear(26 * 26 * 10, 50), nn.ReLU(), nn.Linear(50, 20), nn.ReLU(), nn.Linear(20, 10) )
def forward(self, x): '''Forward pass''' return self.layers(x)
if name == 'main':
# Configuration options k_folds = 5 num_epochs = 1 loss_function = nn.CrossEntropyLoss()
# For fold results results = {}
# Set fixed random number seed torch.manual_seed(42)
# Prepare MNIST dataset by concatenating Train/Test part; we split later. dataset_train_part = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor(), train=True) dataset_test_part = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor(), train=False) dataset = ConcatDataset([dataset_train_part, dataset_test_part])
# Define the K-fold Cross Validator kfold = KFold(n_splits=k_folds, shuffle=True)
# Start print print('
Learn how large language models and other foundation models are working and how you can train open source ones yourself.
Keras is a high-level API for TensorFlow. It is one of the most popular deep learning frameworks.
Read about the fundamentals of machine learning, deep learning and artificial intelligence.
To get in touch with me, please connect with me on LinkedIn. Make sure to write me a message saying hi!
The content on this website is written for educational purposes. In writing the articles, I have attempted to be as correct and precise as possible. Should you find any errors, please let me know by creating an issue or pull request in this GitHub repository.
All text on this website written by me is copyrighted and may not be used without prior permission. Creating citations using content from this website is allowed if a reference is added, including an URL reference to the referenced article.
If you have any questions or remarks, feel free to get in touch.
TensorFlow, the TensorFlow logo and any related marks are trademarks of Google Inc.
PyTorch, the PyTorch logo and any related marks are trademarks of The Linux Foundation.
Montserrat and Source Sans are fonts licensed under the SIL Open Font License version 1.1.
Mathjax is licensed under the Apache License, Version 2.0.