To end my series on building classical convolutional neural networks from scratch in PyTorch, we will build ResNet, a major breakthrough in Computer Vision, which solved the problem of network performance degrading if the network is too deep. It also introduced the concept of Residual Connections (more on this later). We can access the previous articles in the series on my profile, namely LeNet5, AlexNet, and VGG.
We will start by looking into the architecture and intuition behind how ResNet works. We will then compare it to VGG, and examine how it solves some of the problems VGG had. Then, as before, we will load our dataset, CIFAR10 and pre-process it to make it ready for modeling. Then, we will first implement the basic building block of a ResNet (we will call this ResidualBlock), and use this to build our network. Then this network will be trained on the pre-processed data and finally, we will see how the trained model performs on unseen data (test set).
One of the drawbacks of VGG was that it couldn't go as deep as wanted because it started to lose the generalization capability (i.e, it started overfitting). This is because as a neural network gets deeper, the gradients from the loss function start to shrink to zero and thus the weights are not updated. This problem is known as the vanishing gradient problem. ResNet essentially solved this problem by using skip connections.
In the figure above, we can see that, in addition to the normal connections, there is a direct connection that skips some layers in the model (skip connection). With the skip connection, the output changes from h(x) = f(wx +b) to h(x) = f(x) + x. These skip connections help as they allow an alternate shortcut path for the gradients to flow through. Below is the architecture of the 34-layer ResNet.
In this article, we will be using the famous CIFAR-10 dataset, which has become one of the the most common choice for beginner computer vision datasets. The dataset is a labeled subset of the 80 million tiny images dataset. They were collected by Alex Krizhevsky, Vinod Nair, and Geoffrey Hinton. The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images.
The dataset is divided into five training batches and one test batch, each with 10000 images. The test batch contains exactly 1000 randomly-selected images from each class. The training batches contain the remaining images in random order, but some training batches may contain more images from one class than another. Between them, the training batches contain exactly 5000 images from each class. The classes are completely mutually exclusive. There is no overlap between automobiles and trucks. "Automobile" includes sedans, SUVs, and things of that sort. "Truck" includes only big trucks. Neither includes pickup trucks.
Here are the classes in the dataset, as well as 10 random images from each:
Importing the Libraries
We will start by importing the libraries we would use. In addition to that, we will make sure that the Notebook uses the GPU to train the model if it's available
Loading the Dataset
Now we move on to loading our dataset. For this purpose, we will use the
torchvision library which not only provides quick access to hundreds of computer vision datasets, but also easy and intuitive methods to pre-process/transform them so that they are ready for modeling
- We start by defining our
data_loaderfunction which returns the training or test data depending on the arguments
- It's always a good practice to normalize our data in Deep Learning projects as it makes the training faster and easier to converge. For this, we define the variable
normalizewith the mean and standard deviations of each of the channel (red, green, and blue) in the dataset. These can be calculated manually, but are also available online. This is used in the
transformvariable where we resize the data, convert it to tensors and then normalize it
- We make use of data loaders. Data loaders allow us to iterate through the data in batches, and the data is loaded while iterating and not all at once in start into our RAM. This is very helpful if we're dealing with large datasets of around million images.
- Depending on the
testargument, we either load the train (if
test=False) split or the
test=True) split. In case of train, the split is randomly divided into train and validation set (0.9:0.1).
Bring this project to life
ResNet from Scratch
How models work in PyTorch
Before moving onto building the residual block and the ResNet, we would first look into and understand how neural networks are defined in PyTorch:
nn.Moduleprovides a boilerplate for creating custom models along with some necessary functionality that helps in training. That's why every custom model tends to inherit from
- Then there are two main functions inside every custom model. First is the initialization function,
__init__, where we define the various layers we will be using, and second is the
forwardfunction, which defines the sequence in which the above layers will be executed on a given input
Layers in PyTorch
Now coming to the different types of layers available in PyTorch that are useful to us:
nn.Conv2d: These are the convolutional layers that accepts the number of input and output channels as arguments, along with kernel size for the filter. It also accepts any strides or padding if we want to apply those
nn.BatchNorm2d: This applies batch normalization to the output from the convolutional layer
nn.ReLU: This is a type of activation function applied to various outputs in the network
nn.MaxPool2d: This applies max pooling to the output with the kernel size given
nn.Dropout: This is used to apply dropout to the output with a given probability
nn.Linear: This is basically a fully connected layer
nn.Sequential: This is technically not a type of layer but it helps in combining different operations that are part of the same step
Before starting with the network, we need to build a ResidualBlock that we can re-use through out the network. The block (as shown in the architecture) contains a skip connection that is an optional parameter (
downsample ). Note that in the
forward , this is applied directly to the input,
x, and not to the output,
Now, that we have created the ResidualBlock, we can build our ResNet.
Note that there are three blocks in the architecture, containing 3, 3, 6, and 3 layers respectively. To make this block, we create a helper function
_make_layer. The function adds the layers one by one along with the Residual Block. After the blocks, we add the average pooling and the final linear layer.
It is always recommended to try out different values for various hyperparameters in our model, but here we will be using only one setting. Regardless, we recommend everyone try out different ones and see which works best. The hyper-parameters include defining the number of epochs, batch size, learning rate, loss function along with the optimizer. As we are building the 34 layer variant of ResNet, we need to pass the appropriate number of layers as well:
Now, our model is ready for training, but first we need to know how model training works in PyTorch:
- We start by loading the images in batches using our
train_loaderfor every epoch, and also move the data to the GPU using the
devicevariable we defined earlier
- The model is then used to predict on the labels,
model(images), and then we calculate the loss between the predictions and the ground truth using the loss function defined above,
- Now the learning part comes, we use the loss to backpropagate method,
loss.backward(), and update the weights,
optimizer.step(). One important thing that is required before every update is to set the gradients to zero using
optimizer.zero_grad()because otherwise the gradients are accumulated (default behaviour in PyTorch)
- Lastly, after every epoch, we test our model on the validation set, but, as we don't need gradients when evaluating, we can turn it off using
with torch.no_grad()to make the evaluation much faster.
Analyzing the output of the code, we can see that the model is learning as the loss is decreasing while the accuracy on the validation set is increasing with every epoch. But we may notice that it is fluctuating at the end, which could mean the model is overfitting or that the
batch_size is small. We will have to test to find out what's going on:
For testing, we use exactly the same code as validation but with the
Using the above code and training the model for 10 epochs, we were able to achieve an accuracy of 82.87% on the test set:
Let's now conclude what we did in this article:
- We started by understanding the architecture and how ResNet works
- Next, we loaded and pre-processed the CIFAR10 dataset using
- Then, we learned how custom model definitions work in PyTorch and the different types of layers available in
- We built our ResNet from scratch by building a ResidualBlock
- Finally, we trained and tested our model on the CIFAR10 dataset, and the model seemed to perform well on the test dataset with 75% accuracy
Using this article, we got a good introduction and hand-on learning, but we can learn much more if we extend this to other challenges:
- Try using different datasets. One such dataset is CIFAR100, a subset of ImageNet dataset, or the 80 million tiny images dataset
- Experiment with different hyperparameters and see the best combination of them for the model
- Finally, try adding or removing layers from the dataset to see their impact on the capability of the model. Better yet, try to build the ResNet-51 version of this model