Surn commited on
Commit
ab3ef5d
·
1 Parent(s): 80b040e

Working on Hugging Face.

Browse files
Files changed (5) hide show
  1. README.md +1 -1
  2. src/block.py +0 -333
  3. src/generate.py +0 -294
  4. src/lora_controller.py +0 -75
  5. src/transformer.py +0 -270
README.md CHANGED
@@ -7,7 +7,7 @@ sdk: gradio
7
  python_version: 3.10.13
8
  sdk_version: 5.16.0
9
  app_file: app.py
10
- pinned: false
11
  short_description: Transform Your Images into Mesmerizing Hexagon Grids
12
  license: apache-2.0
13
  tags:
 
7
  python_version: 3.10.13
8
  sdk_version: 5.16.0
9
  app_file: app.py
10
+ pinned: true
11
  short_description: Transform Your Images into Mesmerizing Hexagon Grids
12
  license: apache-2.0
13
  tags:
src/block.py DELETED
@@ -1,333 +0,0 @@
1
- import torch
2
- from typing import List, Union, Optional, Dict, Any, Callable
3
- from diffusers.models.attention_processor import Attention, F
4
- from .lora_controller import enable_lora
5
-
6
-
7
- def attn_forward(
8
- attn: Attention,
9
- hidden_states: torch.FloatTensor,
10
- encoder_hidden_states: torch.FloatTensor = None,
11
- condition_latents: torch.FloatTensor = None,
12
- attention_mask: Optional[torch.FloatTensor] = None,
13
- image_rotary_emb: Optional[torch.Tensor] = None,
14
- cond_rotary_emb: Optional[torch.Tensor] = None,
15
- model_config: Optional[Dict[str, Any]] = {},
16
- ) -> torch.FloatTensor:
17
- batch_size, _, _ = (
18
- hidden_states.shape
19
- if encoder_hidden_states is None
20
- else encoder_hidden_states.shape
21
- )
22
-
23
- with enable_lora(
24
- (attn.to_q, attn.to_k, attn.to_v), model_config.get("latent_lora", False)
25
- ):
26
- # `sample` projections.
27
- query = attn.to_q(hidden_states)
28
- key = attn.to_k(hidden_states)
29
- value = attn.to_v(hidden_states)
30
-
31
- inner_dim = key.shape[-1]
32
- head_dim = inner_dim // attn.heads
33
-
34
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
35
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
36
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
37
-
38
- if attn.norm_q is not None:
39
- query = attn.norm_q(query)
40
- if attn.norm_k is not None:
41
- key = attn.norm_k(key)
42
-
43
- # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
44
- if encoder_hidden_states is not None:
45
- # `context` projections.
46
- encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
47
- encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
48
- encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
49
-
50
- encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
51
- batch_size, -1, attn.heads, head_dim
52
- ).transpose(1, 2)
53
- encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
54
- batch_size, -1, attn.heads, head_dim
55
- ).transpose(1, 2)
56
- encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
57
- batch_size, -1, attn.heads, head_dim
58
- ).transpose(1, 2)
59
-
60
- if attn.norm_added_q is not None:
61
- encoder_hidden_states_query_proj = attn.norm_added_q(
62
- encoder_hidden_states_query_proj
63
- )
64
- if attn.norm_added_k is not None:
65
- encoder_hidden_states_key_proj = attn.norm_added_k(
66
- encoder_hidden_states_key_proj
67
- )
68
-
69
- # attention
70
- query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
71
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
72
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
73
-
74
- if image_rotary_emb is not None:
75
- from diffusers.models.embeddings import apply_rotary_emb
76
-
77
- query = apply_rotary_emb(query, image_rotary_emb)
78
- key = apply_rotary_emb(key, image_rotary_emb)
79
-
80
- if condition_latents is not None:
81
- cond_query = attn.to_q(condition_latents)
82
- cond_key = attn.to_k(condition_latents)
83
- cond_value = attn.to_v(condition_latents)
84
-
85
- cond_query = cond_query.view(batch_size, -1, attn.heads, head_dim).transpose(
86
- 1, 2
87
- )
88
- cond_key = cond_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
89
- cond_value = cond_value.view(batch_size, -1, attn.heads, head_dim).transpose(
90
- 1, 2
91
- )
92
- if attn.norm_q is not None:
93
- cond_query = attn.norm_q(cond_query)
94
- if attn.norm_k is not None:
95
- cond_key = attn.norm_k(cond_key)
96
-
97
- if cond_rotary_emb is not None:
98
- cond_query = apply_rotary_emb(cond_query, cond_rotary_emb)
99
- cond_key = apply_rotary_emb(cond_key, cond_rotary_emb)
100
-
101
- if condition_latents is not None:
102
- query = torch.cat([query, cond_query], dim=2)
103
- key = torch.cat([key, cond_key], dim=2)
104
- value = torch.cat([value, cond_value], dim=2)
105
-
106
- if not model_config.get("union_cond_attn", True):
107
- # If we don't want to use the union condition attention, we need to mask the attention
108
- # between the hidden states and the condition latents
109
- attention_mask = torch.ones(
110
- query.shape[2], key.shape[2], device=query.device, dtype=torch.bool
111
- )
112
- condition_n = cond_query.shape[2]
113
- attention_mask[-condition_n:, :-condition_n] = False
114
- attention_mask[:-condition_n, -condition_n:] = False
115
- if hasattr(attn, "c_factor"):
116
- attention_mask = torch.zeros(
117
- query.shape[2], key.shape[2], device=query.device, dtype=query.dtype
118
- )
119
- condition_n = cond_query.shape[2]
120
- bias = torch.log(attn.c_factor[0])
121
- attention_mask[-condition_n:, :-condition_n] = bias
122
- attention_mask[:-condition_n, -condition_n:] = bias
123
- hidden_states = F.scaled_dot_product_attention(
124
- query, key, value, dropout_p=0.0, is_causal=False, attn_mask=attention_mask
125
- )
126
- hidden_states = hidden_states.transpose(1, 2).reshape(
127
- batch_size, -1, attn.heads * head_dim
128
- )
129
- hidden_states = hidden_states.to(query.dtype)
130
-
131
- if encoder_hidden_states is not None:
132
- if condition_latents is not None:
133
- encoder_hidden_states, hidden_states, condition_latents = (
134
- hidden_states[:, : encoder_hidden_states.shape[1]],
135
- hidden_states[
136
- :, encoder_hidden_states.shape[1] : -condition_latents.shape[1]
137
- ],
138
- hidden_states[:, -condition_latents.shape[1] :],
139
- )
140
- else:
141
- encoder_hidden_states, hidden_states = (
142
- hidden_states[:, : encoder_hidden_states.shape[1]],
143
- hidden_states[:, encoder_hidden_states.shape[1] :],
144
- )
145
-
146
- with enable_lora((attn.to_out[0],), model_config.get("latent_lora", False)):
147
- # linear proj
148
- hidden_states = attn.to_out[0](hidden_states)
149
- # dropout
150
- hidden_states = attn.to_out[1](hidden_states)
151
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
152
-
153
- if condition_latents is not None:
154
- condition_latents = attn.to_out[0](condition_latents)
155
- condition_latents = attn.to_out[1](condition_latents)
156
-
157
- return (
158
- (hidden_states, encoder_hidden_states, condition_latents)
159
- if condition_latents is not None
160
- else (hidden_states, encoder_hidden_states)
161
- )
162
- elif condition_latents is not None:
163
- # if there are condition_latents, we need to separate the hidden_states and the condition_latents
164
- hidden_states, condition_latents = (
165
- hidden_states[:, : -condition_latents.shape[1]],
166
- hidden_states[:, -condition_latents.shape[1] :],
167
- )
168
- return hidden_states, condition_latents
169
- else:
170
- return hidden_states
171
-
172
-
173
- def block_forward(
174
- self,
175
- hidden_states: torch.FloatTensor,
176
- encoder_hidden_states: torch.FloatTensor,
177
- condition_latents: torch.FloatTensor,
178
- temb: torch.FloatTensor,
179
- cond_temb: torch.FloatTensor,
180
- cond_rotary_emb=None,
181
- image_rotary_emb=None,
182
- model_config: Optional[Dict[str, Any]] = {},
183
- ):
184
- use_cond = condition_latents is not None
185
- with enable_lora((self.norm1.linear,), model_config.get("latent_lora", False)):
186
- norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
187
- hidden_states, emb=temb
188
- )
189
-
190
- norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = (
191
- self.norm1_context(encoder_hidden_states, emb=temb)
192
- )
193
-
194
- if use_cond:
195
- (
196
- norm_condition_latents,
197
- cond_gate_msa,
198
- cond_shift_mlp,
199
- cond_scale_mlp,
200
- cond_gate_mlp,
201
- ) = self.norm1(condition_latents, emb=cond_temb)
202
-
203
- # Attention.
204
- result = attn_forward(
205
- self.attn,
206
- model_config=model_config,
207
- hidden_states=norm_hidden_states,
208
- encoder_hidden_states=norm_encoder_hidden_states,
209
- condition_latents=norm_condition_latents if use_cond else None,
210
- image_rotary_emb=image_rotary_emb,
211
- cond_rotary_emb=cond_rotary_emb if use_cond else None,
212
- )
213
- attn_output, context_attn_output = result[:2]
214
- cond_attn_output = result[2] if use_cond else None
215
-
216
- # Process attention outputs for the `hidden_states`.
217
- # 1. hidden_states
218
- attn_output = gate_msa.unsqueeze(1) * attn_output
219
- hidden_states = hidden_states + attn_output
220
- # 2. encoder_hidden_states
221
- context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
222
- encoder_hidden_states = encoder_hidden_states + context_attn_output
223
- # 3. condition_latents
224
- if use_cond:
225
- cond_attn_output = cond_gate_msa.unsqueeze(1) * cond_attn_output
226
- condition_latents = condition_latents + cond_attn_output
227
- if model_config.get("add_cond_attn", False):
228
- hidden_states += cond_attn_output
229
-
230
- # LayerNorm + MLP.
231
- # 1. hidden_states
232
- norm_hidden_states = self.norm2(hidden_states)
233
- norm_hidden_states = (
234
- norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
235
- )
236
- # 2. encoder_hidden_states
237
- norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
238
- norm_encoder_hidden_states = (
239
- norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
240
- )
241
- # 3. condition_latents
242
- if use_cond:
243
- norm_condition_latents = self.norm2(condition_latents)
244
- norm_condition_latents = (
245
- norm_condition_latents * (1 + cond_scale_mlp[:, None])
246
- + cond_shift_mlp[:, None]
247
- )
248
-
249
- # Feed-forward.
250
- with enable_lora((self.ff.net[2],), model_config.get("latent_lora", False)):
251
- # 1. hidden_states
252
- ff_output = self.ff(norm_hidden_states)
253
- ff_output = gate_mlp.unsqueeze(1) * ff_output
254
- # 2. encoder_hidden_states
255
- context_ff_output = self.ff_context(norm_encoder_hidden_states)
256
- context_ff_output = c_gate_mlp.unsqueeze(1) * context_ff_output
257
- # 3. condition_latents
258
- if use_cond:
259
- cond_ff_output = self.ff(norm_condition_latents)
260
- cond_ff_output = cond_gate_mlp.unsqueeze(1) * cond_ff_output
261
-
262
- # Process feed-forward outputs.
263
- hidden_states = hidden_states + ff_output
264
- encoder_hidden_states = encoder_hidden_states + context_ff_output
265
- if use_cond:
266
- condition_latents = condition_latents + cond_ff_output
267
-
268
- # Clip to avoid overflow.
269
- if encoder_hidden_states.dtype == torch.float16:
270
- encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
271
-
272
- return encoder_hidden_states, hidden_states, condition_latents if use_cond else None
273
-
274
-
275
- def single_block_forward(
276
- self,
277
- hidden_states: torch.FloatTensor,
278
- temb: torch.FloatTensor,
279
- image_rotary_emb=None,
280
- condition_latents: torch.FloatTensor = None,
281
- cond_temb: torch.FloatTensor = None,
282
- cond_rotary_emb=None,
283
- model_config: Optional[Dict[str, Any]] = {},
284
- ):
285
-
286
- using_cond = condition_latents is not None
287
- residual = hidden_states
288
- with enable_lora(
289
- (
290
- self.norm.linear,
291
- self.proj_mlp,
292
- ),
293
- model_config.get("latent_lora", False),
294
- ):
295
- norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
296
- mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
297
- if using_cond:
298
- residual_cond = condition_latents
299
- norm_condition_latents, cond_gate = self.norm(condition_latents, emb=cond_temb)
300
- mlp_cond_hidden_states = self.act_mlp(self.proj_mlp(norm_condition_latents))
301
-
302
- attn_output = attn_forward(
303
- self.attn,
304
- model_config=model_config,
305
- hidden_states=norm_hidden_states,
306
- image_rotary_emb=image_rotary_emb,
307
- **(
308
- {
309
- "condition_latents": norm_condition_latents,
310
- "cond_rotary_emb": cond_rotary_emb if using_cond else None,
311
- }
312
- if using_cond
313
- else {}
314
- ),
315
- )
316
- if using_cond:
317
- attn_output, cond_attn_output = attn_output
318
-
319
- with enable_lora((self.proj_out,), model_config.get("latent_lora", False)):
320
- hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
321
- gate = gate.unsqueeze(1)
322
- hidden_states = gate * self.proj_out(hidden_states)
323
- hidden_states = residual + hidden_states
324
- if using_cond:
325
- condition_latents = torch.cat([cond_attn_output, mlp_cond_hidden_states], dim=2)
326
- cond_gate = cond_gate.unsqueeze(1)
327
- condition_latents = cond_gate * self.proj_out(condition_latents)
328
- condition_latents = residual_cond + condition_latents
329
-
330
- if hidden_states.dtype == torch.float16:
331
- hidden_states = hidden_states.clip(-65504, 65504)
332
-
333
- return hidden_states if not using_cond else (hidden_states, condition_latents)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/generate.py DELETED
@@ -1,294 +0,0 @@
1
- import torch
2
- import yaml, os
3
- from diffusers.pipelines import FluxPipeline
4
- from typing import List, Union, Optional, Dict, Any, Callable
5
- from .transformer import tranformer_forward
6
- from .condition import Condition
7
-
8
- from diffusers.pipelines.flux.pipeline_flux import (
9
- FluxPipelineOutput,
10
- calculate_shift,
11
- retrieve_timesteps,
12
- np,
13
- )
14
-
15
-
16
- def prepare_params(
17
- prompt: Union[str, List[str]] = None,
18
- prompt_2: Optional[Union[str, List[str]]] = None,
19
- height: Optional[int] = 512,
20
- width: Optional[int] = 512,
21
- num_inference_steps: int = 28,
22
- timesteps: List[int] = None,
23
- guidance_scale: float = 3.5,
24
- num_images_per_prompt: Optional[int] = 1,
25
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
26
- latents: Optional[torch.FloatTensor] = None,
27
- prompt_embeds: Optional[torch.FloatTensor] = None,
28
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
29
- output_type: Optional[str] = "pil",
30
- return_dict: bool = True,
31
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
32
- callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
33
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
34
- max_sequence_length: int = 512,
35
- **kwargs: dict,
36
- ):
37
- return (
38
- prompt,
39
- prompt_2,
40
- height,
41
- width,
42
- num_inference_steps,
43
- timesteps,
44
- guidance_scale,
45
- num_images_per_prompt,
46
- generator,
47
- latents,
48
- prompt_embeds,
49
- pooled_prompt_embeds,
50
- output_type,
51
- return_dict,
52
- joint_attention_kwargs,
53
- callback_on_step_end,
54
- callback_on_step_end_tensor_inputs,
55
- max_sequence_length,
56
- )
57
-
58
-
59
- def seed_everything(seed: int = 42):
60
- torch.backends.cudnn.deterministic = True
61
- torch.manual_seed(seed)
62
- np.random.seed(seed)
63
-
64
-
65
- @torch.no_grad()
66
- def generate(
67
- pipeline: FluxPipeline,
68
- conditions: List[Condition] = None,
69
- model_config: Optional[Dict[str, Any]] = {},
70
- condition_scale: float = 1.0,
71
- **params: dict,
72
- ):
73
- # model_config = model_config or get_config(config_path).get("model", {})
74
- if condition_scale != 1:
75
- for name, module in pipeline.transformer.named_modules():
76
- if not name.endswith(".attn"):
77
- continue
78
- module.c_factor = torch.ones(1, 1) * condition_scale
79
-
80
- self = pipeline
81
- (
82
- prompt,
83
- prompt_2,
84
- height,
85
- width,
86
- num_inference_steps,
87
- timesteps,
88
- guidance_scale,
89
- num_images_per_prompt,
90
- generator,
91
- latents,
92
- prompt_embeds,
93
- pooled_prompt_embeds,
94
- output_type,
95
- return_dict,
96
- joint_attention_kwargs,
97
- callback_on_step_end,
98
- callback_on_step_end_tensor_inputs,
99
- max_sequence_length,
100
- ) = prepare_params(**params)
101
-
102
- height = height or self.default_sample_size * self.vae_scale_factor
103
- width = width or self.default_sample_size * self.vae_scale_factor
104
-
105
- # 1. Check inputs. Raise error if not correct
106
- self.check_inputs(
107
- prompt,
108
- prompt_2,
109
- height,
110
- width,
111
- prompt_embeds=prompt_embeds,
112
- pooled_prompt_embeds=pooled_prompt_embeds,
113
- callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
114
- max_sequence_length=max_sequence_length,
115
- )
116
-
117
- self._guidance_scale = guidance_scale
118
- self._joint_attention_kwargs = joint_attention_kwargs
119
- self._interrupt = False
120
-
121
- # 2. Define call parameters
122
- if prompt is not None and isinstance(prompt, str):
123
- batch_size = 1
124
- elif prompt is not None and isinstance(prompt, list):
125
- batch_size = len(prompt)
126
- else:
127
- batch_size = prompt_embeds.shape[0]
128
-
129
- device = self._execution_device
130
-
131
- lora_scale = (
132
- self.joint_attention_kwargs.get("scale", None)
133
- if self.joint_attention_kwargs is not None
134
- else None
135
- )
136
- (
137
- prompt_embeds,
138
- pooled_prompt_embeds,
139
- text_ids,
140
- ) = self.encode_prompt(
141
- prompt=prompt,
142
- prompt_2=prompt_2,
143
- prompt_embeds=prompt_embeds,
144
- pooled_prompt_embeds=pooled_prompt_embeds,
145
- device=device,
146
- num_images_per_prompt=num_images_per_prompt,
147
- max_sequence_length=max_sequence_length,
148
- lora_scale=lora_scale,
149
- )
150
-
151
- # 4. Prepare latent variables
152
- num_channels_latents = self.transformer.config.in_channels // 4
153
- latents, latent_image_ids = self.prepare_latents(
154
- batch_size * num_images_per_prompt,
155
- num_channels_latents,
156
- height,
157
- width,
158
- prompt_embeds.dtype,
159
- device,
160
- generator,
161
- latents,
162
- )
163
-
164
- # 4.1. Prepare conditions
165
- condition_latents, condition_ids, condition_type_ids = ([] for _ in range(3))
166
- use_condition = conditions is not None or []
167
- if use_condition:
168
- assert len(conditions) <= 1, "Only one condition is supported for now."
169
- pipeline.set_adapters(
170
- {
171
- 512: "subject_512",
172
- 1024: "subject_1024",
173
- }[height]
174
- )
175
- for condition in conditions:
176
- tokens, ids, type_id = condition.encode(self)
177
- condition_latents.append(tokens) # [batch_size, token_n, token_dim]
178
- condition_ids.append(ids) # [token_n, id_dim(3)]
179
- condition_type_ids.append(type_id) # [token_n, 1]
180
- condition_latents = torch.cat(condition_latents, dim=1)
181
- condition_ids = torch.cat(condition_ids, dim=0)
182
- if condition.condition_type == "subject":
183
- delta = 32 if height == 512 else -32
184
- # print(f"Condition delta: {delta}")
185
- condition_ids[:, 2] += delta
186
-
187
- condition_type_ids = torch.cat(condition_type_ids, dim=0)
188
-
189
- # 5. Prepare timesteps
190
- sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
191
- image_seq_len = latents.shape[1]
192
- mu = calculate_shift(
193
- image_seq_len,
194
- self.scheduler.config.base_image_seq_len,
195
- self.scheduler.config.max_image_seq_len,
196
- self.scheduler.config.base_shift,
197
- self.scheduler.config.max_shift,
198
- )
199
- timesteps, num_inference_steps = retrieve_timesteps(
200
- self.scheduler,
201
- num_inference_steps,
202
- device,
203
- timesteps,
204
- sigmas,
205
- mu=mu,
206
- )
207
- num_warmup_steps = max(
208
- len(timesteps) - num_inference_steps * self.scheduler.order, 0
209
- )
210
- self._num_timesteps = len(timesteps)
211
-
212
- # 6. Denoising loop
213
- with self.progress_bar(total=num_inference_steps) as progress_bar:
214
- for i, t in enumerate(timesteps):
215
- if self.interrupt:
216
- continue
217
-
218
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
219
- timestep = t.expand(latents.shape[0]).to(latents.dtype)
220
-
221
- # handle guidance
222
- if self.transformer.config.guidance_embeds:
223
- guidance = torch.tensor([guidance_scale], device=device)
224
- guidance = guidance.expand(latents.shape[0])
225
- else:
226
- guidance = None
227
- noise_pred = tranformer_forward(
228
- self.transformer,
229
- model_config=model_config,
230
- # Inputs of the condition (new feature)
231
- condition_latents=condition_latents if use_condition else None,
232
- condition_ids=condition_ids if use_condition else None,
233
- condition_type_ids=condition_type_ids if use_condition else None,
234
- # Inputs to the original transformer
235
- hidden_states=latents,
236
- # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
237
- timestep=timestep / 1000,
238
- guidance=guidance,
239
- pooled_projections=pooled_prompt_embeds,
240
- encoder_hidden_states=prompt_embeds,
241
- txt_ids=text_ids,
242
- img_ids=latent_image_ids,
243
- joint_attention_kwargs=self.joint_attention_kwargs,
244
- return_dict=False,
245
- )[0]
246
-
247
- # compute the previous noisy sample x_t -> x_t-1
248
- latents_dtype = latents.dtype
249
- latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
250
-
251
- if latents.dtype != latents_dtype:
252
- if torch.backends.mps.is_available():
253
- # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
254
- latents = latents.to(latents_dtype)
255
-
256
- if callback_on_step_end is not None:
257
- callback_kwargs = {}
258
- for k in callback_on_step_end_tensor_inputs:
259
- callback_kwargs[k] = locals()[k]
260
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
261
-
262
- latents = callback_outputs.pop("latents", latents)
263
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
264
-
265
- # call the callback, if provided
266
- if i == len(timesteps) - 1 or (
267
- (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
268
- ):
269
- progress_bar.update()
270
-
271
- if output_type == "latent":
272
- image = latents
273
-
274
- else:
275
- latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
276
- latents = (
277
- latents / self.vae.config.scaling_factor
278
- ) + self.vae.config.shift_factor
279
- image = self.vae.decode(latents, return_dict=False)[0]
280
- image = self.image_processor.postprocess(image, output_type=output_type)
281
-
282
- # Offload all models
283
- self.maybe_free_model_hooks()
284
-
285
- if condition_scale != 1:
286
- for name, module in pipeline.transformer.named_modules():
287
- if not name.endswith(".attn"):
288
- continue
289
- del module.c_factor
290
-
291
- if not return_dict:
292
- return (image,)
293
-
294
- return FluxPipelineOutput(images=image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/lora_controller.py DELETED
@@ -1,75 +0,0 @@
1
- from peft.tuners.tuners_utils import BaseTunerLayer
2
- from typing import List, Any, Optional, Type
3
-
4
-
5
- class enable_lora:
6
- def __init__(self, lora_modules: List[BaseTunerLayer], activated: bool) -> None:
7
- self.activated: bool = activated
8
- if activated:
9
- return
10
- self.lora_modules: List[BaseTunerLayer] = [
11
- each for each in lora_modules if isinstance(each, BaseTunerLayer)
12
- ]
13
- self.scales = [
14
- {
15
- active_adapter: lora_module.scaling[active_adapter]
16
- for active_adapter in lora_module.active_adapters
17
- }
18
- for lora_module in self.lora_modules
19
- ]
20
-
21
- def __enter__(self) -> None:
22
- if self.activated:
23
- return
24
-
25
- for lora_module in self.lora_modules:
26
- if not isinstance(lora_module, BaseTunerLayer):
27
- continue
28
- lora_module.scale_layer(0)
29
-
30
- def __exit__(
31
- self,
32
- exc_type: Optional[Type[BaseException]],
33
- exc_val: Optional[BaseException],
34
- exc_tb: Optional[Any],
35
- ) -> None:
36
- if self.activated:
37
- return
38
- for i, lora_module in enumerate(self.lora_modules):
39
- if not isinstance(lora_module, BaseTunerLayer):
40
- continue
41
- for active_adapter in lora_module.active_adapters:
42
- lora_module.scaling[active_adapter] = self.scales[i][active_adapter]
43
-
44
-
45
- class set_lora_scale:
46
- def __init__(self, lora_modules: List[BaseTunerLayer], scale: float) -> None:
47
- self.lora_modules: List[BaseTunerLayer] = [
48
- each for each in lora_modules if isinstance(each, BaseTunerLayer)
49
- ]
50
- self.scales = [
51
- {
52
- active_adapter: lora_module.scaling[active_adapter]
53
- for active_adapter in lora_module.active_adapters
54
- }
55
- for lora_module in self.lora_modules
56
- ]
57
- self.scale = scale
58
-
59
- def __enter__(self) -> None:
60
- for lora_module in self.lora_modules:
61
- if not isinstance(lora_module, BaseTunerLayer):
62
- continue
63
- lora_module.scale_layer(self.scale)
64
-
65
- def __exit__(
66
- self,
67
- exc_type: Optional[Type[BaseException]],
68
- exc_val: Optional[BaseException],
69
- exc_tb: Optional[Any],
70
- ) -> None:
71
- for i, lora_module in enumerate(self.lora_modules):
72
- if not isinstance(lora_module, BaseTunerLayer):
73
- continue
74
- for active_adapter in lora_module.active_adapters:
75
- lora_module.scaling[active_adapter] = self.scales[i][active_adapter]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/transformer.py DELETED
@@ -1,270 +0,0 @@
1
- import torch
2
- from diffusers.pipelines import FluxPipeline
3
- from typing import List, Union, Optional, Dict, Any, Callable
4
- from .block import block_forward, single_block_forward
5
- from .lora_controller import enable_lora
6
- from diffusers.models.transformers.transformer_flux import (
7
- FluxTransformer2DModel,
8
- Transformer2DModelOutput,
9
- USE_PEFT_BACKEND,
10
- is_torch_version,
11
- scale_lora_layers,
12
- unscale_lora_layers,
13
- logger,
14
- )
15
- import numpy as np
16
-
17
-
18
- def prepare_params(
19
- hidden_states: torch.Tensor,
20
- encoder_hidden_states: torch.Tensor = None,
21
- pooled_projections: torch.Tensor = None,
22
- timestep: torch.LongTensor = None,
23
- img_ids: torch.Tensor = None,
24
- txt_ids: torch.Tensor = None,
25
- guidance: torch.Tensor = None,
26
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
27
- controlnet_block_samples=None,
28
- controlnet_single_block_samples=None,
29
- return_dict: bool = True,
30
- **kwargs: dict,
31
- ):
32
- return (
33
- hidden_states,
34
- encoder_hidden_states,
35
- pooled_projections,
36
- timestep,
37
- img_ids,
38
- txt_ids,
39
- guidance,
40
- joint_attention_kwargs,
41
- controlnet_block_samples,
42
- controlnet_single_block_samples,
43
- return_dict,
44
- )
45
-
46
-
47
- def tranformer_forward(
48
- transformer: FluxTransformer2DModel,
49
- condition_latents: torch.Tensor,
50
- condition_ids: torch.Tensor,
51
- condition_type_ids: torch.Tensor,
52
- model_config: Optional[Dict[str, Any]] = {},
53
- return_conditional_latents: bool = False,
54
- c_t=0,
55
- **params: dict,
56
- ):
57
- self = transformer
58
- use_condition = condition_latents is not None
59
- use_condition_in_single_blocks = model_config.get(
60
- "use_condition_in_single_blocks", True
61
- )
62
- # if return_conditional_latents is True, use_condition and use_condition_in_single_blocks must be True
63
- assert not return_conditional_latents or (
64
- use_condition and use_condition_in_single_blocks
65
- ), "`return_conditional_latents` is True, `use_condition` and `use_condition_in_single_blocks` must be True"
66
-
67
- (
68
- hidden_states,
69
- encoder_hidden_states,
70
- pooled_projections,
71
- timestep,
72
- img_ids,
73
- txt_ids,
74
- guidance,
75
- joint_attention_kwargs,
76
- controlnet_block_samples,
77
- controlnet_single_block_samples,
78
- return_dict,
79
- ) = prepare_params(**params)
80
-
81
- if joint_attention_kwargs is not None:
82
- joint_attention_kwargs = joint_attention_kwargs.copy()
83
- lora_scale = joint_attention_kwargs.pop("scale", 1.0)
84
- else:
85
- lora_scale = 1.0
86
-
87
- if USE_PEFT_BACKEND:
88
- # weight the lora layers by setting `lora_scale` for each PEFT layer
89
- scale_lora_layers(self, lora_scale)
90
- else:
91
- if (
92
- joint_attention_kwargs is not None
93
- and joint_attention_kwargs.get("scale", None) is not None
94
- ):
95
- logger.warning(
96
- "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
97
- )
98
- with enable_lora((self.x_embedder,), model_config.get("latent_lora", False)):
99
- hidden_states = self.x_embedder(hidden_states)
100
- condition_latents = self.x_embedder(condition_latents) if use_condition else None
101
-
102
- timestep = timestep.to(hidden_states.dtype) * 1000
103
- if guidance is not None:
104
- guidance = guidance.to(hidden_states.dtype) * 1000
105
- else:
106
- guidance = None
107
- temb = (
108
- self.time_text_embed(timestep, pooled_projections)
109
- if guidance is None
110
- else self.time_text_embed(timestep, guidance, pooled_projections)
111
- )
112
- cond_temb = (
113
- self.time_text_embed(torch.ones_like(timestep) * c_t * 1000, pooled_projections)
114
- if guidance is None
115
- else self.time_text_embed(
116
- torch.ones_like(timestep) * c_t * 1000, guidance, pooled_projections
117
- )
118
- )
119
- if hasattr(self, "cond_type_embed") and condition_type_ids is not None:
120
- cond_type_proj = self.time_text_embed.time_proj(condition_type_ids[0])
121
- cond_type_emb = self.cond_type_embed(cond_type_proj.to(dtype=cond_temb.dtype))
122
- cond_temb = cond_temb + cond_type_emb
123
- encoder_hidden_states = self.context_embedder(encoder_hidden_states)
124
-
125
- if txt_ids.ndim == 3:
126
- logger.warning(
127
- "Passing `txt_ids` 3d torch.Tensor is deprecated."
128
- "Please remove the batch dimension and pass it as a 2d torch Tensor"
129
- )
130
- txt_ids = txt_ids[0]
131
- if img_ids.ndim == 3:
132
- logger.warning(
133
- "Passing `img_ids` 3d torch.Tensor is deprecated."
134
- "Please remove the batch dimension and pass it as a 2d torch Tensor"
135
- )
136
- img_ids = img_ids[0]
137
-
138
- ids = torch.cat((txt_ids, img_ids), dim=0)
139
- image_rotary_emb = self.pos_embed(ids)
140
- if use_condition:
141
- cond_ids = condition_ids
142
- cond_rotary_emb = self.pos_embed(cond_ids)
143
-
144
- # hidden_states = torch.cat([hidden_states, condition_latents], dim=1)
145
-
146
- for index_block, block in enumerate(self.transformer_blocks):
147
- if self.training and self.gradient_checkpointing:
148
-
149
- def create_custom_forward(module, return_dict=None):
150
- def custom_forward(*inputs):
151
- if return_dict is not None:
152
- return module(*inputs, return_dict=return_dict)
153
- else:
154
- return module(*inputs)
155
-
156
- return custom_forward
157
-
158
- ckpt_kwargs: Dict[str, Any] = (
159
- {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
160
- )
161
- encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
162
- create_custom_forward(block),
163
- hidden_states,
164
- encoder_hidden_states,
165
- temb,
166
- image_rotary_emb,
167
- **ckpt_kwargs,
168
- )
169
-
170
- else:
171
- encoder_hidden_states, hidden_states, condition_latents = block_forward(
172
- block,
173
- model_config=model_config,
174
- hidden_states=hidden_states,
175
- encoder_hidden_states=encoder_hidden_states,
176
- condition_latents=condition_latents if use_condition else None,
177
- temb=temb,
178
- cond_temb=cond_temb if use_condition else None,
179
- cond_rotary_emb=cond_rotary_emb if use_condition else None,
180
- image_rotary_emb=image_rotary_emb,
181
- )
182
-
183
- # controlnet residual
184
- if controlnet_block_samples is not None:
185
- interval_control = len(self.transformer_blocks) / len(
186
- controlnet_block_samples
187
- )
188
- interval_control = int(np.ceil(interval_control))
189
- hidden_states = (
190
- hidden_states
191
- + controlnet_block_samples[index_block // interval_control]
192
- )
193
- hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
194
-
195
- for index_block, block in enumerate(self.single_transformer_blocks):
196
- if self.training and self.gradient_checkpointing:
197
-
198
- def create_custom_forward(module, return_dict=None):
199
- def custom_forward(*inputs):
200
- if return_dict is not None:
201
- return module(*inputs, return_dict=return_dict)
202
- else:
203
- return module(*inputs)
204
-
205
- return custom_forward
206
-
207
- ckpt_kwargs: Dict[str, Any] = (
208
- {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
209
- )
210
- hidden_states = torch.utils.checkpoint.checkpoint(
211
- create_custom_forward(block),
212
- hidden_states,
213
- temb,
214
- image_rotary_emb,
215
- **ckpt_kwargs,
216
- )
217
-
218
- else:
219
- result = single_block_forward(
220
- block,
221
- model_config=model_config,
222
- hidden_states=hidden_states,
223
- temb=temb,
224
- image_rotary_emb=image_rotary_emb,
225
- **(
226
- {
227
- "condition_latents": condition_latents,
228
- "cond_temb": cond_temb,
229
- "cond_rotary_emb": cond_rotary_emb,
230
- }
231
- if use_condition_in_single_blocks and use_condition
232
- else {}
233
- ),
234
- )
235
- if use_condition_in_single_blocks and use_condition:
236
- hidden_states, condition_latents = result
237
- else:
238
- hidden_states = result
239
-
240
- # controlnet residual
241
- if controlnet_single_block_samples is not None:
242
- interval_control = len(self.single_transformer_blocks) / len(
243
- controlnet_single_block_samples
244
- )
245
- interval_control = int(np.ceil(interval_control))
246
- hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
247
- hidden_states[:, encoder_hidden_states.shape[1] :, ...]
248
- + controlnet_single_block_samples[index_block // interval_control]
249
- )
250
-
251
- hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
252
-
253
- hidden_states = self.norm_out(hidden_states, temb)
254
- output = self.proj_out(hidden_states)
255
- if return_conditional_latents:
256
- condition_latents = (
257
- self.norm_out(condition_latents, cond_temb) if use_condition else None
258
- )
259
- condition_output = self.proj_out(condition_latents) if use_condition else None
260
-
261
- if USE_PEFT_BACKEND:
262
- # remove `lora_scale` from each PEFT layer
263
- unscale_lora_layers(self, lora_scale)
264
-
265
- if not return_dict:
266
- return (
267
- (output,) if not return_conditional_latents else (output, condition_output)
268
- )
269
-
270
- return Transformer2DModelOutput(sample=output)