Spaces:
Configuration error
Configuration error
import numpy as np | |
import torch | |
from tqdm import tqdm | |
from typing import List | |
from torchvision.utils import make_grid | |
from base import BaseTrainer | |
from utils import inf_loop | |
import sys | |
from sklearn.mixture import GaussianMixture | |
class Trainer(BaseTrainer): | |
""" | |
Trainer class | |
Note: | |
Inherited from BaseTrainer. | |
""" | |
def __init__(self, model, train_criterion, metrics, optimizer, config, data_loader, | |
valid_data_loader=None, test_data_loader=None, lr_scheduler=None, len_epoch=None, val_criterion=None): | |
super().__init__(model, train_criterion, metrics, optimizer, config, val_criterion) | |
self.config = config | |
self.data_loader = data_loader | |
if len_epoch is None: | |
# epoch-based training | |
self.len_epoch = len(self.data_loader) | |
else: | |
# iteration-based training | |
self.data_loader = inf_loop(data_loader) | |
self.len_epoch = len_epoch | |
self.valid_data_loader = valid_data_loader | |
self.test_data_loader = test_data_loader | |
self.do_validation = self.valid_data_loader is not None | |
self.do_test = self.test_data_loader is not None | |
self.lr_scheduler = lr_scheduler | |
self.log_step = int(np.sqrt(data_loader.batch_size)) | |
self.train_loss_list: List[float] = [] | |
self.val_loss_list: List[float] = [] | |
self.test_loss_list: List[float] = [] | |
#Visdom visualization | |
def _eval_metrics(self, output, label): | |
acc_metrics = np.zeros(len(self.metrics)) | |
for i, metric in enumerate(self.metrics): | |
acc_metrics[i] += metric(output, label) | |
self.writer.add_scalar('{}'.format(metric.__name__), acc_metrics[i]) | |
return acc_metrics | |
def _train_epoch(self, epoch): | |
""" | |
Training logic for an epoch | |
:param epoch: Current training epoch. | |
:return: A log that contains all information you want to save. | |
Note: | |
If you have additional information to record, for example: | |
> additional_log = {"x": x, "y": y} | |
merge it with log before return. i.e. | |
> log = {**log, **additional_log} | |
> return log | |
The metrics in log must have the key 'metrics'. | |
""" | |
self.model.train() | |
total_loss = 0 | |
total_metrics = np.zeros(len(self.metrics)) | |
with tqdm(self.data_loader) as progress: | |
for batch_idx, (data, label, indexs, _) in enumerate(progress): | |
progress.set_description_str(f'Train epoch {epoch}') | |
data, label = data.to(self.device), label.long().to(self.device) | |
output = self.model(data) | |
loss = self.train_criterion(indexs.cpu().detach().numpy().tolist(), output, label) | |
self.optimizer.zero_grad() | |
loss.backward() | |
self.optimizer.step() | |
self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx) | |
self.writer.add_scalar('loss', loss.item()) | |
self.train_loss_list.append(loss.item()) | |
total_loss += loss.item() | |
total_metrics += self._eval_metrics(output, label) | |
if batch_idx % self.log_step == 0: | |
progress.set_postfix_str(' {} Loss: {:.6f}'.format( | |
self._progress(batch_idx), | |
loss.item())) | |
self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) | |
if batch_idx == self.len_epoch: | |
break | |
# if hasattr(self.data_loader, 'run'): | |
# self.data_loader.run() | |
log = { | |
'loss': total_loss / self.len_epoch, | |
'metrics': (total_metrics / self.len_epoch).tolist(), | |
'learning rate': self.lr_scheduler.get_lr() | |
} | |
if self.do_validation: | |
val_log = self._valid_epoch(epoch) | |
log.update(val_log) | |
if self.do_test: | |
test_log, test_meta = self._test_epoch(epoch) | |
log.update(test_log) | |
else: | |
test_meta = [0,0] | |
if self.lr_scheduler is not None: | |
self.lr_scheduler.step() | |
return log | |
def _valid_epoch(self, epoch): | |
""" | |
Validate after training an epoch | |
:return: A log that contains information about validation | |
Note: | |
The validation metrics in log must have the key 'val_metrics'. | |
""" | |
self.model.eval() | |
total_val_loss = 0 | |
total_val_metrics = np.zeros(len(self.metrics)) | |
with torch.no_grad(): | |
with tqdm(self.valid_data_loader) as progress: | |
for batch_idx, (data, label, _, _) in enumerate(progress): | |
progress.set_description_str(f'Valid epoch {epoch}') | |
data, label = data.to(self.device), label.to(self.device) | |
output = self.model(data) | |
loss = self.val_criterion(output, label) | |
self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid') | |
self.writer.add_scalar('loss', loss.item()) | |
self.val_loss_list.append(loss.item()) | |
total_val_loss += loss.item() | |
total_val_metrics += self._eval_metrics(output, label) | |
self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) | |
# add histogram of model parameters to the tensorboard | |
for name, p in self.model.named_parameters(): | |
self.writer.add_histogram(name, p, bins='auto') | |
return { | |
'val_loss': total_val_loss / len(self.valid_data_loader), | |
'val_metrics': (total_val_metrics / len(self.valid_data_loader)).tolist() | |
} | |
def _test_epoch(self, epoch): | |
""" | |
Test after training an epoch | |
:return: A log that contains information about test | |
Note: | |
The Test metrics in log must have the key 'val_metrics'. | |
""" | |
self.model.eval() | |
total_test_loss = 0 | |
total_test_metrics = np.zeros(len(self.metrics)) | |
results = np.zeros((len(self.test_data_loader.dataset), self.config['num_classes']), dtype=np.float32) | |
tar_ = np.zeros((len(self.test_data_loader.dataset),), dtype=np.float32) | |
with torch.no_grad(): | |
with tqdm(self.test_data_loader) as progress: | |
for batch_idx, (data, label,indexs,_) in enumerate(progress): | |
progress.set_description_str(f'Test epoch {epoch}') | |
data, label = data.to(self.device), label.to(self.device) | |
output = self.model(data) | |
loss = self.val_criterion(output, label) | |
self.writer.set_step((epoch - 1) * len(self.test_data_loader) + batch_idx, 'test') | |
self.writer.add_scalar('loss', loss.item()) | |
self.test_loss_list.append(loss.item()) | |
total_test_loss += loss.item() | |
total_test_metrics += self._eval_metrics(output, label) | |
self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) | |
results[indexs.cpu().detach().numpy().tolist()] = output.cpu().detach().numpy().tolist() | |
tar_[indexs.cpu().detach().numpy().tolist()] = label.cpu().detach().numpy().tolist() | |
# add histogram of model parameters to the tensorboard | |
for name, p in self.model.named_parameters(): | |
self.writer.add_histogram(name, p, bins='auto') | |
return { | |
'test_loss': total_test_loss / len(self.test_data_loader), | |
'test_metrics': (total_test_metrics / len(self.test_data_loader)).tolist() | |
},[results,tar_] | |
def _warmup_epoch(self, epoch): | |
total_loss = 0 | |
total_metrics = np.zeros(len(self.metrics)) | |
self.model.train() | |
data_loader = self.data_loader#self.loader.run('warmup') | |
with tqdm(data_loader) as progress: | |
for batch_idx, (data, label, _, indexs , _) in enumerate(progress): | |
progress.set_description_str(f'Warm up epoch {epoch}') | |
data, label = data.to(self.device), label.long().to(self.device) | |
self.optimizer.zero_grad() | |
output = self.model(data) | |
out_prob = torch.nn.functional.softmax(output).data.detach() | |
self.train_criterion.update_hist(indexs.cpu().detach().numpy().tolist(), out_prob) | |
loss = torch.nn.functional.cross_entropy(output, label) | |
loss.backward() | |
self.optimizer.step() | |
self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx) | |
self.writer.add_scalar('loss', loss.item()) | |
self.train_loss_list.append(loss.item()) | |
total_loss += loss.item() | |
total_metrics += self._eval_metrics(output, label) | |
if batch_idx % self.log_step == 0: | |
progress.set_postfix_str(' {} Loss: {:.6f}'.format( | |
self._progress(batch_idx), | |
loss.item())) | |
self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) | |
if batch_idx == self.len_epoch: | |
break | |
if hasattr(self.data_loader, 'run'): | |
self.data_loader.run() | |
log = { | |
'loss': total_loss / self.len_epoch, | |
'noise detection rate' : 0.0, | |
'metrics': (total_metrics / self.len_epoch).tolist(), | |
'learning rate': self.lr_scheduler.get_lr() | |
} | |
if self.do_validation: | |
val_log = self._valid_epoch(epoch) | |
log.update(val_log) | |
if self.do_test: | |
test_log, test_meta = self._test_epoch(epoch) | |
log.update(test_log) | |
else: | |
test_meta = [0,0] | |
return log | |
def _progress(self, batch_idx): | |
base = '[{}/{} ({:.0f}%)]' | |
if hasattr(self.data_loader, 'n_samples'): | |
current = batch_idx * self.data_loader.batch_size | |
total = self.data_loader.n_samples | |
else: | |
current = batch_idx | |
total = self.len_epoch | |
return base.format(current, total, 100.0 * current / total) |