← Back to homepage

How to use K-fold Cross Validation with PyTorch?

February 2, 2021 by Chris

Machine learning models must be evaluated with a test set after they have been trained. We do this to ensure that models have not overfit and to ensure that they work with real-life datasets, which may have slightly deviating distributions compared to the training set.

But in order to make your model really robust, simply evaluating with a train/test split may not be enough.

For example, take the situation where you have a dataset composed of samples from two classes. Most of the samples in the first 80% of your dataset belong to class A, whereas most of the samples in the other 20% belong to class B. If you would take a simple 80/20 hold-out split, then your datasets would have vastly different distributions - and evaluation might result in wrong conclusions.

That's something what you want to avoid. In this article, you'll therefore learn about another technique that can be applied - K-fold Cross Validation. By generating train/test splits across multiple folds, you can perform multiple training and testing sessions, with different splits. You'll also see how you can use K-fold Cross Validation with PyTorch, one of the leading libraries for neural networks these days.

After reading this tutorial, you will...

Update 29/Mar/2021: fixed possible issue with weight leaks.

Update 15/Feb/2021: fixed small textual error.

Summary and code example: K-fold Cross Validation with PyTorch

Model evaluation is often performed with a hold-out split, where an often 80/20 split is made and where 80% of your dataset is used for training the model. and 20% for evaluating the model. While this is a simple approach, it is also very naïve, since it assumes that data is representative across the splits, that it's not a time series dataset and that there are no redundant samples within the datasets.

K-fold Cross Validation is a more robust evaluation technique. It splits the dataset in \(k-1\) training batches and 1 testing batch across \(k\) folds, or situations. Using the training batches, you can then train your model, and subsequently evaluate it with the testing batch. This allows you to train the model for multiple times with different dataset configurations. Even better, it allows you to be more confident in your model evaluation results.

Below, you will see a full example of using K-fold Cross Validation with PyTorch, using Scikit-learn's KFold functionality. It can be used on the go. If you want to understand things in more detail, however, it's best to continue reading the rest of the tutorial as well! 🚀

``` import os import torch from torch import nn from torchvision.datasets import MNIST from torch.utils.data import DataLoader, ConcatDataset from torchvision import transforms from sklearn.model_selection import KFold

def reset_weights(m): ''' Try resetting model weights to avoid weight leakage. ''' for layer in m.children(): if hasattr(layer, 'reset_parameters'): print(f'Reset trainable parameters of layer = {layer}') layer.reset_parameters()

class SimpleConvNet(nn.Module): ''' Simple Convolutional Neural Network ''' def init(self): super().init() self.layers = nn.Sequential( nn.Conv2d(1, 10, kernel_size=3), nn.ReLU(), nn.Flatten(), nn.Linear(26 * 26 * 10, 50), nn.ReLU(), nn.Linear(50, 20), nn.ReLU(), nn.Linear(20, 10) )

def forward(self, x): '''Forward pass''' return self.layers(x)

if name == 'main':

# Configuration options k_folds = 5 num_epochs = 1 loss_function = nn.CrossEntropyLoss()

# For fold results results = {}

# Set fixed random number seed torch.manual_seed(42)

# Prepare MNIST dataset by concatenating Train/Test part; we split later. dataset_train_part = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor(), train=True) dataset_test_part = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor(), train=False) dataset = ConcatDataset([dataset_train_part, dataset_test_part])

# Define the K-fold Cross Validator kfold = KFold(n_splits=k_folds, shuffle=True)

# Start print print('

Hi, I'm Chris!

I know a thing or two about AI and machine learning. Welcome to MachineCurve.com, where machine learning is explained in gentle terms.