from typing import * import os import copy import functools import torch import torch.nn.functional as F from torch.utils.data import DataLoader import numpy as np from easydict import EasyDict as edict from ...modules import sparse as sp from ...utils.general_utils import dict_reduce from ...utils.data_utils import cycle, BalancedResumableSampler from .flow_matching import FlowMatchingTrainer from .mixins.classifier_free_guidance import ClassifierFreeGuidanceMixin from .mixins.text_conditioned import TextConditionedMixin from .mixins.image_conditioned import ImageConditionedMixin class SparseFlowMatchingTrainer(FlowMatchingTrainer): """ Trainer for sparse diffusion model with flow matching objective. Args: models (dict[str, nn.Module]): Models to train. dataset (torch.utils.data.Dataset): Dataset. output_dir (str): Output directory. load_dir (str): Load directory. step (int): Step to load. batch_size (int): Batch size. batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. batch_split (int): Split batch with gradient accumulation. max_steps (int): Max steps. optimizer (dict): Optimizer config. lr_scheduler (dict): Learning rate scheduler config. elastic (dict): Elastic memory management config. grad_clip (float or dict): Gradient clip config. ema_rate (float or list): Exponential moving average rates. fp16_mode (str): FP16 mode. - None: No FP16. - 'inflat_all': Hold a inflated fp32 master param for all params. - 'amp': Automatic mixed precision. fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. finetune_ckpt (dict): Finetune checkpoint. log_param_stats (bool): Log parameter stats. i_print (int): Print interval. i_log (int): Log interval. i_sample (int): Sample interval. i_save (int): Save interval. i_ddpcheck (int): DDP check interval. t_schedule (dict): Time schedule for flow matching. sigma_min (float): Minimum noise level. """ def prepare_dataloader(self, **kwargs): """ Prepare dataloader. """ self.data_sampler = BalancedResumableSampler( self.dataset, shuffle=True, batch_size=self.batch_size_per_gpu, ) self.dataloader = DataLoader( self.dataset, batch_size=self.batch_size_per_gpu, num_workers=int(np.ceil(os.cpu_count() / torch.cuda.device_count())), pin_memory=True, drop_last=True, persistent_workers=True, collate_fn=functools.partial(self.dataset.collate_fn, split_size=self.batch_split), sampler=self.data_sampler, ) self.data_iterator = cycle(self.dataloader) def training_losses( self, x_0: sp.SparseTensor, cond=None, **kwargs ) -> Tuple[Dict, Dict]: """ Compute training losses for a single timestep. Args: x_0: The [N x ... x C] sparse tensor of the inputs. cond: The [N x ...] tensor of additional conditions. kwargs: Additional arguments to pass to the backbone. Returns: a dict with the key "loss" containing a tensor of shape [N]. may also contain other keys for different terms. """ noise = x_0.replace(torch.randn_like(x_0.feats)) t = self.sample_t(x_0.shape[0]).to(x_0.device).float() x_t = self.diffuse(x_0, t, noise=noise) cond = self.get_cond(cond, **kwargs) pred = self.training_models['denoiser'](x_t, t * 1000, cond, **kwargs) assert pred.shape == noise.shape == x_0.shape target = self.get_v(x_0, noise, t) terms = edict() terms["mse"] = F.mse_loss(pred.feats, target.feats) terms["loss"] = terms["mse"] # log loss with time bins mse_per_instance = np.array([ F.mse_loss(pred.feats[x_0.layout[i]], target.feats[x_0.layout[i]]).item() for i in range(x_0.shape[0]) ]) time_bin = np.digitize(t.cpu().numpy(), np.linspace(0, 1, 11)) - 1 for i in range(10): if (time_bin == i).sum() != 0: terms[f"bin_{i}"] = {"mse": mse_per_instance[time_bin == i].mean()} return terms, {} @torch.no_grad() def run_snapshot( self, num_samples: int, batch_size: int, verbose: bool = False, ) -> Dict: dataloader = DataLoader( copy.deepcopy(self.dataset), batch_size=batch_size, shuffle=True, num_workers=0, collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None, ) # inference sampler = self.get_sampler() sample_gt = [] sample = [] cond_vis = [] for i in range(0, num_samples, batch_size): batch = min(batch_size, num_samples - i) data = next(iter(dataloader)) data = {k: v[:batch].cuda() if not isinstance(v, list) else v[:batch] for k, v in data.items()} noise = data['x_0'].replace(torch.randn_like(data['x_0'].feats)) sample_gt.append(data['x_0']) cond_vis.append(self.vis_cond(**data)) del data['x_0'] args = self.get_inference_cond(**data) res = sampler.sample( self.models['denoiser'], noise=noise, **args, steps=50, cfg_strength=3.0, verbose=verbose, ) sample.append(res.samples) sample_gt = sp.sparse_cat(sample_gt) sample = sp.sparse_cat(sample) sample_dict = { 'sample_gt': {'value': sample_gt, 'type': 'sample'}, 'sample': {'value': sample, 'type': 'sample'}, } sample_dict.update(dict_reduce(cond_vis, None, { 'value': lambda x: torch.cat(x, dim=0), 'type': lambda x: x[0], })) return sample_dict class SparseFlowMatchingCFGTrainer(ClassifierFreeGuidanceMixin, SparseFlowMatchingTrainer): """ Trainer for sparse diffusion model with flow matching objective and classifier-free guidance. Args: models (dict[str, nn.Module]): Models to train. dataset (torch.utils.data.Dataset): Dataset. output_dir (str): Output directory. load_dir (str): Load directory. step (int): Step to load. batch_size (int): Batch size. batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. batch_split (int): Split batch with gradient accumulation. max_steps (int): Max steps. optimizer (dict): Optimizer config. lr_scheduler (dict): Learning rate scheduler config. elastic (dict): Elastic memory management config. grad_clip (float or dict): Gradient clip config. ema_rate (float or list): Exponential moving average rates. fp16_mode (str): FP16 mode. - None: No FP16. - 'inflat_all': Hold a inflated fp32 master param for all params. - 'amp': Automatic mixed precision. fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. finetune_ckpt (dict): Finetune checkpoint. log_param_stats (bool): Log parameter stats. i_print (int): Print interval. i_log (int): Log interval. i_sample (int): Sample interval. i_save (int): Save interval. i_ddpcheck (int): DDP check interval. t_schedule (dict): Time schedule for flow matching. sigma_min (float): Minimum noise level. p_uncond (float): Probability of dropping conditions. """ pass class TextConditionedSparseFlowMatchingCFGTrainer(TextConditionedMixin, SparseFlowMatchingCFGTrainer): """ Trainer for sparse text-conditioned diffusion model with flow matching objective and classifier-free guidance. Args: models (dict[str, nn.Module]): Models to train. dataset (torch.utils.data.Dataset): Dataset. output_dir (str): Output directory. load_dir (str): Load directory. step (int): Step to load. batch_size (int): Batch size. batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. batch_split (int): Split batch with gradient accumulation. max_steps (int): Max steps. optimizer (dict): Optimizer config. lr_scheduler (dict): Learning rate scheduler config. elastic (dict): Elastic memory management config. grad_clip (float or dict): Gradient clip config. ema_rate (float or list): Exponential moving average rates. fp16_mode (str): FP16 mode. - None: No FP16. - 'inflat_all': Hold a inflated fp32 master param for all params. - 'amp': Automatic mixed precision. fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. finetune_ckpt (dict): Finetune checkpoint. log_param_stats (bool): Log parameter stats. i_print (int): Print interval. i_log (int): Log interval. i_sample (int): Sample interval. i_save (int): Save interval. i_ddpcheck (int): DDP check interval. t_schedule (dict): Time schedule for flow matching. sigma_min (float): Minimum noise level. p_uncond (float): Probability of dropping conditions. text_cond_model(str): Text conditioning model. """ pass class ImageConditionedSparseFlowMatchingCFGTrainer(ImageConditionedMixin, SparseFlowMatchingCFGTrainer): """ Trainer for sparse image-conditioned diffusion model with flow matching objective and classifier-free guidance. Args: models (dict[str, nn.Module]): Models to train. dataset (torch.utils.data.Dataset): Dataset. output_dir (str): Output directory. load_dir (str): Load directory. step (int): Step to load. batch_size (int): Batch size. batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. batch_split (int): Split batch with gradient accumulation. max_steps (int): Max steps. optimizer (dict): Optimizer config. lr_scheduler (dict): Learning rate scheduler config. elastic (dict): Elastic memory management config. grad_clip (float or dict): Gradient clip config. ema_rate (float or list): Exponential moving average rates. fp16_mode (str): FP16 mode. - None: No FP16. - 'inflat_all': Hold a inflated fp32 master param for all params. - 'amp': Automatic mixed precision. fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. finetune_ckpt (dict): Finetune checkpoint. log_param_stats (bool): Log parameter stats. i_print (int): Print interval. i_log (int): Log interval. i_sample (int): Sample interval. i_save (int): Save interval. i_ddpcheck (int): DDP check interval. t_schedule (dict): Time schedule for flow matching. sigma_min (float): Minimum noise level. p_uncond (float): Probability of dropping conditions. image_cond_model (str): Image conditioning model. """ pass