When you train supervised machine learning models, you'll likely try multiple models, in order to find out how good they are. Part of this process is likely going to be the question how can I compare models objectively?
Training and testing datasets have been invented for this purpose. By splitting a small part off your full dataset, you create a dataset which (1) was not yet seen by the model, and which (2) you assume to approximate the distribution of the population, i.e. the real world scenario you wish to generate a predictive model for.
Now, when generating such a split, you should ensure that your splits are relatively unbiased. In this blog post, we'll cover one technique for doing so: K-fold Cross Validation. Firstly, we'll show you how such splits can be made naïvely - i.e., by a simple hold out split strategy. Then, we introduce K-fold Cross Validation, show you how it works, and why it can produce better results. This is followed by an example, created with Keras and Scikit-learn's KFold functions.
Are you ready? Let's go! 😎
Update 12/Feb/2021: added TensorFlow 2 to title; some styling changes.
Update 11/Jan/2021: added code example to start using K-fold CV straight away.
Update 04/Aug/2020: clarified the (in my view) necessity of validation set even after K-fold CV.
Update 11/Jun/2020: improved K-fold cross validation code based on reader comments.
This quick code can be used to perform K-fold Cross Validation with your TensorFlow/Keras model straight away. If you want to understand it in more detail, make sure to read the rest of the article below!
``` from tensorflow.keras.datasets import cifar10 from tensorflow.keras.models import Sequential from tensorflow.keras.layers import Dense, Flatten, Conv2D, MaxPooling2D from tensorflow.keras.losses import sparse_categorical_crossentropy from tensorflow.keras.optimizers import Adam from sklearn.model_selection import KFold import numpy as np
inputs = np.concatenate((input_train, input_test), axis=0) targets = np.concatenate((target_train, target_test), axis=0)
kfold = KFold(n_splits=num_folds, shuffle=True)
fold_no = 1 for train, test in kfold.split(inputs, targets):
# Define the model architecture model = Sequential() model.add(Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=input_shape)) model.add(MaxPooling2D(pool_size=(2, 2))) model.add(Conv2D(64, kernel_size=(3, 3), activation='relu')) model.add(MaxPooling2D(pool_size=(2, 2))) model.add(Flatten()) model.add(Dense(256, activation='relu')) model.add(Dense(128, activation='relu')) model.add(Dense(no_classes, activation='softmax'))
# Compile the model model.compile(loss=loss_function, optimizer=optimizer, metrics=['accuracy'])
# Generate a print print('
Learn how large language models and other foundation models are working and how you can train open source ones yourself.
Keras is a high-level API for TensorFlow. It is one of the most popular deep learning frameworks.
Read about the fundamentals of machine learning, deep learning and artificial intelligence.
To get in touch with me, please connect with me on LinkedIn. Make sure to write me a message saying hi!
The content on this website is written for educational purposes. In writing the articles, I have attempted to be as correct and precise as possible. Should you find any errors, please let me know by creating an issue or pull request in this GitHub repository.
All text on this website written by me is copyrighted and may not be used without prior permission. Creating citations using content from this website is allowed if a reference is added, including an URL reference to the referenced article.
If you have any questions or remarks, feel free to get in touch.
TensorFlow, the TensorFlow logo and any related marks are trademarks of Google Inc.
PyTorch, the PyTorch logo and any related marks are trademarks of The Linux Foundation.
Montserrat and Source Sans are fonts licensed under the SIL Open Font License version 1.1.
Mathjax is licensed under the Apache License, Version 2.0.