primecai commited on
Commit
3f03890
·
1 Parent(s): b5c5790
app.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
+
5
+ from diffusers.utils import load_image
6
+ from pipeline import FluxConditionalPipeline
7
+ from transformer import FluxTransformer2DConditionalModel
8
+
9
+ import os
10
+
11
+ pipe = None
12
+
13
+ CHECKPOINT = "primecai/dsd_model"
14
+
15
+ def init_pipeline():
16
+ global pipe
17
+ transformer = FluxTransformer2DConditionalModel.from_pretrained(
18
+ os.path.join(CHECKPOINT, "transformer"),
19
+ torch_dtype=torch.bfloat16,
20
+ low_cpu_mem_usage=False,
21
+ ignore_mismatched_sizes=True,
22
+ )
23
+ pipe = FluxConditionalPipeline.from_pretrained(
24
+ "black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
25
+ )
26
+ pipe.load_lora_weights(os.path.join(CHECKPOINT, "pytorch_lora_weights.safetensors"))
27
+ pipe.to("cuda")
28
+
29
+
30
+ def process_image_and_text(image, text, gemini_prompt, guidance, i_guidance, t_guidance):
31
+ w, h, min_size = image.size[0], image.size[1], min(image.size)
32
+ image = image.crop(
33
+ ((w - min_size) // 2, (h - min_size) // 2, (w + min_size) // 2, (h + min_size) // 2)
34
+ ).resize((512, 512))
35
+
36
+ if pipe is None:
37
+ init_pipeline()
38
+
39
+ control_image = load_image(image)
40
+ result_image = pipe(
41
+ prompt=text.strip(),
42
+ negative_prompt="",
43
+ num_inference_steps=28,
44
+ height=512,
45
+ width=1024,
46
+ guidance_scale=guidance,
47
+ image=control_image,
48
+ guidance_scale_real_i=i_guidance,
49
+ guidance_scale_real_t=t_guidance,
50
+ gemini_prompt=gemini_prompt,
51
+ ).images[0]
52
+
53
+ return result_image
54
+
55
+
56
+ def get_samples():
57
+ sample_list = [
58
+ {
59
+ "image": "assets/wanrong_character.png",
60
+ "text": "A chibi-style girl with pink hair, green eyes, wearing a black and gold ornate dress, dancing gracefully in a flower garden, anime art style with clean and detailed lines.",
61
+ },
62
+ {
63
+ "image": "assets/ben_character_squared.png",
64
+ "text": "A confident green-eye young woman with platinum blonde hair in a high ponytail, wearing an oversized orange jacket and black pants, is striking a dynamic pose, anime-style with sharp details and vibrant colors.",
65
+ },
66
+ {
67
+ "image": "assets/seededit_example.png",
68
+ "text": "an adorable small creature with big round orange eyes, fluffy brown fur, wearing a blue scarf with a golden charm, sitting atop a towering stack of colorful books in the middle of a vibrant futuristic city street with towering buildings and glowing neon signs, soft daylight illuminating the scene, detailed and whimsical 3D style.",
69
+ },
70
+ {
71
+ "image": "assets/action_hero_figure.jpeg",
72
+ "text": "A cartoonish muscular action hero figure with long blue hair and red headband sits on a crowded sidewalk on a Christmas evening, covered in snow and wearing a Christmas hat, holding a sign that reads 'DSD!', dramatic cinematic lighting, close-up view, 3D-rendered in a stylized, vibrant art style.",
73
+ },
74
+ {
75
+ "image": "assets/anime_soldier.jpeg",
76
+ "text": "An adorable cartoon goat soldier sits under a beach umbrella with 'DSD!' written on it, bright teal background with soft lighting, 3D-rendered in a playful and vibrant art style.",
77
+ },
78
+ {
79
+ "image": "assets/goat_logo.jpeg",
80
+ "text": "A shirt with this logo on it.",
81
+ },
82
+ {
83
+ "image": "assets/cartoon_cat.png",
84
+ "text": "A cheerful cartoon orange cat sits under a beach umbrella with 'DSD!' written on it under a sunny sky, simplistic and humorous comic art style.",
85
+ },
86
+ ]
87
+ return [[Image.open(sample["image"]), sample["text"]] for sample in sample_list]
88
+
89
+
90
+ demo = gr.Blocks()
91
+
92
+ with demo:
93
+ gr.Markdown(
94
+ f"""
95
+ <div align="center">
96
+
97
+ ## Diffusion Self-Distillation (beta)
98
+
99
+ <a href="https://primecai.github.io/dsd/" target="_blank"><img src="https://img.shields.io/badge/Project-Website-blue" style="display:inline-block;"></a>
100
+ <a href="https://github.com/primecai/diffusion-self-distillation" target="_blank"><img src="https://img.shields.io/github/stars/primecai/diffusion-self-distillation?label=GitHub%20%E2%98%85&logo=github&color=C8C" style="display:inline-block;"></a>
101
+ <a href="https://huggingface.co/papers/2411.18616" target="_blank"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face%20-Space-yellow" style="display:inline-block;"></a>
102
+ <a href="https://x.com/prime_cai?lang=en" target="_blank"><img src="https://shields.io/twitter/follow/:?label=Subscribe%20for%20updates!" style="display:inline-block;"></a>
103
+
104
+ </div>
105
+ """
106
+ )
107
+
108
+ iface = gr.Interface(
109
+ fn=process_image_and_text,
110
+ inputs=[
111
+ gr.Image(type="pil"),
112
+ gr.Textbox(lines=2, label="text", info="Could be something as simple as 'this character playing soccer'."),
113
+ gr.Checkbox(label="Gemini prompt", value=True, info="Use Gemini to enhance the prompt. This is recommended for most cases, unless you have a specific prompt similar to the examples in mind."),
114
+ gr.Slider(minimum=1.0, maximum=6.0, step=0.5, value=3.5, label="guidance scale (tip: start with 3.5, then gradually increase if the consistency is consistently off)"),
115
+ gr.Slider(minimum=1.0, maximum=2.0, step=0.05, value=1.0, label="real guidance scale for image (tip: increase if the image is not consistent)"),
116
+ gr.Slider(minimum=1.0, maximum=2.0, step=0.05, value=1.0, label="real guidance scale for prompt (tip: increase if the prompt is not consistent)"),
117
+ ],
118
+ outputs=gr.Image(type="pil"),
119
+ examples=get_samples(),
120
+ )
121
+
122
+ if __name__ == "__main__":
123
+ init_pipeline()
124
+ demo.launch(debug=False, share=True)
assets/action_hero_figure.jpeg ADDED
assets/anime_soldier.jpeg ADDED
assets/cartoon_cat.png ADDED
assets/goat_logo.jpeg ADDED
assets/wanrong_character.png ADDED
pipeline.py ADDED
@@ -0,0 +1,1009 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ from transformers import (
21
+ CLIPTextModel,
22
+ CLIPTokenizer,
23
+ T5EncoderModel,
24
+ T5TokenizerFast,
25
+ )
26
+
27
+ from diffusers.image_processor import VaeImageProcessor
28
+ from diffusers.loaders import SD3LoraLoaderMixin
29
+ from diffusers.models.autoencoders import AutoencoderKL
30
+ from diffusers.models.transformers import FluxTransformer2DModel
31
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
32
+ from diffusers.utils import (
33
+ USE_PEFT_BACKEND,
34
+ is_torch_xla_available,
35
+ logging,
36
+ replace_example_docstring,
37
+ scale_lora_layers,
38
+ unscale_lora_layers,
39
+ )
40
+ from diffusers.utils.torch_utils import randn_tensor
41
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
42
+ from recaption import enhance_prompt
43
+
44
+
45
+ if is_torch_xla_available():
46
+ import torch_xla.core.xla_model as xm
47
+
48
+ XLA_AVAILABLE = True
49
+ else:
50
+ XLA_AVAILABLE = False
51
+
52
+
53
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
54
+
55
+ EXAMPLE_DOC_STRING = """
56
+ Examples:
57
+ ```py
58
+ >>> import torch
59
+ >>> from diffusers import FluxPipeline
60
+
61
+ >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
62
+ >>> pipe.to("cuda")
63
+ >>> prompt = "A cat holding a sign that says hello world"
64
+ >>> # Depending on the variant being used, the pipeline call will slightly vary.
65
+ >>> # Refer to the pipeline documentation for more details.
66
+ >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
67
+ >>> image.save("flux.png")
68
+ ```
69
+ """
70
+
71
+
72
+ def calculate_shift(
73
+ image_seq_len,
74
+ base_seq_len: int = 256,
75
+ max_seq_len: int = 4096,
76
+ base_shift: float = 0.5,
77
+ max_shift: float = 1.16,
78
+ ):
79
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
80
+ b = base_shift - m * base_seq_len
81
+ mu = image_seq_len * m + b
82
+ return mu
83
+
84
+
85
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
86
+ def retrieve_timesteps(
87
+ scheduler,
88
+ num_inference_steps: Optional[int] = None,
89
+ device: Optional[Union[str, torch.device]] = None,
90
+ timesteps: Optional[List[int]] = None,
91
+ sigmas: Optional[List[float]] = None,
92
+ **kwargs,
93
+ ):
94
+ """
95
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
96
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
97
+
98
+ Args:
99
+ scheduler (`SchedulerMixin`):
100
+ The scheduler to get timesteps from.
101
+ num_inference_steps (`int`):
102
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
103
+ must be `None`.
104
+ device (`str` or `torch.device`, *optional*):
105
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
106
+ timesteps (`List[int]`, *optional*):
107
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
108
+ `num_inference_steps` and `sigmas` must be `None`.
109
+ sigmas (`List[float]`, *optional*):
110
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
111
+ `num_inference_steps` and `timesteps` must be `None`.
112
+
113
+ Returns:
114
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
115
+ second element is the number of inference steps.
116
+ """
117
+ if timesteps is not None and sigmas is not None:
118
+ raise ValueError(
119
+ "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
120
+ )
121
+ if timesteps is not None:
122
+ accepts_timesteps = "timesteps" in set(
123
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
124
+ )
125
+ if not accepts_timesteps:
126
+ raise ValueError(
127
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
128
+ f" timestep schedules. Please check whether you are using the correct scheduler."
129
+ )
130
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
131
+ timesteps = scheduler.timesteps
132
+ num_inference_steps = len(timesteps)
133
+ elif sigmas is not None:
134
+ accept_sigmas = "sigmas" in set(
135
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
136
+ )
137
+ if not accept_sigmas:
138
+ raise ValueError(
139
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
140
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
141
+ )
142
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
143
+ timesteps = scheduler.timesteps
144
+ num_inference_steps = len(timesteps)
145
+ else:
146
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
147
+ timesteps = scheduler.timesteps
148
+ return timesteps, num_inference_steps
149
+
150
+
151
+ class FluxConditionalPipeline(DiffusionPipeline, SD3LoraLoaderMixin):
152
+ r"""
153
+ The Flux pipeline for text-to-image generation.
154
+
155
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
156
+
157
+ Args:
158
+ transformer ([`FluxTransformer2DModel`]):
159
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
160
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
161
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
162
+ vae ([`AutoencoderKL`]):
163
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
164
+ text_encoder ([`CLIPTextModelWithProjection`]):
165
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
166
+ specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant,
167
+ with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size`
168
+ as its dimension.
169
+ text_encoder_2 ([`CLIPTextModelWithProjection`]):
170
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
171
+ specifically the
172
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
173
+ variant.
174
+ tokenizer (`CLIPTokenizer`):
175
+ Tokenizer of class
176
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
177
+ tokenizer_2 (`CLIPTokenizer`):
178
+ Second Tokenizer of class
179
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
180
+ """
181
+
182
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
183
+ _optional_components = []
184
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
185
+
186
+ def __init__(
187
+ self,
188
+ scheduler: FlowMatchEulerDiscreteScheduler,
189
+ vae: AutoencoderKL,
190
+ text_encoder: CLIPTextModel,
191
+ tokenizer: CLIPTokenizer,
192
+ text_encoder_2: T5EncoderModel,
193
+ tokenizer_2: T5TokenizerFast,
194
+ transformer: FluxTransformer2DModel,
195
+ ):
196
+ super().__init__()
197
+
198
+ self.register_modules(
199
+ vae=vae,
200
+ text_encoder=text_encoder,
201
+ text_encoder_2=text_encoder_2,
202
+ tokenizer=tokenizer,
203
+ tokenizer_2=tokenizer_2,
204
+ transformer=transformer,
205
+ scheduler=scheduler,
206
+ )
207
+ self.vae_scale_factor = (
208
+ 2 ** (len(self.vae.config.block_out_channels))
209
+ if hasattr(self, "vae") and self.vae is not None
210
+ else 16
211
+ )
212
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
213
+ self.tokenizer_max_length = (
214
+ self.tokenizer.model_max_length
215
+ if hasattr(self, "tokenizer") and self.tokenizer is not None
216
+ else 77
217
+ )
218
+ self.default_sample_size = 64
219
+
220
+ def _get_t5_prompt_embeds(
221
+ self,
222
+ prompt: Union[str, List[str]] = None,
223
+ num_images_per_prompt: int = 1,
224
+ max_sequence_length: int = 512,
225
+ device: Optional[torch.device] = None,
226
+ dtype: Optional[torch.dtype] = None,
227
+ ):
228
+ device = device or self._execution_device
229
+ dtype = dtype or self.text_encoder.dtype
230
+
231
+ prompt = [prompt] if isinstance(prompt, str) else prompt
232
+ batch_size = len(prompt)
233
+
234
+ text_inputs = self.tokenizer_2(
235
+ prompt,
236
+ padding="max_length",
237
+ max_length=max_sequence_length,
238
+ truncation=True,
239
+ return_length=False,
240
+ return_overflowing_tokens=False,
241
+ return_tensors="pt",
242
+ )
243
+ prompt_attention_mask = text_inputs.attention_mask
244
+ text_input_ids = text_inputs.input_ids
245
+ untruncated_ids = self.tokenizer_2(
246
+ prompt, padding="longest", return_tensors="pt"
247
+ ).input_ids
248
+
249
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
250
+ text_input_ids, untruncated_ids
251
+ ):
252
+ removed_text = self.tokenizer_2.batch_decode(
253
+ untruncated_ids[:, self.tokenizer_max_length - 1 : -1]
254
+ )
255
+ # logger.warning(
256
+ # "The following part of your input was truncated because `max_sequence_length` is set to "
257
+ # f" {max_sequence_length} tokens: {removed_text}"
258
+ # )
259
+
260
+ prompt_embeds = self.text_encoder_2(
261
+ text_input_ids.to(device), output_hidden_states=False
262
+ )[0]
263
+
264
+ dtype = self.text_encoder_2.dtype
265
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
266
+
267
+ _, seq_len, _ = prompt_embeds.shape
268
+
269
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
270
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
271
+ prompt_embeds = prompt_embeds.view(
272
+ batch_size * num_images_per_prompt, seq_len, -1
273
+ )
274
+
275
+ return prompt_embeds, prompt_attention_mask
276
+
277
+ def _get_clip_prompt_embeds(
278
+ self,
279
+ prompt: Union[str, List[str]],
280
+ num_images_per_prompt: int = 1,
281
+ device: Optional[torch.device] = None,
282
+ ):
283
+ device = device or self._execution_device
284
+
285
+ prompt = [prompt] if isinstance(prompt, str) else prompt
286
+ batch_size = len(prompt)
287
+
288
+ text_inputs = self.tokenizer(
289
+ prompt,
290
+ padding="max_length",
291
+ max_length=self.tokenizer_max_length,
292
+ truncation=True,
293
+ return_overflowing_tokens=False,
294
+ return_length=False,
295
+ return_tensors="pt",
296
+ )
297
+
298
+ text_input_ids = text_inputs.input_ids
299
+ untruncated_ids = self.tokenizer(
300
+ prompt, padding="longest", return_tensors="pt"
301
+ ).input_ids
302
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
303
+ text_input_ids, untruncated_ids
304
+ ):
305
+ removed_text = self.tokenizer.batch_decode(
306
+ untruncated_ids[:, self.tokenizer_max_length - 1 : -1]
307
+ )
308
+ logger.warning(
309
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
310
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
311
+ )
312
+ prompt_embeds = self.text_encoder(
313
+ text_input_ids.to(device), output_hidden_states=False
314
+ )
315
+
316
+ # Use pooled output of CLIPTextModel
317
+ prompt_embeds = prompt_embeds.pooler_output
318
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
319
+
320
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
321
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
322
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
323
+
324
+ return prompt_embeds
325
+
326
+ def encode_prompt(
327
+ self,
328
+ prompt: Union[str, List[str]],
329
+ prompt_2: Union[str, List[str]],
330
+ device: Optional[torch.device] = None,
331
+ num_images_per_prompt: int = 1,
332
+ prompt_embeds: Optional[torch.FloatTensor] = None,
333
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
334
+ max_sequence_length: int = 512,
335
+ lora_scale: Optional[float] = None,
336
+ ):
337
+ r"""
338
+
339
+ Args:
340
+ prompt (`str` or `List[str]`, *optional*):
341
+ prompt to be encoded
342
+ prompt_2 (`str` or `List[str]`, *optional*):
343
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
344
+ used in all text-encoders
345
+ device: (`torch.device`):
346
+ torch device
347
+ num_images_per_prompt (`int`):
348
+ number of images that should be generated per prompt
349
+ prompt_embeds (`torch.FloatTensor`, *optional*):
350
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
351
+ provided, text embeddings will be generated from `prompt` input argument.
352
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
353
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
354
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
355
+ clip_skip (`int`, *optional*):
356
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
357
+ the output of the pre-final layer will be used for computing the prompt embeddings.
358
+ lora_scale (`float`, *optional*):
359
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
360
+ """
361
+ device = device or self._execution_device
362
+
363
+ # set lora scale so that monkey patched LoRA
364
+ # function of text encoder can correctly access it
365
+ if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin):
366
+ self._lora_scale = lora_scale
367
+
368
+ # dynamically adjust the LoRA scale
369
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
370
+ scale_lora_layers(self.text_encoder, lora_scale)
371
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
372
+ scale_lora_layers(self.text_encoder_2, lora_scale)
373
+
374
+ prompt = [prompt] if isinstance(prompt, str) else prompt
375
+ if prompt is not None:
376
+ batch_size = len(prompt)
377
+ else:
378
+ batch_size = prompt_embeds.shape[0]
379
+
380
+ prompt_attention_mask = None
381
+ if prompt_embeds is None:
382
+ prompt_2 = prompt_2 or prompt
383
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
384
+
385
+ # We only use the pooled prompt output from the CLIPTextModel
386
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
387
+ prompt=prompt,
388
+ device=device,
389
+ num_images_per_prompt=num_images_per_prompt,
390
+ )
391
+ prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
392
+ prompt=prompt_2,
393
+ num_images_per_prompt=num_images_per_prompt,
394
+ max_sequence_length=max_sequence_length,
395
+ device=device,
396
+ )
397
+
398
+ if self.text_encoder is not None:
399
+ if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
400
+ # Retrieve the original scale by scaling back the LoRA layers
401
+ unscale_lora_layers(self.text_encoder, lora_scale)
402
+
403
+ if self.text_encoder_2 is not None:
404
+ if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
405
+ # Retrieve the original scale by scaling back the LoRA layers
406
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
407
+
408
+ text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(
409
+ device=device, dtype=prompt_embeds.dtype
410
+ )
411
+
412
+ return prompt_embeds, pooled_prompt_embeds, text_ids, prompt_attention_mask
413
+
414
+ def check_inputs(
415
+ self,
416
+ prompt,
417
+ prompt_2,
418
+ height,
419
+ width,
420
+ prompt_embeds=None,
421
+ pooled_prompt_embeds=None,
422
+ callback_on_step_end_tensor_inputs=None,
423
+ max_sequence_length=None,
424
+ image=None
425
+ ):
426
+ if height % 8 != 0 or width % 8 != 0:
427
+ raise ValueError(
428
+ f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
429
+ )
430
+
431
+ if callback_on_step_end_tensor_inputs is not None and not all(
432
+ k in self._callback_tensor_inputs
433
+ for k in callback_on_step_end_tensor_inputs
434
+ ):
435
+ raise ValueError(
436
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
437
+ )
438
+
439
+ if prompt is not None and prompt_embeds is not None:
440
+ raise ValueError(
441
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
442
+ " only forward one of the two."
443
+ )
444
+ elif prompt_2 is not None and prompt_embeds is not None:
445
+ raise ValueError(
446
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
447
+ " only forward one of the two."
448
+ )
449
+ elif prompt is None and prompt_embeds is None:
450
+ raise ValueError(
451
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
452
+ )
453
+ elif prompt is not None and (
454
+ not isinstance(prompt, str) and not isinstance(prompt, list)
455
+ ):
456
+ raise ValueError(
457
+ f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
458
+ )
459
+ elif prompt_2 is not None and (
460
+ not isinstance(prompt_2, str) and not isinstance(prompt_2, list)
461
+ ):
462
+ raise ValueError(
463
+ f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}"
464
+ )
465
+
466
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
467
+ raise ValueError(
468
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
469
+ )
470
+
471
+ if max_sequence_length is not None and max_sequence_length > 512:
472
+ raise ValueError(
473
+ f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}"
474
+ )
475
+
476
+ @staticmethod
477
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
478
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
479
+ latent_image_ids[..., 1] = (
480
+ latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
481
+ )
482
+ latent_image_ids[..., 2] = (
483
+ latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
484
+ )
485
+
486
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = (
487
+ latent_image_ids.shape
488
+ )
489
+
490
+ latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
491
+ latent_image_ids = latent_image_ids.reshape(
492
+ batch_size,
493
+ latent_image_id_height * latent_image_id_width,
494
+ latent_image_id_channels,
495
+ )
496
+
497
+ return latent_image_ids.to(device=device, dtype=dtype)
498
+
499
+ @staticmethod
500
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
501
+ latents = latents.view(
502
+ batch_size, num_channels_latents, height // 2, 2, width // 2, 2
503
+ )
504
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
505
+ latents = latents.reshape(
506
+ batch_size, (height // 2) * (width // 2), num_channels_latents * 4
507
+ )
508
+
509
+ return latents
510
+
511
+ @staticmethod
512
+ def _unpack_latents(latents, height, width, vae_scale_factor):
513
+ batch_size, num_patches, channels = latents.shape
514
+
515
+ height = height // vae_scale_factor
516
+ width = width // vae_scale_factor
517
+
518
+ latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
519
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
520
+
521
+ latents = latents.reshape(
522
+ batch_size, channels // (2 * 2), height * 2, width * 2
523
+ )
524
+
525
+ return latents
526
+
527
+ def prepare_latents(
528
+ self,
529
+ batch_size,
530
+ num_channels_latents,
531
+ height,
532
+ width,
533
+ dtype,
534
+ device,
535
+ generator,
536
+ latents=None,
537
+ ):
538
+ height = 2 * (int(height) // self.vae_scale_factor)
539
+ width = 2 * (int(width) // self.vae_scale_factor)
540
+
541
+ shape = (batch_size, num_channels_latents, height, width)
542
+
543
+ if latents is not None:
544
+ latent_image_ids = self._prepare_latent_image_ids(
545
+ batch_size, height, width, device, dtype
546
+ )
547
+ return latents.to(device=device, dtype=dtype), latent_image_ids
548
+
549
+ if isinstance(generator, list) and len(generator) != batch_size:
550
+ raise ValueError(
551
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
552
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
553
+ )
554
+
555
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
556
+ # _pack_latents(latents, batch_size, num_channels_latents, height, width)
557
+ latents = self._pack_latents(
558
+ latents, batch_size, num_channels_latents, height, width
559
+ )
560
+
561
+ latent_image_ids = self._prepare_latent_image_ids(
562
+ batch_size, height, width, device, dtype
563
+ )
564
+
565
+ return latents, latent_image_ids
566
+
567
+ @property
568
+ def guidance_scale(self):
569
+ return self._guidance_scale
570
+
571
+ @property
572
+ def joint_attention_kwargs(self):
573
+ return self._joint_attention_kwargs
574
+
575
+ @property
576
+ def num_timesteps(self):
577
+ return self._num_timesteps
578
+
579
+ @property
580
+ def interrupt(self):
581
+ return self._interrupt
582
+
583
+ @torch.no_grad()
584
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
585
+ def __call__(
586
+ self,
587
+ prompt: Union[str, List[str]] = None,
588
+ prompt_mask: Optional[Union[torch.FloatTensor, List[torch.FloatTensor]]] = None,
589
+ negative_mask: Optional[
590
+ Union[torch.FloatTensor, List[torch.FloatTensor]]
591
+ ] = None,
592
+ prompt_2: Optional[Union[str, List[str]]] = None,
593
+ height: Optional[int] = None,
594
+ width: Optional[int] = None,
595
+ num_inference_steps: int = 28,
596
+ timesteps: List[int] = None,
597
+ guidance_scale: float = 3.5,
598
+ num_images_per_prompt: Optional[int] = 1,
599
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
600
+ latents: Optional[torch.FloatTensor] = None,
601
+ prompt_embeds: Optional[torch.FloatTensor] = None,
602
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
603
+ output_type: Optional[str] = "pil",
604
+ return_dict: bool = True,
605
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
606
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
607
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
608
+ max_sequence_length: int = 512,
609
+ guidance_scale_real_i: float = 1.0,
610
+ guidance_scale_real_t: float = 1.0,
611
+ negative_prompt: Union[str, List[str]] = "",
612
+ negative_prompt_2: Union[str, List[str]] = "",
613
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
614
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
615
+ no_cfg_until_timestep: int = 2,
616
+ image: Optional[torch.FloatTensor] = None,
617
+ image_path = None,
618
+ cut_output = True,
619
+ gemini_prompt = True
620
+ ):
621
+ r"""
622
+ Function invoked when calling the pipeline for generation.
623
+
624
+ Args:
625
+ prompt (`str` or `List[str]`, *optional*):
626
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
627
+ instead.
628
+ prompt_mask (`str` or `List[str]`, *optional*):
629
+ The prompt or prompts to be used as a mask for the image generation. If not defined, `prompt` is used
630
+ instead.
631
+ prompt_2 (`str` or `List[str]`, *optional*):
632
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
633
+ will be used instead
634
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
635
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
636
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
637
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
638
+ num_inference_steps (`int`, *optional*, defaults to 50):
639
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
640
+ expense of slower inference.
641
+ timesteps (`List[int]`, *optional*):
642
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
643
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
644
+ passed will be used. Must be in descending order.
645
+ guidance_scale (`float`, *optional*, defaults to 7.0):
646
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
647
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
648
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
649
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
650
+ usually at the expense of lower image quality.
651
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
652
+ The number of images to generate per prompt.
653
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
654
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
655
+ to make generation deterministic.
656
+ latents (`torch.FloatTensor`, *optional*):
657
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
658
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
659
+ tensor will ge generated by sampling using the supplied random `generator`.
660
+ prompt_embeds (`torch.FloatTensor`, *optional*):
661
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
662
+ provided, text embeddings will be generated from `prompt` input argument.
663
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
664
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
665
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
666
+ output_type (`str`, *optional*, defaults to `"pil"`):
667
+ The output format of the generate image. Choose between
668
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
669
+ return_dict (`bool`, *optional*, defaults to `True`):
670
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
671
+ joint_attention_kwargs (`dict`, *optional*):
672
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
673
+ `self.processor` in
674
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
675
+ callback_on_step_end (`Callable`, *optional*):
676
+ A function that calls at the end of each denoising steps during the inference. The function is called
677
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
678
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
679
+ `callback_on_step_end_tensor_inputs`.
680
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
681
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
682
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
683
+ `._callback_tensor_inputs` attribute of your pipeline class.
684
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
685
+
686
+ Examples:
687
+
688
+ Returns:
689
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
690
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
691
+ images.
692
+ """
693
+
694
+ height = height or self.default_sample_size * self.vae_scale_factor
695
+ width = width or self.default_sample_size * self.vae_scale_factor
696
+
697
+ # 1. Check inputs. Raise error if not correct
698
+ self.check_inputs(
699
+ prompt,
700
+ prompt_2,
701
+ height,
702
+ width,
703
+ prompt_embeds=prompt_embeds,
704
+ pooled_prompt_embeds=pooled_prompt_embeds,
705
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
706
+ max_sequence_length=max_sequence_length,
707
+ )
708
+
709
+ self._guidance_scale = guidance_scale
710
+ self._guidance_scale_real_i = guidance_scale_real_i
711
+ self._guidance_scale_real_t = guidance_scale_real_t
712
+ self._joint_attention_kwargs = joint_attention_kwargs
713
+ self._interrupt = False
714
+
715
+ # 2. Define call parameters
716
+ if prompt is not None and isinstance(prompt, str):
717
+ batch_size = 1
718
+ elif prompt is not None and isinstance(prompt, list):
719
+ batch_size = len(prompt)
720
+ else:
721
+ batch_size = prompt_embeds.shape[0]
722
+
723
+ device = self._execution_device
724
+
725
+ prompt = enhance_prompt(image, prompt)
726
+ # if gemini_prompt:
727
+ # while True:
728
+ # try:
729
+ # prompt = enhance_prompt(image, prompt)
730
+ # break # Exit the loop if the function succeeds
731
+ # except Exception as e:
732
+ # print(f"An error occurred: {e}")
733
+
734
+ lora_scale = (
735
+ self.joint_attention_kwargs.get("scale", None)
736
+ if self.joint_attention_kwargs is not None
737
+ else None
738
+ )
739
+ (
740
+ prompt_embeds,
741
+ pooled_prompt_embeds,
742
+ text_ids,
743
+ _,
744
+ ) = self.encode_prompt(
745
+ prompt=prompt,
746
+ prompt_2=prompt_2,
747
+ prompt_embeds=prompt_embeds,
748
+ pooled_prompt_embeds=pooled_prompt_embeds,
749
+ device=device,
750
+ num_images_per_prompt=num_images_per_prompt,
751
+ max_sequence_length=max_sequence_length,
752
+ lora_scale=lora_scale,
753
+ )
754
+
755
+ if negative_prompt_2 == "" and negative_prompt != "":
756
+ negative_prompt_2 = negative_prompt
757
+
758
+ negative_text_ids = text_ids
759
+ if guidance_scale_real_i > 1.0 and (
760
+ negative_prompt_embeds is None or negative_pooled_prompt_embeds is None
761
+ ):
762
+ (
763
+ negative_prompt_embeds,
764
+ negative_pooled_prompt_embeds,
765
+ negative_text_ids,
766
+ _,
767
+ ) = self.encode_prompt(
768
+ prompt=negative_prompt,
769
+ prompt_2=negative_prompt_2,
770
+ prompt_embeds=None,
771
+ pooled_prompt_embeds=None,
772
+ device=device,
773
+ num_images_per_prompt=num_images_per_prompt,
774
+ max_sequence_length=max_sequence_length,
775
+ lora_scale=lora_scale,
776
+ )
777
+
778
+ # 3. Preprocess image
779
+ image = self.image_processor.preprocess(image)
780
+ # image = image[..., :512]
781
+ image = torch.nn.functional.interpolate(image, size=512)
782
+ black_image = torch.full((1, 3, 512, 512), -1.0)
783
+ image = torch.cat([image, black_image], dim=3)
784
+ latents_cond = self.vae.encode(image.to(self.vae.dtype).to(self.vae.device)).latent_dist.sample()
785
+ latents_cond = (
786
+ latents_cond - self.vae.config.shift_factor
787
+ ) * self.vae.config.scaling_factor
788
+ # from customization.utils import mask_random_quadrants
789
+ # latent_cond = mask_random_quadrants(latent_cond)
790
+
791
+ # 4. Prepare latent variables
792
+ num_channels_latents = self.transformer.config.in_channels // 4
793
+ latents, latent_image_ids = self.prepare_latents(
794
+ batch_size * num_images_per_prompt,
795
+ num_channels_latents,
796
+ height,
797
+ width,
798
+ prompt_embeds.dtype,
799
+ device,
800
+ generator,
801
+ latents,
802
+ )
803
+ # _pack_latents(latents, batch_size, num_channels_latents, height, width)
804
+ latents_cond = self._pack_latents(
805
+ latents_cond, batch_size, num_channels_latents, 2 * (int(height) // self.vae_scale_factor), 2 * (int(width) // self.vae_scale_factor)
806
+ )
807
+
808
+ # 5. Prepare timesteps
809
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
810
+ image_seq_len = latents.shape[1]
811
+ mu = calculate_shift(
812
+ image_seq_len,
813
+ self.scheduler.config.base_image_seq_len,
814
+ self.scheduler.config.max_image_seq_len,
815
+ self.scheduler.config.base_shift,
816
+ self.scheduler.config.max_shift,
817
+ )
818
+ timesteps, num_inference_steps = retrieve_timesteps(
819
+ self.scheduler,
820
+ num_inference_steps,
821
+ device,
822
+ timesteps,
823
+ sigmas,
824
+ mu=mu,
825
+ )
826
+ num_warmup_steps = max(
827
+ len(timesteps) - num_inference_steps * self.scheduler.order, 0
828
+ )
829
+ self._num_timesteps = len(timesteps)
830
+
831
+ latents = latents.to(self.transformer.device)
832
+ latent_image_ids = latent_image_ids.to(self.transformer.device)
833
+ timesteps = timesteps.to(self.transformer.device)
834
+ text_ids = text_ids.to(self.transformer.device)
835
+
836
+ # 6. Denoising loop
837
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
838
+ for i, t in enumerate(timesteps):
839
+ if self.interrupt:
840
+ continue
841
+
842
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
843
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
844
+
845
+ # handle guidance
846
+ if self.transformer.config.guidance_embeds:
847
+ guidance = torch.tensor(
848
+ [guidance_scale], device=self.transformer.device
849
+ )
850
+ guidance = guidance.expand(latents.shape[0])
851
+ else:
852
+ guidance = None
853
+
854
+ extra_transformer_args = {}
855
+ if prompt_mask is not None:
856
+ extra_transformer_args["attention_mask"] = prompt_mask.to(
857
+ device=self.transformer.device
858
+ )
859
+
860
+ noise_pred = self.transformer(
861
+ hidden_states=latents.to(
862
+ device=self.transformer.device, dtype=self.transformer.dtype
863
+ ),
864
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
865
+ timestep=timestep / 1000,
866
+ guidance=guidance,
867
+ pooled_projections=pooled_prompt_embeds.to(
868
+ device=self.transformer.device, dtype=self.transformer.dtype
869
+ ),
870
+ encoder_hidden_states=prompt_embeds.to(
871
+ device=self.transformer.device, dtype=self.transformer.dtype
872
+ ),
873
+ txt_ids=text_ids,
874
+ img_ids=latent_image_ids,
875
+ joint_attention_kwargs=self.joint_attention_kwargs,
876
+ return_dict=False,
877
+ condition_hidden_states=latents_cond.to(
878
+ device=self.transformer.device, dtype=self.transformer.dtype
879
+ ),
880
+ **extra_transformer_args,
881
+ )[0]
882
+
883
+ # TODO optionally use batch prediction to speed this up.
884
+ if guidance_scale_real_i > 1.0 and i >= no_cfg_until_timestep:
885
+ noise_pred_uncond = self.transformer(
886
+ hidden_states=latents.to(
887
+ device=self.transformer.device, dtype=self.transformer.dtype
888
+ ),
889
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
890
+ timestep=timestep / 1000,
891
+ guidance=guidance,
892
+ pooled_projections=negative_pooled_prompt_embeds.to(
893
+ device=self.transformer.device, dtype=self.transformer.dtype
894
+ ),
895
+ encoder_hidden_states=negative_prompt_embeds.to(
896
+ device=self.transformer.device, dtype=self.transformer.dtype
897
+ ),
898
+ txt_ids=negative_text_ids.to(device=self.transformer.device),
899
+ img_ids=latent_image_ids.to(device=self.transformer.device),
900
+ joint_attention_kwargs=self.joint_attention_kwargs,
901
+ return_dict=False,
902
+ condition_hidden_states=torch.zeros_like(latents_cond).to(
903
+ device=self.transformer.device, dtype=self.transformer.dtype
904
+ ),
905
+ )[0]
906
+ noise_pred_uncond_t = self.transformer(
907
+ hidden_states=latents.to(
908
+ device=self.transformer.device, dtype=self.transformer.dtype
909
+ ),
910
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
911
+ timestep=timestep / 1000,
912
+ guidance=guidance,
913
+ pooled_projections=negative_pooled_prompt_embeds.to(
914
+ device=self.transformer.device, dtype=self.transformer.dtype
915
+ ),
916
+ encoder_hidden_states=negative_prompt_embeds.to(
917
+ device=self.transformer.device, dtype=self.transformer.dtype
918
+ ),
919
+ txt_ids=negative_text_ids.to(device=self.transformer.device),
920
+ img_ids=latent_image_ids.to(device=self.transformer.device),
921
+ joint_attention_kwargs=self.joint_attention_kwargs,
922
+ return_dict=False,
923
+ condition_hidden_states=latents_cond.to(
924
+ device=self.transformer.device, dtype=self.transformer.dtype
925
+ ),
926
+ )[0]
927
+
928
+ # noise_pred = noise_pred_uncond + guidance_scale_real * (
929
+ # noise_pred - noise_pred_uncond
930
+ # )
931
+ noise_pred = noise_pred_uncond + \
932
+ guidance_scale_real_i * (noise_pred_uncond_t - noise_pred_uncond) + \
933
+ guidance_scale_real_t * (noise_pred - noise_pred_uncond_t)
934
+
935
+ # compute the previous noisy sample x_t -> x_t-1
936
+ latents_dtype = latents.dtype
937
+ latents = self.scheduler.step(
938
+ noise_pred, t, latents, return_dict=False
939
+ )[0]
940
+
941
+ if latents.dtype != latents_dtype:
942
+ if torch.backends.mps.is_available():
943
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
944
+ latents = latents.to(latents_dtype)
945
+
946
+ if callback_on_step_end is not None:
947
+ callback_kwargs = {}
948
+ for k in callback_on_step_end_tensor_inputs:
949
+ callback_kwargs[k] = locals()[k]
950
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
951
+
952
+ latents = callback_outputs.pop("latents", latents)
953
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
954
+
955
+ # call the callback, if provided
956
+ if i == len(timesteps) - 1 or (
957
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
958
+ ):
959
+ progress_bar.update()
960
+
961
+ if XLA_AVAILABLE:
962
+ xm.mark_step()
963
+
964
+ if output_type == "latent":
965
+ image = latents
966
+
967
+ else:
968
+ latents = self._unpack_latents(
969
+ latents, height, width, self.vae_scale_factor
970
+ )
971
+ latents = (
972
+ latents / self.vae.config.scaling_factor
973
+ ) + self.vae.config.shift_factor
974
+
975
+ image = self.vae.decode(
976
+ latents.to(device=self.vae.device, dtype=self.vae.dtype),
977
+ return_dict=False,
978
+ )[0]
979
+ if cut_output:
980
+ image = image[..., 512:]
981
+ image = self.image_processor.postprocess(image, output_type=output_type)
982
+
983
+ # Offload all models
984
+ self.maybe_free_model_hooks()
985
+
986
+ if not return_dict:
987
+ return (image,)
988
+
989
+ return FluxPipelineOutput(images=image)
990
+
991
+
992
+ from dataclasses import dataclass
993
+ from typing import List, Union
994
+ import PIL.Image
995
+ from diffusers.utils import BaseOutput
996
+
997
+
998
+ @dataclass
999
+ class FluxPipelineOutput(BaseOutput):
1000
+ """
1001
+ Output class for Stable Diffusion pipelines.
1002
+
1003
+ Args:
1004
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
1005
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
1006
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
1007
+ """
1008
+
1009
+ images: Union[List[PIL.Image.Image], np.ndarray]
recaption.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from google import genai
2
+ import os
3
+
4
+
5
+ def enhance_prompt(image, prompt):
6
+ input_caption_prompt = (
7
+ "Please provide a prompt for a Diffusion Model text-to-image generative model for the image I will give you. "
8
+ "The prompt should be a detailed description of the image, especially the main subject (i.e. the main character/asset/item), the environment, the pose, the lighting, the camera view, the style etc."
9
+ "The prompt should be detailed enough to generate the target image. "
10
+ "The prompt should be short and precise, in one-line format, and does not exceed 77 tokens."
11
+ "The prompt should be individually coherent as a description of the image."
12
+ )
13
+
14
+ caption_model = genai.Client(
15
+ vertexai=False, api_key=os.environ["GOOGLE_API_KEY"]
16
+ )
17
+ input_image_prompt = caption_model.models.generate_content(
18
+ model='gemini-1.5-flash', contents=[input_caption_prompt, image]).text
19
+ input_image_prompt = input_image_prompt.replace('\r', '').replace('\n', '')
20
+
21
+ enhance_instruction = "Enhance this input text prompt: '"
22
+ enhance_instruction += prompt
23
+ enhance_instruction += "'. Please extract other details, especially description of the main subject from the following reference prompt: '"
24
+ enhance_instruction += input_image_prompt
25
+ enhance_instruction += "'. Please keep the details that are mentioned in the input prompt, and enhance the rest. "
26
+ enhance_instruction += "Response with only the enhanced prompt. "
27
+ enhance_instruction += "The enhanced prompt should be short and precise, in one-line format, and does not exceed 77 tokens."
28
+ enhanced_prompt = caption_model.models.generate_content(
29
+ model='gemini-1.5-flash', contents=[enhance_instruction]).text.replace('\r', '').replace('\n', '')
30
+ print("input_image_prompt: ", input_image_prompt)
31
+ print("prompt: ", prompt)
32
+ print("enhanced_prompt: ", enhanced_prompt)
33
+ return enhanced_prompt
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ diffusers
3
+ transformers
4
+ sentencepiece
5
+ accelerate
6
+ google-genai
7
+ Pillow
8
+ protobuf
9
+ peft
10
+ xformers
transformer.py ADDED
@@ -0,0 +1,565 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import Any, Dict, Optional, Tuple, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+
23
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
24
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
25
+ from diffusers.models.attention import FeedForward
26
+ from diffusers.models.attention_processor import (
27
+ Attention,
28
+ AttentionProcessor,
29
+ FluxAttnProcessor2_0,
30
+ FusedFluxAttnProcessor2_0,
31
+ )
32
+ from diffusers.models.modeling_utils import ModelMixin
33
+ from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
34
+ from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
35
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
36
+ from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
37
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
38
+
39
+
40
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
41
+
42
+
43
+ @maybe_allow_in_graph
44
+ class FluxSingleTransformerBlock(nn.Module):
45
+ r"""
46
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
47
+
48
+ Reference: https://arxiv.org/abs/2403.03206
49
+
50
+ Parameters:
51
+ dim (`int`): The number of channels in the input and output.
52
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
53
+ attention_head_dim (`int`): The number of channels in each head.
54
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
55
+ processing of `context` conditions.
56
+ """
57
+
58
+ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
59
+ super().__init__()
60
+ self.mlp_hidden_dim = int(dim * mlp_ratio)
61
+
62
+ self.norm = AdaLayerNormZeroSingle(dim)
63
+ self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
64
+ self.act_mlp = nn.GELU(approximate="tanh")
65
+ self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
66
+
67
+ processor = FluxAttnProcessor2_0()
68
+ self.attn = Attention(
69
+ query_dim=dim,
70
+ cross_attention_dim=None,
71
+ dim_head=attention_head_dim,
72
+ heads=num_attention_heads,
73
+ out_dim=dim,
74
+ bias=True,
75
+ processor=processor,
76
+ qk_norm="rms_norm",
77
+ eps=1e-6,
78
+ pre_only=True,
79
+ )
80
+
81
+ def forward(
82
+ self,
83
+ hidden_states: torch.FloatTensor,
84
+ temb: torch.FloatTensor,
85
+ image_rotary_emb=None,
86
+ ):
87
+ residual = hidden_states
88
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
89
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
90
+
91
+ attn_output = self.attn(
92
+ hidden_states=norm_hidden_states,
93
+ image_rotary_emb=image_rotary_emb,
94
+ )
95
+
96
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
97
+ gate = gate.unsqueeze(1)
98
+ hidden_states = gate * self.proj_out(hidden_states)
99
+ hidden_states = residual + hidden_states
100
+ if hidden_states.dtype == torch.float16:
101
+ hidden_states = hidden_states.clip(-65504, 65504)
102
+
103
+ return hidden_states
104
+
105
+
106
+ @maybe_allow_in_graph
107
+ class FluxTransformerBlock(nn.Module):
108
+ r"""
109
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
110
+
111
+ Reference: https://arxiv.org/abs/2403.03206
112
+
113
+ Parameters:
114
+ dim (`int`): The number of channels in the input and output.
115
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
116
+ attention_head_dim (`int`): The number of channels in each head.
117
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
118
+ processing of `context` conditions.
119
+ """
120
+
121
+ def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6):
122
+ super().__init__()
123
+
124
+ self.norm1 = AdaLayerNormZero(dim)
125
+
126
+ self.norm1_context = AdaLayerNormZero(dim)
127
+
128
+ if hasattr(F, "scaled_dot_product_attention"):
129
+ processor = FluxAttnProcessor2_0()
130
+ else:
131
+ raise ValueError(
132
+ "The current PyTorch version does not support the `scaled_dot_product_attention` function."
133
+ )
134
+ self.attn = Attention(
135
+ query_dim=dim,
136
+ cross_attention_dim=None,
137
+ added_kv_proj_dim=dim,
138
+ dim_head=attention_head_dim,
139
+ heads=num_attention_heads,
140
+ out_dim=dim,
141
+ context_pre_only=False,
142
+ bias=True,
143
+ processor=processor,
144
+ qk_norm=qk_norm,
145
+ eps=eps,
146
+ )
147
+
148
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
149
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
150
+
151
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
152
+ self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
153
+
154
+ # let chunk size default to None
155
+ self._chunk_size = None
156
+ self._chunk_dim = 0
157
+
158
+ def forward(
159
+ self,
160
+ hidden_states: torch.FloatTensor,
161
+ encoder_hidden_states: torch.FloatTensor,
162
+ temb: torch.FloatTensor,
163
+ image_rotary_emb=None,
164
+ ):
165
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
166
+
167
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
168
+ encoder_hidden_states, emb=temb
169
+ )
170
+
171
+ # Attention.
172
+ attn_output, context_attn_output = self.attn(
173
+ hidden_states=norm_hidden_states,
174
+ encoder_hidden_states=norm_encoder_hidden_states,
175
+ image_rotary_emb=image_rotary_emb,
176
+ )
177
+
178
+ # Process attention outputs for the `hidden_states`.
179
+ attn_output = gate_msa.unsqueeze(1) * attn_output
180
+ hidden_states = hidden_states + attn_output
181
+
182
+ norm_hidden_states = self.norm2(hidden_states)
183
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
184
+
185
+ ff_output = self.ff(norm_hidden_states)
186
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
187
+
188
+ hidden_states = hidden_states + ff_output
189
+
190
+ # Process attention outputs for the `encoder_hidden_states`.
191
+
192
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
193
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
194
+
195
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
196
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
197
+
198
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
199
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
200
+ if encoder_hidden_states.dtype == torch.float16:
201
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
202
+
203
+ return encoder_hidden_states, hidden_states
204
+
205
+
206
+ class FluxTransformer2DConditionalModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
207
+ """
208
+ The Transformer model introduced in Flux.
209
+
210
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
211
+
212
+ Parameters:
213
+ patch_size (`int`): Patch size to turn the input data into small patches.
214
+ in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
215
+ num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
216
+ num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
217
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
218
+ num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
219
+ joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
220
+ pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
221
+ guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
222
+ """
223
+
224
+ _supports_gradient_checkpointing = True
225
+ _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
226
+
227
+ @register_to_config
228
+ def __init__(
229
+ self,
230
+ patch_size: int = 1,
231
+ in_channels: int = 64,
232
+ num_layers: int = 19,
233
+ num_single_layers: int = 38,
234
+ attention_head_dim: int = 128,
235
+ num_attention_heads: int = 24,
236
+ joint_attention_dim: int = 4096,
237
+ pooled_projection_dim: int = 768,
238
+ guidance_embeds: bool = False,
239
+ axes_dims_rope: Tuple[int] = (16, 56, 56),
240
+ ):
241
+ super().__init__()
242
+ self.out_channels = in_channels
243
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
244
+
245
+ self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
246
+
247
+ text_time_guidance_cls = (
248
+ CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
249
+ )
250
+ self.time_text_embed = text_time_guidance_cls(
251
+ embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
252
+ )
253
+
254
+ self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
255
+ self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
256
+ self.c_embedder = zero_module(torch.nn.Linear(self.config.in_channels, self.inner_dim))
257
+
258
+ self.transformer_blocks = nn.ModuleList(
259
+ [
260
+ FluxTransformerBlock(
261
+ dim=self.inner_dim,
262
+ num_attention_heads=self.config.num_attention_heads,
263
+ attention_head_dim=self.config.attention_head_dim,
264
+ )
265
+ for i in range(self.config.num_layers)
266
+ ]
267
+ )
268
+
269
+ self.single_transformer_blocks = nn.ModuleList(
270
+ [
271
+ FluxSingleTransformerBlock(
272
+ dim=self.inner_dim,
273
+ num_attention_heads=self.config.num_attention_heads,
274
+ attention_head_dim=self.config.attention_head_dim,
275
+ )
276
+ for i in range(self.config.num_single_layers)
277
+ ]
278
+ )
279
+
280
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
281
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
282
+
283
+ self.gradient_checkpointing = False
284
+
285
+ @property
286
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
287
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
288
+ r"""
289
+ Returns:
290
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
291
+ indexed by its weight name.
292
+ """
293
+ # set recursively
294
+ processors = {}
295
+
296
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
297
+ if hasattr(module, "get_processor"):
298
+ processors[f"{name}.processor"] = module.get_processor()
299
+
300
+ for sub_name, child in module.named_children():
301
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
302
+
303
+ return processors
304
+
305
+ for name, module in self.named_children():
306
+ fn_recursive_add_processors(name, module, processors)
307
+
308
+ return processors
309
+
310
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
311
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
312
+ r"""
313
+ Sets the attention processor to use to compute attention.
314
+
315
+ Parameters:
316
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
317
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
318
+ for **all** `Attention` layers.
319
+
320
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
321
+ processor. This is strongly recommended when setting trainable attention processors.
322
+
323
+ """
324
+ count = len(self.attn_processors.keys())
325
+
326
+ if isinstance(processor, dict) and len(processor) != count:
327
+ raise ValueError(
328
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
329
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
330
+ )
331
+
332
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
333
+ if hasattr(module, "set_processor"):
334
+ if not isinstance(processor, dict):
335
+ module.set_processor(processor)
336
+ else:
337
+ module.set_processor(processor.pop(f"{name}.processor"))
338
+
339
+ for sub_name, child in module.named_children():
340
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
341
+
342
+ for name, module in self.named_children():
343
+ fn_recursive_attn_processor(name, module, processor)
344
+
345
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
346
+ def fuse_qkv_projections(self):
347
+ """
348
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
349
+ are fused. For cross-attention modules, key and value projection matrices are fused.
350
+
351
+ <Tip warning={true}>
352
+
353
+ This API is 🧪 experimental.
354
+
355
+ </Tip>
356
+ """
357
+ self.original_attn_processors = None
358
+
359
+ for _, attn_processor in self.attn_processors.items():
360
+ if "Added" in str(attn_processor.__class__.__name__):
361
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
362
+
363
+ self.original_attn_processors = self.attn_processors
364
+
365
+ for module in self.modules():
366
+ if isinstance(module, Attention):
367
+ module.fuse_projections(fuse=True)
368
+
369
+ self.set_attn_processor(FusedFluxAttnProcessor2_0())
370
+
371
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
372
+ def unfuse_qkv_projections(self):
373
+ """Disables the fused QKV projection if enabled.
374
+
375
+ <Tip warning={true}>
376
+
377
+ This API is 🧪 experimental.
378
+
379
+ </Tip>
380
+
381
+ """
382
+ if self.original_attn_processors is not None:
383
+ self.set_attn_processor(self.original_attn_processors)
384
+
385
+ def _set_gradient_checkpointing(self, module, value=False):
386
+ if hasattr(module, "gradient_checkpointing"):
387
+ module.gradient_checkpointing = value
388
+
389
+ def forward(
390
+ self,
391
+ hidden_states: torch.Tensor,
392
+ encoder_hidden_states: torch.Tensor = None,
393
+ pooled_projections: torch.Tensor = None,
394
+ timestep: torch.LongTensor = None,
395
+ img_ids: torch.Tensor = None,
396
+ txt_ids: torch.Tensor = None,
397
+ guidance: torch.Tensor = None,
398
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
399
+ controlnet_block_samples=None,
400
+ controlnet_single_block_samples=None,
401
+ condition_hidden_states: torch.Tensor = None,
402
+ return_dict: bool = True,
403
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
404
+ """
405
+ The [`FluxTransformer2DModel`] forward method.
406
+
407
+ Args:
408
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
409
+ Input `hidden_states`.
410
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
411
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
412
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
413
+ from the embeddings of input conditions.
414
+ timestep ( `torch.LongTensor`):
415
+ Used to indicate denoising step.
416
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
417
+ A list of tensors that if specified are added to the residuals of transformer blocks.
418
+ joint_attention_kwargs (`dict`, *optional*):
419
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
420
+ `self.processor` in
421
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
422
+ return_dict (`bool`, *optional*, defaults to `True`):
423
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
424
+ tuple.
425
+
426
+ Returns:
427
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
428
+ `tuple` where the first element is the sample tensor.
429
+ """
430
+ if joint_attention_kwargs is not None:
431
+ joint_attention_kwargs = joint_attention_kwargs.copy()
432
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
433
+ else:
434
+ lora_scale = 1.0
435
+
436
+ if USE_PEFT_BACKEND:
437
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
438
+ scale_lora_layers(self, lora_scale)
439
+ else:
440
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
441
+ logger.warning(
442
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
443
+ )
444
+ hidden_states = self.x_embedder(hidden_states) + self.c_embedder(condition_hidden_states)
445
+
446
+ timestep = timestep.to(hidden_states.dtype) * 1000
447
+ if guidance is not None:
448
+ guidance = guidance.to(hidden_states.dtype) * 1000
449
+ else:
450
+ guidance = None
451
+ temb = (
452
+ self.time_text_embed(timestep, pooled_projections)
453
+ if guidance is None
454
+ else self.time_text_embed(timestep, guidance, pooled_projections)
455
+ )
456
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
457
+
458
+ if txt_ids.ndim == 3:
459
+ # logger.warning(
460
+ # "Passing `txt_ids` 3d torch.Tensor is deprecated."
461
+ # "Please remove the batch dimension and pass it as a 2d torch Tensor"
462
+ # )
463
+ txt_ids = txt_ids[0]
464
+ if img_ids.ndim == 3:
465
+ # logger.warning(
466
+ # "Passing `img_ids` 3d torch.Tensor is deprecated."
467
+ # "Please remove the batch dimension and pass it as a 2d torch Tensor"
468
+ # )
469
+ img_ids = img_ids[0]
470
+ ids = torch.cat((txt_ids, img_ids), dim=0)
471
+ image_rotary_emb = self.pos_embed(ids)
472
+
473
+ for index_block, block in enumerate(self.transformer_blocks):
474
+ if self.training and self.gradient_checkpointing:
475
+
476
+ def create_custom_forward(module, return_dict=None):
477
+ def custom_forward(*inputs):
478
+ if return_dict is not None:
479
+ return module(*inputs, return_dict=return_dict)
480
+ else:
481
+ return module(*inputs)
482
+
483
+ return custom_forward
484
+
485
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
486
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
487
+ create_custom_forward(block),
488
+ hidden_states,
489
+ encoder_hidden_states,
490
+ temb,
491
+ image_rotary_emb,
492
+ **ckpt_kwargs,
493
+ )
494
+
495
+ else:
496
+ encoder_hidden_states, hidden_states = block(
497
+ hidden_states=hidden_states,
498
+ encoder_hidden_states=encoder_hidden_states,
499
+ temb=temb,
500
+ image_rotary_emb=image_rotary_emb,
501
+ )
502
+
503
+ # controlnet residual
504
+ if controlnet_block_samples is not None:
505
+ interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
506
+ interval_control = int(np.ceil(interval_control))
507
+ hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
508
+
509
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
510
+
511
+ for index_block, block in enumerate(self.single_transformer_blocks):
512
+ if self.training and self.gradient_checkpointing:
513
+
514
+ def create_custom_forward(module, return_dict=None):
515
+ def custom_forward(*inputs):
516
+ if return_dict is not None:
517
+ return module(*inputs, return_dict=return_dict)
518
+ else:
519
+ return module(*inputs)
520
+
521
+ return custom_forward
522
+
523
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
524
+ hidden_states = torch.utils.checkpoint.checkpoint(
525
+ create_custom_forward(block),
526
+ hidden_states,
527
+ temb,
528
+ image_rotary_emb,
529
+ **ckpt_kwargs,
530
+ )
531
+
532
+ else:
533
+ hidden_states = block(
534
+ hidden_states=hidden_states,
535
+ temb=temb,
536
+ image_rotary_emb=image_rotary_emb,
537
+ )
538
+
539
+ # controlnet residual
540
+ if controlnet_single_block_samples is not None:
541
+ interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
542
+ interval_control = int(np.ceil(interval_control))
543
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
544
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...]
545
+ + controlnet_single_block_samples[index_block // interval_control]
546
+ )
547
+
548
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
549
+
550
+ hidden_states = self.norm_out(hidden_states, temb)
551
+ output = self.proj_out(hidden_states)
552
+
553
+ if USE_PEFT_BACKEND:
554
+ # remove `lora_scale` from each PEFT layer
555
+ unscale_lora_layers(self, lora_scale)
556
+
557
+ if not return_dict:
558
+ return (output,)
559
+
560
+ return Transformer2DModelOutput(sample=output)
561
+
562
+ def zero_module(module):
563
+ for p in module.parameters():
564
+ torch.nn.init.zeros_(p)
565
+ return module