Sometimes, you may wish to perform cropping on the input images that you are feeding to your neural network. While strictly speaking a part of data processing in many cases, it can be interesting to move cropping your input data to the neural network itself, because then you might not need to adapt a full dataset in advance.
In TensorFlow and Keras, cropping your input data is relatively easy, using the Cropping layers readily available there.
In PyTorch, this is different, because Cropping layers are not part of the PyTorch API.
In this article, you will learn how you can perform Cropping within PyTorch anyway - by using the
ZeroPad2d layer, which performs zero padding. By using it in an inverse way, we can remove padding (and hence perform cropping) instead of adding it.
Ready? Let's take a look. 😎
For creating our Cropping layer, we will be using the
ZeroPad2d layer that is available within PyTorch.
Normally, it's used for adding a box of pixels around the input data - which is what padding does. In that case, it's used with positive padding. In the image below, on the left, you can see what happens when it's called with a +1 padding - an extra box of zero-valued pixels is added around the input image.
Now, what if we used a -1 padding instead? You would expect that padding then works in the opposite direction, meaning that a box is not added, but removed. And precisely this effect is what we will use for creating a Cropping layer for your PyTorch model.
Calling Zero Padding with a positive padding results in a zero-valued box of pixels being added to your input image. Using a negative padding removes data from your image.
Let's now take a look at how we can implement
ZeroPad2d for generating a Cropping layer with PyTorch. First, it's time to write down our imports.
import os import torch from torch import nn from torchvision.datasets import MNIST from torch.utils.data import DataLoader from torchvision import transforms import matplotlib.pyplot as plt
These are relatively straight-forward: there are many
torch related imports, which are explained in our articles on PyTorch based networks such as the ConvNet.
Time to move forward with the
CroppingNetwork. Here it is:
class CroppingNetwork(nn.Module): ''' Simple network with one Cropping layer ''' def __init__(self): super().__init__() self.layers = nn.Sequential( nn.ZeroPad2d(-1), nn.ZeroPad2d(-1), nn.ZeroPad2d(-1), nn.ZeroPad2d(-1), ) def forward(self, x): '''Forward pass''' return self.layers(x)
It is actually really simple! By specifying
nn.ZeroPad2d with a cropping size of
-1, we remove 1 column of pixels on the left, 1 on the right, as well as a row from the top and the bottom of the image.
Our input images - MNIST images - have an input shape of
(1, 28, 28) - or
(28, 28) when we reshape them. Since we repeat the layer four times, we remove 4 pixels from the left, 4 from the right, 4 from the top and 4 from the bottom. This means that the shape of our outputs will be
What remains is stitching everything together:
if __name__ == '__main__': # Set fixed random number seed torch.manual_seed(42) # Prepare CIFAR-10 dataset dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor()) trainloader = torch.utils.data.DataLoader(dataset, batch_size=10, shuffle=True, num_workers=1) # Initialize the CroppingNetwork croppingnet = CroppingNetwork() # Iterate over some samples for i, data in enumerate(trainloader, 0): # Unpack inputs and targets inputs, targets = data # Feed samples through the network cropped_samples = croppingnet(inputs) # Reshape the samples reshaped_original = inputs[i].reshape(28, 28) reshaped_cropped = cropped_samples[i].reshape(20, 20) fig, (ax1, ax2) = plt.subplots(1, 2) fig.set_size_inches(9, 5, forward=True) fig.suptitle('Original sample (left) and Cropped sample (right)') ax1.imshow(reshaped_original) ax2.imshow(reshaped_cropped) plt.show()
The code above uses the PyTorch
DataLoader for loading the first minibatch of samples, feeds them through the
CroppingNetwork, and visualizes the results.
And here they are - some examples of what is produced by the cropping network:
PyTorch. (n.d.). ZeroPad2d — PyTorch 1.10.0 documentation. https://pytorch.org/docs/stable/generated/torch.nn.ZeroPad2d.html
Learn how large language models are working and how you can train open source ones yourself.
Keras is a high-level API for TensorFlow. It is one of the most popular deep learning frameworks.
Read about the fundamentals of machine learning, deep learning and artificial intelligence.
To get in touch with me, please connect with me on LinkedIn. Make sure to write me a message saying hi!
The content on this website is written for educational purposes. In writing the articles, I have attempted to be as correct and precise as possible. Should you find any errors, please let me know by creating an issue or pull request in this GitHub repository.
All text on this website written by me is copyrighted and may not be used without prior permission. Creating citations using content from this website is allowed if a reference is added, including an URL reference to the referenced article.
If you have any questions or remarks, feel free to get in touch.
TensorFlow, the TensorFlow logo and any related marks are trademarks of Google Inc.
PyTorch, the PyTorch logo and any related marks are trademarks of The Linux Foundation.
Mathjax is licensed under the Apache License, Version 2.0.