Spaces:
Running
on
Zero
Running
on
Zero
from typing import Any, Callable, Dict, Optional | |
import torch | |
import torch.nn as nn | |
from .gaussian_diffusion import GaussianDiffusion | |
from .k_diffusion import karras_sample | |
DEFAULT_KARRAS_STEPS = 64 | |
DEFAULT_KARRAS_SIGMA_MIN = 1e-3 | |
DEFAULT_KARRAS_SIGMA_MAX = 160 | |
DEFAULT_KARRAS_S_CHURN = 0.0 | |
def uncond_guide_model( | |
model: Callable[..., torch.Tensor], scale: float | |
) -> Callable[..., torch.Tensor]: | |
def model_fn(x_t, ts, **kwargs): | |
half = x_t[: len(x_t) // 2] | |
combined = torch.cat([half, half], dim=0) | |
model_out = model(combined, ts, **kwargs) | |
eps, rest = model_out[:, :3], model_out[:, 3:] | |
cond_eps, uncond_eps = torch.chunk(eps, 2, dim=0) | |
half_eps = uncond_eps + scale * (cond_eps - uncond_eps) | |
eps = torch.cat([half_eps, half_eps], dim=0) | |
return torch.cat([eps, rest], dim=1) | |
return model_fn | |
def sample_latents( | |
*, | |
batch_size: int, | |
model: nn.Module, | |
diffusion: GaussianDiffusion, | |
model_kwargs: Dict[str, Any], | |
guidance_scale: float, | |
clip_denoised: bool, | |
use_fp16: bool, | |
use_karras: bool, | |
karras_steps: int, | |
sigma_min: float, | |
sigma_max: float, | |
s_churn: float, | |
device: Optional[torch.device] = None, | |
progress: bool = False, | |
) -> torch.Tensor: | |
sample_shape = (batch_size, model.d_latent) | |
if device is None: | |
device = next(model.parameters()).device | |
if hasattr(model, "cached_model_kwargs"): | |
model_kwargs = model.cached_model_kwargs(batch_size, model_kwargs) | |
if guidance_scale != 1.0 and guidance_scale != 0.0: | |
for k, v in model_kwargs.copy().items(): | |
model_kwargs[k] = torch.cat([v, torch.zeros_like(v)], dim=0) | |
sample_shape = (batch_size, model.d_latent) | |
with torch.autocast(device_type=device.type, enabled=use_fp16): | |
if use_karras: | |
samples = karras_sample( | |
diffusion=diffusion, | |
model=model, | |
shape=sample_shape, | |
steps=karras_steps, | |
clip_denoised=clip_denoised, | |
model_kwargs=model_kwargs, | |
device=device, | |
sigma_min=sigma_min, | |
sigma_max=sigma_max, | |
s_churn=s_churn, | |
guidance_scale=guidance_scale, | |
progress=progress, | |
) | |
else: | |
internal_batch_size = batch_size | |
if guidance_scale != 1.0: | |
model = uncond_guide_model(model, guidance_scale) | |
internal_batch_size *= 2 | |
samples = diffusion.p_sample_loop( | |
model, | |
shape=(internal_batch_size, *sample_shape[1:]), | |
model_kwargs=model_kwargs, | |
device=device, | |
clip_denoised=clip_denoised, | |
progress=progress, | |
) | |
return samples | |