PseudoTerminal X commited on
Commit
c7b113a
·
verified ·
1 Parent(s): ecd948a

Create pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +1299 -0
pipeline.py ADDED
@@ -0,0 +1,1299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import html
16
+ import inspect
17
+ import re
18
+ import urllib.parse as ul
19
+ from typing import Callable, List, Optional, Tuple, Union
20
+
21
+ import torch
22
+ from transformers import T5EncoderModel, T5Tokenizer
23
+
24
+ from diffusers.image_processor import PixArtImageProcessor, PipelineImageInput
25
+ from diffusers.models import AutoencoderKL, PixArtTransformer2DModel
26
+ from diffusers.schedulers import KarrasDiffusionSchedulers
27
+ from diffusers.utils import (
28
+ BACKENDS_MAPPING,
29
+ deprecate,
30
+ is_bs4_available,
31
+ is_ftfy_available,
32
+ logging,
33
+ replace_example_docstring,
34
+ )
35
+ from diffusers.utils.torch_utils import randn_tensor
36
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
37
+ from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha import (
38
+ ASPECT_RATIO_256_BIN,
39
+ ASPECT_RATIO_512_BIN,
40
+ ASPECT_RATIO_1024_BIN,
41
+ )
42
+
43
+
44
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
45
+ def retrieve_latents(
46
+ encoder_output: torch.Tensor,
47
+ generator: Optional[torch.Generator] = None,
48
+ sample_mode: str = "sample",
49
+ ):
50
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
51
+ return encoder_output.latent_dist.sample(generator)
52
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
53
+ return encoder_output.latent_dist.mode()
54
+ elif hasattr(encoder_output, "latents"):
55
+ return encoder_output.latents
56
+ else:
57
+ raise AttributeError("Could not access latents of provided encoder_output")
58
+
59
+
60
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
61
+
62
+ if is_bs4_available():
63
+ from bs4 import BeautifulSoup
64
+
65
+ if is_ftfy_available():
66
+ import ftfy
67
+
68
+ def debug_print(message: str):
69
+ #print(message)
70
+ pass
71
+
72
+ ASPECT_RATIO_2048_BIN = {
73
+ "0.25": [1024.0, 4096.0],
74
+ "0.26": [1024.0, 3968.0],
75
+ "0.27": [1024.0, 3840.0],
76
+ "0.28": [1024.0, 3712.0],
77
+ "0.32": [1152.0, 3584.0],
78
+ "0.33": [1152.0, 3456.0],
79
+ "0.35": [1152.0, 3328.0],
80
+ "0.4": [1280.0, 3200.0],
81
+ "0.42": [1280.0, 3072.0],
82
+ "0.48": [1408.0, 2944.0],
83
+ "0.5": [1408.0, 2816.0],
84
+ "0.52": [1408.0, 2688.0],
85
+ "0.57": [1536.0, 2688.0],
86
+ "0.6": [1536.0, 2560.0],
87
+ "0.68": [1664.0, 2432.0],
88
+ "0.72": [1664.0, 2304.0],
89
+ "0.78": [1792.0, 2304.0],
90
+ "0.82": [1792.0, 2176.0],
91
+ "0.88": [1920.0, 2176.0],
92
+ "0.94": [1920.0, 2048.0],
93
+ "1.0": [2048.0, 2048.0],
94
+ "1.07": [2048.0, 1920.0],
95
+ "1.13": [2176.0, 1920.0],
96
+ "1.21": [2176.0, 1792.0],
97
+ "1.29": [2304.0, 1792.0],
98
+ "1.38": [2304.0, 1664.0],
99
+ "1.46": [2432.0, 1664.0],
100
+ "1.67": [2560.0, 1536.0],
101
+ "1.75": [2688.0, 1536.0],
102
+ "2.0": [2816.0, 1408.0],
103
+ "2.09": [2944.0, 1408.0],
104
+ "2.4": [3072.0, 1280.0],
105
+ "2.5": [3200.0, 1280.0],
106
+ "2.89": [3328.0, 1152.0],
107
+ "3.0": [3456.0, 1152.0],
108
+ "3.11": [3584.0, 1152.0],
109
+ "3.62": [3712.0, 1024.0],
110
+ "3.75": [3840.0, 1024.0],
111
+ "3.88": [3968.0, 1024.0],
112
+ "4.0": [4096.0, 1024.0],
113
+ }
114
+
115
+
116
+ EXAMPLE_DOC_STRING = """
117
+ Examples:
118
+ ```py
119
+ >>> import torch
120
+ >>> from diffusers import PixArtSigmaPipeline
121
+
122
+ >>> # You can replace the checkpoint id with "PixArt-alpha/PixArt-Sigma-XL-2-512-MS" too.
123
+ >>> pipe = PixArtSigmaPipeline.from_pretrained(
124
+ ... "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", torch_dtype=torch.float16
125
+ ... )
126
+ >>> # Enable memory optimizations.
127
+ >>> # pipe.enable_model_cpu_offload()
128
+
129
+ >>> prompt = "A small cactus with a happy face in the Sahara desert."
130
+ >>> image = pipe(prompt).images[0]
131
+ ```
132
+ """
133
+
134
+
135
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
136
+ def retrieve_timesteps(
137
+ scheduler,
138
+ num_inference_steps: Optional[int] = None,
139
+ device: Optional[Union[str, torch.device]] = None,
140
+ timesteps: Optional[List[int]] = None,
141
+ sigmas: Optional[List[float]] = None,
142
+ **kwargs,
143
+ ):
144
+ """
145
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
146
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
147
+
148
+ Args:
149
+ scheduler (`SchedulerMixin`):
150
+ The scheduler to get timesteps from.
151
+ num_inference_steps (`int`):
152
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
153
+ must be `None`.
154
+ device (`str` or `torch.device`, *optional*):
155
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
156
+ timesteps (`List[int]`, *optional*):
157
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
158
+ `num_inference_steps` and `sigmas` must be `None`.
159
+ sigmas (`List[float]`, *optional*):
160
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
161
+ `num_inference_steps` and `timesteps` must be `None`.
162
+
163
+ Returns:
164
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
165
+ second element is the number of inference steps.
166
+ """
167
+ if timesteps is not None and sigmas is not None:
168
+ raise ValueError(
169
+ "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
170
+ )
171
+ if timesteps is not None:
172
+ accepts_timesteps = "timesteps" in set(
173
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
174
+ )
175
+ if not accepts_timesteps:
176
+ raise ValueError(
177
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
178
+ f" timestep schedules. Please check whether you are using the correct scheduler."
179
+ )
180
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
181
+ timesteps = scheduler.timesteps
182
+ num_inference_steps = len(timesteps)
183
+ elif sigmas is not None:
184
+ accept_sigmas = "sigmas" in set(
185
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
186
+ )
187
+ if not accept_sigmas:
188
+ raise ValueError(
189
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
190
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
191
+ )
192
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
193
+ timesteps = scheduler.timesteps
194
+ num_inference_steps = len(timesteps)
195
+ else:
196
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
197
+ timesteps = scheduler.timesteps
198
+ return timesteps, num_inference_steps
199
+
200
+
201
+ class PixArtSigmaPipeline(DiffusionPipeline):
202
+ r"""
203
+ Pipeline for text-to-image generation using PixArt-Sigma.
204
+ """
205
+
206
+ bad_punct_regex = re.compile(
207
+ r"["
208
+ + "#®•©™&@·º½¾¿¡§~"
209
+ + r"\)"
210
+ + r"\("
211
+ + r"\]"
212
+ + r"\["
213
+ + r"\}"
214
+ + r"\{"
215
+ + r"\|"
216
+ + "\\"
217
+ + r"\/"
218
+ + r"\*"
219
+ + r"]{1,}"
220
+ ) # noqa
221
+
222
+ _optional_components = ["tokenizer", "text_encoder"]
223
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
224
+
225
+ def __init__(
226
+ self,
227
+ tokenizer: T5Tokenizer,
228
+ text_encoder: T5EncoderModel,
229
+ vae: AutoencoderKL,
230
+ transformer: PixArtTransformer2DModel,
231
+ scheduler: KarrasDiffusionSchedulers,
232
+ ):
233
+ super().__init__()
234
+
235
+ self.register_modules(
236
+ tokenizer=tokenizer,
237
+ text_encoder=text_encoder,
238
+ vae=vae,
239
+ transformer=transformer,
240
+ scheduler=scheduler,
241
+ )
242
+
243
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
244
+ self.image_processor = PixArtImageProcessor(
245
+ vae_scale_factor=self.vae_scale_factor
246
+ )
247
+
248
+ def get_timesteps(
249
+ self, num_inference_steps, strength, device, denoising_start=None
250
+ ):
251
+ # get the original timestep using init_timestep
252
+ if denoising_start is None and strength is not None:
253
+ init_timestep = min(
254
+ int(num_inference_steps * strength), num_inference_steps
255
+ )
256
+ debug_print(f"Init timestep: {init_timestep}")
257
+ t_start = max(num_inference_steps - init_timestep, 0)
258
+ debug_print(
259
+ f"t_start = max({num_inference_steps} - {init_timestep}, 0) = {t_start}"
260
+ )
261
+ else:
262
+ debug_print(f"denoising_start: {denoising_start}")
263
+ t_start = 0
264
+
265
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
266
+ # Strength is irrelevant if we directly request a timestep to start at;
267
+ # that is, strength is determined by the denoising_start instead.
268
+ if denoising_start is not None:
269
+ discrete_timestep_cutoff = int(
270
+ round(
271
+ self.scheduler.config.num_train_timesteps
272
+ - (denoising_start * self.scheduler.config.num_train_timesteps)
273
+ )
274
+ )
275
+
276
+ num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
277
+ if self.scheduler.order == 2 and num_inference_steps % 2 == 0:
278
+ # if the scheduler is a 2nd order scheduler we might have to do +1
279
+ # because `num_inference_steps` might be even given that every timestep
280
+ # (except the highest one) is duplicated. If `num_inference_steps` is even it would
281
+ # mean that we cut the timesteps in the middle of the denoising step
282
+ # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1
283
+ # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
284
+ num_inference_steps = num_inference_steps + 1
285
+
286
+ # because t_n+1 >= t_n, we slice the timesteps starting from the end
287
+ timesteps = timesteps[-num_inference_steps:]
288
+ return timesteps, num_inference_steps
289
+
290
+ return timesteps, num_inference_steps - t_start
291
+
292
+ # Copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.PixArtAlphaPipeline.encode_prompt with 120->300
293
+ def encode_prompt(
294
+ self,
295
+ prompt: Union[str, List[str]],
296
+ do_classifier_free_guidance: bool = True,
297
+ negative_prompt: str = "",
298
+ num_images_per_prompt: int = 1,
299
+ device: Optional[torch.device] = None,
300
+ prompt_embeds: Optional[torch.Tensor] = None,
301
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
302
+ prompt_attention_mask: Optional[torch.Tensor] = None,
303
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
304
+ clean_caption: bool = False,
305
+ max_sequence_length: int = 300,
306
+ **kwargs,
307
+ ):
308
+ r"""
309
+ Encodes the prompt into text encoder hidden states.
310
+
311
+ Args:
312
+ prompt (`str` or `List[str]`, *optional*):
313
+ prompt to be encoded
314
+ negative_prompt (`str` or `List[str]`, *optional*):
315
+ The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
316
+ instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
317
+ PixArt-Alpha, this should be "".
318
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
319
+ whether to use classifier free guidance or not
320
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
321
+ number of images that should be generated per prompt
322
+ device: (`torch.device`, *optional*):
323
+ torch device to place the resulting embeddings on
324
+ prompt_embeds (`torch.Tensor`, *optional*):
325
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
326
+ provided, text embeddings will be generated from `prompt` input argument.
327
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
328
+ Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the ""
329
+ string.
330
+ clean_caption (`bool`, defaults to `False`):
331
+ If `True`, the function will preprocess and clean the provided caption before encoding.
332
+ max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt.
333
+ """
334
+
335
+ if "mask_feature" in kwargs:
336
+ deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
337
+ deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False)
338
+
339
+ if device is None:
340
+ device = self._execution_device
341
+
342
+ if prompt is not None and isinstance(prompt, str):
343
+ batch_size = 1
344
+ elif prompt is not None and isinstance(prompt, list):
345
+ batch_size = len(prompt)
346
+ else:
347
+ batch_size = prompt_embeds.shape[0]
348
+
349
+ # See Section 3.1. of the paper.
350
+ max_length = max_sequence_length
351
+
352
+ if prompt_embeds is None:
353
+ prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
354
+ text_inputs = self.tokenizer(
355
+ prompt,
356
+ padding="max_length",
357
+ max_length=max_length,
358
+ truncation=True,
359
+ add_special_tokens=True,
360
+ return_tensors="pt",
361
+ )
362
+ text_input_ids = text_inputs.input_ids
363
+ untruncated_ids = self.tokenizer(
364
+ prompt, padding="longest", return_tensors="pt"
365
+ ).input_ids
366
+
367
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[
368
+ -1
369
+ ] and not torch.equal(text_input_ids, untruncated_ids):
370
+ removed_text = self.tokenizer.batch_decode(
371
+ untruncated_ids[:, max_length - 1 : -1]
372
+ )
373
+ logger.warning(
374
+ "The following part of your input was truncated because T5 can only handle sequences up to"
375
+ f" {max_length} tokens: {removed_text}"
376
+ )
377
+
378
+ prompt_attention_mask = text_inputs.attention_mask
379
+ prompt_attention_mask = prompt_attention_mask.to(device)
380
+
381
+ prompt_embeds = self.text_encoder(
382
+ text_input_ids.to(device), attention_mask=prompt_attention_mask
383
+ )
384
+ prompt_embeds = prompt_embeds[0]
385
+
386
+ if self.text_encoder is not None:
387
+ dtype = self.text_encoder.dtype
388
+ elif self.transformer is not None:
389
+ dtype = self.transformer.dtype
390
+ else:
391
+ dtype = None
392
+
393
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
394
+
395
+ bs_embed, seq_len, _ = prompt_embeds.shape
396
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
397
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
398
+ prompt_embeds = prompt_embeds.view(
399
+ bs_embed * num_images_per_prompt, seq_len, -1
400
+ )
401
+ prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
402
+ prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
403
+
404
+ # get unconditional embeddings for classifier free guidance
405
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
406
+ uncond_tokens = (
407
+ [negative_prompt] * batch_size
408
+ if isinstance(negative_prompt, str)
409
+ else negative_prompt
410
+ )
411
+ uncond_tokens = self._text_preprocessing(
412
+ uncond_tokens, clean_caption=clean_caption
413
+ )
414
+ max_length = prompt_embeds.shape[1]
415
+ uncond_input = self.tokenizer(
416
+ uncond_tokens,
417
+ padding="max_length",
418
+ max_length=max_length,
419
+ truncation=True,
420
+ return_attention_mask=True,
421
+ add_special_tokens=True,
422
+ return_tensors="pt",
423
+ )
424
+ negative_prompt_attention_mask = uncond_input.attention_mask
425
+ negative_prompt_attention_mask = negative_prompt_attention_mask.to(device)
426
+
427
+ negative_prompt_embeds = self.text_encoder(
428
+ uncond_input.input_ids.to(device),
429
+ attention_mask=negative_prompt_attention_mask,
430
+ )
431
+ negative_prompt_embeds = negative_prompt_embeds[0]
432
+
433
+ if do_classifier_free_guidance:
434
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
435
+ seq_len = negative_prompt_embeds.shape[1]
436
+
437
+ negative_prompt_embeds = negative_prompt_embeds.to(
438
+ dtype=dtype, device=device
439
+ )
440
+
441
+ negative_prompt_embeds = negative_prompt_embeds.repeat(
442
+ 1, num_images_per_prompt, 1
443
+ )
444
+ negative_prompt_embeds = negative_prompt_embeds.view(
445
+ batch_size * num_images_per_prompt, seq_len, -1
446
+ )
447
+
448
+ negative_prompt_attention_mask = negative_prompt_attention_mask.view(
449
+ bs_embed, -1
450
+ )
451
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(
452
+ num_images_per_prompt, 1
453
+ )
454
+ else:
455
+ negative_prompt_embeds = None
456
+ negative_prompt_attention_mask = None
457
+
458
+ return (
459
+ prompt_embeds,
460
+ prompt_attention_mask,
461
+ negative_prompt_embeds,
462
+ negative_prompt_attention_mask,
463
+ )
464
+
465
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
466
+ def prepare_extra_step_kwargs(self, generator, eta):
467
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
468
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
469
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
470
+ # and should be between [0, 1]
471
+
472
+ accepts_eta = "eta" in set(
473
+ inspect.signature(self.scheduler.step).parameters.keys()
474
+ )
475
+ extra_step_kwargs = {}
476
+ if accepts_eta:
477
+ extra_step_kwargs["eta"] = eta
478
+
479
+ # check if the scheduler accepts generator
480
+ accepts_generator = "generator" in set(
481
+ inspect.signature(self.scheduler.step).parameters.keys()
482
+ )
483
+ if accepts_generator:
484
+ extra_step_kwargs["generator"] = generator
485
+ return extra_step_kwargs
486
+
487
+ # Copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.PixArtAlphaPipeline.check_inputs
488
+ def check_inputs(
489
+ self,
490
+ prompt,
491
+ height,
492
+ width,
493
+ strength,
494
+ num_inference_steps,
495
+ negative_prompt,
496
+ callback_steps,
497
+ prompt_embeds=None,
498
+ negative_prompt_embeds=None,
499
+ prompt_attention_mask=None,
500
+ negative_prompt_attention_mask=None,
501
+ ):
502
+ if strength is None:
503
+ if height % 8 != 0 or width % 8 != 0:
504
+ raise ValueError(
505
+ f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
506
+ )
507
+ else:
508
+ if strength < 0 or strength > 1:
509
+ raise ValueError(
510
+ f"The value of strength should in [0.0, 1.0] but is {strength}"
511
+ )
512
+ if num_inference_steps is None:
513
+ raise ValueError("`num_inference_steps` cannot be None.")
514
+ elif not isinstance(num_inference_steps, int) or num_inference_steps <= 0:
515
+ raise ValueError(
516
+ f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type"
517
+ f" {type(num_inference_steps)}."
518
+ )
519
+ if (callback_steps is None) or (
520
+ callback_steps is not None
521
+ and (not isinstance(callback_steps, int) or callback_steps <= 0)
522
+ ):
523
+ raise ValueError(
524
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
525
+ f" {type(callback_steps)}."
526
+ )
527
+
528
+ if prompt is not None and prompt_embeds is not None:
529
+ prompt = None
530
+
531
+ if prompt is None and prompt_embeds is None:
532
+ raise ValueError(
533
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
534
+ )
535
+ elif prompt is not None and (
536
+ not isinstance(prompt, str) and not isinstance(prompt, list)
537
+ ):
538
+ raise ValueError(
539
+ f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
540
+ )
541
+
542
+ if prompt is not None and negative_prompt_embeds is not None:
543
+ raise ValueError(
544
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
545
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
546
+ )
547
+
548
+ if negative_prompt is not None and negative_prompt_embeds is not None:
549
+ negative_prompt = None
550
+
551
+ if prompt_embeds is not None and prompt_attention_mask is None:
552
+ raise ValueError(
553
+ "Must provide `prompt_attention_mask` when specifying `prompt_embeds`."
554
+ )
555
+
556
+ if (
557
+ negative_prompt_embeds is not None
558
+ and negative_prompt_attention_mask is None
559
+ ):
560
+ raise ValueError(
561
+ "Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`."
562
+ )
563
+
564
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
565
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
566
+ raise ValueError(
567
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
568
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
569
+ f" {negative_prompt_embeds.shape}."
570
+ )
571
+ if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
572
+ raise ValueError(
573
+ "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
574
+ f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
575
+ f" {negative_prompt_attention_mask.shape}."
576
+ )
577
+
578
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
579
+ def _text_preprocessing(self, text, clean_caption=False):
580
+ if clean_caption and not is_bs4_available():
581
+ logger.warning(
582
+ BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")
583
+ )
584
+ logger.warning("Setting `clean_caption` to False...")
585
+ clean_caption = False
586
+
587
+ if clean_caption and not is_ftfy_available():
588
+ logger.warning(
589
+ BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")
590
+ )
591
+ logger.warning("Setting `clean_caption` to False...")
592
+ clean_caption = False
593
+
594
+ if not isinstance(text, (tuple, list)):
595
+ text = [text]
596
+
597
+ def process(text: str):
598
+ if clean_caption:
599
+ text = self._clean_caption(text)
600
+ text = self._clean_caption(text)
601
+ else:
602
+ text = text.lower().strip()
603
+ return text
604
+
605
+ return [process(t) for t in text]
606
+
607
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
608
+ def _clean_caption(self, caption):
609
+ caption = str(caption)
610
+ caption = ul.unquote_plus(caption)
611
+ caption = caption.strip().lower()
612
+ caption = re.sub("<person>", "person", caption)
613
+ # urls:
614
+ caption = re.sub(
615
+ r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
616
+ "",
617
+ caption,
618
+ ) # regex for urls
619
+ caption = re.sub(
620
+ r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
621
+ "",
622
+ caption,
623
+ ) # regex for urls
624
+ # html:
625
+ caption = BeautifulSoup(caption, features="html.parser").text
626
+
627
+ # @<nickname>
628
+ caption = re.sub(r"@[\w\d]+\b", "", caption)
629
+
630
+ # 31C0—31EF CJK Strokes
631
+ # 31F0—31FF Katakana Phonetic Extensions
632
+ # 3200—32FF Enclosed CJK Letters and Months
633
+ # 3300—33FF CJK Compatibility
634
+ # 3400—4DBF CJK Unified Ideographs Extension A
635
+ # 4DC0—4DFF Yijing Hexagram Symbols
636
+ # 4E00—9FFF CJK Unified Ideographs
637
+ caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
638
+ caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
639
+ caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
640
+ caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
641
+ caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
642
+ caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
643
+ caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
644
+ #######################################################
645
+
646
+ # все виды тире / all types of dash --> "-"
647
+ caption = re.sub(
648
+ r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
649
+ "-",
650
+ caption,
651
+ )
652
+
653
+ # кавычки к одному стандарту
654
+ caption = re.sub(r"[`´«»“”¨]", '"', caption)
655
+ caption = re.sub(r"[‘’]", "'", caption)
656
+
657
+ # &quot;
658
+ caption = re.sub(r"&quot;?", "", caption)
659
+ # &amp
660
+ caption = re.sub(r"&amp", "", caption)
661
+
662
+ # ip adresses:
663
+ caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
664
+
665
+ # article ids:
666
+ caption = re.sub(r"\d:\d\d\s+$", "", caption)
667
+
668
+ # \n
669
+ caption = re.sub(r"\\n", " ", caption)
670
+
671
+ # "#123"
672
+ caption = re.sub(r"#\d{1,3}\b", "", caption)
673
+ # "#12345.."
674
+ caption = re.sub(r"#\d{5,}\b", "", caption)
675
+ # "123456.."
676
+ caption = re.sub(r"\b\d{6,}\b", "", caption)
677
+ # filenames:
678
+ caption = re.sub(
679
+ r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption
680
+ )
681
+
682
+ #
683
+ caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
684
+ caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
685
+
686
+ caption = re.sub(
687
+ self.bad_punct_regex, r" ", caption
688
+ ) # ***AUSVERKAUFT***, #AUSVERKAUFT
689
+ caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
690
+
691
+ # this-is-my-cute-cat / this_is_my_cute_cat
692
+ regex2 = re.compile(r"(?:\-|\_)")
693
+ if len(re.findall(regex2, caption)) > 3:
694
+ caption = re.sub(regex2, " ", caption)
695
+
696
+ caption = ftfy.fix_text(caption)
697
+ caption = html.unescape(html.unescape(caption))
698
+
699
+ caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
700
+ caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
701
+ caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
702
+
703
+ caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
704
+ caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
705
+ caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
706
+ caption = re.sub(
707
+ r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption
708
+ )
709
+ caption = re.sub(r"\bpage\s+\d+\b", "", caption)
710
+
711
+ caption = re.sub(
712
+ r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption
713
+ ) # j2d1a2a...
714
+
715
+ caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
716
+
717
+ caption = re.sub(r"\b\s+\:\s+", r": ", caption)
718
+ caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
719
+ caption = re.sub(r"\s+", " ", caption)
720
+
721
+ caption.strip()
722
+
723
+ caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
724
+ caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
725
+ caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
726
+ caption = re.sub(r"^\.\S+$", "", caption)
727
+
728
+ return caption.strip()
729
+
730
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
731
+ def prepare_latents(
732
+ self,
733
+ batch_size,
734
+ num_channels_latents,
735
+ height,
736
+ width,
737
+ dtype,
738
+ device,
739
+ generator,
740
+ _latents=None,
741
+ timestep=None,
742
+ add_noise=False,
743
+ image=None,
744
+ ):
745
+ shape = (
746
+ batch_size,
747
+ num_channels_latents,
748
+ int(height) // self.vae_scale_factor,
749
+ int(width) // self.vae_scale_factor,
750
+ )
751
+ if isinstance(generator, list) and len(generator) != batch_size:
752
+ raise ValueError(
753
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
754
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
755
+ )
756
+
757
+ if _latents is not None:
758
+ init_latents = _latents.to(device)
759
+ elif image is None and _latents is None:
760
+ debug_print("Make random latents tensor")
761
+ init_latents = randn_tensor(
762
+ shape, generator=generator, device=device, dtype=dtype
763
+ )
764
+
765
+ latents_mean = latents_std = None
766
+ if (
767
+ hasattr(self.vae.config, "latents_mean")
768
+ and self.vae.config.latents_mean is not None
769
+ ):
770
+ latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
771
+ if (
772
+ hasattr(self.vae.config, "latents_std")
773
+ and self.vae.config.latents_std is not None
774
+ ):
775
+ latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
776
+ if image is not None and hasattr(image, "shape") and image.shape[1] == 4:
777
+ debug_print("Received valid latent image input.")
778
+ init_latents = image
779
+
780
+ if init_latents is not None:
781
+ # scale the initial noise by the standard deviation required by the scheduler
782
+ debug_print(f"Scaling the initial noise by the std required by the scheduler.")
783
+ init_latents = init_latents * self.scheduler.init_noise_sigma
784
+
785
+ if image is not None and image.shape[1] < 4:
786
+ debug_print("Received RGB or similar image. Processing..")
787
+ # make sure the VAE is in float32 mode, as it overflows in float16
788
+ if self.vae.config.force_upcast:
789
+ image = image.float()
790
+ self.vae.to(dtype=torch.float32)
791
+
792
+ if isinstance(generator, list) and len(generator) != batch_size:
793
+ raise ValueError(
794
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
795
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
796
+ )
797
+
798
+ elif isinstance(generator, list):
799
+ init_latents = [
800
+ retrieve_latents(
801
+ self.vae.encode(image[i : i + 1]), generator=generator[i]
802
+ )
803
+ for i in range(batch_size)
804
+ ]
805
+ init_latents = torch.cat(init_latents, dim=0)
806
+ else:
807
+ debug_print("Encode image to latents.")
808
+ init_latents = retrieve_latents(
809
+ self.vae.encode(image), generator=generator
810
+ )
811
+
812
+ if self.vae.config.force_upcast:
813
+ self.vae.to(dtype)
814
+
815
+ debug_print("Set initial latents..")
816
+ init_latents = init_latents.to(dtype)
817
+ if latents_mean is not None and latents_std is not None:
818
+ debug_print("Scaling latents by mean/std")
819
+ latents_mean = latents_mean.to(device=device, dtype=dtype)
820
+ latents_std = latents_std.to(device=device, dtype=dtype)
821
+ init_latents = (
822
+ (init_latents - latents_mean)
823
+ * self.vae.config.scaling_factor
824
+ / latents_std
825
+ )
826
+ else:
827
+ debug_print("Scaling latents only by scaling_factor")
828
+ init_latents = self.vae.config.scaling_factor * init_latents
829
+
830
+ if (
831
+ batch_size > init_latents.shape[0]
832
+ and batch_size % init_latents.shape[0] == 0
833
+ ):
834
+ # expand init_latents for batch_size
835
+ additional_image_per_prompt = batch_size // init_latents.shape[0]
836
+ init_latents = torch.cat(
837
+ [init_latents] * additional_image_per_prompt, dim=0
838
+ )
839
+ elif (
840
+ batch_size > init_latents.shape[0]
841
+ and batch_size % init_latents.shape[0] != 0
842
+ ):
843
+ raise ValueError(
844
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
845
+ )
846
+ else:
847
+ init_latents = torch.cat([init_latents], dim=0)
848
+
849
+ if (
850
+ add_noise
851
+ and timestep is not None
852
+ and (_latents is not None or image is not None)
853
+ ):
854
+ shape = init_latents.shape
855
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
856
+ # get latents
857
+ debug_print(f"Adding noise to tensor for timestep: {timestep}")
858
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
859
+
860
+ return init_latents
861
+
862
+ @property
863
+ def denoising_start(self):
864
+ return self._denoising_start
865
+
866
+ @property
867
+ def denoising_end(self):
868
+ return self._denoising_end
869
+
870
+ @property
871
+ def num_timesteps(self):
872
+ return self._num_timesteps
873
+
874
+ @torch.no_grad()
875
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
876
+ def __call__(
877
+ self,
878
+ prompt: Union[str, List[str]] = None,
879
+ negative_prompt: str = "",
880
+ strength: float = None,
881
+ num_inference_steps: int = 20,
882
+ timesteps: List[int] = None,
883
+ sigmas: List[float] = None,
884
+ denoising_start: Optional[float] = None,
885
+ denoising_end: Optional[float] = None,
886
+ guidance_scale: float = 4.5,
887
+ num_images_per_prompt: Optional[int] = 1,
888
+ height: Optional[int] = None,
889
+ width: Optional[int] = None,
890
+ eta: float = 0.0,
891
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
892
+ image: Optional[PipelineImageInput] = None,
893
+ latents: Optional[torch.Tensor] = None,
894
+ prompt_embeds: Optional[torch.Tensor] = None,
895
+ prompt_attention_mask: Optional[torch.Tensor] = None,
896
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
897
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
898
+ output_type: Optional[str] = "pil",
899
+ return_dict: bool = True,
900
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
901
+ callback_steps: int = 1,
902
+ clean_caption: bool = True,
903
+ use_resolution_binning: bool = True,
904
+ max_sequence_length: int = 300,
905
+ **kwargs,
906
+ ) -> Union[ImagePipelineOutput, Tuple]:
907
+ """
908
+ Function invoked when calling the pipeline for generation.
909
+
910
+ Args:
911
+ prompt (`str` or `List[str]`, *optional*):
912
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
913
+ instead.
914
+ negative_prompt (`str` or `List[str]`, *optional*):
915
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
916
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
917
+ less than `1`).
918
+ strength (`float`, *optional*, defaults to 0.3):
919
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
920
+ will be used as a starting point, adding more noise to it the larger the `strength`. The number of
921
+ denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
922
+ be maximum and the denoising process will run for the full number of iterations specified in
923
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. Note that in the case of
924
+ `denoising_start` being declared as an integer, the value of `strength` will be ignored.
925
+ num_inference_steps (`int`, *optional*, defaults to 100):
926
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
927
+ expense of slower inference.
928
+ denoising_start (`float`, *optional*):
929
+ When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be
930
+ bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and
931
+ it is assumed that the passed `image` is a partly denoised image. Note that when this is specified,
932
+ strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline
933
+ is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refine Image
934
+ Quality**](https://huggingface.co/docs/diffusers/using-diffusers/sdxl#refine-image-quality).
935
+ denoising_end (`float`, *optional*):
936
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
937
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
938
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
939
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
940
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
941
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
942
+ timesteps (`List[int]`, *optional*):
943
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
944
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
945
+ passed will be used. Must be in descending order.
946
+ sigmas (`List[float]`, *optional*):
947
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
948
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
949
+ will be used.
950
+ guidance_scale (`float`, *optional*, defaults to 4.5):
951
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
952
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
953
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
954
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
955
+ usually at the expense of lower image quality.
956
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
957
+ The number of images to generate per prompt.
958
+ height (`int`, *optional*, defaults to self.unet.config.sample_size):
959
+ The height in pixels of the generated image.
960
+ width (`int`, *optional*, defaults to self.unet.config.sample_size):
961
+ The width in pixels of the generated image.
962
+ eta (`float`, *optional*, defaults to 0.0):
963
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
964
+ [`schedulers.DDIMScheduler`], will be ignored for others.
965
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
966
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
967
+ to make generation deterministic.
968
+ latents (`torch.Tensor`, *optional*):
969
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
970
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
971
+ tensor will ge generated by sampling using the supplied random `generator`.
972
+ prompt_embeds (`torch.Tensor`, *optional*):
973
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
974
+ provided, text embeddings will be generated from `prompt` input argument.
975
+ prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings.
976
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
977
+ Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
978
+ provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
979
+ negative_prompt_attention_mask (`torch.Tensor`, *optional*):
980
+ Pre-generated attention mask for negative text embeddings.
981
+ output_type (`str`, *optional*, defaults to `"pil"`):
982
+ The output format of the generate image. Choose between
983
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
984
+ return_dict (`bool`, *optional*, defaults to `True`):
985
+ Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
986
+ callback (`Callable`, *optional*):
987
+ A function that will be called every `callback_steps` steps during inference. The function will be
988
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
989
+ callback_steps (`int`, *optional*, defaults to 1):
990
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
991
+ called at every step.
992
+ clean_caption (`bool`, *optional*, defaults to `True`):
993
+ Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
994
+ be installed. If the dependencies are not installed, the embeddings will be created from the raw
995
+ prompt.
996
+ use_resolution_binning (`bool` defaults to `True`):
997
+ If set to `True`, the requested height and width are first mapped to the closest resolutions using
998
+ `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to
999
+ the requested resolution. Useful for generating non-square images.
1000
+ max_sequence_length (`int` defaults to 300): Maximum sequence length to use with the `prompt`.
1001
+
1002
+ Examples:
1003
+
1004
+ Returns:
1005
+ [`~pipelines.ImagePipelineOutput`] or `tuple`:
1006
+ If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
1007
+ returned where the first element is a list with the generated images
1008
+ """
1009
+ # 1. Check inputs. Raise error if not correct
1010
+ height = height or self.transformer.config.sample_size * self.vae_scale_factor
1011
+ width = width or self.transformer.config.sample_size * self.vae_scale_factor
1012
+ if use_resolution_binning:
1013
+ if self.transformer.config.sample_size == 256:
1014
+ aspect_ratio_bin = ASPECT_RATIO_2048_BIN
1015
+ elif self.transformer.config.sample_size == 128:
1016
+ aspect_ratio_bin = ASPECT_RATIO_1024_BIN
1017
+ elif self.transformer.config.sample_size == 64:
1018
+ aspect_ratio_bin = ASPECT_RATIO_512_BIN
1019
+ elif self.transformer.config.sample_size == 32:
1020
+ aspect_ratio_bin = ASPECT_RATIO_256_BIN
1021
+ else:
1022
+ raise ValueError("Invalid sample size")
1023
+ orig_height, orig_width = height, width
1024
+ height, width = self.image_processor.classify_height_width_bin(
1025
+ height, width, ratios=aspect_ratio_bin
1026
+ )
1027
+
1028
+ self.check_inputs(
1029
+ prompt,
1030
+ height,
1031
+ width,
1032
+ strength,
1033
+ num_inference_steps,
1034
+ negative_prompt,
1035
+ callback_steps,
1036
+ prompt_embeds,
1037
+ negative_prompt_embeds,
1038
+ prompt_attention_mask,
1039
+ negative_prompt_attention_mask,
1040
+ )
1041
+
1042
+ # 2. Default height and width to transformer
1043
+ if prompt is not None and isinstance(prompt, str):
1044
+ batch_size = 1
1045
+ elif prompt is not None and isinstance(prompt, list):
1046
+ batch_size = len(prompt)
1047
+ else:
1048
+ batch_size = prompt_embeds.shape[0]
1049
+
1050
+ device = self._execution_device
1051
+ self._denoising_start = denoising_start
1052
+ self._num_timesteps = num_inference_steps
1053
+ self._denoising_end = denoising_end
1054
+
1055
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1056
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1057
+ # corresponds to doing no classifier free guidance.
1058
+ do_classifier_free_guidance = guidance_scale > 1.0
1059
+
1060
+ # 3. Encode input prompt
1061
+ (
1062
+ prompt_embeds,
1063
+ prompt_attention_mask,
1064
+ negative_prompt_embeds,
1065
+ negative_prompt_attention_mask,
1066
+ ) = self.encode_prompt(
1067
+ prompt,
1068
+ do_classifier_free_guidance,
1069
+ negative_prompt=negative_prompt,
1070
+ num_images_per_prompt=num_images_per_prompt,
1071
+ device=device,
1072
+ prompt_embeds=prompt_embeds,
1073
+ negative_prompt_embeds=negative_prompt_embeds,
1074
+ prompt_attention_mask=prompt_attention_mask,
1075
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
1076
+ clean_caption=clean_caption,
1077
+ max_sequence_length=max_sequence_length,
1078
+ )
1079
+ if do_classifier_free_guidance:
1080
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1081
+ prompt_attention_mask = torch.cat(
1082
+ [negative_prompt_attention_mask, prompt_attention_mask], dim=0
1083
+ )
1084
+
1085
+ # 4. Prepare timesteps
1086
+ def denoising_value_valid(dnv):
1087
+ return isinstance(dnv, float) and 0 < dnv < 1
1088
+
1089
+ timesteps, num_inference_steps = retrieve_timesteps(
1090
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
1091
+ )
1092
+
1093
+ # 5. Prepare latents.
1094
+ if image is not None:
1095
+ image = self.image_processor.preprocess(image)
1096
+ image = image.to(device=self.vae.device, dtype=self.vae.dtype)
1097
+
1098
+ latent_channels = self.transformer.config.in_channels
1099
+ latent_timestep = None
1100
+ if (
1101
+ denoising_end is not None
1102
+ or denoising_start is not None
1103
+ or strength is not None
1104
+ ):
1105
+ timesteps, num_inference_steps = self.get_timesteps(
1106
+ num_inference_steps,
1107
+ strength,
1108
+ device,
1109
+ denoising_start=(
1110
+ self.denoising_start
1111
+ if denoising_value_valid(self.denoising_start)
1112
+ else None
1113
+ ),
1114
+ )
1115
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
1116
+ if latents is not None:
1117
+ height, width = latents.shape[-2:]
1118
+ height = height * self.vae_scale_factor
1119
+ width = width * self.vae_scale_factor
1120
+ add_noise = (
1121
+ True
1122
+ if (
1123
+ self.denoising_start is None
1124
+ and (image is not None or latents is not None)
1125
+ )
1126
+ else False
1127
+ )
1128
+ debug_print(f"Add_noise: {add_noise}")
1129
+ if latents is None:
1130
+ debug_print("Prepare latents..")
1131
+ latents = self.prepare_latents(
1132
+ batch_size * num_images_per_prompt,
1133
+ latent_channels,
1134
+ height,
1135
+ width,
1136
+ prompt_embeds.dtype,
1137
+ device,
1138
+ generator,
1139
+ latents,
1140
+ timestep=latent_timestep,
1141
+ add_noise=add_noise,
1142
+ image=image,
1143
+ )
1144
+
1145
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1146
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1147
+
1148
+ # 6.1 Prepare micro-conditions.
1149
+ added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
1150
+
1151
+ # 7. Denoising loop
1152
+ num_warmup_steps = max(
1153
+ len(timesteps) - num_inference_steps * self.scheduler.order, 0
1154
+ )
1155
+ if (
1156
+ self.denoising_end is not None
1157
+ and self.denoising_start is not None
1158
+ and denoising_value_valid(self.denoising_end)
1159
+ and denoising_value_valid(self.denoising_start)
1160
+ and self.denoising_start >= self.denoising_end
1161
+ ):
1162
+ raise ValueError(
1163
+ f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: "
1164
+ + f" {self.denoising_end} when using type float."
1165
+ )
1166
+ if self.denoising_start is not None:
1167
+ if denoising_value_valid(self.denoising_start):
1168
+ discrete_timestep_cutoff = int(
1169
+ round(
1170
+ self.scheduler.config.num_train_timesteps
1171
+ - (denoising_start * self.scheduler.config.num_train_timesteps)
1172
+ )
1173
+ )
1174
+
1175
+ num_inference_steps = (
1176
+ (timesteps < discrete_timestep_cutoff).sum().item()
1177
+ )
1178
+ debug_print(
1179
+ f"Beginning inference for stage2 with {num_inference_steps} steps."
1180
+ )
1181
+
1182
+ else:
1183
+ raise ValueError(
1184
+ f"`denoising_start` must be a float between 0 and 1: {denoising_start}"
1185
+ )
1186
+ if self.denoising_end is not None:
1187
+ if denoising_value_valid(self.denoising_end):
1188
+ discrete_timestep_cutoff = int(
1189
+ round(
1190
+ self.scheduler.config.num_train_timesteps
1191
+ - (
1192
+ self.denoising_end
1193
+ * self.scheduler.config.num_train_timesteps
1194
+ )
1195
+ )
1196
+ )
1197
+ num_inference_steps = len(
1198
+ list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))
1199
+ )
1200
+ debug_print(
1201
+ f"Beginning inference for stage1 with {num_inference_steps} steps."
1202
+ )
1203
+ timesteps = timesteps[:num_inference_steps]
1204
+ else:
1205
+ raise ValueError(
1206
+ f"`denoising_end` must be a float between 0 and 1: {denoising_end}"
1207
+ )
1208
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1209
+ for i, t in enumerate(timesteps):
1210
+ latent_model_input = (
1211
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
1212
+ )
1213
+ latent_model_input = self.scheduler.scale_model_input(
1214
+ latent_model_input, t
1215
+ )
1216
+
1217
+ current_timestep = t
1218
+ if not torch.is_tensor(current_timestep):
1219
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
1220
+ # This would be a good case for the `match` statement (Python 3.10+)
1221
+ is_mps = latent_model_input.device.type == "mps"
1222
+ if isinstance(current_timestep, float):
1223
+ dtype = torch.float32 if is_mps else torch.float64
1224
+ else:
1225
+ dtype = torch.int32 if is_mps else torch.int64
1226
+ current_timestep = torch.tensor(
1227
+ [current_timestep],
1228
+ dtype=dtype,
1229
+ device=latent_model_input.device,
1230
+ )
1231
+ elif len(current_timestep.shape) == 0:
1232
+ current_timestep = current_timestep[None].to(
1233
+ latent_model_input.device
1234
+ )
1235
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1236
+ current_timestep = current_timestep.expand(latent_model_input.shape[0])
1237
+
1238
+ # predict noise model_output
1239
+ noise_pred = self.transformer(
1240
+ latent_model_input.to(
1241
+ device=self.transformer.device, dtype=self.transformer.dtype
1242
+ ),
1243
+ encoder_hidden_states=prompt_embeds,
1244
+ encoder_attention_mask=prompt_attention_mask,
1245
+ timestep=current_timestep,
1246
+ added_cond_kwargs=added_cond_kwargs,
1247
+ return_dict=False,
1248
+ )[0]
1249
+
1250
+ # perform guidance
1251
+ if do_classifier_free_guidance:
1252
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1253
+ noise_pred = noise_pred_uncond + guidance_scale * (
1254
+ noise_pred_text - noise_pred_uncond
1255
+ )
1256
+
1257
+ # learned sigma
1258
+ if self.transformer.config.out_channels // 2 == latent_channels:
1259
+ noise_pred = noise_pred.chunk(2, dim=1)[0]
1260
+ else:
1261
+ noise_pred = noise_pred
1262
+
1263
+ # compute previous image: x_t -> x_t-1
1264
+ latents = self.scheduler.step(
1265
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=False
1266
+ )[0]
1267
+
1268
+ # call the callback, if provided
1269
+ if i == len(timesteps) - 1 or (
1270
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
1271
+ ):
1272
+ progress_bar.update()
1273
+ if callback is not None and i % callback_steps == 0:
1274
+ step_idx = i // getattr(self.scheduler, "order", 1)
1275
+ callback(step_idx, t, latents)
1276
+
1277
+ if not output_type == "latent":
1278
+ image = self.vae.decode(
1279
+ latents.to(device=self.vae.device, dtype=self.vae.dtype)
1280
+ / self.vae.config.scaling_factor,
1281
+ return_dict=False,
1282
+ )[0]
1283
+ if use_resolution_binning:
1284
+ image = self.image_processor.resize_and_crop_tensor(
1285
+ image, orig_width, orig_height
1286
+ )
1287
+ else:
1288
+ image = latents
1289
+
1290
+ if not output_type == "latent":
1291
+ image = self.image_processor.postprocess(image, output_type=output_type)
1292
+
1293
+ # Offload all models
1294
+ self.maybe_free_model_hooks()
1295
+
1296
+ if not return_dict:
1297
+ return (image,)
1298
+
1299
+ return ImagePipelineOutput(images=image)