File size: 7,820 Bytes
fd5e0f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import torch
from diffusers import LCMScheduler
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import *


class Hack_SDPipe_Stepwise(StableDiffusionPipeline):

    @torch.no_grad()
    def _use_lcm(self,use=True,ckpt='"latent-consistency/lcm-lora-sdv1-5"'):
        if use:
            self.use_lcm = True
            adapter_id = ckpt
            self.scheduler = LCMScheduler.from_config(self.scheduler.config)
            # load and fuse lcm lora
            self._guidance_scale = 0.0
            self.load_lora_weights(adapter_id)
            self.fuse_lora()
        else:
            self.use_lcm = False
            self._guidance_scale = 7.5

    @torch.no_grad()
    def re_init(self,num_inference_steps=50):
        # hyper-parameters
        eta = 0.0
        timesteps = None
        generator = None
        self._clip_skip = None
        self._interrupt = False
        self._guidance_rescale = 0.0
        self.added_cond_kwargs = None
        self._cross_attention_kwargs = None
        self._do_classifier_free_guidance = self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None

        # 2. Define call parameters
        batch_size = 1
        device = self._execution_device
        
        # 4. Prepare timesteps
        self.timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
        # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
        self.extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

        # 6.2 Optionally get Guidance Scale Embedding
        self.timestep_cond = None
        if self.unet.config.time_cond_proj_dim is not None:
            guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * 1)
            self.timestep_cond = self.get_guidance_scale_embedding(guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim).to(device=device)

    @torch.no_grad()
    def _encode_text_prompt(self,
                            prompt,
                            negative_prompt='fake,ugly,unreal'):
        # 3. Encode input prompt
        lora_scale = (self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None)
        prompt_embeds, negative_prompt_embeds = self.encode_prompt(
            prompt,
            self._execution_device,
            1,
            self.do_classifier_free_guidance,
            negative_prompt,
            prompt_embeds=None,
            negative_prompt_embeds=None,
            lora_scale=lora_scale,
            clip_skip=self.clip_skip,
        )
        # For classifier free guidance, we need to do two forward passes.
        # Here we concatenate the unconditional and text embeddings into a single batch
        # to avoid doing two forward passes
        if self.do_classifier_free_guidance:
            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
        return prompt_embeds

    @torch.no_grad()
    def _step_noise(self,
                    latents,
                    time_step,
                    prompt_embeds):
        # expand the latents if we are doing classifier free guidance
        latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
        latent_model_input = self.scheduler.scale_model_input(latent_model_input, time_step)
        # predict the noise residual
        noise_pred = self.unet(
            latent_model_input,
            time_step,
            encoder_hidden_states=prompt_embeds,
            timestep_cond=self.timestep_cond,
            cross_attention_kwargs=self.cross_attention_kwargs,
            added_cond_kwargs=self.added_cond_kwargs,
            return_dict=False,
        )[0]
        # perform guidance
        if self.do_classifier_free_guidance:
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)

        if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
            # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
            noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
        return noise_pred

    # @torch.no_grad()
    def _encode(self, input):
        '''
        # single condition encoding 
        input: B3HW
        return: B4H'W'
        if low-vram: vae on cpu, input should also on cpu
        '''
        h = self.vae.encoder(input)
        moments = self.vae.quant_conv(h)
        mean, logvar = torch.chunk(moments, 2, dim=1)
        # scale latent
        latent = mean * self.vae.config.scaling_factor
        return latent
    
    def _decode(self, latent):
        '''
        single target decoding
        input: B4H'W'
        return: B3HW
        '''
        # scale latent
        latent = latent / self.vae.config.scaling_factor
        # decode
        z = self.vae.post_quant_conv(latent)
        output = self.vae.decoder(z)
        return output

    def _solve_x0_full_step(self, latents, noise_pred, t):
        self.alpha_t = torch.sqrt(self.scheduler.alphas_cumprod).to(t.device)
        self.sigma_t = torch.sqrt(1-self.scheduler.alphas_cumprod).to(t.device)
        a_t, s_t = self.alpha_t[t], self.sigma_t[t]
        x0_latents = (latents - s_t * noise_pred) / a_t
        x0 = self._decode(x0_latents)
        return x0_latents, x0
        
    def _solve_x0(self, latents, noise_pred, t):
        x0_latents = self.scheduler.step(noise_pred, t.squeeze(), latents)
        # note here must be a fake denoise
        self.scheduler._step_index-=1
        # results
        x0_latents = x0_latents.denoised if self.use_lcm else x0_latents.pred_original_sample
        x0 = self._decode(x0_latents)
        return x0_latents, x0

    def _step_denoise(self, latents, noise_pred, t):
        latents = self.scheduler.step(noise_pred, t.squeeze(), latents).prev_sample   
        return latents
    
    def xt_x0_noise(
        self,
        xt_latents: torch.Tensor,
        x0_latents: torch.Tensor,
        timesteps: torch.IntTensor,
        ) -> torch.Tensor:
        # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
        # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
        # for the subsequent add_noise calls
        alphas_cumprod = self.scheduler.alphas_cumprod.to(dtype=xt_latents.dtype,device=xt_latents.device)
        timesteps = timesteps.to(xt_latents.device)

        sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
        sqrt_alpha_prod = sqrt_alpha_prod.flatten()
        while len(sqrt_alpha_prod.shape) < len(xt_latents.shape):
            sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)

        sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
        sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
        while len(sqrt_one_minus_alpha_prod.shape) < len(xt_latents.shape):
            sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)

        noise = (xt_latents - sqrt_alpha_prod * x0_latents) / sqrt_one_minus_alpha_prod
        return noise
    
    def _solve_noise_given_x0_latent(self, latents, x0_latents, t):
        noise = self.xt_x0_noise(latents,x0_latents,t)
        # -------------------- noise for supervision -----------------
        if self.scheduler.config.prediction_type == "epsilon":
            noise = noise
        elif self.scheduler.config.prediction_type == "v_prediction":
            noise = self.scheduler.get_velocity(x0_latents, noise, t)
        # ------------------------------------------------------------
        return noise