Jechen00's picture
changed model used by app to have less compute
b1814f3
#####################################
# Packages & Dependencies
#####################################
import argparse
import torch
from torch import nn
import utils, data_setup, model, engine
import yaml
# Setup random seeds
utils.set_seed(0)
# Setup hyperparameters
parser = argparse.ArgumentParser(fromfile_prefix_chars = '@')
parser.add_argument('-nw', '--num-workers', help = 'Number of workers for dataloaders.',
type = int, default = 0)
parser.add_argument('-ne', '--num-epochs', help = 'Number of epochs to train model for.',
type = int, default = 25)
parser.add_argument('-bs', '--batch-size', help = 'Size of batches to split training set.',
type = int, default = 100)
parser.add_argument('-lr', '--learning-rate', help = 'Learning rate for the optimizer.',
type = float, default = 0.001)
parser.add_argument('-p', '--patience', help = 'Number of epochs to wait before early stopping.',
type = int, default = 10)
parser.add_argument('-md', '--min-delta', help = 'Minimum decrease in loss to reset patience.',
type = float, default = 0.001)
args = parser.parse_args()
#####################################
# Training Code
#####################################
if __name__ == '__main__':
print(f'{'#' * 50}\n'
f'\033[1mTraining hyperparameters:\033[0m \n'
f' - num-workers: {args.num_workers} \n'
f' - num-epochs: {args.num_epochs} \n'
f' - batch-size: {args.batch_size} \n'
f' - learning-rate: {args.learning_rate} \n'
f' - patience: {args.patience} \n'
f' - min-delta: {args.min_delta} \n'
f'{'#' * 50}')
# Get dataloaders
train_dl, test_dl = data_setup.get_dataloaders(root = './mnist_data',
batch_size = args.batch_size,
num_workers = args.num_workers)
# Set up saving directory and file name
save_dir = '../saved_models'
base_name = 'tiny_vgg_less_compute'
mod_name = f'{base_name}_model.pth'
# Get TinyVGG model
mod_kwargs = {
'num_blks': 2,
'num_convs': 2,
'in_channels': 1,
'hidden_channels': 5,
'fc_hidden_dim': 128,
'num_classes': len(train_dl.dataset.classes)
}
vgg_mod = model.TinyVGG(**mod_kwargs).to(utils.DEVICE)
torch.compile(vgg_mod)
# Save model kwargs and train settings
with open(f'{save_dir}/{base_name}_settings.yaml', 'w') as f:
yaml.dump({'train_kwargs': vars(args), 'mod_kwargs': mod_kwargs}, f)
# Get loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params = vgg_mod.parameters(), lr = args.learning_rate)
# Train model
mod_res = engine.train(model = vgg_mod,
train_dl = train_dl,
test_dl = test_dl,
loss_fn = loss_fn,
optimizer = optimizer,
num_epochs = args.num_epochs,
patience = args.patience,
min_delta = args.min_delta,
device = utils.DEVICE,
save_mod = True,
save_dir = save_dir,
mod_name = mod_name)