Spaces:
Configuration error
Configuration error
from typing import TypeVar, List, Tuple | |
import torch | |
from tqdm import tqdm | |
from abc import abstractmethod | |
from numpy import inf | |
from logger import TensorboardWriter | |
import numpy as np | |
class BaseTrainer: | |
""" | |
Base class for all trainers | |
""" | |
def __init__(self, model, train_criterion, metrics, optimizer, config, val_criterion): | |
self.config = config | |
self.logger = config.get_logger('trainer', config['trainer']['verbosity']) | |
# setup GPU device if available, move model into configured device | |
self.device, device_ids = self._prepare_device(config['n_gpu']) | |
self.model = model.to(self.device) | |
if len(device_ids) > 1: | |
self.model = torch.nn.DataParallel(model, device_ids=device_ids) | |
self.train_criterion = train_criterion.to(self.device) | |
self.val_criterion = val_criterion | |
self.metrics = metrics | |
self.optimizer = optimizer | |
cfg_trainer = config['trainer'] | |
self.epochs = cfg_trainer['epochs'] | |
self.save_period = cfg_trainer['save_period'] | |
self.monitor = cfg_trainer.get('monitor', 'off') | |
# configuration to monitor model performance and save best | |
if self.monitor == 'off': | |
self.mnt_mode = 'off' | |
self.mnt_best = 0 | |
else: | |
self.mnt_mode, self.mnt_metric = self.monitor.split() | |
assert self.mnt_mode in ['min', 'max'] | |
self.mnt_best = inf if self.mnt_mode == 'min' else -inf | |
self.early_stop = cfg_trainer.get('early_stop', inf) | |
self.start_epoch = 1 | |
self.checkpoint_dir = config.save_dir | |
# setup visualization writer instance | |
self.writer = TensorboardWriter(config.log_dir, self.logger, cfg_trainer['tensorboard']) | |
if config.resume is not None: | |
self._resume_checkpoint(config.resume) | |
def _train_epoch(self, epoch): | |
""" | |
Training logic for an epoch | |
:param epoch: Current epochs number | |
""" | |
raise NotImplementedError | |
def train(self): | |
""" | |
Full training logic | |
""" | |
not_improved_count = 0 | |
for epoch in tqdm(range(self.start_epoch, self.epochs + 1), desc='Total progress: '): | |
if epoch <= self.config['trainer']['warmup']: | |
result = self._warmup_epoch(epoch) | |
else: | |
result= self._train_epoch(epoch) | |
# save logged informations into log dict | |
log = {'epoch': epoch} | |
for key, value in result.items(): | |
if key == 'metrics': | |
log.update({mtr.__name__: value[i] for i, mtr in enumerate(self.metrics)}) | |
elif key == 'val_metrics': | |
log.update({'val_' + mtr.__name__: value[i] for i, mtr in enumerate(self.metrics)}) | |
elif key == 'test_metrics': | |
log.update({'test_' + mtr.__name__: value[i] for i, mtr in enumerate(self.metrics)}) | |
else: | |
log[key] = value | |
# print logged informations to the screen | |
for key, value in log.items(): | |
self.logger.info(' {:15s}: {}'.format(str(key), value)) | |
# evaluate model performance according to configured metric, save best checkpoint as model_best | |
best = False | |
if self.mnt_mode != 'off': | |
try: | |
# check whether model performance improved or not, according to specified metric(mnt_metric) | |
improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \ | |
(self.mnt_mode == 'max' and log[self.mnt_metric] >= self.mnt_best) | |
except KeyError: | |
self.logger.warning("Warning: Metric '{}' is not found. " | |
"Model performance monitoring is disabled.".format(self.mnt_metric)) | |
self.mnt_mode = 'off' | |
improved = False | |
if improved: | |
self.mnt_best = log[self.mnt_metric] | |
not_improved_count = 0 | |
best = True | |
else: | |
not_improved_count += 1 | |
if not_improved_count > self.early_stop: | |
self.logger.info("Validation performance didn\'t improve for {} epochs. " | |
"Training stops.".format(self.early_stop)) | |
break | |
if epoch % self.save_period == 0: | |
self._save_checkpoint(epoch, save_best=best) | |
def _prepare_device(self, n_gpu_use): | |
""" | |
setup GPU device if available, move model into configured device | |
""" | |
n_gpu = torch.cuda.device_count() | |
if n_gpu_use > 0 and n_gpu == 0: | |
self.logger.warning("Warning: There\'s no GPU available on this machine," | |
"training will be performed on CPU.") | |
n_gpu_use = 0 | |
if n_gpu_use > n_gpu: | |
self.logger.warning("Warning: The number of GPU\'s configured to use is {}, but only {} are available " | |
"on this machine.".format(n_gpu_use, n_gpu)) | |
n_gpu_use = n_gpu | |
device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu') | |
list_ids = list(range(n_gpu_use)) | |
return device, list_ids | |
def _save_checkpoint(self, epoch, save_best=False): | |
""" | |
Saving checkpoints | |
:param epoch: current epoch number | |
:param log: logging information of the epoch | |
:param save_best: if True, rename the saved checkpoint to 'model_best.pth' | |
""" | |
arch = type(self.model).__name__ | |
state = { | |
'arch': arch, | |
'epoch': epoch, | |
'state_dict': self.model.state_dict(), | |
'optimizer': self.optimizer.state_dict(), | |
'monitor_best': self.mnt_best | |
} | |
# filename = str(self.checkpoint_dir / 'checkpoint-epoch{}.pth'.format(epoch)) | |
# torch.save(state, filename) | |
# self.logger.info("Saving checkpoint: {} ...".format(filename)) | |
if save_best: | |
best_path = str(self.checkpoint_dir / 'model_best.pth') | |
torch.save(state, best_path) | |
self.logger.info("Saving current best: model_best.pth at: {} ...".format(best_path)) | |
def _resume_checkpoint(self, resume_path): | |
""" | |
Resume from saved checkpoints | |
:param resume_path: Checkpoint path to be resumed | |
""" | |
resume_path = str(resume_path) | |
self.logger.info("Loading checkpoint: {} ...".format(resume_path)) | |
checkpoint = torch.load(resume_path) | |
self.start_epoch = checkpoint['epoch'] + 1 | |
self.mnt_best = checkpoint['monitor_best'] | |
# load architecture params from checkpoint. | |
if checkpoint['config']['arch'] != self.config['arch']: | |
self.logger.warning("Warning: Architecture configuration given in config file is different from that of " | |
"checkpoint. This may yield an exception while state_dict is being loaded.") | |
self.model.load_state_dict(checkpoint['state_dict']) | |
# load optimizer state from checkpoint only when optimizer type is not changed. | |
if checkpoint['config']['optimizer']['type'] != self.config['optimizer']['type']: | |
self.logger.warning("Warning: Optimizer type given in config file is different from that of checkpoint. " | |
"Optimizer parameters not being resumed.") | |
else: | |
self.optimizer.load_state_dict(checkpoint['optimizer']) | |
self.logger.info("Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch)) | |