Spaces:
Running
on
Zero
Running
on
Zero
from typing import * | |
import copy | |
import torch | |
import torch.nn.functional as F | |
from torch.utils.data import DataLoader | |
from easydict import EasyDict as edict | |
from ..basic import BasicTrainer | |
class SparseStructureVaeTrainer(BasicTrainer): | |
""" | |
Trainer for Sparse Structure VAE. | |
Args: | |
models (dict[str, nn.Module]): Models to train. | |
dataset (torch.utils.data.Dataset): Dataset. | |
output_dir (str): Output directory. | |
load_dir (str): Load directory. | |
step (int): Step to load. | |
batch_size (int): Batch size. | |
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. | |
batch_split (int): Split batch with gradient accumulation. | |
max_steps (int): Max steps. | |
optimizer (dict): Optimizer config. | |
lr_scheduler (dict): Learning rate scheduler config. | |
elastic (dict): Elastic memory management config. | |
grad_clip (float or dict): Gradient clip config. | |
ema_rate (float or list): Exponential moving average rates. | |
fp16_mode (str): FP16 mode. | |
- None: No FP16. | |
- 'inflat_all': Hold a inflated fp32 master param for all params. | |
- 'amp': Automatic mixed precision. | |
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. | |
finetune_ckpt (dict): Finetune checkpoint. | |
log_param_stats (bool): Log parameter stats. | |
i_print (int): Print interval. | |
i_log (int): Log interval. | |
i_sample (int): Sample interval. | |
i_save (int): Save interval. | |
i_ddpcheck (int): DDP check interval. | |
loss_type (str): Loss type. 'bce' for binary cross entropy, 'l1' for L1 loss, 'dice' for Dice loss. | |
lambda_kl (float): KL divergence loss weight. | |
""" | |
def __init__( | |
self, | |
*args, | |
loss_type='bce', | |
lambda_kl=1e-6, | |
**kwargs | |
): | |
super().__init__(*args, **kwargs) | |
self.loss_type = loss_type | |
self.lambda_kl = lambda_kl | |
def training_losses( | |
self, | |
ss: torch.Tensor, | |
**kwargs | |
) -> Tuple[Dict, Dict]: | |
""" | |
Compute training losses. | |
Args: | |
ss: The [N x 1 x H x W x D] tensor of binary sparse structure. | |
Returns: | |
a dict with the key "loss" containing a scalar tensor. | |
may also contain other keys for different terms. | |
""" | |
z, mean, logvar = self.training_models['encoder'](ss.float(), sample_posterior=True, return_raw=True) | |
logits = self.training_models['decoder'](z) | |
terms = edict(loss = 0.0) | |
if self.loss_type == 'bce': | |
terms["bce"] = F.binary_cross_entropy_with_logits(logits, ss.float(), reduction='mean') | |
terms["loss"] = terms["loss"] + terms["bce"] | |
elif self.loss_type == 'l1': | |
terms["l1"] = F.l1_loss(F.sigmoid(logits), ss.float(), reduction='mean') | |
terms["loss"] = terms["loss"] + terms["l1"] | |
elif self.loss_type == 'dice': | |
logits = F.sigmoid(logits) | |
terms["dice"] = 1 - (2 * (logits * ss.float()).sum() + 1) / (logits.sum() + ss.float().sum() + 1) | |
terms["loss"] = terms["loss"] + terms["dice"] | |
else: | |
raise ValueError(f'Invalid loss type {self.loss_type}') | |
terms["kl"] = 0.5 * torch.mean(mean.pow(2) + logvar.exp() - logvar - 1) | |
terms["loss"] = terms["loss"] + self.lambda_kl * terms["kl"] | |
return terms, {} | |
def snapshot(self, suffix=None, num_samples=64, batch_size=1, verbose=False): | |
super().snapshot(suffix=suffix, num_samples=num_samples, batch_size=batch_size, verbose=verbose) | |
def run_snapshot( | |
self, | |
num_samples: int, | |
batch_size: int, | |
verbose: bool = False, | |
) -> Dict: | |
dataloader = DataLoader( | |
copy.deepcopy(self.dataset), | |
batch_size=batch_size, | |
shuffle=True, | |
num_workers=0, | |
collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None, | |
) | |
# inference | |
gts = [] | |
recons = [] | |
for i in range(0, num_samples, batch_size): | |
batch = min(batch_size, num_samples - i) | |
data = next(iter(dataloader)) | |
args = {k: v[:batch].cuda() if isinstance(v, torch.Tensor) else v[:batch] for k, v in data.items()} | |
z = self.models['encoder'](args['ss'].float(), sample_posterior=False) | |
logits = self.models['decoder'](z) | |
recon = (logits > 0).long() | |
gts.append(args['ss']) | |
recons.append(recon) | |
sample_dict = { | |
'gt': {'value': torch.cat(gts, dim=0), 'type': 'sample'}, | |
'recon': {'value': torch.cat(recons, dim=0), 'type': 'sample'}, | |
} | |
return sample_dict | |