Spaces:
Build error
Build error
File size: 7,820 Bytes
fd5e0f7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
import torch
from diffusers import LCMScheduler
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import *
class Hack_SDPipe_Stepwise(StableDiffusionPipeline):
@torch.no_grad()
def _use_lcm(self,use=True,ckpt='"latent-consistency/lcm-lora-sdv1-5"'):
if use:
self.use_lcm = True
adapter_id = ckpt
self.scheduler = LCMScheduler.from_config(self.scheduler.config)
# load and fuse lcm lora
self._guidance_scale = 0.0
self.load_lora_weights(adapter_id)
self.fuse_lora()
else:
self.use_lcm = False
self._guidance_scale = 7.5
@torch.no_grad()
def re_init(self,num_inference_steps=50):
# hyper-parameters
eta = 0.0
timesteps = None
generator = None
self._clip_skip = None
self._interrupt = False
self._guidance_rescale = 0.0
self.added_cond_kwargs = None
self._cross_attention_kwargs = None
self._do_classifier_free_guidance = self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
# 2. Define call parameters
batch_size = 1
device = self._execution_device
# 4. Prepare timesteps
self.timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
self.extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 6.2 Optionally get Guidance Scale Embedding
self.timestep_cond = None
if self.unet.config.time_cond_proj_dim is not None:
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * 1)
self.timestep_cond = self.get_guidance_scale_embedding(guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim).to(device=device)
@torch.no_grad()
def _encode_text_prompt(self,
prompt,
negative_prompt='fake,ugly,unreal'):
# 3. Encode input prompt
lora_scale = (self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None)
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
self._execution_device,
1,
self.do_classifier_free_guidance,
negative_prompt,
prompt_embeds=None,
negative_prompt_embeds=None,
lora_scale=lora_scale,
clip_skip=self.clip_skip,
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
return prompt_embeds
@torch.no_grad()
def _step_noise(self,
latents,
time_step,
prompt_embeds):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, time_step)
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
time_step,
encoder_hidden_states=prompt_embeds,
timestep_cond=self.timestep_cond,
cross_attention_kwargs=self.cross_attention_kwargs,
added_cond_kwargs=self.added_cond_kwargs,
return_dict=False,
)[0]
# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
return noise_pred
# @torch.no_grad()
def _encode(self, input):
'''
# single condition encoding
input: B3HW
return: B4H'W'
if low-vram: vae on cpu, input should also on cpu
'''
h = self.vae.encoder(input)
moments = self.vae.quant_conv(h)
mean, logvar = torch.chunk(moments, 2, dim=1)
# scale latent
latent = mean * self.vae.config.scaling_factor
return latent
def _decode(self, latent):
'''
single target decoding
input: B4H'W'
return: B3HW
'''
# scale latent
latent = latent / self.vae.config.scaling_factor
# decode
z = self.vae.post_quant_conv(latent)
output = self.vae.decoder(z)
return output
def _solve_x0_full_step(self, latents, noise_pred, t):
self.alpha_t = torch.sqrt(self.scheduler.alphas_cumprod).to(t.device)
self.sigma_t = torch.sqrt(1-self.scheduler.alphas_cumprod).to(t.device)
a_t, s_t = self.alpha_t[t], self.sigma_t[t]
x0_latents = (latents - s_t * noise_pred) / a_t
x0 = self._decode(x0_latents)
return x0_latents, x0
def _solve_x0(self, latents, noise_pred, t):
x0_latents = self.scheduler.step(noise_pred, t.squeeze(), latents)
# note here must be a fake denoise
self.scheduler._step_index-=1
# results
x0_latents = x0_latents.denoised if self.use_lcm else x0_latents.pred_original_sample
x0 = self._decode(x0_latents)
return x0_latents, x0
def _step_denoise(self, latents, noise_pred, t):
latents = self.scheduler.step(noise_pred, t.squeeze(), latents).prev_sample
return latents
def xt_x0_noise(
self,
xt_latents: torch.Tensor,
x0_latents: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
# for the subsequent add_noise calls
alphas_cumprod = self.scheduler.alphas_cumprod.to(dtype=xt_latents.dtype,device=xt_latents.device)
timesteps = timesteps.to(xt_latents.device)
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(xt_latents.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(xt_latents.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
noise = (xt_latents - sqrt_alpha_prod * x0_latents) / sqrt_one_minus_alpha_prod
return noise
def _solve_noise_given_x0_latent(self, latents, x0_latents, t):
noise = self.xt_x0_noise(latents,x0_latents,t)
# -------------------- noise for supervision -----------------
if self.scheduler.config.prediction_type == "epsilon":
noise = noise
elif self.scheduler.config.prediction_type == "v_prediction":
noise = self.scheduler.get_velocity(x0_latents, noise, t)
# ------------------------------------------------------------
return noise
|