turboedit commited on
Commit
29303b0
·
verified ·
1 Parent(s): c813f05

Upload turbo_edit/utils.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. turbo_edit/utils.py +1357 -0
turbo_edit/utils.py ADDED
@@ -0,0 +1,1357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ from typing import List, Optional, Union
3
+ import PIL
4
+ import PIL.Image
5
+ import torch
6
+ from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput
7
+ from diffusers.utils import make_image_grid
8
+ from PIL import Image, ImageDraw, ImageFont
9
+ import os
10
+ from diffusers.utils import (
11
+ logging,
12
+ USE_PEFT_BACKEND,
13
+ scale_lora_layers,
14
+ unscale_lora_layers,
15
+ )
16
+ from diffusers.loaders import (
17
+ StableDiffusionXLLoraLoaderMixin,
18
+ )
19
+ from diffusers.image_processor import VaeImageProcessor
20
+
21
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
22
+
23
+ from diffusers.models.lora import adjust_lora_scale_text_encoder
24
+ from diffusers import DiffusionPipeline
25
+
26
+
27
+ VECTOR_DATA_FOLDER = "vector_data"
28
+ VECTOR_DATA_DICT = "vector_data"
29
+
30
+
31
+ def encode_image(image: PIL.Image, pipe: DiffusionPipeline):
32
+ pipe.image_processor: VaeImageProcessor = pipe.image_processor # type: ignore
33
+ image = pipe.image_processor.pil_to_numpy(image)
34
+ image = pipe.image_processor.numpy_to_pt(image)
35
+ image = image.to(pipe.device)
36
+ return (
37
+ pipe.vae.encode(
38
+ pipe.image_processor.preprocess(image),
39
+ ).latent_dist.mode()
40
+ * pipe.vae.config.scaling_factor
41
+ )
42
+
43
+
44
+ def decode_latents(latent, pipe):
45
+ latent_img = pipe.vae.decode(
46
+ latent / pipe.vae.config.scaling_factor, return_dict=False
47
+ )[0]
48
+ return pipe.image_processor.postprocess(latent_img, output_type="pil")
49
+
50
+
51
+ def get_device(argv, args=None):
52
+ import sys
53
+
54
+ def debugger_is_active():
55
+ return hasattr(sys, "gettrace") and sys.gettrace() is not None
56
+
57
+ if args:
58
+ return (
59
+ torch.device("cuda")
60
+ if (torch.cuda.is_available() and not debugger_is_active())
61
+ and not args.force_use_cpu
62
+ else torch.device("cpu")
63
+ )
64
+
65
+ return (
66
+ torch.device("cuda")
67
+ if (torch.cuda.is_available() and not debugger_is_active())
68
+ and not "cpu" in set(argv[1:])
69
+ else torch.device("cpu")
70
+ )
71
+
72
+
73
+ def deterministic_ddim_step(
74
+ model_output: torch.FloatTensor,
75
+ timestep: int,
76
+ sample: torch.FloatTensor,
77
+ eta: float = 0.0,
78
+ use_clipped_model_output: bool = False,
79
+ generator=None,
80
+ variance_noise: Optional[torch.FloatTensor] = None,
81
+ return_dict: bool = True,
82
+ scheduler=None,
83
+ ):
84
+
85
+ if scheduler.num_inference_steps is None:
86
+ raise ValueError(
87
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
88
+ )
89
+
90
+ prev_timestep = (
91
+ timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps
92
+ )
93
+
94
+ # 2. compute alphas, betas
95
+ alpha_prod_t = scheduler.alphas_cumprod[timestep]
96
+ alpha_prod_t_prev = (
97
+ scheduler.alphas_cumprod[prev_timestep]
98
+ if prev_timestep >= 0
99
+ else scheduler.final_alpha_cumprod
100
+ )
101
+
102
+ beta_prod_t = 1 - alpha_prod_t
103
+
104
+ if scheduler.config.prediction_type == "epsilon":
105
+ pred_original_sample = (
106
+ sample - beta_prod_t ** (0.5) * model_output
107
+ ) / alpha_prod_t ** (0.5)
108
+ pred_epsilon = model_output
109
+ elif scheduler.config.prediction_type == "sample":
110
+ pred_original_sample = model_output
111
+ pred_epsilon = (
112
+ sample - alpha_prod_t ** (0.5) * pred_original_sample
113
+ ) / beta_prod_t ** (0.5)
114
+ elif scheduler.config.prediction_type == "v_prediction":
115
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (
116
+ beta_prod_t**0.5
117
+ ) * model_output
118
+ pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
119
+ else:
120
+ raise ValueError(
121
+ f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, `sample`, or"
122
+ " `v_prediction`"
123
+ )
124
+
125
+ # 4. Clip or threshold "predicted x_0"
126
+ if scheduler.config.thresholding:
127
+ pred_original_sample = scheduler._threshold_sample(pred_original_sample)
128
+ elif scheduler.config.clip_sample:
129
+ pred_original_sample = pred_original_sample.clamp(
130
+ -scheduler.config.clip_sample_range,
131
+ scheduler.config.clip_sample_range,
132
+ )
133
+
134
+ # 5. compute variance: "sigma_t(η)" -> see formula (16)
135
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
136
+ variance = scheduler._get_variance(timestep, prev_timestep)
137
+ std_dev_t = eta * variance ** (0.5)
138
+
139
+ if use_clipped_model_output:
140
+ # the pred_epsilon is always re-derived from the clipped x_0 in Glide
141
+ pred_epsilon = (
142
+ sample - alpha_prod_t ** (0.5) * pred_original_sample
143
+ ) / beta_prod_t ** (0.5)
144
+
145
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
146
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (
147
+ 0.5
148
+ ) * pred_epsilon
149
+
150
+ # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
151
+ prev_sample = (
152
+ alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
153
+ )
154
+ return prev_sample
155
+
156
+
157
+ def deterministic_euler_step(
158
+ model_output: torch.FloatTensor,
159
+ timestep: Union[float, torch.FloatTensor],
160
+ sample: torch.FloatTensor,
161
+ eta,
162
+ use_clipped_model_output,
163
+ generator,
164
+ variance_noise,
165
+ return_dict,
166
+ scheduler,
167
+ ):
168
+ """
169
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
170
+ process from the learned model outputs (most often the predicted noise).
171
+
172
+ Args:
173
+ model_output (`torch.FloatTensor`):
174
+ The direct output from learned diffusion model.
175
+ timestep (`float`):
176
+ The current discrete timestep in the diffusion chain.
177
+ sample (`torch.FloatTensor`):
178
+ A current instance of a sample created by the diffusion process.
179
+ generator (`torch.Generator`, *optional*):
180
+ A random number generator.
181
+ return_dict (`bool`):
182
+ Whether or not to return a
183
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple.
184
+
185
+ Returns:
186
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
187
+ If return_dict is `True`,
188
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned,
189
+ otherwise a tuple is returned where the first element is the sample tensor.
190
+
191
+ """
192
+
193
+ if (
194
+ isinstance(timestep, int)
195
+ or isinstance(timestep, torch.IntTensor)
196
+ or isinstance(timestep, torch.LongTensor)
197
+ ):
198
+ raise ValueError(
199
+ (
200
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
201
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
202
+ " one of the `scheduler.timesteps` as a timestep."
203
+ ),
204
+ )
205
+
206
+ if scheduler.step_index is None:
207
+ scheduler._init_step_index(timestep)
208
+
209
+ sigma = scheduler.sigmas[scheduler.step_index]
210
+
211
+ # Upcast to avoid precision issues when computing prev_sample
212
+ sample = sample.to(torch.float32)
213
+
214
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
215
+ if scheduler.config.prediction_type == "epsilon":
216
+ pred_original_sample = sample - sigma * model_output
217
+ elif scheduler.config.prediction_type == "v_prediction":
218
+ # * c_out + input * c_skip
219
+ pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (
220
+ sample / (sigma**2 + 1)
221
+ )
222
+ elif scheduler.config.prediction_type == "sample":
223
+ raise NotImplementedError("prediction_type not implemented yet: sample")
224
+ else:
225
+ raise ValueError(
226
+ f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
227
+ )
228
+
229
+ sigma_from = scheduler.sigmas[scheduler.step_index]
230
+ sigma_to = scheduler.sigmas[scheduler.step_index + 1]
231
+ sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
232
+ sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
233
+
234
+ # 2. Convert to an ODE derivative
235
+ derivative = (sample - pred_original_sample) / sigma
236
+
237
+ dt = sigma_down - sigma
238
+
239
+ prev_sample = sample + derivative * dt
240
+
241
+ # Cast sample back to model compatible dtype
242
+ prev_sample = prev_sample.to(model_output.dtype)
243
+
244
+ # upon completion increase step index by one
245
+ scheduler._step_index += 1
246
+
247
+ return prev_sample
248
+
249
+
250
+ def deterministic_non_ancestral_euler_step(
251
+ model_output: torch.FloatTensor,
252
+ timestep: Union[float, torch.FloatTensor],
253
+ sample: torch.FloatTensor,
254
+ eta: float = 0.0,
255
+ use_clipped_model_output: bool = False,
256
+ s_churn: float = 0.0,
257
+ s_tmin: float = 0.0,
258
+ s_tmax: float = float("inf"),
259
+ s_noise: float = 1.0,
260
+ generator: Optional[torch.Generator] = None,
261
+ variance_noise: Optional[torch.FloatTensor] = None,
262
+ return_dict: bool = True,
263
+ scheduler=None,
264
+ ):
265
+ """
266
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
267
+ process from the learned model outputs (most often the predicted noise).
268
+
269
+ Args:
270
+ model_output (`torch.FloatTensor`):
271
+ The direct output from learned diffusion model.
272
+ timestep (`float`):
273
+ The current discrete timestep in the diffusion chain.
274
+ sample (`torch.FloatTensor`):
275
+ A current instance of a sample created by the diffusion process.
276
+ s_churn (`float`):
277
+ s_tmin (`float`):
278
+ s_tmax (`float`):
279
+ s_noise (`float`, defaults to 1.0):
280
+ Scaling factor for noise added to the sample.
281
+ generator (`torch.Generator`, *optional*):
282
+ A random number generator.
283
+ return_dict (`bool`):
284
+ Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
285
+ tuple.
286
+
287
+ Returns:
288
+ [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
289
+ If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
290
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
291
+ """
292
+
293
+ if (
294
+ isinstance(timestep, int)
295
+ or isinstance(timestep, torch.IntTensor)
296
+ or isinstance(timestep, torch.LongTensor)
297
+ ):
298
+ raise ValueError(
299
+ (
300
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
301
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
302
+ " one of the `scheduler.timesteps` as a timestep."
303
+ ),
304
+ )
305
+
306
+ if not scheduler.is_scale_input_called:
307
+ logger.warning(
308
+ "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
309
+ "See `StableDiffusionPipeline` for a usage example."
310
+ )
311
+
312
+ if scheduler.step_index is None:
313
+ scheduler._init_step_index(timestep)
314
+
315
+ # Upcast to avoid precision issues when computing prev_sample
316
+ sample = sample.to(torch.float32)
317
+
318
+ sigma = scheduler.sigmas[scheduler.step_index]
319
+
320
+ gamma = (
321
+ min(s_churn / (len(scheduler.sigmas) - 1), 2**0.5 - 1)
322
+ if s_tmin <= sigma <= s_tmax
323
+ else 0.0
324
+ )
325
+
326
+ sigma_hat = sigma * (gamma + 1)
327
+
328
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
329
+ # NOTE: "original_sample" should not be an expected prediction_type but is left in for
330
+ # backwards compatibility
331
+ if (
332
+ scheduler.config.prediction_type == "original_sample"
333
+ or scheduler.config.prediction_type == "sample"
334
+ ):
335
+ pred_original_sample = model_output
336
+ elif scheduler.config.prediction_type == "epsilon":
337
+ pred_original_sample = sample - sigma_hat * model_output
338
+ elif scheduler.config.prediction_type == "v_prediction":
339
+ # denoised = model_output * c_out + input * c_skip
340
+ pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (
341
+ sample / (sigma**2 + 1)
342
+ )
343
+ else:
344
+ raise ValueError(
345
+ f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
346
+ )
347
+
348
+ # 2. Convert to an ODE derivative
349
+ derivative = (sample - pred_original_sample) / sigma_hat
350
+
351
+ dt = scheduler.sigmas[scheduler.step_index + 1] - sigma_hat
352
+
353
+ prev_sample = sample + derivative * dt
354
+
355
+ # Cast sample back to model compatible dtype
356
+ prev_sample = prev_sample.to(model_output.dtype)
357
+
358
+ # upon completion increase step index by one
359
+ scheduler._step_index += 1
360
+
361
+ return prev_sample
362
+
363
+
364
+ def deterministic_ddpm_step(
365
+ model_output: torch.FloatTensor,
366
+ timestep: Union[float, torch.FloatTensor],
367
+ sample: torch.FloatTensor,
368
+ eta,
369
+ use_clipped_model_output,
370
+ generator,
371
+ variance_noise,
372
+ return_dict,
373
+ scheduler,
374
+ ):
375
+ """
376
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
377
+ process from the learned model outputs (most often the predicted noise).
378
+
379
+ Args:
380
+ model_output (`torch.FloatTensor`):
381
+ The direct output from learned diffusion model.
382
+ timestep (`float`):
383
+ The current discrete timestep in the diffusion chain.
384
+ sample (`torch.FloatTensor`):
385
+ A current instance of a sample created by the diffusion process.
386
+ generator (`torch.Generator`, *optional*):
387
+ A random number generator.
388
+ return_dict (`bool`, *optional*, defaults to `True`):
389
+ Whether or not to return a [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`.
390
+
391
+ Returns:
392
+ [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`:
393
+ If return_dict is `True`, [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] is returned, otherwise a
394
+ tuple is returned where the first element is the sample tensor.
395
+
396
+ """
397
+ t = timestep
398
+
399
+ prev_t = scheduler.previous_timestep(t)
400
+
401
+ if model_output.shape[1] == sample.shape[1] * 2 and scheduler.variance_type in [
402
+ "learned",
403
+ "learned_range",
404
+ ]:
405
+ model_output, predicted_variance = torch.split(
406
+ model_output, sample.shape[1], dim=1
407
+ )
408
+ else:
409
+ predicted_variance = None
410
+
411
+ # 1. compute alphas, betas
412
+ alpha_prod_t = scheduler.alphas_cumprod[t]
413
+ alpha_prod_t_prev = (
414
+ scheduler.alphas_cumprod[prev_t] if prev_t >= 0 else scheduler.one
415
+ )
416
+ beta_prod_t = 1 - alpha_prod_t
417
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
418
+ current_alpha_t = alpha_prod_t / alpha_prod_t_prev
419
+ current_beta_t = 1 - current_alpha_t
420
+
421
+ # 2. compute predicted original sample from predicted noise also called
422
+ # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
423
+ if scheduler.config.prediction_type == "epsilon":
424
+ pred_original_sample = (
425
+ sample - beta_prod_t ** (0.5) * model_output
426
+ ) / alpha_prod_t ** (0.5)
427
+ elif scheduler.config.prediction_type == "sample":
428
+ pred_original_sample = model_output
429
+ elif scheduler.config.prediction_type == "v_prediction":
430
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (
431
+ beta_prod_t**0.5
432
+ ) * model_output
433
+ else:
434
+ raise ValueError(
435
+ f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, `sample` or"
436
+ " `v_prediction` for the DDPMScheduler."
437
+ )
438
+
439
+ # 3. Clip or threshold "predicted x_0"
440
+ if scheduler.config.thresholding:
441
+ pred_original_sample = scheduler._threshold_sample(pred_original_sample)
442
+ elif scheduler.config.clip_sample:
443
+ pred_original_sample = pred_original_sample.clamp(
444
+ -scheduler.config.clip_sample_range, scheduler.config.clip_sample_range
445
+ )
446
+
447
+ # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
448
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
449
+ pred_original_sample_coeff = (
450
+ alpha_prod_t_prev ** (0.5) * current_beta_t
451
+ ) / beta_prod_t
452
+ current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t
453
+
454
+ # 5. Compute predicted previous sample µ_t
455
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
456
+ pred_prev_sample = (
457
+ pred_original_sample_coeff * pred_original_sample
458
+ + current_sample_coeff * sample
459
+ )
460
+
461
+ return pred_prev_sample
462
+
463
+
464
+ def normalize(
465
+ z_t,
466
+ i,
467
+ max_norm_zs,
468
+ ):
469
+ max_norm = max_norm_zs[i]
470
+ if max_norm < 0:
471
+ return z_t, 1
472
+
473
+ norm = torch.norm(z_t)
474
+ if norm < max_norm:
475
+ return z_t, 1
476
+
477
+ coeff = max_norm / norm
478
+ z_t = z_t * coeff
479
+ return z_t, coeff
480
+
481
+
482
+ def find_index(timesteps, timestep):
483
+ for i, t in enumerate(timesteps):
484
+ if t == timestep:
485
+ return i
486
+ return -1
487
+
488
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
489
+ map_timpstep_to_index = {
490
+ torch.tensor(799): 0,
491
+ torch.tensor(599): 1,
492
+ torch.tensor(399): 2,
493
+ torch.tensor(199): 3,
494
+ torch.tensor(799, device=device): 0,
495
+ torch.tensor(599, device=device): 1,
496
+ torch.tensor(399, device=device): 2,
497
+ torch.tensor(199, device=device): 3,
498
+ }
499
+
500
+ def step_save_latents(
501
+ self,
502
+ model_output: torch.FloatTensor,
503
+ timestep: int,
504
+ sample: torch.FloatTensor,
505
+ eta: float = 0.0,
506
+ use_clipped_model_output: bool = False,
507
+ generator=None,
508
+ variance_noise: Optional[torch.FloatTensor] = None,
509
+ return_dict: bool = True,
510
+ ):
511
+ # print(self._save_timesteps)
512
+ # timestep_index = map_timpstep_to_index[timestep]
513
+ # timestep_index = ((self._save_timesteps == timestep).nonzero(as_tuple=True)[0]).item()
514
+ timestep_index = self._save_timesteps.index(timestep) if not self.clean_step_run else -1
515
+ next_timestep_index = timestep_index + 1 if not self.clean_step_run else -1
516
+ u_hat_t = self.step_function(
517
+ model_output=model_output,
518
+ timestep=timestep,
519
+ sample=sample,
520
+ eta=eta,
521
+ use_clipped_model_output=use_clipped_model_output,
522
+ generator=generator,
523
+ variance_noise=variance_noise,
524
+ return_dict=False,
525
+ scheduler=self,
526
+ )
527
+
528
+ x_t_minus_1 = self.x_ts[next_timestep_index]
529
+ self.x_ts_c_hat.append(u_hat_t)
530
+
531
+ z_t = x_t_minus_1 - u_hat_t
532
+ self.latents.append(z_t)
533
+
534
+ z_t, _ = normalize(z_t, timestep_index, self._config.max_norm_zs)
535
+
536
+ x_t_minus_1_predicted = u_hat_t + z_t
537
+
538
+ if not return_dict:
539
+ return (x_t_minus_1_predicted,)
540
+
541
+ return DDIMSchedulerOutput(prev_sample=x_t_minus_1, pred_original_sample=None)
542
+
543
+
544
+ def step_use_latents(
545
+ self,
546
+ model_output: torch.FloatTensor,
547
+ timestep: int,
548
+ sample: torch.FloatTensor,
549
+ eta: float = 0.0,
550
+ use_clipped_model_output: bool = False,
551
+ generator=None,
552
+ variance_noise: Optional[torch.FloatTensor] = None,
553
+ return_dict: bool = True,
554
+ ):
555
+ # timestep_index = ((self._save_timesteps == timestep).nonzero(as_tuple=True)[0]).item()
556
+ timestep_index = self._timesteps.index(timestep) if not self.clean_step_run else -1
557
+ next_timestep_index = (
558
+ timestep_index + 1 if not self.clean_step_run else -1
559
+ )
560
+ z_t = self.latents[next_timestep_index] # + 1 because latents[0] is X_T
561
+
562
+ _, normalize_coefficient = normalize(
563
+ z_t[0] if self._config.breakdown == "x_t_hat_c_with_zeros" else z_t,
564
+ timestep_index,
565
+ self._config.max_norm_zs,
566
+ )
567
+
568
+ if normalize_coefficient == 0:
569
+ eta = 0
570
+
571
+ # eta = normalize_coefficient
572
+
573
+ x_t_hat_c_hat = self.step_function(
574
+ model_output=model_output,
575
+ timestep=timestep,
576
+ sample=sample,
577
+ eta=eta,
578
+ use_clipped_model_output=use_clipped_model_output,
579
+ generator=generator,
580
+ variance_noise=variance_noise,
581
+ return_dict=False,
582
+ scheduler=self,
583
+ )
584
+
585
+ w1 = self._config.ws1[timestep_index]
586
+ w2 = self._config.ws2[timestep_index]
587
+
588
+ x_t_minus_1_exact = self.x_ts[next_timestep_index]
589
+ x_t_minus_1_exact = x_t_minus_1_exact.expand_as(x_t_hat_c_hat)
590
+
591
+ x_t_c_hat: torch.Tensor = self.x_ts_c_hat[next_timestep_index]
592
+ if self._config.breakdown == "x_t_c_hat":
593
+ raise NotImplementedError("breakdown x_t_c_hat not implemented yet")
594
+
595
+ # x_t_c_hat = x_t_c_hat.expand_as(x_t_hat_c_hat)
596
+ x_t_c = x_t_c_hat[0].expand_as(x_t_hat_c_hat)
597
+
598
+ # if self._config.breakdown == "x_t_c_hat":
599
+ # v1 = x_t_hat_c_hat - x_t_c_hat
600
+ # v2 = x_t_c_hat - x_t_c
601
+ if (
602
+ self._config.breakdown == "x_t_hat_c"
603
+ or self._config.breakdown == "x_t_hat_c_with_zeros"
604
+ ):
605
+ zero_index_reconstruction = 1 if not self.time_measure_n else 0
606
+ edit_prompts_num = (
607
+ (model_output.size(0) - zero_index_reconstruction) // 3
608
+ if self._config.breakdown == "x_t_hat_c_with_zeros" and not self.p_to_p
609
+ else (model_output.size(0) - zero_index_reconstruction) // 2
610
+ )
611
+ x_t_hat_c_indices = (zero_index_reconstruction, edit_prompts_num + zero_index_reconstruction)
612
+ edit_images_indices = (
613
+ edit_prompts_num + zero_index_reconstruction,
614
+ (
615
+ model_output.size(0)
616
+ if self._config.breakdown == "x_t_hat_c"
617
+ else zero_index_reconstruction + 2 * edit_prompts_num
618
+ ),
619
+ )
620
+ x_t_hat_c = torch.zeros_like(x_t_hat_c_hat)
621
+ x_t_hat_c[edit_images_indices[0] : edit_images_indices[1]] = x_t_hat_c_hat[
622
+ x_t_hat_c_indices[0] : x_t_hat_c_indices[1]
623
+ ]
624
+ v1 = x_t_hat_c_hat - x_t_hat_c
625
+ v2 = x_t_hat_c - normalize_coefficient * x_t_c
626
+ if self._config.breakdown == "x_t_hat_c_with_zeros" and not self.p_to_p:
627
+ path = os.path.join(
628
+ self.folder_name,
629
+ VECTOR_DATA_FOLDER,
630
+ self.image_name,
631
+ )
632
+ if not hasattr(self, VECTOR_DATA_DICT):
633
+ os.makedirs(path, exist_ok=True)
634
+ self.vector_data = dict()
635
+
636
+ x_t_0 = x_t_c_hat[1]
637
+ empty_prompt_indices = (1 + 2 * edit_prompts_num, 1 + 3 * edit_prompts_num)
638
+ x_t_hat_0 = x_t_hat_c_hat[empty_prompt_indices[0] : empty_prompt_indices[1]]
639
+
640
+ self.vector_data[timestep.item()] = dict()
641
+ self.vector_data[timestep.item()]["x_t_hat_c"] = x_t_hat_c[
642
+ edit_images_indices[0] : edit_images_indices[1]
643
+ ]
644
+ self.vector_data[timestep.item()]["x_t_hat_0"] = x_t_hat_0
645
+ self.vector_data[timestep.item()]["x_t_c"] = x_t_c[0].expand_as(x_t_hat_0)
646
+ self.vector_data[timestep.item()]["x_t_0"] = x_t_0.expand_as(x_t_hat_0)
647
+ self.vector_data[timestep.item()]["x_t_hat_c_hat"] = x_t_hat_c_hat[
648
+ edit_images_indices[0] : edit_images_indices[1]
649
+ ]
650
+ self.vector_data[timestep.item()]["x_t_minus_1_noisy"] = x_t_minus_1_exact[
651
+ 0
652
+ ].expand_as(x_t_hat_0)
653
+ self.vector_data[timestep.item()]["x_t_minus_1_clean"] = self.x_0s[
654
+ next_timestep_index
655
+ ].expand_as(x_t_hat_0)
656
+
657
+ else: # no breakdown
658
+ v1 = x_t_hat_c_hat - normalize_coefficient * x_t_c
659
+ v2 = 0
660
+
661
+ if self.save_intermediate_results and not self.p_to_p:
662
+ delta = v1 + v2
663
+ v1_plus_x0 = self.x_0s[next_timestep_index] + v1
664
+ v2_plus_x0 = self.x_0s[next_timestep_index] + v2
665
+ delta_plus_x0 = self.x_0s[next_timestep_index] + delta
666
+
667
+ v1_images = decode_latents(v1, self.pipe)
668
+ self.v1s_images.append(v1_images)
669
+ v2_images = (
670
+ decode_latents(v2, self.pipe)
671
+ if self._config.breakdown != "no_breakdown"
672
+ else [PIL.Image.new("RGB", (1, 1))]
673
+ )
674
+ self.v2s_images.append(v2_images)
675
+ delta_images = decode_latents(delta, self.pipe)
676
+ self.deltas_images.append(delta_images)
677
+ v1_plus_x0_images = decode_latents(v1_plus_x0, self.pipe)
678
+ self.v1_x0s.append(v1_plus_x0_images)
679
+ v2_plus_x0_images = (
680
+ decode_latents(v2_plus_x0, self.pipe)
681
+ if self._config.breakdown != "no_breakdown"
682
+ else [PIL.Image.new("RGB", (1, 1))]
683
+ )
684
+ self.v2_x0s.append(v2_plus_x0_images)
685
+ delta_plus_x0_images = decode_latents(delta_plus_x0, self.pipe)
686
+ self.deltas_x0s.append(delta_plus_x0_images)
687
+
688
+ # print(f"v1 norm: {torch.norm(v1, dim=0).mean()}")
689
+ # if self._config.breakdown != "no_breakdown":
690
+ # print(f"v2 norm: {torch.norm(v2, dim=0).mean()}")
691
+ # print(f"v sum norm: {torch.norm(v1 + v2, dim=0).mean()}")
692
+
693
+ x_t_minus_1 = normalize_coefficient * x_t_minus_1_exact + w1 * v1 + w2 * v2
694
+
695
+ if (
696
+ self._config.breakdown == "x_t_hat_c"
697
+ or self._config.breakdown == "x_t_hat_c_with_zeros"
698
+ ):
699
+ x_t_minus_1[x_t_hat_c_indices[0] : x_t_hat_c_indices[1]] = x_t_minus_1[
700
+ edit_images_indices[0] : edit_images_indices[1]
701
+ ] # update x_t_hat_c to be x_t_hat_c_hat
702
+ if self._config.breakdown == "x_t_hat_c_with_zeros" and not self.p_to_p:
703
+ x_t_minus_1[empty_prompt_indices[0] : empty_prompt_indices[1]] = (
704
+ x_t_minus_1[edit_images_indices[0] : edit_images_indices[1]]
705
+ )
706
+ self.vector_data[timestep.item()]["x_t_minus_1_edited"] = x_t_minus_1[
707
+ edit_images_indices[0] : edit_images_indices[1]
708
+ ]
709
+ if timestep == self._timesteps[-1]:
710
+ torch.save(
711
+ self.vector_data,
712
+ os.path.join(
713
+ path,
714
+ f"{VECTOR_DATA_DICT}.pt",
715
+ ),
716
+ )
717
+ # p_to_p_force_perfect_reconstruction
718
+ if not self.time_measure_n:
719
+ x_t_minus_1[0] = x_t_minus_1_exact[0]
720
+
721
+ if not return_dict:
722
+ return (x_t_minus_1,)
723
+
724
+ return DDIMSchedulerOutput(
725
+ prev_sample=x_t_minus_1,
726
+ pred_original_sample=None,
727
+ )
728
+
729
+
730
+
731
+ def get_ddpm_inversion_scheduler(
732
+ scheduler,
733
+ step_function,
734
+ config,
735
+ timesteps,
736
+ save_timesteps,
737
+ latents,
738
+ x_ts,
739
+ x_ts_c_hat,
740
+ save_intermediate_results,
741
+ pipe,
742
+ x_0,
743
+ v1s_images,
744
+ v2s_images,
745
+ deltas_images,
746
+ v1_x0s,
747
+ v2_x0s,
748
+ deltas_x0s,
749
+ folder_name,
750
+ image_name,
751
+ time_measure_n,
752
+ ):
753
+ def step(
754
+ model_output: torch.FloatTensor,
755
+ timestep: int,
756
+ sample: torch.FloatTensor,
757
+ eta: float = 0.0,
758
+ use_clipped_model_output: bool = False,
759
+ generator=None,
760
+ variance_noise: Optional[torch.FloatTensor] = None,
761
+ return_dict: bool = True,
762
+ ):
763
+ # if scheduler.is_save:
764
+ # start = timer()
765
+ res_inv = step_save_latents(
766
+ scheduler,
767
+ model_output[:1, :, :, :],
768
+ timestep,
769
+ sample[:1, :, :, :],
770
+ eta,
771
+ use_clipped_model_output,
772
+ generator,
773
+ variance_noise,
774
+ return_dict,
775
+ )
776
+ # end = timer()
777
+ # print(f"Run Time Inv: {end - start}")
778
+
779
+ res_inf = step_use_latents(
780
+ scheduler,
781
+ model_output[1:, :, :, :],
782
+ timestep,
783
+ sample[1:, :, :, :],
784
+ eta,
785
+ use_clipped_model_output,
786
+ generator,
787
+ variance_noise,
788
+ return_dict,
789
+ )
790
+ # res = res_inv
791
+ res = (torch.cat((res_inv[0], res_inf[0]), dim=0),)
792
+ return res
793
+ # return res
794
+
795
+ scheduler.step_function = step_function
796
+ scheduler.is_save = True
797
+ scheduler._timesteps = timesteps
798
+ scheduler._save_timesteps = save_timesteps if save_timesteps else timesteps
799
+ scheduler._config = config
800
+ scheduler.latents = latents
801
+ scheduler.x_ts = x_ts
802
+ scheduler.x_ts_c_hat = x_ts_c_hat
803
+ scheduler.step = step
804
+ scheduler.save_intermediate_results = save_intermediate_results
805
+ scheduler.pipe = pipe
806
+ scheduler.v1s_images = v1s_images
807
+ scheduler.v2s_images = v2s_images
808
+ scheduler.deltas_images = deltas_images
809
+ scheduler.v1_x0s = v1_x0s
810
+ scheduler.v2_x0s = v2_x0s
811
+ scheduler.deltas_x0s = deltas_x0s
812
+ scheduler.clean_step_run = False
813
+ scheduler.x_0s = create_xts(
814
+ config.noise_shift_delta,
815
+ config.noise_timesteps,
816
+ config.clean_step_timestep,
817
+ None,
818
+ pipe.scheduler,
819
+ timesteps,
820
+ x_0,
821
+ no_add_noise=True,
822
+ )
823
+ scheduler.folder_name = folder_name
824
+ scheduler.image_name = image_name
825
+ scheduler.p_to_p = False
826
+ scheduler.p_to_p_replace = False
827
+ scheduler.time_measure_n = time_measure_n
828
+ return scheduler
829
+
830
+
831
+ def create_grid(
832
+ images,
833
+ p_to_p_images,
834
+ prompts,
835
+ original_image_path,
836
+ ):
837
+ images_len = len(images) if len(images) > 0 else len(p_to_p_images)
838
+ images_size = images[0].size if len(images) > 0 else p_to_p_images[0].size
839
+ x_0 = Image.open(original_image_path).resize(images_size)
840
+
841
+ images_ = [x_0] + images + ([x_0] + p_to_p_images if p_to_p_images else [])
842
+
843
+ l1 = 1 if len(images) > 0 else 0
844
+ l2 = 1 if len(p_to_p_images) else 0
845
+ grid = make_image_grid(images_, rows=l1 + l2, cols=images_len + 1, resize=None)
846
+
847
+ width = images_size[0]
848
+ height = width // 5
849
+ font = ImageFont.truetype("font.ttf", width // 14)
850
+
851
+ grid1 = Image.new("RGB", size=(grid.size[0], grid.size[1] + height))
852
+ grid1.paste(grid, (0, 0))
853
+
854
+ draw = ImageDraw.Draw(grid1)
855
+
856
+ c_width = 0
857
+ for prompt in prompts:
858
+ if len(prompt) > 30:
859
+ prompt = prompt[:30] + "\n" + prompt[30:]
860
+ draw.text((c_width, width * 2), prompt, font=font, fill=(255, 255, 255))
861
+ c_width += width
862
+
863
+ return grid1
864
+
865
+
866
+ def save_intermediate_results(
867
+ v1s_images,
868
+ v2s_images,
869
+ deltas_images,
870
+ v1_x0s,
871
+ v2_x0s,
872
+ deltas_x0s,
873
+ folder_name,
874
+ original_prompt,
875
+ ):
876
+ from diffusers.utils import make_image_grid
877
+
878
+ path = f"{folder_name}/{original_prompt}_intermediate_results/"
879
+ os.makedirs(path, exist_ok=True)
880
+ make_image_grid(
881
+ list(itertools.chain(*v1s_images)),
882
+ rows=len(v1s_images),
883
+ cols=len(v1s_images[0]),
884
+ ).save(f"{path}v1s_images.png")
885
+ make_image_grid(
886
+ list(itertools.chain(*v2s_images)),
887
+ rows=len(v2s_images),
888
+ cols=len(v2s_images[0]),
889
+ ).save(f"{path}v2s_images.png")
890
+ make_image_grid(
891
+ list(itertools.chain(*deltas_images)),
892
+ rows=len(deltas_images),
893
+ cols=len(deltas_images[0]),
894
+ ).save(f"{path}deltas_images.png")
895
+ make_image_grid(
896
+ list(itertools.chain(*v1_x0s)),
897
+ rows=len(v1_x0s),
898
+ cols=len(v1_x0s[0]),
899
+ ).save(f"{path}v1_x0s.png")
900
+ make_image_grid(
901
+ list(itertools.chain(*v2_x0s)),
902
+ rows=len(v2_x0s),
903
+ cols=len(v2_x0s[0]),
904
+ ).save(f"{path}v2_x0s.png")
905
+ make_image_grid(
906
+ list(itertools.chain(*deltas_x0s)),
907
+ rows=len(deltas_x0s[0]),
908
+ cols=len(deltas_x0s),
909
+ ).save(f"{path}deltas_x0s.png")
910
+ for i, image in enumerate(list(itertools.chain(*deltas_x0s))):
911
+ image.save(f"{path}deltas_x0s_{i}.png")
912
+
913
+
914
+ # copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.py and removed the add_noise line
915
+ def prepare_latents_no_add_noise(
916
+ self,
917
+ image,
918
+ timestep,
919
+ batch_size,
920
+ num_images_per_prompt,
921
+ dtype,
922
+ device,
923
+ generator=None,
924
+ ):
925
+ from diffusers.utils import deprecate
926
+
927
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
928
+ raise ValueError(
929
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
930
+ )
931
+
932
+ image = image.to(device=device, dtype=dtype)
933
+
934
+ batch_size = batch_size * num_images_per_prompt
935
+
936
+ if image.shape[1] == 4:
937
+ init_latents = image
938
+
939
+ else:
940
+ if isinstance(generator, list) and len(generator) != batch_size:
941
+ raise ValueError(
942
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
943
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
944
+ )
945
+
946
+ elif isinstance(generator, list):
947
+ init_latents = [
948
+ self.retrieve_latents(
949
+ self.vae.encode(image[i : i + 1]), generator=generator[i]
950
+ )
951
+ for i in range(batch_size)
952
+ ]
953
+ init_latents = torch.cat(init_latents, dim=0)
954
+ else:
955
+ init_latents = self.retrieve_latents(
956
+ self.vae.encode(image), generator=generator
957
+ )
958
+
959
+ init_latents = self.vae.config.scaling_factor * init_latents
960
+
961
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
962
+ # expand init_latents for batch_size
963
+ deprecation_message = (
964
+ f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
965
+ " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
966
+ " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
967
+ " your script to pass as many initial images as text prompts to suppress this warning."
968
+ )
969
+ deprecate(
970
+ "len(prompt) != len(image)",
971
+ "1.0.0",
972
+ deprecation_message,
973
+ standard_warn=False,
974
+ )
975
+ additional_image_per_prompt = batch_size // init_latents.shape[0]
976
+ init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
977
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
978
+ raise ValueError(
979
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
980
+ )
981
+ else:
982
+ init_latents = torch.cat([init_latents], dim=0)
983
+
984
+ # get latents
985
+ latents = init_latents
986
+
987
+ return latents
988
+
989
+
990
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
991
+ def encode_prompt_empty_prompt_zeros_sdxl(
992
+ self,
993
+ prompt: str,
994
+ prompt_2: Optional[str] = None,
995
+ device: Optional[torch.device] = None,
996
+ num_images_per_prompt: int = 1,
997
+ do_classifier_free_guidance: bool = True,
998
+ negative_prompt: Optional[str] = None,
999
+ negative_prompt_2: Optional[str] = None,
1000
+ prompt_embeds: Optional[torch.FloatTensor] = None,
1001
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
1002
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
1003
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
1004
+ lora_scale: Optional[float] = None,
1005
+ clip_skip: Optional[int] = None,
1006
+ ):
1007
+ r"""
1008
+ Encodes the prompt into text encoder hidden states.
1009
+
1010
+ Args:
1011
+ prompt (`str` or `List[str]`, *optional*):
1012
+ prompt to be encoded
1013
+ prompt_2 (`str` or `List[str]`, *optional*):
1014
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
1015
+ used in both text-encoders
1016
+ device: (`torch.device`):
1017
+ torch device
1018
+ num_images_per_prompt (`int`):
1019
+ number of images that should be generated per prompt
1020
+ do_classifier_free_guidance (`bool`):
1021
+ whether to use classifier free guidance or not
1022
+ negative_prompt (`str` or `List[str]`, *optional*):
1023
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
1024
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
1025
+ less than `1`).
1026
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
1027
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
1028
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
1029
+ prompt_embeds (`torch.FloatTensor`, *optional*):
1030
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1031
+ provided, text embeddings will be generated from `prompt` input argument.
1032
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1033
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1034
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
1035
+ argument.
1036
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
1037
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
1038
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
1039
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
1040
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1041
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
1042
+ input argument.
1043
+ lora_scale (`float`, *optional*):
1044
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
1045
+ clip_skip (`int`, *optional*):
1046
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
1047
+ the output of the pre-final layer will be used for computing the prompt embeddings.
1048
+ """
1049
+ device = device or self._execution_device
1050
+
1051
+ # set lora scale so that monkey patched LoRA
1052
+ # function of text encoder can correctly access it
1053
+ if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
1054
+ self._lora_scale = lora_scale
1055
+
1056
+ # dynamically adjust the LoRA scale
1057
+ if self.text_encoder is not None:
1058
+ if not USE_PEFT_BACKEND:
1059
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
1060
+ else:
1061
+ scale_lora_layers(self.text_encoder, lora_scale)
1062
+
1063
+ if self.text_encoder_2 is not None:
1064
+ if not USE_PEFT_BACKEND:
1065
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
1066
+ else:
1067
+ scale_lora_layers(self.text_encoder_2, lora_scale)
1068
+
1069
+ prompt = [prompt] if isinstance(prompt, str) else prompt
1070
+
1071
+ if prompt is not None:
1072
+ batch_size = len(prompt)
1073
+ else:
1074
+ batch_size = prompt_embeds.shape[0]
1075
+
1076
+ # Define tokenizers and text encoders
1077
+ tokenizers = (
1078
+ [self.tokenizer, self.tokenizer_2]
1079
+ if self.tokenizer is not None
1080
+ else [self.tokenizer_2]
1081
+ )
1082
+ text_encoders = (
1083
+ [self.text_encoder, self.text_encoder_2]
1084
+ if self.text_encoder is not None
1085
+ else [self.text_encoder_2]
1086
+ )
1087
+
1088
+ if prompt_embeds is None:
1089
+ prompt_2 = prompt_2 or prompt
1090
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
1091
+
1092
+ # textual inversion: procecss multi-vector tokens if necessary
1093
+ prompt_embeds_list = []
1094
+ prompts = [prompt, prompt_2]
1095
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
1096
+
1097
+ text_inputs = tokenizer(
1098
+ prompt,
1099
+ padding="max_length",
1100
+ max_length=tokenizer.model_max_length,
1101
+ truncation=True,
1102
+ return_tensors="pt",
1103
+ )
1104
+
1105
+ text_input_ids = text_inputs.input_ids
1106
+ untruncated_ids = tokenizer(
1107
+ prompt, padding="longest", return_tensors="pt"
1108
+ ).input_ids
1109
+
1110
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[
1111
+ -1
1112
+ ] and not torch.equal(text_input_ids, untruncated_ids):
1113
+ removed_text = tokenizer.batch_decode(
1114
+ untruncated_ids[:, tokenizer.model_max_length - 1 : -1]
1115
+ )
1116
+ logger.warning(
1117
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
1118
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
1119
+ )
1120
+
1121
+ prompt_embeds = text_encoder(
1122
+ text_input_ids.to(device), output_hidden_states=True
1123
+ )
1124
+
1125
+ # We are only ALWAYS interested in the pooled output of the final text encoder
1126
+ pooled_prompt_embeds = prompt_embeds[0]
1127
+ if clip_skip is None:
1128
+ prompt_embeds = prompt_embeds.hidden_states[-2]
1129
+ else:
1130
+ # "2" because SDXL always indexes from the penultimate layer.
1131
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
1132
+
1133
+ if self.config.force_zeros_for_empty_prompt:
1134
+ prompt_embeds[[i for i in range(len(prompt)) if prompt[i] == ""]] = 0
1135
+ pooled_prompt_embeds[
1136
+ [i for i in range(len(prompt)) if prompt[i] == ""]
1137
+ ] = 0
1138
+
1139
+ prompt_embeds_list.append(prompt_embeds)
1140
+
1141
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
1142
+
1143
+ # get unconditional embeddings for classifier free guidance
1144
+ zero_out_negative_prompt = (
1145
+ negative_prompt is None and self.config.force_zeros_for_empty_prompt
1146
+ )
1147
+ if (
1148
+ do_classifier_free_guidance
1149
+ and negative_prompt_embeds is None
1150
+ and zero_out_negative_prompt
1151
+ ):
1152
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
1153
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
1154
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
1155
+ negative_prompt = negative_prompt or ""
1156
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
1157
+
1158
+ # normalize str to list
1159
+ negative_prompt = (
1160
+ batch_size * [negative_prompt]
1161
+ if isinstance(negative_prompt, str)
1162
+ else negative_prompt
1163
+ )
1164
+ negative_prompt_2 = (
1165
+ batch_size * [negative_prompt_2]
1166
+ if isinstance(negative_prompt_2, str)
1167
+ else negative_prompt_2
1168
+ )
1169
+
1170
+ uncond_tokens: List[str]
1171
+ if prompt is not None and type(prompt) is not type(negative_prompt):
1172
+ raise TypeError(
1173
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
1174
+ f" {type(prompt)}."
1175
+ )
1176
+ elif batch_size != len(negative_prompt):
1177
+ raise ValueError(
1178
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
1179
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
1180
+ " the batch size of `prompt`."
1181
+ )
1182
+ else:
1183
+ uncond_tokens = [negative_prompt, negative_prompt_2]
1184
+
1185
+ negative_prompt_embeds_list = []
1186
+ for negative_prompt, tokenizer, text_encoder in zip(
1187
+ uncond_tokens, tokenizers, text_encoders
1188
+ ):
1189
+
1190
+ max_length = prompt_embeds.shape[1]
1191
+ uncond_input = tokenizer(
1192
+ negative_prompt,
1193
+ padding="max_length",
1194
+ max_length=max_length,
1195
+ truncation=True,
1196
+ return_tensors="pt",
1197
+ )
1198
+
1199
+ negative_prompt_embeds = text_encoder(
1200
+ uncond_input.input_ids.to(device),
1201
+ output_hidden_states=True,
1202
+ )
1203
+ # We are only ALWAYS interested in the pooled output of the final text encoder
1204
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
1205
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
1206
+
1207
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
1208
+
1209
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
1210
+
1211
+ if self.text_encoder_2 is not None:
1212
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
1213
+ else:
1214
+ prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
1215
+
1216
+ bs_embed, seq_len, _ = prompt_embeds.shape
1217
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
1218
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
1219
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
1220
+
1221
+ if do_classifier_free_guidance:
1222
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
1223
+ seq_len = negative_prompt_embeds.shape[1]
1224
+
1225
+ if self.text_encoder_2 is not None:
1226
+ negative_prompt_embeds = negative_prompt_embeds.to(
1227
+ dtype=self.text_encoder_2.dtype, device=device
1228
+ )
1229
+ else:
1230
+ negative_prompt_embeds = negative_prompt_embeds.to(
1231
+ dtype=self.unet.dtype, device=device
1232
+ )
1233
+
1234
+ negative_prompt_embeds = negative_prompt_embeds.repeat(
1235
+ 1, num_images_per_prompt, 1
1236
+ )
1237
+ negative_prompt_embeds = negative_prompt_embeds.view(
1238
+ batch_size * num_images_per_prompt, seq_len, -1
1239
+ )
1240
+
1241
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
1242
+ bs_embed * num_images_per_prompt, -1
1243
+ )
1244
+ if do_classifier_free_guidance:
1245
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(
1246
+ 1, num_images_per_prompt
1247
+ ).view(bs_embed * num_images_per_prompt, -1)
1248
+
1249
+ if self.text_encoder is not None:
1250
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
1251
+ # Retrieve the original scale by scaling back the LoRA layers
1252
+ unscale_lora_layers(self.text_encoder, lora_scale)
1253
+
1254
+ if self.text_encoder_2 is not None:
1255
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
1256
+ # Retrieve the original scale by scaling back the LoRA layers
1257
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
1258
+
1259
+ return (
1260
+ prompt_embeds,
1261
+ negative_prompt_embeds,
1262
+ pooled_prompt_embeds,
1263
+ negative_pooled_prompt_embeds,
1264
+ )
1265
+
1266
+
1267
+ def create_xts(
1268
+ noise_shift_delta,
1269
+ noise_timesteps,
1270
+ clean_step_timestep,
1271
+ generator,
1272
+ scheduler,
1273
+ timesteps,
1274
+ x_0,
1275
+ no_add_noise=False,
1276
+ ):
1277
+ if noise_timesteps is None:
1278
+ noising_delta = noise_shift_delta * (timesteps[0] - timesteps[1])
1279
+ noise_timesteps = [timestep - int(noising_delta) for timestep in timesteps]
1280
+
1281
+ first_x_0_idx = len(noise_timesteps)
1282
+ for i in range(len(noise_timesteps)):
1283
+ if noise_timesteps[i] <= 0:
1284
+ first_x_0_idx = i
1285
+ break
1286
+
1287
+ noise_timesteps = noise_timesteps[:first_x_0_idx]
1288
+
1289
+ x_0_expanded = x_0.expand(len(noise_timesteps), -1, -1, -1)
1290
+ noise = (
1291
+ torch.randn(x_0_expanded.size(), generator=generator, device="cpu").to(
1292
+ x_0.device
1293
+ )
1294
+ if not no_add_noise
1295
+ else torch.zeros_like(x_0_expanded)
1296
+ )
1297
+ x_ts = scheduler.add_noise(
1298
+ x_0_expanded,
1299
+ noise,
1300
+ torch.IntTensor(noise_timesteps),
1301
+ )
1302
+ x_ts = [t.unsqueeze(dim=0) for t in list(x_ts)]
1303
+ x_ts += [x_0] * (len(timesteps) - first_x_0_idx)
1304
+ x_ts += [x_0]
1305
+ if clean_step_timestep > 0:
1306
+ x_ts += [x_0]
1307
+ return x_ts
1308
+
1309
+
1310
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
1311
+ def add_noise(
1312
+ self,
1313
+ original_samples: torch.FloatTensor,
1314
+ noise: torch.FloatTensor,
1315
+ image_timesteps: torch.IntTensor,
1316
+ noise_timesteps: torch.IntTensor,
1317
+ ) -> torch.FloatTensor:
1318
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
1319
+ # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
1320
+ # for the subsequent add_noise calls
1321
+ self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
1322
+ alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype)
1323
+ timesteps = timesteps.to(original_samples.device)
1324
+
1325
+ sqrt_alpha_prod = alphas_cumprod[image_timesteps] ** 0.5
1326
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
1327
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
1328
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
1329
+
1330
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[noise_timesteps]) ** 0.5
1331
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
1332
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
1333
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
1334
+
1335
+ noisy_samples = (
1336
+ sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
1337
+ )
1338
+ return noisy_samples
1339
+
1340
+
1341
+ def make_image_grid(
1342
+ images: List[PIL.Image.Image], rows: int, cols: int, resize: int = None, size=None
1343
+ ) -> PIL.Image.Image:
1344
+ """
1345
+ Prepares a single grid of images. Useful for visualization purposes.
1346
+ """
1347
+ assert len(images) == rows * cols
1348
+
1349
+ if resize is not None:
1350
+ images = [img.resize((resize, resize)) for img in images]
1351
+
1352
+ w, h = size
1353
+ grid = Image.new("RGB", size=(cols * w, rows * h))
1354
+
1355
+ for i, img in enumerate(images):
1356
+ grid.paste(img, box=(i % cols * w, i // cols * h))
1357
+ return grid