Jechen00's picture
changed model to require less compute
70fea38
#####################################
# Packages & Dependencies
#####################################
import torch
from torch import nn
#####################################
# VGG Model Class
#####################################
class VGGBlock(nn.Module):
'''
Defines a modified block in the VGG architecture,
which includes batch normalization between convolutional layers and ReLU activations.
Reference: https://poloclub.github.io/cnn-explainer/
Reference: https://d2l.ai/chapter_convolutional-modern/vgg.html
Args:
num_convs (int): Number of consecutive convolutional layers + ReLU activations.
in_channels (int): Number of channels in the input.
hidden_channels (int): Number of hidden channels between convolutional layers.
out_channels (int): Number of channels in the output.
'''
def __init__(self,
num_convs: int,
in_channels: int,
hidden_channels: int,
out_channels: int):
super().__init__()
self.layers = []
for i in range(num_convs):
conv_in = in_channels if i == 0 else hidden_channels
conv_out = out_channels if i == num_convs-1 else hidden_channels
self.layers += [
nn.Conv2d(conv_in, conv_out, kernel_size = 3, stride = 1, padding = 1),
nn.BatchNorm2d(conv_out),
nn.ReLU()
]
self.layers.append(nn.MaxPool2d(kernel_size = 2, stride = 2))
self.vgg_blk = nn.Sequential(*self.layers)
def forward(self, X: torch.Tensor) -> torch.Tensor:
'''
Forward pass of VGG block.
Args:
X (torch.Tensor): Input tensor of shape (batch_size, in_channels, height, width)
Returns:
torch.Tensor: Output tensor of shape (batch_size, out_channels, new_height, new_width)
'''
return self.vgg_blk(X)
class TinyVGG(nn.Module):
'''
Creates a simplified version of a VGG model, adapted from
https://github.com/poloclub/cnn-explainer/blob/master/tiny-vgg/tiny-vgg.py.
The main difference is that the hidden dimensions and number of convolutional layers
remain the same across VGG blocks and the classifier's linear layers has output fewer features.
Args:
num_blks (int): Number of VGG blocks to put in the model
num_convs (int): Number of consecutive convolutional layers + ReLU activations in each VGG block.
in_channels (int): Number of channels in the input.
hidden_channels (int): Number of hidden channels between convolutional layers.
fc_hidden_dim (int): Number of output (hidden) features for the first linear layer of the classifer.
num_classes (int): Number of class labels.
'''
def __init__(self,
num_blks: int,
num_convs: int,
in_channels: int,
hidden_channels: int,
fc_hidden_dim: int,
num_classes: int):
super().__init__()
self.all_blks = []
for i in range(num_blks):
conv_in = in_channels if i == 0 else hidden_channels
self.all_blks.append(
VGGBlock(num_convs, conv_in, hidden_channels, hidden_channels)
)
self.vgg_body = nn.Sequential(*self.all_blks)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.LazyLinear(fc_hidden_dim), nn.ReLU(), nn.Dropout(0.5),
nn.LazyLinear(num_classes)
)
self.vgg_body.apply(self._custom_init)
self.classifier.apply(self._custom_init)
def _custom_init(self, module):
'''
Initializes convolutional layer weights with Xavier initialization method.
Initializes convolutional layer biases to zero.
'''
if isinstance(module, (nn.Conv2d)):
nn.init.xavier_uniform_(module.weight)
nn.init.zeros_(module.bias)
def forward(self, X: torch.Tensor) -> torch.Tensor:
'''
Forward pass of the TinyVGG model.
Args:
X (torch.Tensor): Input tensor of shape (batch_size, in_channels, height, width).
Returns:
torch.Tensor: Logits of shape (batch_size, num_classes).
'''
X = self.vgg_body(X)
return self.classifier(X)