Maybe-Ray

ResNet and Skip Connections

This is part 5 of my series. In the previous post, we looked at both VGG and GoogleNet models. We implemented both models and discussed their architectural differences in detail. We then used pre-trained versions of the models to transfer learn on the Caltech-256 dataset.

ResNet

As neural networks become deeper, they become increasingly difficult to train due to the vanishing gradient problem. This has been somewhat mitigated by the use of normalization techniques, such as Batch Normalization.

A second problem arises with deep neural networks called the degradation problem. This is the problem that residual neural networks try to solve.

In this post, we will go over the Degradation problem and explain how residual networks solve it. We will then implement ResNet18 and use a pre-trained version of the model to fine-tune it on the Caltech-256 dataset.

The Degradation Problem

In the machine learning world, we take input data X and try to create a function that will transform it into y . In this case, our neural network will be the function f that takes the input and transforms it.

y=F(x)

A deep network should be able to approximate the same type of function as a shallow network through the later layers, learning an identity mapping. But this is often not the case, as we add more layers, the training loss of the deeper network actually increases after a certain point. Even though the network has more capacity, the network fails to learn the identity mapping of the function, and so the loss increases. This is called the degradation problem.

Due to the presence of weights and nonlinear activation functions, it becomes difficult for the network to “zero out” and just pass data through unchanged.

Skip Connections: the solution

Skip connections help solve the degradation problems by allowing the input to bypass one or more layers and be added directly to the output.

In simpler terms, this just adds the input a layer block receives and adds it to that block's output. This can be shown in the diagram below:

residual block

As you can see, this helps the network to learn the identity mappings more easily, so adding extra layers won’t degrade performance when they’re not needed.

The shortcut connections also have an added advantage of allowing gradients to propagate more effectively through the network during backpropagation, helping earlier layers learn better.

Implementing the Model

The full version of the implementation code can be found in the Kaggle Notebook.

The code below implements the residual blocks for the ResNet model. The conv_block function creates a layer block of the layers that are repeatedly used in the model. The stack_block is just stacking the block layers, similar to the architecture of the VGG family of models.

def conv_block(in_chns, out_chns, padding=False, stride=1, ReLu=True):
    block = nn.Sequential(
        (
            nn.Conv2d(in_chns, out_chns, 3, stride=stride, padding="same", bias=False)
            if padding
            else nn.Conv2d(in_chns, out_chns, 3, stride=stride, padding=1, bias=False)
        ),
        nn.LazyBatchNorm2d(),
    )
    if ReLu:
        block.append(nn.ReLU(inplace=True))
    return block


def stack_conv_block(in_chns, out_chns, stride=1):
    return nn.Sequential(
        conv_block(in_chns, out_chns, stride=stride),
        conv_block(out_chns, out_chns, padding=True, ReLu=False),
    )


class residual_block(nn.Module):
    def __init__(self, in_chns, out_chns, stride=1):
        super().__init__()
        if (in_chns == out_chns) and stride == 1:
            self.residual = stack_conv_block(in_chns, out_chns, stride=1)
            self.shortcut = nn.Identity()

        else:
            self.residual = stack_conv_block(in_chns, out_chns, stride=stride)
            self.shortcut = nn.Conv2d(in_chns, out_chns, 1, stride=stride)

        self.relu = nn.ReLU()

    def forward(self, x):
        res = self.residual(x)
        return self.relu(res + self.shortcut(x))

The code below is the implementation of ResNet-18. I added Adaptive Average pooling because it appears frequently in PyTorch implementations of CNN models. In this case, it just allows us to pass in different input sizes into our model. This pooling layer makes it easier to flatten the output from the convolutional layers and pass them into the Linear layer without worrying about input size.

class ResNet18(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        # residual blocks
        self.residual_block_1 = nn.Sequential(
            residual_block(64, 64, 1), residual_block(64, 64, 2)
        )

        self.residual_block_2 = nn.Sequential(
            residual_block(64, 128, 1), residual_block(128, 128, 2)
        )

        self.residual_block_3 = nn.Sequential(
            residual_block(128, 256, 1), residual_block(256, 256, 2)
        )

        self.residual_block_4 = nn.Sequential(
            residual_block(256, 512, 1), residual_block(512, 512, 2)
        )

        self.features = nn.Sequential(
            nn.Conv2d(3, 64, 7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(3, 2, padding=1),
            self.residual_block_1,
            self.residual_block_2,
            self.residual_block_3,
            self.residual_block_4,
            nn.AdaptiveAvgPool2d(1), 
            nn.Flatten()
        )

        self.fc = nn.Linear(512, num_classes)

        self._initialize_weights()


    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        features = self.features(x)
        return self.fc(features)

Transfer learning

The Kaggle Notebook has the full code; this post will only look at the model setup and not include the training loop and setting up the dataset.

The code below simply loads the pre-trained 34-layer ResNet model. It freezes the convolutional layer and adds a new classification layer.

# Loading the pre-trained ResNet model with ImageNet weights
model = resnet34(weights='IMAGENET1K_V1')

# freezing layers
for i in model.parameters():
    i.requires_grad = False

# Creating our classifier 
classifier = nn.Sequential(
    nn.Dropout(0.5),
    nn.Linear(512, 128),
    nn.BatchNorm1d(128),
    nn.ReLU(inplace=True),
    nn.Dropout(0.3),
    nn.Linear(128, num_classes)
)

# adding our new classifier to our model, and making sure its parameters are trainable 

model.fc = classifier

for i in model.fc.parameters():
     i.requires_grad = True

Results

The 34-layer ResNet model received a loss of 0.9302 and an accuracy of 78.5%. This score is slightly worse compared to the VGG and GoogleNet scores we received in the last post. But we could have received better results if we had trained with the larger variations of the model

Conclusions

Residual networks, also known as skip connections, are a staple in the world of CNNs today. They have been used in later versions of the inception nets and have inspired architectures such as DenseNets. The ResNet family of models is still used today as the backbone of most computer vision-based tasks. An overall theme in the last couple of posts has been trying to make CNN models deeper and more efficient. ResNets are a testament to this, with the largest model having 152 layers.

Series

  1. LeNet Implementation
  2. Exploring pre-trained Convolutional layers
  3. AlexNet Implementation
  4. VGG vs Inception: Just how deep can they go?
  5. ResNet and Skip Connections (you are here right now)
  6. Autoencoder
  7. Variational Autoencoders (VAEs)
  8. Generative Adversarial Networks (GANs)

References

Intuition behind Residual Neural Networks

Residual Connections in Deep Neural Networks