junbiao.chen
Trellis update
cc0c59d
from typing import *
import copy
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
from easydict import EasyDict as edict
from ..basic import BasicTrainer
from ...pipelines import samplers
from ...utils.general_utils import dict_reduce
from .mixins.classifier_free_guidance import ClassifierFreeGuidanceMixin
from .mixins.text_conditioned import TextConditionedMixin
from .mixins.image_conditioned import ImageConditionedMixin
class FlowMatchingTrainer(BasicTrainer):
"""
Trainer for diffusion model with flow matching objective.
Args:
models (dict[str, nn.Module]): Models to train.
dataset (torch.utils.data.Dataset): Dataset.
output_dir (str): Output directory.
load_dir (str): Load directory.
step (int): Step to load.
batch_size (int): Batch size.
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
batch_split (int): Split batch with gradient accumulation.
max_steps (int): Max steps.
optimizer (dict): Optimizer config.
lr_scheduler (dict): Learning rate scheduler config.
elastic (dict): Elastic memory management config.
grad_clip (float or dict): Gradient clip config.
ema_rate (float or list): Exponential moving average rates.
fp16_mode (str): FP16 mode.
- None: No FP16.
- 'inflat_all': Hold a inflated fp32 master param for all params.
- 'amp': Automatic mixed precision.
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
finetune_ckpt (dict): Finetune checkpoint.
log_param_stats (bool): Log parameter stats.
i_print (int): Print interval.
i_log (int): Log interval.
i_sample (int): Sample interval.
i_save (int): Save interval.
i_ddpcheck (int): DDP check interval.
t_schedule (dict): Time schedule for flow matching.
sigma_min (float): Minimum noise level.
"""
def __init__(
self,
*args,
t_schedule: dict = {
'name': 'logitNormal',
'args': {
'mean': 0.0,
'std': 1.0,
}
},
sigma_min: float = 1e-5,
**kwargs
):
super().__init__(*args, **kwargs)
self.t_schedule = t_schedule
self.sigma_min = sigma_min
def diffuse(self, x_0: torch.Tensor, t: torch.Tensor, noise: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Diffuse the data for a given number of diffusion steps.
In other words, sample from q(x_t | x_0).
Args:
x_0: The [N x C x ...] tensor of noiseless inputs.
t: The [N] tensor of diffusion steps [0-1].
noise: If specified, use this noise instead of generating new noise.
Returns:
x_t, the noisy version of x_0 under timestep t.
"""
if noise is None:
noise = torch.randn_like(x_0)
assert noise.shape == x_0.shape, "noise must have same shape as x_0"
t = t.view(-1, *[1 for _ in range(len(x_0.shape) - 1)])
x_t = (1 - t) * x_0 + (self.sigma_min + (1 - self.sigma_min) * t) * noise
return x_t
def reverse_diffuse(self, x_t: torch.Tensor, t: torch.Tensor, noise: torch.Tensor) -> torch.Tensor:
"""
Get original image from noisy version under timestep t.
"""
assert noise.shape == x_t.shape, "noise must have same shape as x_t"
t = t.view(-1, *[1 for _ in range(len(x_t.shape) - 1)])
x_0 = (x_t - (self.sigma_min + (1 - self.sigma_min) * t) * noise) / (1 - t)
return x_0
def get_v(self, x_0: torch.Tensor, noise: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
"""
Compute the velocity of the diffusion process at time t.
"""
return (1 - self.sigma_min) * noise - x_0
def get_cond(self, cond, **kwargs):
"""
Get the conditioning data.
"""
return cond
def get_inference_cond(self, cond, **kwargs):
"""
Get the conditioning data for inference.
"""
return {'cond': cond, **kwargs}
def get_sampler(self, **kwargs) -> samplers.FlowEulerSampler:
"""
Get the sampler for the diffusion process.
"""
return samplers.FlowEulerSampler(self.sigma_min)
def vis_cond(self, **kwargs):
"""
Visualize the conditioning data.
"""
return {}
def sample_t(self, batch_size: int) -> torch.Tensor:
"""
Sample timesteps.
"""
if self.t_schedule['name'] == 'uniform':
t = torch.rand(batch_size)
elif self.t_schedule['name'] == 'logitNormal':
mean = self.t_schedule['args']['mean']
std = self.t_schedule['args']['std']
t = torch.sigmoid(torch.randn(batch_size) * std + mean)
else:
raise ValueError(f"Unknown t_schedule: {self.t_schedule['name']}")
return t
def training_losses(
self,
x_0: torch.Tensor,
cond=None,
**kwargs
) -> Tuple[Dict, Dict]:
"""
Compute training losses for a single timestep.
Args:
x_0: The [N x C x ...] tensor of noiseless inputs.
cond: The [N x ...] tensor of additional conditions.
kwargs: Additional arguments to pass to the backbone.
Returns:
a dict with the key "loss" containing a tensor of shape [N].
may also contain other keys for different terms.
"""
noise = torch.randn_like(x_0)
t = self.sample_t(x_0.shape[0]).to(x_0.device).float()
x_t = self.diffuse(x_0, t, noise=noise)
cond = self.get_cond(cond, **kwargs)
pred = self.training_models['denoiser'](x_t, t * 1000, cond, **kwargs)
assert pred.shape == noise.shape == x_0.shape
target = self.get_v(x_0, noise, t)
terms = edict()
terms["mse"] = F.mse_loss(pred, target)
terms["loss"] = terms["mse"]
# log loss with time bins
mse_per_instance = np.array([
F.mse_loss(pred[i], target[i]).item()
for i in range(x_0.shape[0])
])
time_bin = np.digitize(t.cpu().numpy(), np.linspace(0, 1, 11)) - 1
for i in range(10):
if (time_bin == i).sum() != 0:
terms[f"bin_{i}"] = {"mse": mse_per_instance[time_bin == i].mean()}
return terms, {}
@torch.no_grad()
def run_snapshot(
self,
num_samples: int,
batch_size: int,
verbose: bool = False,
) -> Dict:
dataloader = DataLoader(
copy.deepcopy(self.dataset),
batch_size=batch_size,
shuffle=True,
num_workers=0,
collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
)
# inference
sampler = self.get_sampler()
sample_gt = []
sample = []
cond_vis = []
for i in range(0, num_samples, batch_size):
batch = min(batch_size, num_samples - i)
data = next(iter(dataloader))
data = {k: v[:batch].cuda() if isinstance(v, torch.Tensor) else v[:batch] for k, v in data.items()}
noise = torch.randn_like(data['x_0'])
sample_gt.append(data['x_0'])
cond_vis.append(self.vis_cond(**data))
del data['x_0']
args = self.get_inference_cond(**data)
res = sampler.sample(
self.models['denoiser'],
noise=noise,
**args,
steps=50, cfg_strength=3.0, verbose=verbose,
)
sample.append(res.samples)
sample_gt = torch.cat(sample_gt, dim=0)
sample = torch.cat(sample, dim=0)
sample_dict = {
'sample_gt': {'value': sample_gt, 'type': 'sample'},
'sample': {'value': sample, 'type': 'sample'},
}
sample_dict.update(dict_reduce(cond_vis, None, {
'value': lambda x: torch.cat(x, dim=0),
'type': lambda x: x[0],
}))
return sample_dict
class FlowMatchingCFGTrainer(ClassifierFreeGuidanceMixin, FlowMatchingTrainer):
"""
Trainer for diffusion model with flow matching objective and classifier-free guidance.
Args:
models (dict[str, nn.Module]): Models to train.
dataset (torch.utils.data.Dataset): Dataset.
output_dir (str): Output directory.
load_dir (str): Load directory.
step (int): Step to load.
batch_size (int): Batch size.
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
batch_split (int): Split batch with gradient accumulation.
max_steps (int): Max steps.
optimizer (dict): Optimizer config.
lr_scheduler (dict): Learning rate scheduler config.
elastic (dict): Elastic memory management config.
grad_clip (float or dict): Gradient clip config.
ema_rate (float or list): Exponential moving average rates.
fp16_mode (str): FP16 mode.
- None: No FP16.
- 'inflat_all': Hold a inflated fp32 master param for all params.
- 'amp': Automatic mixed precision.
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
finetune_ckpt (dict): Finetune checkpoint.
log_param_stats (bool): Log parameter stats.
i_print (int): Print interval.
i_log (int): Log interval.
i_sample (int): Sample interval.
i_save (int): Save interval.
i_ddpcheck (int): DDP check interval.
t_schedule (dict): Time schedule for flow matching.
sigma_min (float): Minimum noise level.
p_uncond (float): Probability of dropping conditions.
"""
pass
class TextConditionedFlowMatchingCFGTrainer(TextConditionedMixin, FlowMatchingCFGTrainer):
"""
Trainer for text-conditioned diffusion model with flow matching objective and classifier-free guidance.
Args:
models (dict[str, nn.Module]): Models to train.
dataset (torch.utils.data.Dataset): Dataset.
output_dir (str): Output directory.
load_dir (str): Load directory.
step (int): Step to load.
batch_size (int): Batch size.
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
batch_split (int): Split batch with gradient accumulation.
max_steps (int): Max steps.
optimizer (dict): Optimizer config.
lr_scheduler (dict): Learning rate scheduler config.
elastic (dict): Elastic memory management config.
grad_clip (float or dict): Gradient clip config.
ema_rate (float or list): Exponential moving average rates.
fp16_mode (str): FP16 mode.
- None: No FP16.
- 'inflat_all': Hold a inflated fp32 master param for all params.
- 'amp': Automatic mixed precision.
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
finetune_ckpt (dict): Finetune checkpoint.
log_param_stats (bool): Log parameter stats.
i_print (int): Print interval.
i_log (int): Log interval.
i_sample (int): Sample interval.
i_save (int): Save interval.
i_ddpcheck (int): DDP check interval.
t_schedule (dict): Time schedule for flow matching.
sigma_min (float): Minimum noise level.
p_uncond (float): Probability of dropping conditions.
text_cond_model(str): Text conditioning model.
"""
pass
class ImageConditionedFlowMatchingCFGTrainer(ImageConditionedMixin, FlowMatchingCFGTrainer):
"""
Trainer for image-conditioned diffusion model with flow matching objective and classifier-free guidance.
Args:
models (dict[str, nn.Module]): Models to train.
dataset (torch.utils.data.Dataset): Dataset.
output_dir (str): Output directory.
load_dir (str): Load directory.
step (int): Step to load.
batch_size (int): Batch size.
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
batch_split (int): Split batch with gradient accumulation.
max_steps (int): Max steps.
optimizer (dict): Optimizer config.
lr_scheduler (dict): Learning rate scheduler config.
elastic (dict): Elastic memory management config.
grad_clip (float or dict): Gradient clip config.
ema_rate (float or list): Exponential moving average rates.
fp16_mode (str): FP16 mode.
- None: No FP16.
- 'inflat_all': Hold a inflated fp32 master param for all params.
- 'amp': Automatic mixed precision.
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
finetune_ckpt (dict): Finetune checkpoint.
log_param_stats (bool): Log parameter stats.
i_print (int): Print interval.
i_log (int): Log interval.
i_sample (int): Sample interval.
i_save (int): Save interval.
i_ddpcheck (int): DDP check interval.
t_schedule (dict): Time schedule for flow matching.
sigma_min (float): Minimum noise level.
p_uncond (float): Probability of dropping conditions.
image_cond_model (str): Image conditioning model.
"""
pass