turboedit commited on
Commit
849c21d
โ€ข
1 Parent(s): 66cc9a6

Delete turbo_edit

Browse files
turbo_edit/run_configs/noise_shift_3_steps.yaml DELETED
@@ -1,19 +0,0 @@
1
- breakdown: "x_t_hat_c"
2
- cross_r: 0.9
3
- eta_reconstruct: 1
4
- eta_retrieve: 1
5
- max_norm_zs: [-1, -1, 15.5]
6
- model: "stabilityai/sdxl-turbo"
7
- noise_shift_delta: 1
8
- noise_timesteps: [599, 299, 0]
9
- timesteps: [799, 499, 199]
10
- num_steps_inversion: 5
11
- step_start: 1
12
- real_cfg_scale: 0
13
- real_cfg_scale_save: 0
14
- scheduler_type: "ddpm"
15
- seed: 2
16
- self_r: 0.5
17
- ws1: 1.5
18
- ws2: 1
19
- clean_step_timestep: 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
turbo_edit/run_configs/noise_shift_guidance_1_5.yaml DELETED
@@ -1,18 +0,0 @@
1
- breakdown: "x_t_hat_c"
2
- cross_r: 0.9
3
- eta: 1
4
- max_norm_zs: [-1, -1, -1, 15.5]
5
- model: "stabilityai/sdxl-turbo"
6
- noise_shift_delta: 1
7
- noise_timesteps: null
8
- num_steps_inversion: 5
9
- step_start: 1
10
- real_cfg_scale: 0
11
- real_cfg_scale_save: 0
12
- scheduler_type: "ddpm"
13
- seed: 2
14
- self_r: 0.5
15
- timesteps: null
16
- ws1: 1.5
17
- ws2: 1
18
- clean_step_timestep: 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
turbo_edit/utils.py DELETED
@@ -1,1357 +0,0 @@
1
- import itertools
2
- from typing import List, Optional, Union
3
- import PIL
4
- import PIL.Image
5
- import torch
6
- from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput
7
- from diffusers.utils import make_image_grid
8
- from PIL import Image, ImageDraw, ImageFont
9
- import os
10
- from diffusers.utils import (
11
- logging,
12
- USE_PEFT_BACKEND,
13
- scale_lora_layers,
14
- unscale_lora_layers,
15
- )
16
- from diffusers.loaders import (
17
- StableDiffusionXLLoraLoaderMixin,
18
- )
19
- from diffusers.image_processor import VaeImageProcessor
20
-
21
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
22
-
23
- from diffusers.models.lora import adjust_lora_scale_text_encoder
24
- from diffusers import DiffusionPipeline
25
-
26
-
27
- VECTOR_DATA_FOLDER = "vector_data"
28
- VECTOR_DATA_DICT = "vector_data"
29
-
30
-
31
- def encode_image(image: PIL.Image, pipe: DiffusionPipeline):
32
- pipe.image_processor: VaeImageProcessor = pipe.image_processor # type: ignore
33
- image = pipe.image_processor.pil_to_numpy(image)
34
- image = pipe.image_processor.numpy_to_pt(image)
35
- image = image.to(pipe.device)
36
- return (
37
- pipe.vae.encode(
38
- pipe.image_processor.preprocess(image),
39
- ).latent_dist.mode()
40
- * pipe.vae.config.scaling_factor
41
- )
42
-
43
-
44
- def decode_latents(latent, pipe):
45
- latent_img = pipe.vae.decode(
46
- latent / pipe.vae.config.scaling_factor, return_dict=False
47
- )[0]
48
- return pipe.image_processor.postprocess(latent_img, output_type="pil")
49
-
50
-
51
- def get_device(argv, args=None):
52
- import sys
53
-
54
- def debugger_is_active():
55
- return hasattr(sys, "gettrace") and sys.gettrace() is not None
56
-
57
- if args:
58
- return (
59
- torch.device("cuda")
60
- if (torch.cuda.is_available() and not debugger_is_active())
61
- and not args.force_use_cpu
62
- else torch.device("cpu")
63
- )
64
-
65
- return (
66
- torch.device("cuda")
67
- if (torch.cuda.is_available() and not debugger_is_active())
68
- and not "cpu" in set(argv[1:])
69
- else torch.device("cpu")
70
- )
71
-
72
-
73
- def deterministic_ddim_step(
74
- model_output: torch.FloatTensor,
75
- timestep: int,
76
- sample: torch.FloatTensor,
77
- eta: float = 0.0,
78
- use_clipped_model_output: bool = False,
79
- generator=None,
80
- variance_noise: Optional[torch.FloatTensor] = None,
81
- return_dict: bool = True,
82
- scheduler=None,
83
- ):
84
-
85
- if scheduler.num_inference_steps is None:
86
- raise ValueError(
87
- "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
88
- )
89
-
90
- prev_timestep = (
91
- timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps
92
- )
93
-
94
- # 2. compute alphas, betas
95
- alpha_prod_t = scheduler.alphas_cumprod[timestep]
96
- alpha_prod_t_prev = (
97
- scheduler.alphas_cumprod[prev_timestep]
98
- if prev_timestep >= 0
99
- else scheduler.final_alpha_cumprod
100
- )
101
-
102
- beta_prod_t = 1 - alpha_prod_t
103
-
104
- if scheduler.config.prediction_type == "epsilon":
105
- pred_original_sample = (
106
- sample - beta_prod_t ** (0.5) * model_output
107
- ) / alpha_prod_t ** (0.5)
108
- pred_epsilon = model_output
109
- elif scheduler.config.prediction_type == "sample":
110
- pred_original_sample = model_output
111
- pred_epsilon = (
112
- sample - alpha_prod_t ** (0.5) * pred_original_sample
113
- ) / beta_prod_t ** (0.5)
114
- elif scheduler.config.prediction_type == "v_prediction":
115
- pred_original_sample = (alpha_prod_t**0.5) * sample - (
116
- beta_prod_t**0.5
117
- ) * model_output
118
- pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
119
- else:
120
- raise ValueError(
121
- f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, `sample`, or"
122
- " `v_prediction`"
123
- )
124
-
125
- # 4. Clip or threshold "predicted x_0"
126
- if scheduler.config.thresholding:
127
- pred_original_sample = scheduler._threshold_sample(pred_original_sample)
128
- elif scheduler.config.clip_sample:
129
- pred_original_sample = pred_original_sample.clamp(
130
- -scheduler.config.clip_sample_range,
131
- scheduler.config.clip_sample_range,
132
- )
133
-
134
- # 5. compute variance: "sigma_t(ฮท)" -> see formula (16)
135
- # ฯƒ_t = sqrt((1 โˆ’ ฮฑ_tโˆ’1)/(1 โˆ’ ฮฑ_t)) * sqrt(1 โˆ’ ฮฑ_t/ฮฑ_tโˆ’1)
136
- variance = scheduler._get_variance(timestep, prev_timestep)
137
- std_dev_t = eta * variance ** (0.5)
138
-
139
- if use_clipped_model_output:
140
- # the pred_epsilon is always re-derived from the clipped x_0 in Glide
141
- pred_epsilon = (
142
- sample - alpha_prod_t ** (0.5) * pred_original_sample
143
- ) / beta_prod_t ** (0.5)
144
-
145
- # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
146
- pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (
147
- 0.5
148
- ) * pred_epsilon
149
-
150
- # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
151
- prev_sample = (
152
- alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
153
- )
154
- return prev_sample
155
-
156
-
157
- def deterministic_euler_step(
158
- model_output: torch.FloatTensor,
159
- timestep: Union[float, torch.FloatTensor],
160
- sample: torch.FloatTensor,
161
- eta,
162
- use_clipped_model_output,
163
- generator,
164
- variance_noise,
165
- return_dict,
166
- scheduler,
167
- ):
168
- """
169
- Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
170
- process from the learned model outputs (most often the predicted noise).
171
-
172
- Args:
173
- model_output (`torch.FloatTensor`):
174
- The direct output from learned diffusion model.
175
- timestep (`float`):
176
- The current discrete timestep in the diffusion chain.
177
- sample (`torch.FloatTensor`):
178
- A current instance of a sample created by the diffusion process.
179
- generator (`torch.Generator`, *optional*):
180
- A random number generator.
181
- return_dict (`bool`):
182
- Whether or not to return a
183
- [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple.
184
-
185
- Returns:
186
- [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
187
- If return_dict is `True`,
188
- [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned,
189
- otherwise a tuple is returned where the first element is the sample tensor.
190
-
191
- """
192
-
193
- if (
194
- isinstance(timestep, int)
195
- or isinstance(timestep, torch.IntTensor)
196
- or isinstance(timestep, torch.LongTensor)
197
- ):
198
- raise ValueError(
199
- (
200
- "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
201
- " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
202
- " one of the `scheduler.timesteps` as a timestep."
203
- ),
204
- )
205
-
206
- if scheduler.step_index is None:
207
- scheduler._init_step_index(timestep)
208
-
209
- sigma = scheduler.sigmas[scheduler.step_index]
210
-
211
- # Upcast to avoid precision issues when computing prev_sample
212
- sample = sample.to(torch.float32)
213
-
214
- # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
215
- if scheduler.config.prediction_type == "epsilon":
216
- pred_original_sample = sample - sigma * model_output
217
- elif scheduler.config.prediction_type == "v_prediction":
218
- # * c_out + input * c_skip
219
- pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (
220
- sample / (sigma**2 + 1)
221
- )
222
- elif scheduler.config.prediction_type == "sample":
223
- raise NotImplementedError("prediction_type not implemented yet: sample")
224
- else:
225
- raise ValueError(
226
- f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
227
- )
228
-
229
- sigma_from = scheduler.sigmas[scheduler.step_index]
230
- sigma_to = scheduler.sigmas[scheduler.step_index + 1]
231
- sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
232
- sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
233
-
234
- # 2. Convert to an ODE derivative
235
- derivative = (sample - pred_original_sample) / sigma
236
-
237
- dt = sigma_down - sigma
238
-
239
- prev_sample = sample + derivative * dt
240
-
241
- # Cast sample back to model compatible dtype
242
- prev_sample = prev_sample.to(model_output.dtype)
243
-
244
- # upon completion increase step index by one
245
- scheduler._step_index += 1
246
-
247
- return prev_sample
248
-
249
-
250
- def deterministic_non_ancestral_euler_step(
251
- model_output: torch.FloatTensor,
252
- timestep: Union[float, torch.FloatTensor],
253
- sample: torch.FloatTensor,
254
- eta: float = 0.0,
255
- use_clipped_model_output: bool = False,
256
- s_churn: float = 0.0,
257
- s_tmin: float = 0.0,
258
- s_tmax: float = float("inf"),
259
- s_noise: float = 1.0,
260
- generator: Optional[torch.Generator] = None,
261
- variance_noise: Optional[torch.FloatTensor] = None,
262
- return_dict: bool = True,
263
- scheduler=None,
264
- ):
265
- """
266
- Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
267
- process from the learned model outputs (most often the predicted noise).
268
-
269
- Args:
270
- model_output (`torch.FloatTensor`):
271
- The direct output from learned diffusion model.
272
- timestep (`float`):
273
- The current discrete timestep in the diffusion chain.
274
- sample (`torch.FloatTensor`):
275
- A current instance of a sample created by the diffusion process.
276
- s_churn (`float`):
277
- s_tmin (`float`):
278
- s_tmax (`float`):
279
- s_noise (`float`, defaults to 1.0):
280
- Scaling factor for noise added to the sample.
281
- generator (`torch.Generator`, *optional*):
282
- A random number generator.
283
- return_dict (`bool`):
284
- Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
285
- tuple.
286
-
287
- Returns:
288
- [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
289
- If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
290
- returned, otherwise a tuple is returned where the first element is the sample tensor.
291
- """
292
-
293
- if (
294
- isinstance(timestep, int)
295
- or isinstance(timestep, torch.IntTensor)
296
- or isinstance(timestep, torch.LongTensor)
297
- ):
298
- raise ValueError(
299
- (
300
- "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
301
- " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
302
- " one of the `scheduler.timesteps` as a timestep."
303
- ),
304
- )
305
-
306
- if not scheduler.is_scale_input_called:
307
- logger.warning(
308
- "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
309
- "See `StableDiffusionPipeline` for a usage example."
310
- )
311
-
312
- if scheduler.step_index is None:
313
- scheduler._init_step_index(timestep)
314
-
315
- # Upcast to avoid precision issues when computing prev_sample
316
- sample = sample.to(torch.float32)
317
-
318
- sigma = scheduler.sigmas[scheduler.step_index]
319
-
320
- gamma = (
321
- min(s_churn / (len(scheduler.sigmas) - 1), 2**0.5 - 1)
322
- if s_tmin <= sigma <= s_tmax
323
- else 0.0
324
- )
325
-
326
- sigma_hat = sigma * (gamma + 1)
327
-
328
- # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
329
- # NOTE: "original_sample" should not be an expected prediction_type but is left in for
330
- # backwards compatibility
331
- if (
332
- scheduler.config.prediction_type == "original_sample"
333
- or scheduler.config.prediction_type == "sample"
334
- ):
335
- pred_original_sample = model_output
336
- elif scheduler.config.prediction_type == "epsilon":
337
- pred_original_sample = sample - sigma_hat * model_output
338
- elif scheduler.config.prediction_type == "v_prediction":
339
- # denoised = model_output * c_out + input * c_skip
340
- pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (
341
- sample / (sigma**2 + 1)
342
- )
343
- else:
344
- raise ValueError(
345
- f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
346
- )
347
-
348
- # 2. Convert to an ODE derivative
349
- derivative = (sample - pred_original_sample) / sigma_hat
350
-
351
- dt = scheduler.sigmas[scheduler.step_index + 1] - sigma_hat
352
-
353
- prev_sample = sample + derivative * dt
354
-
355
- # Cast sample back to model compatible dtype
356
- prev_sample = prev_sample.to(model_output.dtype)
357
-
358
- # upon completion increase step index by one
359
- scheduler._step_index += 1
360
-
361
- return prev_sample
362
-
363
-
364
- def deterministic_ddpm_step(
365
- model_output: torch.FloatTensor,
366
- timestep: Union[float, torch.FloatTensor],
367
- sample: torch.FloatTensor,
368
- eta,
369
- use_clipped_model_output,
370
- generator,
371
- variance_noise,
372
- return_dict,
373
- scheduler,
374
- ):
375
- """
376
- Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
377
- process from the learned model outputs (most often the predicted noise).
378
-
379
- Args:
380
- model_output (`torch.FloatTensor`):
381
- The direct output from learned diffusion model.
382
- timestep (`float`):
383
- The current discrete timestep in the diffusion chain.
384
- sample (`torch.FloatTensor`):
385
- A current instance of a sample created by the diffusion process.
386
- generator (`torch.Generator`, *optional*):
387
- A random number generator.
388
- return_dict (`bool`, *optional*, defaults to `True`):
389
- Whether or not to return a [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`.
390
-
391
- Returns:
392
- [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`:
393
- If return_dict is `True`, [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] is returned, otherwise a
394
- tuple is returned where the first element is the sample tensor.
395
-
396
- """
397
- t = timestep
398
-
399
- prev_t = scheduler.previous_timestep(t)
400
-
401
- if model_output.shape[1] == sample.shape[1] * 2 and scheduler.variance_type in [
402
- "learned",
403
- "learned_range",
404
- ]:
405
- model_output, predicted_variance = torch.split(
406
- model_output, sample.shape[1], dim=1
407
- )
408
- else:
409
- predicted_variance = None
410
-
411
- # 1. compute alphas, betas
412
- alpha_prod_t = scheduler.alphas_cumprod[t]
413
- alpha_prod_t_prev = (
414
- scheduler.alphas_cumprod[prev_t] if prev_t >= 0 else scheduler.one
415
- )
416
- beta_prod_t = 1 - alpha_prod_t
417
- beta_prod_t_prev = 1 - alpha_prod_t_prev
418
- current_alpha_t = alpha_prod_t / alpha_prod_t_prev
419
- current_beta_t = 1 - current_alpha_t
420
-
421
- # 2. compute predicted original sample from predicted noise also called
422
- # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
423
- if scheduler.config.prediction_type == "epsilon":
424
- pred_original_sample = (
425
- sample - beta_prod_t ** (0.5) * model_output
426
- ) / alpha_prod_t ** (0.5)
427
- elif scheduler.config.prediction_type == "sample":
428
- pred_original_sample = model_output
429
- elif scheduler.config.prediction_type == "v_prediction":
430
- pred_original_sample = (alpha_prod_t**0.5) * sample - (
431
- beta_prod_t**0.5
432
- ) * model_output
433
- else:
434
- raise ValueError(
435
- f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, `sample` or"
436
- " `v_prediction` for the DDPMScheduler."
437
- )
438
-
439
- # 3. Clip or threshold "predicted x_0"
440
- if scheduler.config.thresholding:
441
- pred_original_sample = scheduler._threshold_sample(pred_original_sample)
442
- elif scheduler.config.clip_sample:
443
- pred_original_sample = pred_original_sample.clamp(
444
- -scheduler.config.clip_sample_range, scheduler.config.clip_sample_range
445
- )
446
-
447
- # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
448
- # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
449
- pred_original_sample_coeff = (
450
- alpha_prod_t_prev ** (0.5) * current_beta_t
451
- ) / beta_prod_t
452
- current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t
453
-
454
- # 5. Compute predicted previous sample ยต_t
455
- # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
456
- pred_prev_sample = (
457
- pred_original_sample_coeff * pred_original_sample
458
- + current_sample_coeff * sample
459
- )
460
-
461
- return pred_prev_sample
462
-
463
-
464
- def normalize(
465
- z_t,
466
- i,
467
- max_norm_zs,
468
- ):
469
- max_norm = max_norm_zs[i]
470
- if max_norm < 0:
471
- return z_t, 1
472
-
473
- norm = torch.norm(z_t)
474
- if norm < max_norm:
475
- return z_t, 1
476
-
477
- coeff = max_norm / norm
478
- z_t = z_t * coeff
479
- return z_t, coeff
480
-
481
-
482
- def find_index(timesteps, timestep):
483
- for i, t in enumerate(timesteps):
484
- if t == timestep:
485
- return i
486
- return -1
487
-
488
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
489
- map_timpstep_to_index = {
490
- torch.tensor(799): 0,
491
- torch.tensor(599): 1,
492
- torch.tensor(399): 2,
493
- torch.tensor(199): 3,
494
- torch.tensor(799, device=device): 0,
495
- torch.tensor(599, device=device): 1,
496
- torch.tensor(399, device=device): 2,
497
- torch.tensor(199, device=device): 3,
498
- }
499
-
500
- def step_save_latents(
501
- self,
502
- model_output: torch.FloatTensor,
503
- timestep: int,
504
- sample: torch.FloatTensor,
505
- eta: float = 0.0,
506
- use_clipped_model_output: bool = False,
507
- generator=None,
508
- variance_noise: Optional[torch.FloatTensor] = None,
509
- return_dict: bool = True,
510
- ):
511
- # print(self._save_timesteps)
512
- # timestep_index = map_timpstep_to_index[timestep]
513
- # timestep_index = ((self._save_timesteps == timestep).nonzero(as_tuple=True)[0]).item()
514
- timestep_index = self._save_timesteps.index(timestep) if not self.clean_step_run else -1
515
- next_timestep_index = timestep_index + 1 if not self.clean_step_run else -1
516
- u_hat_t = self.step_function(
517
- model_output=model_output,
518
- timestep=timestep,
519
- sample=sample,
520
- eta=eta,
521
- use_clipped_model_output=use_clipped_model_output,
522
- generator=generator,
523
- variance_noise=variance_noise,
524
- return_dict=False,
525
- scheduler=self,
526
- )
527
-
528
- x_t_minus_1 = self.x_ts[next_timestep_index]
529
- self.x_ts_c_hat.append(u_hat_t)
530
-
531
- z_t = x_t_minus_1 - u_hat_t
532
- self.latents.append(z_t)
533
-
534
- z_t, _ = normalize(z_t, timestep_index, self._config.max_norm_zs)
535
-
536
- x_t_minus_1_predicted = u_hat_t + z_t
537
-
538
- if not return_dict:
539
- return (x_t_minus_1_predicted,)
540
-
541
- return DDIMSchedulerOutput(prev_sample=x_t_minus_1, pred_original_sample=None)
542
-
543
-
544
- def step_use_latents(
545
- self,
546
- model_output: torch.FloatTensor,
547
- timestep: int,
548
- sample: torch.FloatTensor,
549
- eta: float = 0.0,
550
- use_clipped_model_output: bool = False,
551
- generator=None,
552
- variance_noise: Optional[torch.FloatTensor] = None,
553
- return_dict: bool = True,
554
- ):
555
- # timestep_index = ((self._save_timesteps == timestep).nonzero(as_tuple=True)[0]).item()
556
- timestep_index = self._timesteps.index(timestep) if not self.clean_step_run else -1
557
- next_timestep_index = (
558
- timestep_index + 1 if not self.clean_step_run else -1
559
- )
560
- z_t = self.latents[next_timestep_index] # + 1 because latents[0] is X_T
561
-
562
- _, normalize_coefficient = normalize(
563
- z_t[0] if self._config.breakdown == "x_t_hat_c_with_zeros" else z_t,
564
- timestep_index,
565
- self._config.max_norm_zs,
566
- )
567
-
568
- if normalize_coefficient == 0:
569
- eta = 0
570
-
571
- # eta = normalize_coefficient
572
-
573
- x_t_hat_c_hat = self.step_function(
574
- model_output=model_output,
575
- timestep=timestep,
576
- sample=sample,
577
- eta=eta,
578
- use_clipped_model_output=use_clipped_model_output,
579
- generator=generator,
580
- variance_noise=variance_noise,
581
- return_dict=False,
582
- scheduler=self,
583
- )
584
-
585
- w1 = self._config.ws1[timestep_index]
586
- w2 = self._config.ws2[timestep_index]
587
-
588
- x_t_minus_1_exact = self.x_ts[next_timestep_index]
589
- x_t_minus_1_exact = x_t_minus_1_exact.expand_as(x_t_hat_c_hat)
590
-
591
- x_t_c_hat: torch.Tensor = self.x_ts_c_hat[next_timestep_index]
592
- if self._config.breakdown == "x_t_c_hat":
593
- raise NotImplementedError("breakdown x_t_c_hat not implemented yet")
594
-
595
- # x_t_c_hat = x_t_c_hat.expand_as(x_t_hat_c_hat)
596
- x_t_c = x_t_c_hat[0].expand_as(x_t_hat_c_hat)
597
-
598
- # if self._config.breakdown == "x_t_c_hat":
599
- # v1 = x_t_hat_c_hat - x_t_c_hat
600
- # v2 = x_t_c_hat - x_t_c
601
- if (
602
- self._config.breakdown == "x_t_hat_c"
603
- or self._config.breakdown == "x_t_hat_c_with_zeros"
604
- ):
605
- zero_index_reconstruction = 1 if not self.time_measure_n else 0
606
- edit_prompts_num = (
607
- (model_output.size(0) - zero_index_reconstruction) // 3
608
- if self._config.breakdown == "x_t_hat_c_with_zeros" and not self.p_to_p
609
- else (model_output.size(0) - zero_index_reconstruction) // 2
610
- )
611
- x_t_hat_c_indices = (zero_index_reconstruction, edit_prompts_num + zero_index_reconstruction)
612
- edit_images_indices = (
613
- edit_prompts_num + zero_index_reconstruction,
614
- (
615
- model_output.size(0)
616
- if self._config.breakdown == "x_t_hat_c"
617
- else zero_index_reconstruction + 2 * edit_prompts_num
618
- ),
619
- )
620
- x_t_hat_c = torch.zeros_like(x_t_hat_c_hat)
621
- x_t_hat_c[edit_images_indices[0] : edit_images_indices[1]] = x_t_hat_c_hat[
622
- x_t_hat_c_indices[0] : x_t_hat_c_indices[1]
623
- ]
624
- v1 = x_t_hat_c_hat - x_t_hat_c
625
- v2 = x_t_hat_c - normalize_coefficient * x_t_c
626
- if self._config.breakdown == "x_t_hat_c_with_zeros" and not self.p_to_p:
627
- path = os.path.join(
628
- self.folder_name,
629
- VECTOR_DATA_FOLDER,
630
- self.image_name,
631
- )
632
- if not hasattr(self, VECTOR_DATA_DICT):
633
- os.makedirs(path, exist_ok=True)
634
- self.vector_data = dict()
635
-
636
- x_t_0 = x_t_c_hat[1]
637
- empty_prompt_indices = (1 + 2 * edit_prompts_num, 1 + 3 * edit_prompts_num)
638
- x_t_hat_0 = x_t_hat_c_hat[empty_prompt_indices[0] : empty_prompt_indices[1]]
639
-
640
- self.vector_data[timestep.item()] = dict()
641
- self.vector_data[timestep.item()]["x_t_hat_c"] = x_t_hat_c[
642
- edit_images_indices[0] : edit_images_indices[1]
643
- ]
644
- self.vector_data[timestep.item()]["x_t_hat_0"] = x_t_hat_0
645
- self.vector_data[timestep.item()]["x_t_c"] = x_t_c[0].expand_as(x_t_hat_0)
646
- self.vector_data[timestep.item()]["x_t_0"] = x_t_0.expand_as(x_t_hat_0)
647
- self.vector_data[timestep.item()]["x_t_hat_c_hat"] = x_t_hat_c_hat[
648
- edit_images_indices[0] : edit_images_indices[1]
649
- ]
650
- self.vector_data[timestep.item()]["x_t_minus_1_noisy"] = x_t_minus_1_exact[
651
- 0
652
- ].expand_as(x_t_hat_0)
653
- self.vector_data[timestep.item()]["x_t_minus_1_clean"] = self.x_0s[
654
- next_timestep_index
655
- ].expand_as(x_t_hat_0)
656
-
657
- else: # no breakdown
658
- v1 = x_t_hat_c_hat - normalize_coefficient * x_t_c
659
- v2 = 0
660
-
661
- if self.save_intermediate_results and not self.p_to_p:
662
- delta = v1 + v2
663
- v1_plus_x0 = self.x_0s[next_timestep_index] + v1
664
- v2_plus_x0 = self.x_0s[next_timestep_index] + v2
665
- delta_plus_x0 = self.x_0s[next_timestep_index] + delta
666
-
667
- v1_images = decode_latents(v1, self.pipe)
668
- self.v1s_images.append(v1_images)
669
- v2_images = (
670
- decode_latents(v2, self.pipe)
671
- if self._config.breakdown != "no_breakdown"
672
- else [PIL.Image.new("RGB", (1, 1))]
673
- )
674
- self.v2s_images.append(v2_images)
675
- delta_images = decode_latents(delta, self.pipe)
676
- self.deltas_images.append(delta_images)
677
- v1_plus_x0_images = decode_latents(v1_plus_x0, self.pipe)
678
- self.v1_x0s.append(v1_plus_x0_images)
679
- v2_plus_x0_images = (
680
- decode_latents(v2_plus_x0, self.pipe)
681
- if self._config.breakdown != "no_breakdown"
682
- else [PIL.Image.new("RGB", (1, 1))]
683
- )
684
- self.v2_x0s.append(v2_plus_x0_images)
685
- delta_plus_x0_images = decode_latents(delta_plus_x0, self.pipe)
686
- self.deltas_x0s.append(delta_plus_x0_images)
687
-
688
- # print(f"v1 norm: {torch.norm(v1, dim=0).mean()}")
689
- # if self._config.breakdown != "no_breakdown":
690
- # print(f"v2 norm: {torch.norm(v2, dim=0).mean()}")
691
- # print(f"v sum norm: {torch.norm(v1 + v2, dim=0).mean()}")
692
-
693
- x_t_minus_1 = normalize_coefficient * x_t_minus_1_exact + w1 * v1 + w2 * v2
694
-
695
- if (
696
- self._config.breakdown == "x_t_hat_c"
697
- or self._config.breakdown == "x_t_hat_c_with_zeros"
698
- ):
699
- x_t_minus_1[x_t_hat_c_indices[0] : x_t_hat_c_indices[1]] = x_t_minus_1[
700
- edit_images_indices[0] : edit_images_indices[1]
701
- ] # update x_t_hat_c to be x_t_hat_c_hat
702
- if self._config.breakdown == "x_t_hat_c_with_zeros" and not self.p_to_p:
703
- x_t_minus_1[empty_prompt_indices[0] : empty_prompt_indices[1]] = (
704
- x_t_minus_1[edit_images_indices[0] : edit_images_indices[1]]
705
- )
706
- self.vector_data[timestep.item()]["x_t_minus_1_edited"] = x_t_minus_1[
707
- edit_images_indices[0] : edit_images_indices[1]
708
- ]
709
- if timestep == self._timesteps[-1]:
710
- torch.save(
711
- self.vector_data,
712
- os.path.join(
713
- path,
714
- f"{VECTOR_DATA_DICT}.pt",
715
- ),
716
- )
717
- # p_to_p_force_perfect_reconstruction
718
- if not self.time_measure_n:
719
- x_t_minus_1[0] = x_t_minus_1_exact[0]
720
-
721
- if not return_dict:
722
- return (x_t_minus_1,)
723
-
724
- return DDIMSchedulerOutput(
725
- prev_sample=x_t_minus_1,
726
- pred_original_sample=None,
727
- )
728
-
729
-
730
-
731
- def get_ddpm_inversion_scheduler(
732
- scheduler,
733
- step_function,
734
- config,
735
- timesteps,
736
- save_timesteps,
737
- latents,
738
- x_ts,
739
- x_ts_c_hat,
740
- save_intermediate_results,
741
- pipe,
742
- x_0,
743
- v1s_images,
744
- v2s_images,
745
- deltas_images,
746
- v1_x0s,
747
- v2_x0s,
748
- deltas_x0s,
749
- folder_name,
750
- image_name,
751
- time_measure_n,
752
- ):
753
- def step(
754
- model_output: torch.FloatTensor,
755
- timestep: int,
756
- sample: torch.FloatTensor,
757
- eta: float = 0.0,
758
- use_clipped_model_output: bool = False,
759
- generator=None,
760
- variance_noise: Optional[torch.FloatTensor] = None,
761
- return_dict: bool = True,
762
- ):
763
- # if scheduler.is_save:
764
- # start = timer()
765
- res_inv = step_save_latents(
766
- scheduler,
767
- model_output[:1, :, :, :],
768
- timestep,
769
- sample[:1, :, :, :],
770
- eta,
771
- use_clipped_model_output,
772
- generator,
773
- variance_noise,
774
- return_dict,
775
- )
776
- # end = timer()
777
- # print(f"Run Time Inv: {end - start}")
778
-
779
- res_inf = step_use_latents(
780
- scheduler,
781
- model_output[1:, :, :, :],
782
- timestep,
783
- sample[1:, :, :, :],
784
- eta,
785
- use_clipped_model_output,
786
- generator,
787
- variance_noise,
788
- return_dict,
789
- )
790
- # res = res_inv
791
- res = (torch.cat((res_inv[0], res_inf[0]), dim=0),)
792
- return res
793
- # return res
794
-
795
- scheduler.step_function = step_function
796
- scheduler.is_save = True
797
- scheduler._timesteps = timesteps
798
- scheduler._save_timesteps = save_timesteps if save_timesteps else timesteps
799
- scheduler._config = config
800
- scheduler.latents = latents
801
- scheduler.x_ts = x_ts
802
- scheduler.x_ts_c_hat = x_ts_c_hat
803
- scheduler.step = step
804
- scheduler.save_intermediate_results = save_intermediate_results
805
- scheduler.pipe = pipe
806
- scheduler.v1s_images = v1s_images
807
- scheduler.v2s_images = v2s_images
808
- scheduler.deltas_images = deltas_images
809
- scheduler.v1_x0s = v1_x0s
810
- scheduler.v2_x0s = v2_x0s
811
- scheduler.deltas_x0s = deltas_x0s
812
- scheduler.clean_step_run = False
813
- scheduler.x_0s = create_xts(
814
- config.noise_shift_delta,
815
- config.noise_timesteps,
816
- config.clean_step_timestep,
817
- None,
818
- pipe.scheduler,
819
- timesteps,
820
- x_0,
821
- no_add_noise=True,
822
- )
823
- scheduler.folder_name = folder_name
824
- scheduler.image_name = image_name
825
- scheduler.p_to_p = False
826
- scheduler.p_to_p_replace = False
827
- scheduler.time_measure_n = time_measure_n
828
- return scheduler
829
-
830
-
831
- def create_grid(
832
- images,
833
- p_to_p_images,
834
- prompts,
835
- original_image_path,
836
- ):
837
- images_len = len(images) if len(images) > 0 else len(p_to_p_images)
838
- images_size = images[0].size if len(images) > 0 else p_to_p_images[0].size
839
- x_0 = Image.open(original_image_path).resize(images_size)
840
-
841
- images_ = [x_0] + images + ([x_0] + p_to_p_images if p_to_p_images else [])
842
-
843
- l1 = 1 if len(images) > 0 else 0
844
- l2 = 1 if len(p_to_p_images) else 0
845
- grid = make_image_grid(images_, rows=l1 + l2, cols=images_len + 1, resize=None)
846
-
847
- width = images_size[0]
848
- height = width // 5
849
- font = ImageFont.truetype("font.ttf", width // 14)
850
-
851
- grid1 = Image.new("RGB", size=(grid.size[0], grid.size[1] + height))
852
- grid1.paste(grid, (0, 0))
853
-
854
- draw = ImageDraw.Draw(grid1)
855
-
856
- c_width = 0
857
- for prompt in prompts:
858
- if len(prompt) > 30:
859
- prompt = prompt[:30] + "\n" + prompt[30:]
860
- draw.text((c_width, width * 2), prompt, font=font, fill=(255, 255, 255))
861
- c_width += width
862
-
863
- return grid1
864
-
865
-
866
- def save_intermediate_results(
867
- v1s_images,
868
- v2s_images,
869
- deltas_images,
870
- v1_x0s,
871
- v2_x0s,
872
- deltas_x0s,
873
- folder_name,
874
- original_prompt,
875
- ):
876
- from diffusers.utils import make_image_grid
877
-
878
- path = f"{folder_name}/{original_prompt}_intermediate_results/"
879
- os.makedirs(path, exist_ok=True)
880
- make_image_grid(
881
- list(itertools.chain(*v1s_images)),
882
- rows=len(v1s_images),
883
- cols=len(v1s_images[0]),
884
- ).save(f"{path}v1s_images.png")
885
- make_image_grid(
886
- list(itertools.chain(*v2s_images)),
887
- rows=len(v2s_images),
888
- cols=len(v2s_images[0]),
889
- ).save(f"{path}v2s_images.png")
890
- make_image_grid(
891
- list(itertools.chain(*deltas_images)),
892
- rows=len(deltas_images),
893
- cols=len(deltas_images[0]),
894
- ).save(f"{path}deltas_images.png")
895
- make_image_grid(
896
- list(itertools.chain(*v1_x0s)),
897
- rows=len(v1_x0s),
898
- cols=len(v1_x0s[0]),
899
- ).save(f"{path}v1_x0s.png")
900
- make_image_grid(
901
- list(itertools.chain(*v2_x0s)),
902
- rows=len(v2_x0s),
903
- cols=len(v2_x0s[0]),
904
- ).save(f"{path}v2_x0s.png")
905
- make_image_grid(
906
- list(itertools.chain(*deltas_x0s)),
907
- rows=len(deltas_x0s[0]),
908
- cols=len(deltas_x0s),
909
- ).save(f"{path}deltas_x0s.png")
910
- for i, image in enumerate(list(itertools.chain(*deltas_x0s))):
911
- image.save(f"{path}deltas_x0s_{i}.png")
912
-
913
-
914
- # copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.py and removed the add_noise line
915
- def prepare_latents_no_add_noise(
916
- self,
917
- image,
918
- timestep,
919
- batch_size,
920
- num_images_per_prompt,
921
- dtype,
922
- device,
923
- generator=None,
924
- ):
925
- from diffusers.utils import deprecate
926
-
927
- if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
928
- raise ValueError(
929
- f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
930
- )
931
-
932
- image = image.to(device=device, dtype=dtype)
933
-
934
- batch_size = batch_size * num_images_per_prompt
935
-
936
- if image.shape[1] == 4:
937
- init_latents = image
938
-
939
- else:
940
- if isinstance(generator, list) and len(generator) != batch_size:
941
- raise ValueError(
942
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
943
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
944
- )
945
-
946
- elif isinstance(generator, list):
947
- init_latents = [
948
- self.retrieve_latents(
949
- self.vae.encode(image[i : i + 1]), generator=generator[i]
950
- )
951
- for i in range(batch_size)
952
- ]
953
- init_latents = torch.cat(init_latents, dim=0)
954
- else:
955
- init_latents = self.retrieve_latents(
956
- self.vae.encode(image), generator=generator
957
- )
958
-
959
- init_latents = self.vae.config.scaling_factor * init_latents
960
-
961
- if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
962
- # expand init_latents for batch_size
963
- deprecation_message = (
964
- f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
965
- " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
966
- " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
967
- " your script to pass as many initial images as text prompts to suppress this warning."
968
- )
969
- deprecate(
970
- "len(prompt) != len(image)",
971
- "1.0.0",
972
- deprecation_message,
973
- standard_warn=False,
974
- )
975
- additional_image_per_prompt = batch_size // init_latents.shape[0]
976
- init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
977
- elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
978
- raise ValueError(
979
- f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
980
- )
981
- else:
982
- init_latents = torch.cat([init_latents], dim=0)
983
-
984
- # get latents
985
- latents = init_latents
986
-
987
- return latents
988
-
989
-
990
- # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
991
- def encode_prompt_empty_prompt_zeros_sdxl(
992
- self,
993
- prompt: str,
994
- prompt_2: Optional[str] = None,
995
- device: Optional[torch.device] = None,
996
- num_images_per_prompt: int = 1,
997
- do_classifier_free_guidance: bool = True,
998
- negative_prompt: Optional[str] = None,
999
- negative_prompt_2: Optional[str] = None,
1000
- prompt_embeds: Optional[torch.FloatTensor] = None,
1001
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
1002
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
1003
- negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
1004
- lora_scale: Optional[float] = None,
1005
- clip_skip: Optional[int] = None,
1006
- ):
1007
- r"""
1008
- Encodes the prompt into text encoder hidden states.
1009
-
1010
- Args:
1011
- prompt (`str` or `List[str]`, *optional*):
1012
- prompt to be encoded
1013
- prompt_2 (`str` or `List[str]`, *optional*):
1014
- The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
1015
- used in both text-encoders
1016
- device: (`torch.device`):
1017
- torch device
1018
- num_images_per_prompt (`int`):
1019
- number of images that should be generated per prompt
1020
- do_classifier_free_guidance (`bool`):
1021
- whether to use classifier free guidance or not
1022
- negative_prompt (`str` or `List[str]`, *optional*):
1023
- The prompt or prompts not to guide the image generation. If not defined, one has to pass
1024
- `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
1025
- less than `1`).
1026
- negative_prompt_2 (`str` or `List[str]`, *optional*):
1027
- The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
1028
- `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
1029
- prompt_embeds (`torch.FloatTensor`, *optional*):
1030
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1031
- provided, text embeddings will be generated from `prompt` input argument.
1032
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1033
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1034
- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
1035
- argument.
1036
- pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
1037
- Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
1038
- If not provided, pooled text embeddings will be generated from `prompt` input argument.
1039
- negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
1040
- Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1041
- weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
1042
- input argument.
1043
- lora_scale (`float`, *optional*):
1044
- A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
1045
- clip_skip (`int`, *optional*):
1046
- Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
1047
- the output of the pre-final layer will be used for computing the prompt embeddings.
1048
- """
1049
- device = device or self._execution_device
1050
-
1051
- # set lora scale so that monkey patched LoRA
1052
- # function of text encoder can correctly access it
1053
- if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
1054
- self._lora_scale = lora_scale
1055
-
1056
- # dynamically adjust the LoRA scale
1057
- if self.text_encoder is not None:
1058
- if not USE_PEFT_BACKEND:
1059
- adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
1060
- else:
1061
- scale_lora_layers(self.text_encoder, lora_scale)
1062
-
1063
- if self.text_encoder_2 is not None:
1064
- if not USE_PEFT_BACKEND:
1065
- adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
1066
- else:
1067
- scale_lora_layers(self.text_encoder_2, lora_scale)
1068
-
1069
- prompt = [prompt] if isinstance(prompt, str) else prompt
1070
-
1071
- if prompt is not None:
1072
- batch_size = len(prompt)
1073
- else:
1074
- batch_size = prompt_embeds.shape[0]
1075
-
1076
- # Define tokenizers and text encoders
1077
- tokenizers = (
1078
- [self.tokenizer, self.tokenizer_2]
1079
- if self.tokenizer is not None
1080
- else [self.tokenizer_2]
1081
- )
1082
- text_encoders = (
1083
- [self.text_encoder, self.text_encoder_2]
1084
- if self.text_encoder is not None
1085
- else [self.text_encoder_2]
1086
- )
1087
-
1088
- if prompt_embeds is None:
1089
- prompt_2 = prompt_2 or prompt
1090
- prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
1091
-
1092
- # textual inversion: procecss multi-vector tokens if necessary
1093
- prompt_embeds_list = []
1094
- prompts = [prompt, prompt_2]
1095
- for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
1096
-
1097
- text_inputs = tokenizer(
1098
- prompt,
1099
- padding="max_length",
1100
- max_length=tokenizer.model_max_length,
1101
- truncation=True,
1102
- return_tensors="pt",
1103
- )
1104
-
1105
- text_input_ids = text_inputs.input_ids
1106
- untruncated_ids = tokenizer(
1107
- prompt, padding="longest", return_tensors="pt"
1108
- ).input_ids
1109
-
1110
- if untruncated_ids.shape[-1] >= text_input_ids.shape[
1111
- -1
1112
- ] and not torch.equal(text_input_ids, untruncated_ids):
1113
- removed_text = tokenizer.batch_decode(
1114
- untruncated_ids[:, tokenizer.model_max_length - 1 : -1]
1115
- )
1116
- logger.warning(
1117
- "The following part of your input was truncated because CLIP can only handle sequences up to"
1118
- f" {tokenizer.model_max_length} tokens: {removed_text}"
1119
- )
1120
-
1121
- prompt_embeds = text_encoder(
1122
- text_input_ids.to(device), output_hidden_states=True
1123
- )
1124
-
1125
- # We are only ALWAYS interested in the pooled output of the final text encoder
1126
- pooled_prompt_embeds = prompt_embeds[0]
1127
- if clip_skip is None:
1128
- prompt_embeds = prompt_embeds.hidden_states[-2]
1129
- else:
1130
- # "2" because SDXL always indexes from the penultimate layer.
1131
- prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
1132
-
1133
- if self.config.force_zeros_for_empty_prompt:
1134
- prompt_embeds[[i for i in range(len(prompt)) if prompt[i] == ""]] = 0
1135
- pooled_prompt_embeds[
1136
- [i for i in range(len(prompt)) if prompt[i] == ""]
1137
- ] = 0
1138
-
1139
- prompt_embeds_list.append(prompt_embeds)
1140
-
1141
- prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
1142
-
1143
- # get unconditional embeddings for classifier free guidance
1144
- zero_out_negative_prompt = (
1145
- negative_prompt is None and self.config.force_zeros_for_empty_prompt
1146
- )
1147
- if (
1148
- do_classifier_free_guidance
1149
- and negative_prompt_embeds is None
1150
- and zero_out_negative_prompt
1151
- ):
1152
- negative_prompt_embeds = torch.zeros_like(prompt_embeds)
1153
- negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
1154
- elif do_classifier_free_guidance and negative_prompt_embeds is None:
1155
- negative_prompt = negative_prompt or ""
1156
- negative_prompt_2 = negative_prompt_2 or negative_prompt
1157
-
1158
- # normalize str to list
1159
- negative_prompt = (
1160
- batch_size * [negative_prompt]
1161
- if isinstance(negative_prompt, str)
1162
- else negative_prompt
1163
- )
1164
- negative_prompt_2 = (
1165
- batch_size * [negative_prompt_2]
1166
- if isinstance(negative_prompt_2, str)
1167
- else negative_prompt_2
1168
- )
1169
-
1170
- uncond_tokens: List[str]
1171
- if prompt is not None and type(prompt) is not type(negative_prompt):
1172
- raise TypeError(
1173
- f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
1174
- f" {type(prompt)}."
1175
- )
1176
- elif batch_size != len(negative_prompt):
1177
- raise ValueError(
1178
- f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
1179
- f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
1180
- " the batch size of `prompt`."
1181
- )
1182
- else:
1183
- uncond_tokens = [negative_prompt, negative_prompt_2]
1184
-
1185
- negative_prompt_embeds_list = []
1186
- for negative_prompt, tokenizer, text_encoder in zip(
1187
- uncond_tokens, tokenizers, text_encoders
1188
- ):
1189
-
1190
- max_length = prompt_embeds.shape[1]
1191
- uncond_input = tokenizer(
1192
- negative_prompt,
1193
- padding="max_length",
1194
- max_length=max_length,
1195
- truncation=True,
1196
- return_tensors="pt",
1197
- )
1198
-
1199
- negative_prompt_embeds = text_encoder(
1200
- uncond_input.input_ids.to(device),
1201
- output_hidden_states=True,
1202
- )
1203
- # We are only ALWAYS interested in the pooled output of the final text encoder
1204
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
1205
- negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
1206
-
1207
- negative_prompt_embeds_list.append(negative_prompt_embeds)
1208
-
1209
- negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
1210
-
1211
- if self.text_encoder_2 is not None:
1212
- prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
1213
- else:
1214
- prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
1215
-
1216
- bs_embed, seq_len, _ = prompt_embeds.shape
1217
- # duplicate text embeddings for each generation per prompt, using mps friendly method
1218
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
1219
- prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
1220
-
1221
- if do_classifier_free_guidance:
1222
- # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
1223
- seq_len = negative_prompt_embeds.shape[1]
1224
-
1225
- if self.text_encoder_2 is not None:
1226
- negative_prompt_embeds = negative_prompt_embeds.to(
1227
- dtype=self.text_encoder_2.dtype, device=device
1228
- )
1229
- else:
1230
- negative_prompt_embeds = negative_prompt_embeds.to(
1231
- dtype=self.unet.dtype, device=device
1232
- )
1233
-
1234
- negative_prompt_embeds = negative_prompt_embeds.repeat(
1235
- 1, num_images_per_prompt, 1
1236
- )
1237
- negative_prompt_embeds = negative_prompt_embeds.view(
1238
- batch_size * num_images_per_prompt, seq_len, -1
1239
- )
1240
-
1241
- pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
1242
- bs_embed * num_images_per_prompt, -1
1243
- )
1244
- if do_classifier_free_guidance:
1245
- negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(
1246
- 1, num_images_per_prompt
1247
- ).view(bs_embed * num_images_per_prompt, -1)
1248
-
1249
- if self.text_encoder is not None:
1250
- if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
1251
- # Retrieve the original scale by scaling back the LoRA layers
1252
- unscale_lora_layers(self.text_encoder, lora_scale)
1253
-
1254
- if self.text_encoder_2 is not None:
1255
- if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
1256
- # Retrieve the original scale by scaling back the LoRA layers
1257
- unscale_lora_layers(self.text_encoder_2, lora_scale)
1258
-
1259
- return (
1260
- prompt_embeds,
1261
- negative_prompt_embeds,
1262
- pooled_prompt_embeds,
1263
- negative_pooled_prompt_embeds,
1264
- )
1265
-
1266
-
1267
- def create_xts(
1268
- noise_shift_delta,
1269
- noise_timesteps,
1270
- clean_step_timestep,
1271
- generator,
1272
- scheduler,
1273
- timesteps,
1274
- x_0,
1275
- no_add_noise=False,
1276
- ):
1277
- if noise_timesteps is None:
1278
- noising_delta = noise_shift_delta * (timesteps[0] - timesteps[1])
1279
- noise_timesteps = [timestep - int(noising_delta) for timestep in timesteps]
1280
-
1281
- first_x_0_idx = len(noise_timesteps)
1282
- for i in range(len(noise_timesteps)):
1283
- if noise_timesteps[i] <= 0:
1284
- first_x_0_idx = i
1285
- break
1286
-
1287
- noise_timesteps = noise_timesteps[:first_x_0_idx]
1288
-
1289
- x_0_expanded = x_0.expand(len(noise_timesteps), -1, -1, -1)
1290
- noise = (
1291
- torch.randn(x_0_expanded.size(), generator=generator, device="cpu").to(
1292
- x_0.device
1293
- )
1294
- if not no_add_noise
1295
- else torch.zeros_like(x_0_expanded)
1296
- )
1297
- x_ts = scheduler.add_noise(
1298
- x_0_expanded,
1299
- noise,
1300
- torch.IntTensor(noise_timesteps),
1301
- )
1302
- x_ts = [t.unsqueeze(dim=0) for t in list(x_ts)]
1303
- x_ts += [x_0] * (len(timesteps) - first_x_0_idx)
1304
- x_ts += [x_0]
1305
- if clean_step_timestep > 0:
1306
- x_ts += [x_0]
1307
- return x_ts
1308
-
1309
-
1310
- # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
1311
- def add_noise(
1312
- self,
1313
- original_samples: torch.FloatTensor,
1314
- noise: torch.FloatTensor,
1315
- image_timesteps: torch.IntTensor,
1316
- noise_timesteps: torch.IntTensor,
1317
- ) -> torch.FloatTensor:
1318
- # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
1319
- # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
1320
- # for the subsequent add_noise calls
1321
- self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
1322
- alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype)
1323
- timesteps = timesteps.to(original_samples.device)
1324
-
1325
- sqrt_alpha_prod = alphas_cumprod[image_timesteps] ** 0.5
1326
- sqrt_alpha_prod = sqrt_alpha_prod.flatten()
1327
- while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
1328
- sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
1329
-
1330
- sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[noise_timesteps]) ** 0.5
1331
- sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
1332
- while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
1333
- sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
1334
-
1335
- noisy_samples = (
1336
- sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
1337
- )
1338
- return noisy_samples
1339
-
1340
-
1341
- def make_image_grid(
1342
- images: List[PIL.Image.Image], rows: int, cols: int, resize: int = None, size=None
1343
- ) -> PIL.Image.Image:
1344
- """
1345
- Prepares a single grid of images. Useful for visualization purposes.
1346
- """
1347
- assert len(images) == rows * cols
1348
-
1349
- if resize is not None:
1350
- images = [img.resize((resize, resize)) for img in images]
1351
-
1352
- w, h = size
1353
- grid = Image.new("RGB", size=(cols * w, rows * h))
1354
-
1355
- for i, img in enumerate(images):
1356
- grid.paste(img, box=(i % cols * w, i // cols * h))
1357
- return grid