assi1 / ELR /trainer /trainer.py
uthurumella's picture
Upload 69 files
72fc481 verified
import numpy as np
import torch
from tqdm import tqdm
from typing import List
from torchvision.utils import make_grid
from base import BaseTrainer
from utils import inf_loop
import sys
from sklearn.mixture import GaussianMixture
class Trainer(BaseTrainer):
"""
Trainer class
Note:
Inherited from BaseTrainer.
"""
def __init__(self, model, train_criterion, metrics, optimizer, config, data_loader,
valid_data_loader=None, test_data_loader=None, lr_scheduler=None, len_epoch=None, val_criterion=None):
super().__init__(model, train_criterion, metrics, optimizer, config, val_criterion)
self.config = config
self.data_loader = data_loader
if len_epoch is None:
# epoch-based training
self.len_epoch = len(self.data_loader)
else:
# iteration-based training
self.data_loader = inf_loop(data_loader)
self.len_epoch = len_epoch
self.valid_data_loader = valid_data_loader
self.test_data_loader = test_data_loader
self.do_validation = self.valid_data_loader is not None
self.do_test = self.test_data_loader is not None
self.lr_scheduler = lr_scheduler
self.log_step = int(np.sqrt(data_loader.batch_size))
self.train_loss_list: List[float] = []
self.val_loss_list: List[float] = []
self.test_loss_list: List[float] = []
#Visdom visualization
def _eval_metrics(self, output, label):
acc_metrics = np.zeros(len(self.metrics))
for i, metric in enumerate(self.metrics):
acc_metrics[i] += metric(output, label)
self.writer.add_scalar('{}'.format(metric.__name__), acc_metrics[i])
return acc_metrics
def _train_epoch(self, epoch):
"""
Training logic for an epoch
:param epoch: Current training epoch.
:return: A log that contains all information you want to save.
Note:
If you have additional information to record, for example:
> additional_log = {"x": x, "y": y}
merge it with log before return. i.e.
> log = {**log, **additional_log}
> return log
The metrics in log must have the key 'metrics'.
"""
self.model.train()
total_loss = 0
total_metrics = np.zeros(len(self.metrics))
with tqdm(self.data_loader) as progress:
for batch_idx, (data, label, indexs, _) in enumerate(progress):
progress.set_description_str(f'Train epoch {epoch}')
data, label = data.to(self.device), label.long().to(self.device)
output = self.model(data)
loss = self.train_criterion(indexs.cpu().detach().numpy().tolist(), output, label)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)
self.writer.add_scalar('loss', loss.item())
self.train_loss_list.append(loss.item())
total_loss += loss.item()
total_metrics += self._eval_metrics(output, label)
if batch_idx % self.log_step == 0:
progress.set_postfix_str(' {} Loss: {:.6f}'.format(
self._progress(batch_idx),
loss.item()))
self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))
if batch_idx == self.len_epoch:
break
# if hasattr(self.data_loader, 'run'):
# self.data_loader.run()
log = {
'loss': total_loss / self.len_epoch,
'metrics': (total_metrics / self.len_epoch).tolist(),
'learning rate': self.lr_scheduler.get_lr()
}
if self.do_validation:
val_log = self._valid_epoch(epoch)
log.update(val_log)
if self.do_test:
test_log, test_meta = self._test_epoch(epoch)
log.update(test_log)
else:
test_meta = [0,0]
if self.lr_scheduler is not None:
self.lr_scheduler.step()
return log
def _valid_epoch(self, epoch):
"""
Validate after training an epoch
:return: A log that contains information about validation
Note:
The validation metrics in log must have the key 'val_metrics'.
"""
self.model.eval()
total_val_loss = 0
total_val_metrics = np.zeros(len(self.metrics))
with torch.no_grad():
with tqdm(self.valid_data_loader) as progress:
for batch_idx, (data, label, _, _) in enumerate(progress):
progress.set_description_str(f'Valid epoch {epoch}')
data, label = data.to(self.device), label.to(self.device)
output = self.model(data)
loss = self.val_criterion(output, label)
self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid')
self.writer.add_scalar('loss', loss.item())
self.val_loss_list.append(loss.item())
total_val_loss += loss.item()
total_val_metrics += self._eval_metrics(output, label)
self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))
# add histogram of model parameters to the tensorboard
for name, p in self.model.named_parameters():
self.writer.add_histogram(name, p, bins='auto')
return {
'val_loss': total_val_loss / len(self.valid_data_loader),
'val_metrics': (total_val_metrics / len(self.valid_data_loader)).tolist()
}
def _test_epoch(self, epoch):
"""
Test after training an epoch
:return: A log that contains information about test
Note:
The Test metrics in log must have the key 'val_metrics'.
"""
self.model.eval()
total_test_loss = 0
total_test_metrics = np.zeros(len(self.metrics))
results = np.zeros((len(self.test_data_loader.dataset), self.config['num_classes']), dtype=np.float32)
tar_ = np.zeros((len(self.test_data_loader.dataset),), dtype=np.float32)
with torch.no_grad():
with tqdm(self.test_data_loader) as progress:
for batch_idx, (data, label,indexs,_) in enumerate(progress):
progress.set_description_str(f'Test epoch {epoch}')
data, label = data.to(self.device), label.to(self.device)
output = self.model(data)
loss = self.val_criterion(output, label)
self.writer.set_step((epoch - 1) * len(self.test_data_loader) + batch_idx, 'test')
self.writer.add_scalar('loss', loss.item())
self.test_loss_list.append(loss.item())
total_test_loss += loss.item()
total_test_metrics += self._eval_metrics(output, label)
self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))
results[indexs.cpu().detach().numpy().tolist()] = output.cpu().detach().numpy().tolist()
tar_[indexs.cpu().detach().numpy().tolist()] = label.cpu().detach().numpy().tolist()
# add histogram of model parameters to the tensorboard
for name, p in self.model.named_parameters():
self.writer.add_histogram(name, p, bins='auto')
return {
'test_loss': total_test_loss / len(self.test_data_loader),
'test_metrics': (total_test_metrics / len(self.test_data_loader)).tolist()
},[results,tar_]
def _warmup_epoch(self, epoch):
total_loss = 0
total_metrics = np.zeros(len(self.metrics))
self.model.train()
data_loader = self.data_loader#self.loader.run('warmup')
with tqdm(data_loader) as progress:
for batch_idx, (data, label, _, indexs , _) in enumerate(progress):
progress.set_description_str(f'Warm up epoch {epoch}')
data, label = data.to(self.device), label.long().to(self.device)
self.optimizer.zero_grad()
output = self.model(data)
out_prob = torch.nn.functional.softmax(output).data.detach()
self.train_criterion.update_hist(indexs.cpu().detach().numpy().tolist(), out_prob)
loss = torch.nn.functional.cross_entropy(output, label)
loss.backward()
self.optimizer.step()
self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)
self.writer.add_scalar('loss', loss.item())
self.train_loss_list.append(loss.item())
total_loss += loss.item()
total_metrics += self._eval_metrics(output, label)
if batch_idx % self.log_step == 0:
progress.set_postfix_str(' {} Loss: {:.6f}'.format(
self._progress(batch_idx),
loss.item()))
self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))
if batch_idx == self.len_epoch:
break
if hasattr(self.data_loader, 'run'):
self.data_loader.run()
log = {
'loss': total_loss / self.len_epoch,
'noise detection rate' : 0.0,
'metrics': (total_metrics / self.len_epoch).tolist(),
'learning rate': self.lr_scheduler.get_lr()
}
if self.do_validation:
val_log = self._valid_epoch(epoch)
log.update(val_log)
if self.do_test:
test_log, test_meta = self._test_epoch(epoch)
log.update(test_log)
else:
test_meta = [0,0]
return log
def _progress(self, batch_idx):
base = '[{}/{} ({:.0f}%)]'
if hasattr(self.data_loader, 'n_samples'):
current = batch_idx * self.data_loader.batch_size
total = self.data_loader.n_samples
else:
current = batch_idx
total = self.len_epoch
return base.format(current, total, 100.0 * current / total)