|
|
|
|
|
|
|
import argparse |
|
import torch |
|
from torch import nn |
|
|
|
import utils, data_setup, model, engine |
|
import yaml |
|
|
|
|
|
utils.set_seed(0) |
|
|
|
|
|
parser = argparse.ArgumentParser(fromfile_prefix_chars = '@') |
|
|
|
parser.add_argument('-nw', '--num-workers', help = 'Number of workers for dataloaders.', |
|
type = int, default = 0) |
|
parser.add_argument('-ne', '--num-epochs', help = 'Number of epochs to train model for.', |
|
type = int, default = 25) |
|
parser.add_argument('-bs', '--batch-size', help = 'Size of batches to split training set.', |
|
type = int, default = 100) |
|
parser.add_argument('-lr', '--learning-rate', help = 'Learning rate for the optimizer.', |
|
type = float, default = 0.001) |
|
parser.add_argument('-p', '--patience', help = 'Number of epochs to wait before early stopping.', |
|
type = int, default = 10) |
|
parser.add_argument('-md', '--min-delta', help = 'Minimum decrease in loss to reset patience.', |
|
type = float, default = 0.001) |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
print(f'{'#' * 50}\n' |
|
f'\033[1mTraining hyperparameters:\033[0m \n' |
|
f' - num-workers: {args.num_workers} \n' |
|
f' - num-epochs: {args.num_epochs} \n' |
|
f' - batch-size: {args.batch_size} \n' |
|
f' - learning-rate: {args.learning_rate} \n' |
|
f' - patience: {args.patience} \n' |
|
f' - min-delta: {args.min_delta} \n' |
|
f'{'#' * 50}') |
|
|
|
|
|
train_dl, test_dl = data_setup.get_dataloaders(root = './mnist_data', |
|
batch_size = args.batch_size, |
|
num_workers = args.num_workers) |
|
|
|
|
|
save_dir = '../saved_models' |
|
|
|
base_name = 'tiny_vgg_less_compute' |
|
mod_name = f'{base_name}_model.pth' |
|
|
|
|
|
mod_kwargs = { |
|
'num_blks': 2, |
|
'num_convs': 2, |
|
'in_channels': 1, |
|
'hidden_channels': 5, |
|
'fc_hidden_dim': 128, |
|
'num_classes': len(train_dl.dataset.classes) |
|
} |
|
|
|
vgg_mod = model.TinyVGG(**mod_kwargs).to(utils.DEVICE) |
|
torch.compile(vgg_mod) |
|
|
|
|
|
with open(f'{save_dir}/{base_name}_settings.yaml', 'w') as f: |
|
yaml.dump({'train_kwargs': vars(args), 'mod_kwargs': mod_kwargs}, f) |
|
|
|
|
|
loss_fn = nn.CrossEntropyLoss() |
|
optimizer = torch.optim.Adam(params = vgg_mod.parameters(), lr = args.learning_rate) |
|
|
|
|
|
mod_res = engine.train(model = vgg_mod, |
|
train_dl = train_dl, |
|
test_dl = test_dl, |
|
loss_fn = loss_fn, |
|
optimizer = optimizer, |
|
num_epochs = args.num_epochs, |
|
patience = args.patience, |
|
min_delta = args.min_delta, |
|
device = utils.DEVICE, |
|
save_mod = True, |
|
save_dir = save_dir, |
|
mod_name = mod_name) |
|
|