Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,871 Bytes
efa71f7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
from typing import Any, Callable, Dict, Optional
import torch
import torch.nn as nn
from .gaussian_diffusion import GaussianDiffusion
from .k_diffusion import karras_sample
DEFAULT_KARRAS_STEPS = 64
DEFAULT_KARRAS_SIGMA_MIN = 1e-3
DEFAULT_KARRAS_SIGMA_MAX = 160
DEFAULT_KARRAS_S_CHURN = 0.0
def uncond_guide_model(
model: Callable[..., torch.Tensor], scale: float
) -> Callable[..., torch.Tensor]:
def model_fn(x_t, ts, **kwargs):
half = x_t[: len(x_t) // 2]
combined = torch.cat([half, half], dim=0)
model_out = model(combined, ts, **kwargs)
eps, rest = model_out[:, :3], model_out[:, 3:]
cond_eps, uncond_eps = torch.chunk(eps, 2, dim=0)
half_eps = uncond_eps + scale * (cond_eps - uncond_eps)
eps = torch.cat([half_eps, half_eps], dim=0)
return torch.cat([eps, rest], dim=1)
return model_fn
def sample_latents(
*,
batch_size: int,
model: nn.Module,
diffusion: GaussianDiffusion,
model_kwargs: Dict[str, Any],
guidance_scale: float,
clip_denoised: bool,
use_fp16: bool,
use_karras: bool,
karras_steps: int,
sigma_min: float,
sigma_max: float,
s_churn: float,
device: Optional[torch.device] = None,
progress: bool = False,
) -> torch.Tensor:
sample_shape = (batch_size, model.d_latent)
if device is None:
device = next(model.parameters()).device
if hasattr(model, "cached_model_kwargs"):
model_kwargs = model.cached_model_kwargs(batch_size, model_kwargs)
if guidance_scale != 1.0 and guidance_scale != 0.0:
for k, v in model_kwargs.copy().items():
model_kwargs[k] = torch.cat([v, torch.zeros_like(v)], dim=0)
sample_shape = (batch_size, model.d_latent)
with torch.autocast(device_type=device.type, enabled=use_fp16):
if use_karras:
samples = karras_sample(
diffusion=diffusion,
model=model,
shape=sample_shape,
steps=karras_steps,
clip_denoised=clip_denoised,
model_kwargs=model_kwargs,
device=device,
sigma_min=sigma_min,
sigma_max=sigma_max,
s_churn=s_churn,
guidance_scale=guidance_scale,
progress=progress,
)
else:
internal_batch_size = batch_size
if guidance_scale != 1.0:
model = uncond_guide_model(model, guidance_scale)
internal_batch_size *= 2
samples = diffusion.p_sample_loop(
model,
shape=(internal_batch_size, *sample_shape[1:]),
model_kwargs=model_kwargs,
device=device,
clip_denoised=clip_denoised,
progress=progress,
)
return samples
|