Surn commited on
Commit
6ef117e
·
1 Parent(s): 1cb68a6

Merge from Main repository

Browse files
LUT/BlackWhite.cube ADDED
The diff for this file is too large to render. See raw diff
 
LUT/CineCold.cube ADDED
The diff for this file is too large to render. See raw diff
 
LUT/CineDrama.cube ADDED
The diff for this file is too large to render. See raw diff
 
LUT/CineVibrant.cube ADDED
The diff for this file is too large to render. See raw diff
 
LUT/CineWarm.cube ADDED
The diff for this file is too large to render. See raw diff
 
LUT/Depth_of_Field.cube ADDED
The diff for this file is too large to render. See raw diff
 
LUT/Glow_Highlights.cube CHANGED
The diff for this file is too large to render. See raw diff
 
LUT/RedWhiteBlue.cube ADDED
The diff for this file is too large to render. See raw diff
 
assets/logo.png → LUT/daisy.jpg RENAMED
File without changes
LUT/grayscale.cube CHANGED
The diff for this file is too large to render. See raw diff
 
LUT/scenery01.cube CHANGED
The diff for this file is too large to render. See raw diff
 
app.py CHANGED
@@ -6,6 +6,7 @@ from tempfile import NamedTemporaryFile
6
  from pathlib import Path
7
  import atexit
8
  import random
 
9
  # Import constants
10
  import utils.constants as constants
11
 
@@ -308,16 +309,16 @@ with gr.Blocks(css_paths="style_20250128.css", title="HexaGrid Creator", theme='
308
  )
309
  with gr.Column():
310
  with gr.Accordion("Hex Coloring and Exclusion", open = False):
311
- with gr.Row():
312
- with gr.Column():
313
- color_picker = gr.ColorPicker(label="Pick a color to exclude",value="#505050")
314
- with gr.Column():
315
- filter_color = gr.Checkbox(label="Filter Excluded Colors from Sampling", value=False,)
316
- exclude_color_button = gr.Button("Exclude Color", elem_id="exlude_color_button", elem_classes="solid")
317
- color_display = gr.DataFrame(label="List of Excluded RGBA Colors", headers=["R", "G", "B", "A"], elem_id="excluded_colors", type="array", value=build_dataframe(excluded_color_list), interactive=True, elem_classes="solid centered")
318
- selected_row = gr.Number(0, label="Selected Row", visible=False)
319
- delete_button = gr.Button("Delete Row", elem_id="delete_exclusion_button", elem_classes="solid")
320
- fill_hex = gr.Checkbox(label="Fill Hex with color from Image", value=True)
321
  with gr.Accordion("Image Filters", open = False):
322
  with gr.Row():
323
  with gr.Column():
@@ -468,15 +469,15 @@ with gr.Blocks(css_paths="style_20250128.css", title="HexaGrid Creator", theme='
468
  ### The custom color list is a comma separated list of hex colors.
469
  #### Example: "A,2,3,4,5,6,7,8,9,10,J,Q,K", "red,#0000FF,#00FF00,red,#FFFF00,#00FFFF,#FF8000,#FF00FF,#FF0080,#FF8000,#FF0080,lightblue"
470
  """, elem_id="hex_text_info", visible=False)
471
- add_hex_text.change(
472
- fn=lambda x: (
473
- gr.update(visible=(x == "Custom List")),
474
- gr.update(visible=(x == "Custom List")),
475
- gr.update(visible=(x != None))
476
- ),
477
- inputs=add_hex_text,
478
- outputs=[custom_text_list, custom_text_color_list, hex_text_info]
479
- )
480
  with gr.Row():
481
  hex_size = gr.Number(label="Hexagon Size", value=32, minimum=1, maximum=768)
482
  border_size = gr.Slider(-5,25,value=0,step=1,label="Border Size")
 
6
  from pathlib import Path
7
  import atexit
8
  import random
9
+ import spaces
10
  # Import constants
11
  import utils.constants as constants
12
 
 
309
  )
310
  with gr.Column():
311
  with gr.Accordion("Hex Coloring and Exclusion", open = False):
312
+ with gr.Row():
313
+ with gr.Column():
314
+ color_picker = gr.ColorPicker(label="Pick a color to exclude",value="#505050")
315
+ with gr.Column():
316
+ filter_color = gr.Checkbox(label="Filter Excluded Colors from Sampling", value=False,)
317
+ exclude_color_button = gr.Button("Exclude Color", elem_id="exlude_color_button", elem_classes="solid")
318
+ color_display = gr.DataFrame(label="List of Excluded RGBA Colors", headers=["R", "G", "B", "A"], elem_id="excluded_colors", type="array", value=build_dataframe(excluded_color_list), interactive=True, elem_classes="solid centered")
319
+ selected_row = gr.Number(0, label="Selected Row", visible=False)
320
+ delete_button = gr.Button("Delete Row", elem_id="delete_exclusion_button", elem_classes="solid")
321
+ fill_hex = gr.Checkbox(label="Fill Hex with color from Image", value=True)
322
  with gr.Accordion("Image Filters", open = False):
323
  with gr.Row():
324
  with gr.Column():
 
469
  ### The custom color list is a comma separated list of hex colors.
470
  #### Example: "A,2,3,4,5,6,7,8,9,10,J,Q,K", "red,#0000FF,#00FF00,red,#FFFF00,#00FFFF,#FF8000,#FF00FF,#FF0080,#FF8000,#FF0080,lightblue"
471
  """, elem_id="hex_text_info", visible=False)
472
+ add_hex_text.change(
473
+ fn=lambda x: (
474
+ gr.update(visible=(x == "Custom List")),
475
+ gr.update(visible=(x == "Custom List")),
476
+ gr.update(visible=(x != None))
477
+ ),
478
+ inputs=add_hex_text,
479
+ outputs=[custom_text_list, custom_text_color_list, hex_text_info]
480
+ )
481
  with gr.Row():
482
  hex_size = gr.Number(label="Hexagon Size", value=32, minimum=1, maximum=768)
483
  border_size = gr.Slider(-5,25,value=0,step=1,label="Border Size")
assets/logo_hex.png DELETED

Git LFS Details

  • SHA256: 9c0f91c488296e7234f829effe6da9d997704fa9b4e95739af7049d8d91db72b
  • Pointer size: 131 Bytes
  • Size of remote file: 547 kB
assets/logo_old.png DELETED

Git LFS Details

  • SHA256: 77dc2f8c3d405d0a11cde6dbca3ab97e5a3bbd23f89872fa94483953c8505799
  • Pointer size: 130 Bytes
  • Size of remote file: 49.9 kB
assets/logo_hex.gif → images/prerendered/grid_1.png RENAMED
File without changes
src/block.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import List, Union, Optional, Dict, Any, Callable
3
+ from diffusers.models.attention_processor import Attention, F
4
+ from .lora_controller import enable_lora
5
+
6
+
7
+ def attn_forward(
8
+ attn: Attention,
9
+ hidden_states: torch.FloatTensor,
10
+ encoder_hidden_states: torch.FloatTensor = None,
11
+ condition_latents: torch.FloatTensor = None,
12
+ attention_mask: Optional[torch.FloatTensor] = None,
13
+ image_rotary_emb: Optional[torch.Tensor] = None,
14
+ cond_rotary_emb: Optional[torch.Tensor] = None,
15
+ model_config: Optional[Dict[str, Any]] = {},
16
+ ) -> torch.FloatTensor:
17
+ batch_size, _, _ = (
18
+ hidden_states.shape
19
+ if encoder_hidden_states is None
20
+ else encoder_hidden_states.shape
21
+ )
22
+
23
+ with enable_lora(
24
+ (attn.to_q, attn.to_k, attn.to_v), model_config.get("latent_lora", False)
25
+ ):
26
+ # `sample` projections.
27
+ query = attn.to_q(hidden_states)
28
+ key = attn.to_k(hidden_states)
29
+ value = attn.to_v(hidden_states)
30
+
31
+ inner_dim = key.shape[-1]
32
+ head_dim = inner_dim // attn.heads
33
+
34
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
35
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
36
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
37
+
38
+ if attn.norm_q is not None:
39
+ query = attn.norm_q(query)
40
+ if attn.norm_k is not None:
41
+ key = attn.norm_k(key)
42
+
43
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
44
+ if encoder_hidden_states is not None:
45
+ # `context` projections.
46
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
47
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
48
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
49
+
50
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
51
+ batch_size, -1, attn.heads, head_dim
52
+ ).transpose(1, 2)
53
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
54
+ batch_size, -1, attn.heads, head_dim
55
+ ).transpose(1, 2)
56
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
57
+ batch_size, -1, attn.heads, head_dim
58
+ ).transpose(1, 2)
59
+
60
+ if attn.norm_added_q is not None:
61
+ encoder_hidden_states_query_proj = attn.norm_added_q(
62
+ encoder_hidden_states_query_proj
63
+ )
64
+ if attn.norm_added_k is not None:
65
+ encoder_hidden_states_key_proj = attn.norm_added_k(
66
+ encoder_hidden_states_key_proj
67
+ )
68
+
69
+ # attention
70
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
71
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
72
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
73
+
74
+ if image_rotary_emb is not None:
75
+ from diffusers.models.embeddings import apply_rotary_emb
76
+
77
+ query = apply_rotary_emb(query, image_rotary_emb)
78
+ key = apply_rotary_emb(key, image_rotary_emb)
79
+
80
+ if condition_latents is not None:
81
+ cond_query = attn.to_q(condition_latents)
82
+ cond_key = attn.to_k(condition_latents)
83
+ cond_value = attn.to_v(condition_latents)
84
+
85
+ cond_query = cond_query.view(batch_size, -1, attn.heads, head_dim).transpose(
86
+ 1, 2
87
+ )
88
+ cond_key = cond_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
89
+ cond_value = cond_value.view(batch_size, -1, attn.heads, head_dim).transpose(
90
+ 1, 2
91
+ )
92
+ if attn.norm_q is not None:
93
+ cond_query = attn.norm_q(cond_query)
94
+ if attn.norm_k is not None:
95
+ cond_key = attn.norm_k(cond_key)
96
+
97
+ if cond_rotary_emb is not None:
98
+ cond_query = apply_rotary_emb(cond_query, cond_rotary_emb)
99
+ cond_key = apply_rotary_emb(cond_key, cond_rotary_emb)
100
+
101
+ if condition_latents is not None:
102
+ query = torch.cat([query, cond_query], dim=2)
103
+ key = torch.cat([key, cond_key], dim=2)
104
+ value = torch.cat([value, cond_value], dim=2)
105
+
106
+ if not model_config.get("union_cond_attn", True):
107
+ # If we don't want to use the union condition attention, we need to mask the attention
108
+ # between the hidden states and the condition latents
109
+ attention_mask = torch.ones(
110
+ query.shape[2], key.shape[2], device=query.device, dtype=torch.bool
111
+ )
112
+ condition_n = cond_query.shape[2]
113
+ attention_mask[-condition_n:, :-condition_n] = False
114
+ attention_mask[:-condition_n, -condition_n:] = False
115
+ if hasattr(attn, "c_factor"):
116
+ attention_mask = torch.zeros(
117
+ query.shape[2], key.shape[2], device=query.device, dtype=query.dtype
118
+ )
119
+ condition_n = cond_query.shape[2]
120
+ bias = torch.log(attn.c_factor[0])
121
+ attention_mask[-condition_n:, :-condition_n] = bias
122
+ attention_mask[:-condition_n, -condition_n:] = bias
123
+ hidden_states = F.scaled_dot_product_attention(
124
+ query, key, value, dropout_p=0.0, is_causal=False, attn_mask=attention_mask
125
+ )
126
+ hidden_states = hidden_states.transpose(1, 2).reshape(
127
+ batch_size, -1, attn.heads * head_dim
128
+ )
129
+ hidden_states = hidden_states.to(query.dtype)
130
+
131
+ if encoder_hidden_states is not None:
132
+ if condition_latents is not None:
133
+ encoder_hidden_states, hidden_states, condition_latents = (
134
+ hidden_states[:, : encoder_hidden_states.shape[1]],
135
+ hidden_states[
136
+ :, encoder_hidden_states.shape[1] : -condition_latents.shape[1]
137
+ ],
138
+ hidden_states[:, -condition_latents.shape[1] :],
139
+ )
140
+ else:
141
+ encoder_hidden_states, hidden_states = (
142
+ hidden_states[:, : encoder_hidden_states.shape[1]],
143
+ hidden_states[:, encoder_hidden_states.shape[1] :],
144
+ )
145
+
146
+ with enable_lora((attn.to_out[0],), model_config.get("latent_lora", False)):
147
+ # linear proj
148
+ hidden_states = attn.to_out[0](hidden_states)
149
+ # dropout
150
+ hidden_states = attn.to_out[1](hidden_states)
151
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
152
+
153
+ if condition_latents is not None:
154
+ condition_latents = attn.to_out[0](condition_latents)
155
+ condition_latents = attn.to_out[1](condition_latents)
156
+
157
+ return (
158
+ (hidden_states, encoder_hidden_states, condition_latents)
159
+ if condition_latents is not None
160
+ else (hidden_states, encoder_hidden_states)
161
+ )
162
+ elif condition_latents is not None:
163
+ # if there are condition_latents, we need to separate the hidden_states and the condition_latents
164
+ hidden_states, condition_latents = (
165
+ hidden_states[:, : -condition_latents.shape[1]],
166
+ hidden_states[:, -condition_latents.shape[1] :],
167
+ )
168
+ return hidden_states, condition_latents
169
+ else:
170
+ return hidden_states
171
+
172
+
173
+ def block_forward(
174
+ self,
175
+ hidden_states: torch.FloatTensor,
176
+ encoder_hidden_states: torch.FloatTensor,
177
+ condition_latents: torch.FloatTensor,
178
+ temb: torch.FloatTensor,
179
+ cond_temb: torch.FloatTensor,
180
+ cond_rotary_emb=None,
181
+ image_rotary_emb=None,
182
+ model_config: Optional[Dict[str, Any]] = {},
183
+ ):
184
+ use_cond = condition_latents is not None
185
+ with enable_lora((self.norm1.linear,), model_config.get("latent_lora", False)):
186
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
187
+ hidden_states, emb=temb
188
+ )
189
+
190
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = (
191
+ self.norm1_context(encoder_hidden_states, emb=temb)
192
+ )
193
+
194
+ if use_cond:
195
+ (
196
+ norm_condition_latents,
197
+ cond_gate_msa,
198
+ cond_shift_mlp,
199
+ cond_scale_mlp,
200
+ cond_gate_mlp,
201
+ ) = self.norm1(condition_latents, emb=cond_temb)
202
+
203
+ # Attention.
204
+ result = attn_forward(
205
+ self.attn,
206
+ model_config=model_config,
207
+ hidden_states=norm_hidden_states,
208
+ encoder_hidden_states=norm_encoder_hidden_states,
209
+ condition_latents=norm_condition_latents if use_cond else None,
210
+ image_rotary_emb=image_rotary_emb,
211
+ cond_rotary_emb=cond_rotary_emb if use_cond else None,
212
+ )
213
+ attn_output, context_attn_output = result[:2]
214
+ cond_attn_output = result[2] if use_cond else None
215
+
216
+ # Process attention outputs for the `hidden_states`.
217
+ # 1. hidden_states
218
+ attn_output = gate_msa.unsqueeze(1) * attn_output
219
+ hidden_states = hidden_states + attn_output
220
+ # 2. encoder_hidden_states
221
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
222
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
223
+ # 3. condition_latents
224
+ if use_cond:
225
+ cond_attn_output = cond_gate_msa.unsqueeze(1) * cond_attn_output
226
+ condition_latents = condition_latents + cond_attn_output
227
+ if model_config.get("add_cond_attn", False):
228
+ hidden_states += cond_attn_output
229
+
230
+ # LayerNorm + MLP.
231
+ # 1. hidden_states
232
+ norm_hidden_states = self.norm2(hidden_states)
233
+ norm_hidden_states = (
234
+ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
235
+ )
236
+ # 2. encoder_hidden_states
237
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
238
+ norm_encoder_hidden_states = (
239
+ norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
240
+ )
241
+ # 3. condition_latents
242
+ if use_cond:
243
+ norm_condition_latents = self.norm2(condition_latents)
244
+ norm_condition_latents = (
245
+ norm_condition_latents * (1 + cond_scale_mlp[:, None])
246
+ + cond_shift_mlp[:, None]
247
+ )
248
+
249
+ # Feed-forward.
250
+ with enable_lora((self.ff.net[2],), model_config.get("latent_lora", False)):
251
+ # 1. hidden_states
252
+ ff_output = self.ff(norm_hidden_states)
253
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
254
+ # 2. encoder_hidden_states
255
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
256
+ context_ff_output = c_gate_mlp.unsqueeze(1) * context_ff_output
257
+ # 3. condition_latents
258
+ if use_cond:
259
+ cond_ff_output = self.ff(norm_condition_latents)
260
+ cond_ff_output = cond_gate_mlp.unsqueeze(1) * cond_ff_output
261
+
262
+ # Process feed-forward outputs.
263
+ hidden_states = hidden_states + ff_output
264
+ encoder_hidden_states = encoder_hidden_states + context_ff_output
265
+ if use_cond:
266
+ condition_latents = condition_latents + cond_ff_output
267
+
268
+ # Clip to avoid overflow.
269
+ if encoder_hidden_states.dtype == torch.float16:
270
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
271
+
272
+ return encoder_hidden_states, hidden_states, condition_latents if use_cond else None
273
+
274
+
275
+ def single_block_forward(
276
+ self,
277
+ hidden_states: torch.FloatTensor,
278
+ temb: torch.FloatTensor,
279
+ image_rotary_emb=None,
280
+ condition_latents: torch.FloatTensor = None,
281
+ cond_temb: torch.FloatTensor = None,
282
+ cond_rotary_emb=None,
283
+ model_config: Optional[Dict[str, Any]] = {},
284
+ ):
285
+
286
+ using_cond = condition_latents is not None
287
+ residual = hidden_states
288
+ with enable_lora(
289
+ (
290
+ self.norm.linear,
291
+ self.proj_mlp,
292
+ ),
293
+ model_config.get("latent_lora", False),
294
+ ):
295
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
296
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
297
+ if using_cond:
298
+ residual_cond = condition_latents
299
+ norm_condition_latents, cond_gate = self.norm(condition_latents, emb=cond_temb)
300
+ mlp_cond_hidden_states = self.act_mlp(self.proj_mlp(norm_condition_latents))
301
+
302
+ attn_output = attn_forward(
303
+ self.attn,
304
+ model_config=model_config,
305
+ hidden_states=norm_hidden_states,
306
+ image_rotary_emb=image_rotary_emb,
307
+ **(
308
+ {
309
+ "condition_latents": norm_condition_latents,
310
+ "cond_rotary_emb": cond_rotary_emb if using_cond else None,
311
+ }
312
+ if using_cond
313
+ else {}
314
+ ),
315
+ )
316
+ if using_cond:
317
+ attn_output, cond_attn_output = attn_output
318
+
319
+ with enable_lora((self.proj_out,), model_config.get("latent_lora", False)):
320
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
321
+ gate = gate.unsqueeze(1)
322
+ hidden_states = gate * self.proj_out(hidden_states)
323
+ hidden_states = residual + hidden_states
324
+ if using_cond:
325
+ condition_latents = torch.cat([cond_attn_output, mlp_cond_hidden_states], dim=2)
326
+ cond_gate = cond_gate.unsqueeze(1)
327
+ condition_latents = cond_gate * self.proj_out(condition_latents)
328
+ condition_latents = residual_cond + condition_latents
329
+
330
+ if hidden_states.dtype == torch.float16:
331
+ hidden_states = hidden_states.clip(-65504, 65504)
332
+
333
+ return hidden_states if not using_cond else (hidden_states, condition_latents)
src/condition.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Optional, Union, List, Tuple
3
+ from diffusers.pipelines import FluxPipeline
4
+ from PIL import Image, ImageFilter
5
+ import numpy as np
6
+ import cv2
7
+
8
+ condition_dict = {
9
+ "depth": 0,
10
+ "canny": 1,
11
+ "subject": 4,
12
+ "coloring": 6,
13
+ "deblurring": 7,
14
+ "fill": 9,
15
+ }
16
+ class Condition(object):
17
+ def __init__(
18
+ self,
19
+ condition_type: str,
20
+ raw_img: Union[Image.Image, torch.Tensor] = None,
21
+ condition: Union[Image.Image, torch.Tensor] = None,
22
+ mask=None,
23
+ ) -> None:
24
+ self.condition_type = condition_type
25
+ assert raw_img is not None or condition is not None
26
+ if raw_img is not None:
27
+ self.condition = self.get_condition(condition_type, raw_img)
28
+ else:
29
+ self.condition = condition
30
+ # TODO: Add mask support
31
+ assert mask is None, "Mask not supported yet"
32
+ def get_condition(
33
+ self, condition_type: str, raw_img: Union[Image.Image, torch.Tensor]
34
+ ) -> Union[Image.Image, torch.Tensor]:
35
+ """
36
+ Returns the condition image.
37
+ """
38
+ if condition_type == "depth":
39
+ from transformers import pipeline
40
+ depth_pipe = pipeline(
41
+ task="depth-estimation",
42
+ model="LiheYoung/depth-anything-small-hf",
43
+ device="cuda",
44
+ )
45
+ source_image = raw_img.convert("RGB")
46
+ condition_img = depth_pipe(source_image)["depth"].convert("RGB")
47
+ return condition_img
48
+ elif condition_type == "canny":
49
+ img = np.array(raw_img)
50
+ edges = cv2.Canny(img, 100, 200)
51
+ edges = Image.fromarray(edges).convert("RGB")
52
+ return edges
53
+ elif condition_type == "subject":
54
+ return raw_img
55
+ elif condition_type == "coloring":
56
+ return raw_img.convert("L").convert("RGB")
57
+ elif condition_type == "deblurring":
58
+ condition_image = (
59
+ raw_img.convert("RGB")
60
+ .filter(ImageFilter.GaussianBlur(10))
61
+ .convert("RGB")
62
+ )
63
+ return condition_image
64
+ elif condition_type == "fill":
65
+ return raw_img.convert("RGB")
66
+ return self.condition
67
+ @property
68
+ def type_id(self) -> int:
69
+ """
70
+ Returns the type id of the condition.
71
+ """
72
+ return condition_dict[self.condition_type]
73
+ @classmethod
74
+ def get_type_id(cls, condition_type: str) -> int:
75
+ """
76
+ Returns the type id of the condition.
77
+ """
78
+ return condition_dict[condition_type]
79
+ def _encode_image(self, pipe: FluxPipeline, cond_img: Image.Image) -> torch.Tensor:
80
+ """
81
+ Encodes an image condition into tokens using the pipeline.
82
+ """
83
+ cond_img = pipe.image_processor.preprocess(cond_img)
84
+ cond_img = cond_img.to(pipe.device).to(pipe.dtype)
85
+ cond_img = pipe.vae.encode(cond_img).latent_dist.sample()
86
+ cond_img = (
87
+ cond_img - pipe.vae.config.shift_factor
88
+ ) * pipe.vae.config.scaling_factor
89
+ cond_tokens = pipe._pack_latents(cond_img, *cond_img.shape)
90
+ cond_ids = pipe._prepare_latent_image_ids(
91
+ cond_img.shape[0],
92
+ cond_img.shape[2]//2,
93
+ cond_img.shape[3]//2,
94
+ pipe.device,
95
+ pipe.dtype,
96
+ )
97
+ return cond_tokens, cond_ids
98
+ def encode(self, pipe: FluxPipeline) -> Tuple[torch.Tensor, torch.Tensor, int]:
99
+ """
100
+ Encodes the condition into tokens, ids and type_id.
101
+ """
102
+ if self.condition_type in [
103
+ "depth",
104
+ "canny",
105
+ "subject",
106
+ "coloring",
107
+ "deblurring",
108
+ "fill",
109
+ ]:
110
+ tokens, ids = self._encode_image(pipe, self.condition)
111
+ else:
112
+ raise NotImplementedError(
113
+ f"Condition type {self.condition_type} not implemented"
114
+ )
115
+ type_id = torch.ones_like(ids[:, :1]) * self.type_id
116
+ return tokens, ids, type_id
src/generate.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import yaml, os
3
+ from diffusers.pipelines import FluxPipeline
4
+ from typing import List, Union, Optional, Dict, Any, Callable
5
+ from .transformer import tranformer_forward
6
+ from .condition import Condition
7
+
8
+ from diffusers.pipelines.flux.pipeline_flux import (
9
+ FluxPipelineOutput,
10
+ calculate_shift,
11
+ retrieve_timesteps,
12
+ np,
13
+ )
14
+
15
+
16
+ def prepare_params(
17
+ prompt: Union[str, List[str]] = None,
18
+ prompt_2: Optional[Union[str, List[str]]] = None,
19
+ height: Optional[int] = 512,
20
+ width: Optional[int] = 512,
21
+ num_inference_steps: int = 28,
22
+ timesteps: List[int] = None,
23
+ guidance_scale: float = 3.5,
24
+ num_images_per_prompt: Optional[int] = 1,
25
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
26
+ latents: Optional[torch.FloatTensor] = None,
27
+ prompt_embeds: Optional[torch.FloatTensor] = None,
28
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
29
+ output_type: Optional[str] = "pil",
30
+ return_dict: bool = True,
31
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
32
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
33
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
34
+ max_sequence_length: int = 512,
35
+ **kwargs: dict,
36
+ ):
37
+ return (
38
+ prompt,
39
+ prompt_2,
40
+ height,
41
+ width,
42
+ num_inference_steps,
43
+ timesteps,
44
+ guidance_scale,
45
+ num_images_per_prompt,
46
+ generator,
47
+ latents,
48
+ prompt_embeds,
49
+ pooled_prompt_embeds,
50
+ output_type,
51
+ return_dict,
52
+ joint_attention_kwargs,
53
+ callback_on_step_end,
54
+ callback_on_step_end_tensor_inputs,
55
+ max_sequence_length,
56
+ )
57
+
58
+
59
+ def seed_everything(seed: int = 42):
60
+ torch.backends.cudnn.deterministic = True
61
+ torch.manual_seed(seed)
62
+ np.random.seed(seed)
63
+
64
+
65
+ @torch.no_grad()
66
+ def generate(
67
+ pipeline: FluxPipeline,
68
+ conditions: List[Condition] = None,
69
+ model_config: Optional[Dict[str, Any]] = {},
70
+ condition_scale: float = 1.0,
71
+ **params: dict,
72
+ ):
73
+ # model_config = model_config or get_config(config_path).get("model", {})
74
+ if condition_scale != 1:
75
+ for name, module in pipeline.transformer.named_modules():
76
+ if not name.endswith(".attn"):
77
+ continue
78
+ module.c_factor = torch.ones(1, 1) * condition_scale
79
+
80
+ self = pipeline
81
+ (
82
+ prompt,
83
+ prompt_2,
84
+ height,
85
+ width,
86
+ num_inference_steps,
87
+ timesteps,
88
+ guidance_scale,
89
+ num_images_per_prompt,
90
+ generator,
91
+ latents,
92
+ prompt_embeds,
93
+ pooled_prompt_embeds,
94
+ output_type,
95
+ return_dict,
96
+ joint_attention_kwargs,
97
+ callback_on_step_end,
98
+ callback_on_step_end_tensor_inputs,
99
+ max_sequence_length,
100
+ ) = prepare_params(**params)
101
+
102
+ height = height or self.default_sample_size * self.vae_scale_factor
103
+ width = width or self.default_sample_size * self.vae_scale_factor
104
+
105
+ # 1. Check inputs. Raise error if not correct
106
+ self.check_inputs(
107
+ prompt,
108
+ prompt_2,
109
+ height,
110
+ width,
111
+ prompt_embeds=prompt_embeds,
112
+ pooled_prompt_embeds=pooled_prompt_embeds,
113
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
114
+ max_sequence_length=max_sequence_length,
115
+ )
116
+
117
+ self._guidance_scale = guidance_scale
118
+ self._joint_attention_kwargs = joint_attention_kwargs
119
+ self._interrupt = False
120
+
121
+ # 2. Define call parameters
122
+ if prompt is not None and isinstance(prompt, str):
123
+ batch_size = 1
124
+ elif prompt is not None and isinstance(prompt, list):
125
+ batch_size = len(prompt)
126
+ else:
127
+ batch_size = prompt_embeds.shape[0]
128
+
129
+ device = self._execution_device
130
+
131
+ lora_scale = (
132
+ self.joint_attention_kwargs.get("scale", None)
133
+ if self.joint_attention_kwargs is not None
134
+ else None
135
+ )
136
+ (
137
+ prompt_embeds,
138
+ pooled_prompt_embeds,
139
+ text_ids,
140
+ ) = self.encode_prompt(
141
+ prompt=prompt,
142
+ prompt_2=prompt_2,
143
+ prompt_embeds=prompt_embeds,
144
+ pooled_prompt_embeds=pooled_prompt_embeds,
145
+ device=device,
146
+ num_images_per_prompt=num_images_per_prompt,
147
+ max_sequence_length=max_sequence_length,
148
+ lora_scale=lora_scale,
149
+ )
150
+
151
+ # 4. Prepare latent variables
152
+ num_channels_latents = self.transformer.config.in_channels // 4
153
+ latents, latent_image_ids = self.prepare_latents(
154
+ batch_size * num_images_per_prompt,
155
+ num_channels_latents,
156
+ height,
157
+ width,
158
+ prompt_embeds.dtype,
159
+ device,
160
+ generator,
161
+ latents,
162
+ )
163
+
164
+ # 4.1. Prepare conditions
165
+ condition_latents, condition_ids, condition_type_ids = ([] for _ in range(3))
166
+ use_condition = conditions is not None or []
167
+ if use_condition:
168
+ assert len(conditions) <= 1, "Only one condition is supported for now."
169
+ pipeline.set_adapters(
170
+ {
171
+ 512: "subject_512",
172
+ 1024: "subject_1024",
173
+ }[height]
174
+ )
175
+ for condition in conditions:
176
+ tokens, ids, type_id = condition.encode(self)
177
+ condition_latents.append(tokens) # [batch_size, token_n, token_dim]
178
+ condition_ids.append(ids) # [token_n, id_dim(3)]
179
+ condition_type_ids.append(type_id) # [token_n, 1]
180
+ condition_latents = torch.cat(condition_latents, dim=1)
181
+ condition_ids = torch.cat(condition_ids, dim=0)
182
+ if condition.condition_type == "subject":
183
+ delta = 32 if height == 512 else -32
184
+ # print(f"Condition delta: {delta}")
185
+ condition_ids[:, 2] += delta
186
+
187
+ condition_type_ids = torch.cat(condition_type_ids, dim=0)
188
+
189
+ # 5. Prepare timesteps
190
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
191
+ image_seq_len = latents.shape[1]
192
+ mu = calculate_shift(
193
+ image_seq_len,
194
+ self.scheduler.config.base_image_seq_len,
195
+ self.scheduler.config.max_image_seq_len,
196
+ self.scheduler.config.base_shift,
197
+ self.scheduler.config.max_shift,
198
+ )
199
+ timesteps, num_inference_steps = retrieve_timesteps(
200
+ self.scheduler,
201
+ num_inference_steps,
202
+ device,
203
+ timesteps,
204
+ sigmas,
205
+ mu=mu,
206
+ )
207
+ num_warmup_steps = max(
208
+ len(timesteps) - num_inference_steps * self.scheduler.order, 0
209
+ )
210
+ self._num_timesteps = len(timesteps)
211
+
212
+ # 6. Denoising loop
213
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
214
+ for i, t in enumerate(timesteps):
215
+ if self.interrupt:
216
+ continue
217
+
218
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
219
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
220
+
221
+ # handle guidance
222
+ if self.transformer.config.guidance_embeds:
223
+ guidance = torch.tensor([guidance_scale], device=device)
224
+ guidance = guidance.expand(latents.shape[0])
225
+ else:
226
+ guidance = None
227
+ noise_pred = tranformer_forward(
228
+ self.transformer,
229
+ model_config=model_config,
230
+ # Inputs of the condition (new feature)
231
+ condition_latents=condition_latents if use_condition else None,
232
+ condition_ids=condition_ids if use_condition else None,
233
+ condition_type_ids=condition_type_ids if use_condition else None,
234
+ # Inputs to the original transformer
235
+ hidden_states=latents,
236
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
237
+ timestep=timestep / 1000,
238
+ guidance=guidance,
239
+ pooled_projections=pooled_prompt_embeds,
240
+ encoder_hidden_states=prompt_embeds,
241
+ txt_ids=text_ids,
242
+ img_ids=latent_image_ids,
243
+ joint_attention_kwargs=self.joint_attention_kwargs,
244
+ return_dict=False,
245
+ )[0]
246
+
247
+ # compute the previous noisy sample x_t -> x_t-1
248
+ latents_dtype = latents.dtype
249
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
250
+
251
+ if latents.dtype != latents_dtype:
252
+ if torch.backends.mps.is_available():
253
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
254
+ latents = latents.to(latents_dtype)
255
+
256
+ if callback_on_step_end is not None:
257
+ callback_kwargs = {}
258
+ for k in callback_on_step_end_tensor_inputs:
259
+ callback_kwargs[k] = locals()[k]
260
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
261
+
262
+ latents = callback_outputs.pop("latents", latents)
263
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
264
+
265
+ # call the callback, if provided
266
+ if i == len(timesteps) - 1 or (
267
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
268
+ ):
269
+ progress_bar.update()
270
+
271
+ if output_type == "latent":
272
+ image = latents
273
+
274
+ else:
275
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
276
+ latents = (
277
+ latents / self.vae.config.scaling_factor
278
+ ) + self.vae.config.shift_factor
279
+ image = self.vae.decode(latents, return_dict=False)[0]
280
+ image = self.image_processor.postprocess(image, output_type=output_type)
281
+
282
+ # Offload all models
283
+ self.maybe_free_model_hooks()
284
+
285
+ if condition_scale != 1:
286
+ for name, module in pipeline.transformer.named_modules():
287
+ if not name.endswith(".attn"):
288
+ continue
289
+ del module.c_factor
290
+
291
+ if not return_dict:
292
+ return (image,)
293
+
294
+ return FluxPipelineOutput(images=image)
src/lora_controller.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from peft.tuners.tuners_utils import BaseTunerLayer
2
+ from typing import List, Any, Optional, Type
3
+
4
+
5
+ class enable_lora:
6
+ def __init__(self, lora_modules: List[BaseTunerLayer], activated: bool) -> None:
7
+ self.activated: bool = activated
8
+ if activated:
9
+ return
10
+ self.lora_modules: List[BaseTunerLayer] = [
11
+ each for each in lora_modules if isinstance(each, BaseTunerLayer)
12
+ ]
13
+ self.scales = [
14
+ {
15
+ active_adapter: lora_module.scaling[active_adapter]
16
+ for active_adapter in lora_module.active_adapters
17
+ }
18
+ for lora_module in self.lora_modules
19
+ ]
20
+
21
+ def __enter__(self) -> None:
22
+ if self.activated:
23
+ return
24
+
25
+ for lora_module in self.lora_modules:
26
+ if not isinstance(lora_module, BaseTunerLayer):
27
+ continue
28
+ lora_module.scale_layer(0)
29
+
30
+ def __exit__(
31
+ self,
32
+ exc_type: Optional[Type[BaseException]],
33
+ exc_val: Optional[BaseException],
34
+ exc_tb: Optional[Any],
35
+ ) -> None:
36
+ if self.activated:
37
+ return
38
+ for i, lora_module in enumerate(self.lora_modules):
39
+ if not isinstance(lora_module, BaseTunerLayer):
40
+ continue
41
+ for active_adapter in lora_module.active_adapters:
42
+ lora_module.scaling[active_adapter] = self.scales[i][active_adapter]
43
+
44
+
45
+ class set_lora_scale:
46
+ def __init__(self, lora_modules: List[BaseTunerLayer], scale: float) -> None:
47
+ self.lora_modules: List[BaseTunerLayer] = [
48
+ each for each in lora_modules if isinstance(each, BaseTunerLayer)
49
+ ]
50
+ self.scales = [
51
+ {
52
+ active_adapter: lora_module.scaling[active_adapter]
53
+ for active_adapter in lora_module.active_adapters
54
+ }
55
+ for lora_module in self.lora_modules
56
+ ]
57
+ self.scale = scale
58
+
59
+ def __enter__(self) -> None:
60
+ for lora_module in self.lora_modules:
61
+ if not isinstance(lora_module, BaseTunerLayer):
62
+ continue
63
+ lora_module.scale_layer(self.scale)
64
+
65
+ def __exit__(
66
+ self,
67
+ exc_type: Optional[Type[BaseException]],
68
+ exc_val: Optional[BaseException],
69
+ exc_tb: Optional[Any],
70
+ ) -> None:
71
+ for i, lora_module in enumerate(self.lora_modules):
72
+ if not isinstance(lora_module, BaseTunerLayer):
73
+ continue
74
+ for active_adapter in lora_module.active_adapters:
75
+ lora_module.scaling[active_adapter] = self.scales[i][active_adapter]
src/transformer.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers.pipelines import FluxPipeline
3
+ from typing import List, Union, Optional, Dict, Any, Callable
4
+ from .block import block_forward, single_block_forward
5
+ from .lora_controller import enable_lora
6
+ from diffusers.models.transformers.transformer_flux import (
7
+ FluxTransformer2DModel,
8
+ Transformer2DModelOutput,
9
+ USE_PEFT_BACKEND,
10
+ is_torch_version,
11
+ scale_lora_layers,
12
+ unscale_lora_layers,
13
+ logger,
14
+ )
15
+ import numpy as np
16
+
17
+
18
+ def prepare_params(
19
+ hidden_states: torch.Tensor,
20
+ encoder_hidden_states: torch.Tensor = None,
21
+ pooled_projections: torch.Tensor = None,
22
+ timestep: torch.LongTensor = None,
23
+ img_ids: torch.Tensor = None,
24
+ txt_ids: torch.Tensor = None,
25
+ guidance: torch.Tensor = None,
26
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
27
+ controlnet_block_samples=None,
28
+ controlnet_single_block_samples=None,
29
+ return_dict: bool = True,
30
+ **kwargs: dict,
31
+ ):
32
+ return (
33
+ hidden_states,
34
+ encoder_hidden_states,
35
+ pooled_projections,
36
+ timestep,
37
+ img_ids,
38
+ txt_ids,
39
+ guidance,
40
+ joint_attention_kwargs,
41
+ controlnet_block_samples,
42
+ controlnet_single_block_samples,
43
+ return_dict,
44
+ )
45
+
46
+
47
+ def tranformer_forward(
48
+ transformer: FluxTransformer2DModel,
49
+ condition_latents: torch.Tensor,
50
+ condition_ids: torch.Tensor,
51
+ condition_type_ids: torch.Tensor,
52
+ model_config: Optional[Dict[str, Any]] = {},
53
+ return_conditional_latents: bool = False,
54
+ c_t=0,
55
+ **params: dict,
56
+ ):
57
+ self = transformer
58
+ use_condition = condition_latents is not None
59
+ use_condition_in_single_blocks = model_config.get(
60
+ "use_condition_in_single_blocks", True
61
+ )
62
+ # if return_conditional_latents is True, use_condition and use_condition_in_single_blocks must be True
63
+ assert not return_conditional_latents or (
64
+ use_condition and use_condition_in_single_blocks
65
+ ), "`return_conditional_latents` is True, `use_condition` and `use_condition_in_single_blocks` must be True"
66
+
67
+ (
68
+ hidden_states,
69
+ encoder_hidden_states,
70
+ pooled_projections,
71
+ timestep,
72
+ img_ids,
73
+ txt_ids,
74
+ guidance,
75
+ joint_attention_kwargs,
76
+ controlnet_block_samples,
77
+ controlnet_single_block_samples,
78
+ return_dict,
79
+ ) = prepare_params(**params)
80
+
81
+ if joint_attention_kwargs is not None:
82
+ joint_attention_kwargs = joint_attention_kwargs.copy()
83
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
84
+ else:
85
+ lora_scale = 1.0
86
+
87
+ if USE_PEFT_BACKEND:
88
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
89
+ scale_lora_layers(self, lora_scale)
90
+ else:
91
+ if (
92
+ joint_attention_kwargs is not None
93
+ and joint_attention_kwargs.get("scale", None) is not None
94
+ ):
95
+ logger.warning(
96
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
97
+ )
98
+ with enable_lora((self.x_embedder,), model_config.get("latent_lora", False)):
99
+ hidden_states = self.x_embedder(hidden_states)
100
+ condition_latents = self.x_embedder(condition_latents) if use_condition else None
101
+
102
+ timestep = timestep.to(hidden_states.dtype) * 1000
103
+ if guidance is not None:
104
+ guidance = guidance.to(hidden_states.dtype) * 1000
105
+ else:
106
+ guidance = None
107
+ temb = (
108
+ self.time_text_embed(timestep, pooled_projections)
109
+ if guidance is None
110
+ else self.time_text_embed(timestep, guidance, pooled_projections)
111
+ )
112
+ cond_temb = (
113
+ self.time_text_embed(torch.ones_like(timestep) * c_t * 1000, pooled_projections)
114
+ if guidance is None
115
+ else self.time_text_embed(
116
+ torch.ones_like(timestep) * c_t * 1000, guidance, pooled_projections
117
+ )
118
+ )
119
+ if hasattr(self, "cond_type_embed") and condition_type_ids is not None:
120
+ cond_type_proj = self.time_text_embed.time_proj(condition_type_ids[0])
121
+ cond_type_emb = self.cond_type_embed(cond_type_proj.to(dtype=cond_temb.dtype))
122
+ cond_temb = cond_temb + cond_type_emb
123
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
124
+
125
+ if txt_ids.ndim == 3:
126
+ logger.warning(
127
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
128
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
129
+ )
130
+ txt_ids = txt_ids[0]
131
+ if img_ids.ndim == 3:
132
+ logger.warning(
133
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
134
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
135
+ )
136
+ img_ids = img_ids[0]
137
+
138
+ ids = torch.cat((txt_ids, img_ids), dim=0)
139
+ image_rotary_emb = self.pos_embed(ids)
140
+ if use_condition:
141
+ cond_ids = condition_ids
142
+ cond_rotary_emb = self.pos_embed(cond_ids)
143
+
144
+ # hidden_states = torch.cat([hidden_states, condition_latents], dim=1)
145
+
146
+ for index_block, block in enumerate(self.transformer_blocks):
147
+ if self.training and self.gradient_checkpointing:
148
+
149
+ def create_custom_forward(module, return_dict=None):
150
+ def custom_forward(*inputs):
151
+ if return_dict is not None:
152
+ return module(*inputs, return_dict=return_dict)
153
+ else:
154
+ return module(*inputs)
155
+
156
+ return custom_forward
157
+
158
+ ckpt_kwargs: Dict[str, Any] = (
159
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
160
+ )
161
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
162
+ create_custom_forward(block),
163
+ hidden_states,
164
+ encoder_hidden_states,
165
+ temb,
166
+ image_rotary_emb,
167
+ **ckpt_kwargs,
168
+ )
169
+
170
+ else:
171
+ encoder_hidden_states, hidden_states, condition_latents = block_forward(
172
+ block,
173
+ model_config=model_config,
174
+ hidden_states=hidden_states,
175
+ encoder_hidden_states=encoder_hidden_states,
176
+ condition_latents=condition_latents if use_condition else None,
177
+ temb=temb,
178
+ cond_temb=cond_temb if use_condition else None,
179
+ cond_rotary_emb=cond_rotary_emb if use_condition else None,
180
+ image_rotary_emb=image_rotary_emb,
181
+ )
182
+
183
+ # controlnet residual
184
+ if controlnet_block_samples is not None:
185
+ interval_control = len(self.transformer_blocks) / len(
186
+ controlnet_block_samples
187
+ )
188
+ interval_control = int(np.ceil(interval_control))
189
+ hidden_states = (
190
+ hidden_states
191
+ + controlnet_block_samples[index_block // interval_control]
192
+ )
193
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
194
+
195
+ for index_block, block in enumerate(self.single_transformer_blocks):
196
+ if self.training and self.gradient_checkpointing:
197
+
198
+ def create_custom_forward(module, return_dict=None):
199
+ def custom_forward(*inputs):
200
+ if return_dict is not None:
201
+ return module(*inputs, return_dict=return_dict)
202
+ else:
203
+ return module(*inputs)
204
+
205
+ return custom_forward
206
+
207
+ ckpt_kwargs: Dict[str, Any] = (
208
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
209
+ )
210
+ hidden_states = torch.utils.checkpoint.checkpoint(
211
+ create_custom_forward(block),
212
+ hidden_states,
213
+ temb,
214
+ image_rotary_emb,
215
+ **ckpt_kwargs,
216
+ )
217
+
218
+ else:
219
+ result = single_block_forward(
220
+ block,
221
+ model_config=model_config,
222
+ hidden_states=hidden_states,
223
+ temb=temb,
224
+ image_rotary_emb=image_rotary_emb,
225
+ **(
226
+ {
227
+ "condition_latents": condition_latents,
228
+ "cond_temb": cond_temb,
229
+ "cond_rotary_emb": cond_rotary_emb,
230
+ }
231
+ if use_condition_in_single_blocks and use_condition
232
+ else {}
233
+ ),
234
+ )
235
+ if use_condition_in_single_blocks and use_condition:
236
+ hidden_states, condition_latents = result
237
+ else:
238
+ hidden_states = result
239
+
240
+ # controlnet residual
241
+ if controlnet_single_block_samples is not None:
242
+ interval_control = len(self.single_transformer_blocks) / len(
243
+ controlnet_single_block_samples
244
+ )
245
+ interval_control = int(np.ceil(interval_control))
246
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
247
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...]
248
+ + controlnet_single_block_samples[index_block // interval_control]
249
+ )
250
+
251
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
252
+
253
+ hidden_states = self.norm_out(hidden_states, temb)
254
+ output = self.proj_out(hidden_states)
255
+ if return_conditional_latents:
256
+ condition_latents = (
257
+ self.norm_out(condition_latents, cond_temb) if use_condition else None
258
+ )
259
+ condition_output = self.proj_out(condition_latents) if use_condition else None
260
+
261
+ if USE_PEFT_BACKEND:
262
+ # remove `lora_scale` from each PEFT layer
263
+ unscale_lora_layers(self, lora_scale)
264
+
265
+ if not return_dict:
266
+ return (
267
+ (output,) if not return_conditional_latents else (output, condition_output)
268
+ )
269
+
270
+ return Transformer2DModelOutput(sample=output)
utils/ai_generator.py CHANGED
@@ -1,7 +1,8 @@
1
  # utils/ai_generator.py
2
 
3
  import os
4
- import time # Added for implementing delays
 
5
  import torch
6
  import random
7
  from utils.ai_generator_diffusers_flux import generate_ai_image_local
@@ -34,6 +35,9 @@ def generate_ai_image(
34
  lora_weights=None,
35
  conditioned_image=None,
36
  pipeline = "FluxPipeline",
 
 
 
37
  *args,
38
  **kwargs
39
  ):
@@ -51,7 +55,9 @@ def generate_ai_image(
51
  seed=seed,
52
  conditioned_image=conditioned_image,
53
  pipeline_name=pipeline,
54
- strength=0.5
 
 
55
  )
56
  else:
57
  print("No local GPU available. Sending request to Hugging Face API.")
@@ -59,10 +65,12 @@ def generate_ai_image(
59
  map_option,
60
  prompt_textbox_value,
61
  neg_prompt_textbox_value,
62
- model
 
 
63
  )
64
 
65
- def generate_ai_image_remote(map_option, prompt_textbox_value, neg_prompt_textbox_value, model, height=512, width=896, num_inference_steps=50, guidance_scale=3.5, seed=777):
66
  max_retries = 3
67
  retry_delay = 4 # Initial delay in seconds
68
 
 
1
  # utils/ai_generator.py
2
 
3
  import os
4
+ import time
5
+ from turtle import width # Added for implementing delays
6
  import torch
7
  import random
8
  from utils.ai_generator_diffusers_flux import generate_ai_image_local
 
35
  lora_weights=None,
36
  conditioned_image=None,
37
  pipeline = "FluxPipeline",
38
+ width=912,
39
+ height=512,
40
+ strength=0.5,
41
  *args,
42
  **kwargs
43
  ):
 
55
  seed=seed,
56
  conditioned_image=conditioned_image,
57
  pipeline_name=pipeline,
58
+ strength=strength,
59
+ height=height,
60
+ width=width
61
  )
62
  else:
63
  print("No local GPU available. Sending request to Hugging Face API.")
 
65
  map_option,
66
  prompt_textbox_value,
67
  neg_prompt_textbox_value,
68
+ model,
69
+ height=height,
70
+ width=width
71
  )
72
 
73
+ def generate_ai_image_remote(map_option, prompt_textbox_value, neg_prompt_textbox_value, model, height=512, width=912, num_inference_steps=30, guidance_scale=3.5, seed=777):
74
  max_retries = 3
75
  retry_delay = 4 # Initial delay in seconds
76
 
utils/ai_generator_diffusers_flux.py CHANGED
@@ -1,13 +1,13 @@
1
  # utils/ai_generator_diffusers_flux.py
2
  import os
3
  import torch
4
- from diffusers import FluxPipeline,FluxImg2ImgPipeline
5
  import accelerate
6
  import transformers
7
  import safetensors
8
  import xformers
9
  from diffusers.utils import load_image
10
- # from huggingface_hub import hf_hub_download
11
  from PIL import Image
12
  from tempfile import NamedTemporaryFile
13
  from src.condition import Condition
@@ -16,15 +16,14 @@ from utils.image_utils import (
16
  crop_and_resize_image,
17
  )
18
  from utils.version_info import (
19
- versions_html,
20
  get_torch_info,
21
  get_diffusers_version,
22
  get_transformers_version,
23
  get_xformers_version
24
  )
25
- from utils.lora_details import get_trigger_words
26
  from utils.color_utils import detect_color_format
27
- # import utils.misc as misc
28
  from pathlib import Path
29
  import warnings
30
  warnings.filterwarnings("ignore", message=".*Torch was not compiled with flash attention.*")
@@ -93,6 +92,7 @@ def generate_image_from_text(
93
  generate_params = {k: v for k, v in generate_params.items() if v is not None}
94
  result = pipe(**generate_params)
95
  image = result.images[0]
 
96
  return image
97
 
98
  def generate_image_lowmem(
@@ -101,10 +101,10 @@ def generate_image_lowmem(
101
  model_name="black-forest-labs/FLUX.1-dev",
102
  lora_weights=None,
103
  conditioned_image=None,
104
- image_width=1344,
105
  image_height=848,
106
  guidance_scale=3.5,
107
- num_inference_steps=50,
108
  seed=0,
109
  true_cfg_scale=1.0,
110
  pipeline_name="FluxPipeline",
@@ -117,7 +117,7 @@ def generate_image_lowmem(
117
  raise ValueError(f"Unsupported pipeline type '{pipeline_name}'. "
118
  f"Available options: {list(PIPELINE_CLASSES.keys())}")
119
  device = "cuda" if torch.cuda.is_available() else "cpu"
120
- print(f"device:{device}\nmodel_name:{model_name}\n")
121
  print(f"\n {get_torch_info()}\n")
122
  # Disable gradient calculations
123
  with torch.no_grad():
@@ -141,27 +141,59 @@ def generate_image_lowmem(
141
  if pipeline_name == "FluxPipeline":
142
  pipe.enable_vae_tiling()
143
  # Load LoRA weights
 
144
  if lora_weights:
145
  for lora_weight in lora_weights:
146
  lora_configs = constants.LORA_DETAILS.get(lora_weight, [])
 
147
  if lora_configs:
148
  for config in lora_configs:
149
  # Load LoRA weights with optional weight_name and adapter_name
150
- weight_name = config.get("weight_name")
151
- adapter_name = config.get("adapter_name")
152
- if weight_name and adapter_name:
153
- pipe.load_lora_weights(
154
- lora_weight,
155
- weight_name=weight_name,
156
- adapter_name=adapter_name,
157
- use_auth_token=constants.HF_API_TOKEN
158
- )
159
- else:
160
- pipe.load_lora_weights(
161
- lora_weight,
162
- use_auth_token=constants.HF_API_TOKEN
163
- )
164
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  # Apply 'pipe' configurations if present
166
  if 'pipe' in config:
167
  pipe_config = config['pipe']
@@ -174,6 +206,7 @@ def generate_image_lowmem(
174
  print(f"Method {method_name} not found in pipe.")
175
  else:
176
  pipe.load_lora_weights(lora_weight, use_auth_token=constants.HF_API_TOKEN)
 
177
  generator = torch.Generator(device=device).manual_seed(seed)
178
  conditions = []
179
  if conditioned_image is not None:
@@ -194,8 +227,20 @@ def generate_image_lowmem(
194
  "negative_prompt": neg_prompt,
195
  "true_cfg_scale": true_cfg_scale,
196
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  generate_params = {
198
- "prompt": text,
199
  "height": image_height,
200
  "width": image_width,
201
  "guidance_scale": guidance_scale,
@@ -204,6 +249,7 @@ def generate_image_lowmem(
204
  if additional_parameters:
205
  generate_params.update(additional_parameters)
206
  generate_params = {k: v for k, v in generate_params.items() if v is not None}
 
207
  # Generate the image
208
  result = pipe(**generate_params)
209
  image = result.images[0]
@@ -214,6 +260,7 @@ def generate_image_lowmem(
214
  # Delete the pipeline and clear cache
215
  del pipe
216
  torch.cuda.empty_cache()
 
217
  print(torch.cuda.memory_summary(device=None, abbreviated=False))
218
  return image
219
 
@@ -225,8 +272,8 @@ def generate_ai_image_local (
225
  lora_weights=None,
226
  conditioned_image=None,
227
  height=512,
228
- width=896,
229
- num_inference_steps=50,
230
  guidance_scale=3.5,
231
  seed=777,
232
  pipeline_name="FluxPipeline",
@@ -293,4 +340,20 @@ def generate_ai_image_local (
293
  return tmp.name
294
  except Exception as e:
295
  print(f"Error generating AI image: {e}")
296
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # utils/ai_generator_diffusers_flux.py
2
  import os
3
  import torch
4
+ from diffusers import FluxPipeline,FluxImg2ImgPipeline,FluxControlPipeline
5
  import accelerate
6
  import transformers
7
  import safetensors
8
  import xformers
9
  from diffusers.utils import load_image
10
+ from huggingface_hub import hf_hub_download
11
  from PIL import Image
12
  from tempfile import NamedTemporaryFile
13
  from src.condition import Condition
 
16
  crop_and_resize_image,
17
  )
18
  from utils.version_info import (
 
19
  get_torch_info,
20
  get_diffusers_version,
21
  get_transformers_version,
22
  get_xformers_version
23
  )
24
+ from utils.lora_details import get_trigger_words, approximate_token_count, split_prompt_precisely
25
  from utils.color_utils import detect_color_format
26
+ import utils.misc as misc
27
  from pathlib import Path
28
  import warnings
29
  warnings.filterwarnings("ignore", message=".*Torch was not compiled with flash attention.*")
 
92
  generate_params = {k: v for k, v in generate_params.items() if v is not None}
93
  result = pipe(**generate_params)
94
  image = result.images[0]
95
+ pipe.unload_lora_weights()
96
  return image
97
 
98
  def generate_image_lowmem(
 
101
  model_name="black-forest-labs/FLUX.1-dev",
102
  lora_weights=None,
103
  conditioned_image=None,
104
+ image_width=1368,
105
  image_height=848,
106
  guidance_scale=3.5,
107
+ num_inference_steps=30,
108
  seed=0,
109
  true_cfg_scale=1.0,
110
  pipeline_name="FluxPipeline",
 
117
  raise ValueError(f"Unsupported pipeline type '{pipeline_name}'. "
118
  f"Available options: {list(PIPELINE_CLASSES.keys())}")
119
  device = "cuda" if torch.cuda.is_available() else "cpu"
120
+ print(f"device:{device}\nmodel_name:{model_name}\nlora_weights:{lora_weights}\n")
121
  print(f"\n {get_torch_info()}\n")
122
  # Disable gradient calculations
123
  with torch.no_grad():
 
141
  if pipeline_name == "FluxPipeline":
142
  pipe.enable_vae_tiling()
143
  # Load LoRA weights
144
+ # note: does not yet handle multiple LoRA weights with different names, needs .set_adapters(["depth", "hyper-sd"], adapter_weights=[0.85, 0.125])
145
  if lora_weights:
146
  for lora_weight in lora_weights:
147
  lora_configs = constants.LORA_DETAILS.get(lora_weight, [])
148
+ lora_weight_set = False
149
  if lora_configs:
150
  for config in lora_configs:
151
  # Load LoRA weights with optional weight_name and adapter_name
152
+ if 'weight_name' in config:
153
+ weight_name = config.get("weight_name")
154
+ adapter_name = config.get("adapter_name")
155
+ lora_collection = config.get("lora_collection")
156
+ if weight_name and adapter_name and lora_collection and lora_weight_set == False:
157
+ pipe.load_lora_weights(
158
+ lora_collection,
159
+ weight_name=weight_name,
160
+ adapter_name=adapter_name,
161
+ token=constants.HF_API_TOKEN
162
+ )
163
+ lora_weight_set = True
164
+ print(f"\npipe.load_lora_weights({lora_weight}, weight_name={weight_name}, adapter_name={adapter_name}, lora_collection={lora_collection}\n")
165
+ elif weight_name and adapter_name==None and lora_collection and lora_weight_set == False:
166
+ pipe.load_lora_weights(
167
+ lora_collection,
168
+ weight_name=weight_name,
169
+ token=constants.HF_API_TOKEN
170
+ )
171
+ lora_weight_set = True
172
+ print(f"\npipe.load_lora_weights({lora_weight}, weight_name={weight_name}, adapter_name={adapter_name}, lora_collection={lora_collection}\n")
173
+ elif weight_name and adapter_name and lora_weight_set == False:
174
+ pipe.load_lora_weights(
175
+ lora_weight,
176
+ weight_name=weight_name,
177
+ adapter_name=adapter_name,
178
+ token=constants.HF_API_TOKEN
179
+ )
180
+ lora_weight_set = True
181
+ print(f"\npipe.load_lora_weights({lora_weight}, weight_name={weight_name}, adapter_name={adapter_name}\n")
182
+ elif weight_name and adapter_name==None and lora_weight_set == False:
183
+ pipe.load_lora_weights(
184
+ lora_weight,
185
+ weight_name=weight_name,
186
+ token=constants.HF_API_TOKEN
187
+ )
188
+ lora_weight_set = True
189
+ print(f"\npipe.load_lora_weights({lora_weight}, weight_name={weight_name}, adapter_name={adapter_name}\n")
190
+ elif lora_weight_set == False:
191
+ pipe.load_lora_weights(
192
+ lora_weight,
193
+ token=constants.HF_API_TOKEN
194
+ )
195
+ lora_weight_set = True
196
+ print(f"\npipe.load_lora_weights({lora_weight}, weight_name={weight_name}, adapter_name={adapter_name}\n")
197
  # Apply 'pipe' configurations if present
198
  if 'pipe' in config:
199
  pipe_config = config['pipe']
 
206
  print(f"Method {method_name} not found in pipe.")
207
  else:
208
  pipe.load_lora_weights(lora_weight, use_auth_token=constants.HF_API_TOKEN)
209
+ # Set the random seed for reproducibility
210
  generator = torch.Generator(device=device).manual_seed(seed)
211
  conditions = []
212
  if conditioned_image is not None:
 
227
  "negative_prompt": neg_prompt,
228
  "true_cfg_scale": true_cfg_scale,
229
  }
230
+ # handle long prompts by splitting them
231
+ if approximate_token_count(text) > 76:
232
+ prompt, prompt2 = split_prompt_precisely(text)
233
+ prompt_parameters = {
234
+ "prompt" : prompt,
235
+ "prompt_2": prompt2
236
+ }
237
+ else:
238
+ prompt_parameters = {
239
+ "prompt" :text
240
+ }
241
+ additional_parameters.update(prompt_parameters)
242
+ # Combine all parameters
243
  generate_params = {
 
244
  "height": image_height,
245
  "width": image_width,
246
  "guidance_scale": guidance_scale,
 
249
  if additional_parameters:
250
  generate_params.update(additional_parameters)
251
  generate_params = {k: v for k, v in generate_params.items() if v is not None}
252
+ print(f"generate_params: {generate_params}")
253
  # Generate the image
254
  result = pipe(**generate_params)
255
  image = result.images[0]
 
260
  # Delete the pipeline and clear cache
261
  del pipe
262
  torch.cuda.empty_cache()
263
+ torch.cuda.ipc_collect()
264
  print(torch.cuda.memory_summary(device=None, abbreviated=False))
265
  return image
266
 
 
272
  lora_weights=None,
273
  conditioned_image=None,
274
  height=512,
275
+ width=912,
276
+ num_inference_steps=30,
277
  guidance_scale=3.5,
278
  seed=777,
279
  pipeline_name="FluxPipeline",
 
340
  return tmp.name
341
  except Exception as e:
342
  print(f"Error generating AI image: {e}")
343
+ return None
344
+
345
+ # does not work
346
+ def merge_LoRA_weights(model="black-forest-labs/FLUX.1-dev",
347
+ lora_weights="Borcherding/FLUX.1-dev-LoRA-FractalLand-v0.1"):
348
+
349
+ model_suffix = model.split("/")[-1]
350
+ if model_suffix not in lora_weights:
351
+ raise ValueError(f"The model suffix '{model_suffix}' must be in the lora_weights string '{lora_weights}' to proceed.")
352
+
353
+ pipe = FluxPipeline.from_pretrained(model, torch_dtype=torch.bfloat16)
354
+ pipe.load_lora_weights(lora_weights)
355
+ pipe.save_lora_weights(os.getenv("TMPDIR"))
356
+ lora_name = lora_weights.split("/")[-1] + "-merged"
357
+ pipe.save_pretrained(lora_name)
358
+ pipe.unload_lora_weights()
359
+
utils/color_utils.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils/color_utils.py
2
+
3
+ from PIL import Image, ImageColor
4
+ import re
5
+ import cairocffi as cairo
6
+ import pangocffi
7
+ import pangocairocffi
8
+
9
+
10
+ def multiply_and_clamp(value, scale, min_value=0, max_value=255):
11
+ return min(max(value * scale, min_value), max_value)
12
+
13
+ # Convert decimal color to hexadecimal color (rgb or rgba)
14
+ def rgb_to_hex(rgb):
15
+ color = "#"
16
+ for i in rgb:
17
+ num = int(i)
18
+ color += str(hex(num))[-2:].replace("x", "0").upper()
19
+ return color
20
+
21
+ def parse_hex_color(hex_color, base = 1):
22
+ """
23
+ This function is set to pass the color in (1.0,1.0, 1.0, 1.0) format.
24
+ Change base to 255 to get the color in (255, 255, 255, 255) format.
25
+ Parses a hex color string or tuple into RGBA components.
26
+ Parses color values specified in various formats and convert them into normalized RGBA components
27
+ suitable for use in color calculations, rendering, or manipulation.
28
+
29
+ Supports:
30
+ - #RRGGBBAA
31
+ - #RRGGBB (assumes full opacity)
32
+ - (r, g, b, a) tuple
33
+ """
34
+ if isinstance(hex_color, tuple):
35
+ if len(hex_color) == 4:
36
+ r, g, b, a = hex_color
37
+ elif len(hex_color) == 3:
38
+ r, g, b = hex_color
39
+ a = 1.0 # Full opacity
40
+ else:
41
+ raise ValueError("Tuple must be in the format (r, g, b) or (r, g, b, a)")
42
+ return r / 255.0, g / 255.0, b / 255.0, a / 255.0 if a <= 1 else a
43
+
44
+ if hex_color.startswith("#"):
45
+ if len(hex_color) == 6:
46
+ r = int(hex_color[0:2], 16) / 255.0
47
+ g = int(hex_color[2:4], 16) / 255.0
48
+ b = int(hex_color[4:6], 16) / 255.0
49
+ a = 1.0 # Full opacity
50
+ elif len(hex_color) == 8:
51
+ r = int(hex_color[0:2], 16) / 255.0
52
+ g = int(hex_color[2:4], 16) / 255.0
53
+ b = int(hex_color[4:6], 16) / 255.0
54
+ a = int(hex_color[6:8], 16) / 255.0
55
+ else:
56
+ try:
57
+ r, g, b, a = ImageColor.getcolor(hex_color, "RGBA")
58
+ r = r / 255
59
+ g = g / 255
60
+ b = b / 255
61
+ a = a / 255
62
+ except:
63
+ raise ValueError("Hex color must be in the format RRGGBB, RRGGBBAA, ( r, g, b, a) or a common color name")
64
+ return multiply_and_clamp(r,base, max_value= base), multiply_and_clamp(g, base, max_value= base), multiply_and_clamp(b , base, max_value= base), multiply_and_clamp(a , base, max_value= base)
65
+
66
+ # Define a function to convert a hexadecimal color code to an RGB(A) tuple
67
+ def hex_to_rgb(hex):
68
+ if hex.startswith("#"):
69
+ clean_hex = hex.replace('#','')
70
+ # Use a generator expression to convert pairs of hexadecimal digits to integers and create a tuple
71
+ return tuple(int(clean_hex[i:i+2], 16) for i in range(0, len(clean_hex),2))
72
+ else:
73
+ return detect_color_format(hex)
74
+
75
+ def detect_color_format(color):
76
+ """
77
+ Detects if the color is in RGB, RGBA, or hex format,
78
+ and converts it to an RGBA tuple with integer components.
79
+
80
+ Args:
81
+ color (str or tuple): The color to detect.
82
+
83
+ Returns:
84
+ tuple: The color in RGBA format as a tuple of 4 integers.
85
+
86
+ Raises:
87
+ ValueError: If the input color is not in a recognized format.
88
+ """
89
+ # Handle color as a tuple of floats or integers
90
+ if isinstance(color, tuple):
91
+ if len(color) == 3 or len(color) == 4:
92
+ # Ensure all components are numbers
93
+ if all(isinstance(c, (int, float)) for c in color):
94
+ r, g, b = color[:3]
95
+ a = color[3] if len(color) == 4 else 255
96
+ return (
97
+ max(0, min(255, int(round(r)))),
98
+ max(0, min(255, int(round(g)))),
99
+ max(0, min(255, int(round(b)))),
100
+ max(0, min(255, int(round(a * 255)) if a <= 1 else round(a))),
101
+ )
102
+ else:
103
+ raise ValueError(f"Invalid color tuple length: {len(color)}")
104
+ # Handle hex color codes
105
+ if isinstance(color, str):
106
+ color = color.strip()
107
+ # Try to use PIL's ImageColor
108
+ try:
109
+ rgba = ImageColor.getcolor(color, "RGBA")
110
+ return rgba
111
+ except ValueError:
112
+ pass
113
+ # Handle 'rgba(r, g, b, a)' string format
114
+ rgba_match = re.match(r'rgba\(\s*([0-9.]+),\s*([0-9.]+),\s*([0-9.]+),\s*([0-9.]+)\s*\)', color)
115
+ if rgba_match:
116
+ r, g, b, a = map(float, rgba_match.groups())
117
+ return (
118
+ max(0, min(255, int(round(r)))),
119
+ max(0, min(255, int(round(g)))),
120
+ max(0, min(255, int(round(b)))),
121
+ max(0, min(255, int(round(a * 255)) if a <= 1 else round(a))),
122
+ )
123
+ # Handle 'rgb(r, g, b)' string format
124
+ rgb_match = re.match(r'rgb\(\s*([0-9.]+),\s*([0-9.]+),\s*([0-9.]+)\s*\)', color)
125
+ if rgb_match:
126
+ r, g, b = map(float, rgb_match.groups())
127
+ return (
128
+ max(0, min(255, int(round(r)))),
129
+ max(0, min(255, int(round(g)))),
130
+ max(0, min(255, int(round(b)))),
131
+ 255,
132
+ )
133
+
134
+ # If none of the above conversions work, raise an error
135
+ raise ValueError(f"Invalid color format: {color}")
136
+
137
+
138
+ def update_color_opacity(color, opacity):
139
+ """
140
+ Updates the opacity of a color value.
141
+
142
+ Parameters:
143
+ color (tuple): A color represented as an RGB or RGBA tuple.
144
+ opacity (int): An integer between 0 and 255 representing the desired opacity.
145
+
146
+ Returns:
147
+ tuple: The color as an RGBA tuple with the updated opacity.
148
+ """
149
+ # Ensure opacity is within the valid range
150
+ opacity = max(0, min(255, int(opacity)))
151
+
152
+ if len(color) == 3:
153
+ # Color is RGB, add the opacity to make it RGBA
154
+ return color + (opacity,)
155
+ elif len(color) == 4:
156
+ # Color is RGBA, replace the alpha value with the new opacity
157
+ return color[:3] + (opacity,)
158
+ else:
159
+ raise ValueError(f"Invalid color format: {color}. Must be an RGB or RGBA tuple.")
160
+
161
+ def draw_text_with_emojis(image, text, font_color, offset_x, offset_y, font_name, font_size):
162
+ """
163
+ Draws text with emojis directly onto the given PIL image at specified coordinates with the specified color.
164
+ Parameters:
165
+ image (PIL.Image.Image): The RGBA image to draw on.
166
+ text (str): The text to draw, including emojis.
167
+ font_color (tuple): RGBA color tuple for the text (e.g., (255, 0, 0, 255)).
168
+ offset_x (int): The x-coordinate for the text center position.
169
+ offset_y (int): The y-coordinate for the text center position.
170
+ font_name (str): The name of the font family.
171
+ font_size (int): Size of the font.
172
+ Returns:
173
+ None: The function modifies the image in place.
174
+ """
175
+ if image.mode != 'RGBA':
176
+ raise ValueError("Image must be in RGBA mode.")
177
+ # Convert PIL image to a mutable bytearray
178
+ img_data = bytearray(image.tobytes("raw", "BGRA"))
179
+ # Create a Cairo ImageSurface that wraps the image's data
180
+ surface = cairo.ImageSurface.create_for_data(
181
+ img_data,
182
+ cairo.FORMAT_ARGB32,
183
+ image.width,
184
+ image.height,
185
+ image.width * 4
186
+ )
187
+ context = cairo.Context(surface)
188
+ # Create Pango layout
189
+ layout = pangocairocffi.create_layout(context)
190
+ layout._set_text(text)
191
+ # Set font description
192
+ desc = pangocffi.FontDescription()
193
+ desc._set_family(font_name)
194
+ desc._set_size(pangocffi.units_from_double(font_size))
195
+ layout._set_font_description(desc)
196
+ # Set text color
197
+ r, g, b, a = parse_hex_color(font_color)
198
+ context.set_source_rgba(r , g , b , a )
199
+ # Move to the position (top-left corner adjusted to center the text)
200
+ context.move_to(offset_x, offset_y)
201
+ # Render the text
202
+ pangocairocffi.show_layout(context, layout)
203
+ # Flush the surface to ensure all drawing operations are complete
204
+ surface.flush()
205
+ # Convert the modified bytearray back to a PIL Image
206
+ modified_image = Image.frombuffer(
207
+ "RGBA",
208
+ (image.width, image.height),
209
+ bytes(img_data),
210
+ "raw",
211
+ "BGRA", # Cairo stores data in BGRA order
212
+ surface.get_stride(),
213
+ ).convert("RGBA")
214
+ return modified_image
utils/constants.py CHANGED
@@ -1,8 +1,8 @@
1
- import os
2
  #Set the environment variables
3
  os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
4
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:256,expandable_segments:True"
5
- IS_SHARED_SPACE = "Surn/HexaGridCreator" in os.environ.get('SPACE_ID', '')
6
 
7
  # Set the temporary folder location
8
  os.environ['TEMP'] = r'e:\\TMP'
@@ -292,3 +292,29 @@ lut_folder = "./LUT"
292
  lut_files = [os.path.join(lut_folder, f).replace("\\", "/") for f in os.listdir(lut_folder) if f.endswith(".cube")]
293
 
294
  temp_files = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
  #Set the environment variables
3
  os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
4
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:256,expandable_segments:True"
5
+ IS_SHARED_SPACE = "Surn/HexaGrid" in os.environ.get('SPACE_ID', '')
6
 
7
  # Set the temporary folder location
8
  os.environ['TEMP'] = r'e:\\TMP'
 
292
  lut_files = [os.path.join(lut_folder, f).replace("\\", "/") for f in os.listdir(lut_folder) if f.endswith(".cube")]
293
 
294
  temp_files = []
295
+
296
+
297
+ cards = [
298
+ "2♥️", "3♥️", "4♥️", "5♥️", "6♥️", "7♥️", "8♥️", "9♥️", "10♥️", "J♥️", "Q♥️", "K♥️", "A♥️",
299
+ "2♦️", "3♦️", "4♦️", "5♦️", "6♦️", "7♦️", "8♦️", "9♦️", "10♦️", "J♦️", "Q♦️", "K♦️", "A♦️",
300
+ "2♣️", "3♣️", "4♣️", "5♣️", "6♣️", "7♣️", "8♣️", "9♣️", "10♣️", "J♣️", "Q♣️", "K♣️", "A♣️",
301
+ "2♠️", "3♠️", "4♠️", "5♠️", "6♠️", "7♠️", "8♠️", "9♠️", "10♠️", "J♠️", "Q♠️", "K♠️", "A♠️"
302
+ ]
303
+ cards_alternating = [
304
+ "2♥️", "3♥️", "4♥️", "5♥️", "6♥️", "7♥️", "8♥️", "9♥️", "10♥️", "J♥️", "Q♥️", "K♥️", "A♥️",
305
+ "2♣️", "3♣️", "4♣️", "5♣️", "6♣️", "7♣️", "8♣️", "9♣️", "10♣️", "J♣️", "Q♣️", "K♣️", "A♣️",
306
+ "2♦️", "3♦️", "4♦️", "5♦️", "6♦️", "7♦️", "8♦️", "9♦️", "10♦️", "J♦️", "Q♦️", "K♦️", "A♦️",
307
+ "2♠️", "3♠️", "4♠️", "5♠️", "6♠️", "7♠️", "8♠️", "9♠️", "10♠️", "J♠️", "Q♠️", "K♠️", "A♠️"
308
+ ]
309
+ card_colors = [
310
+ "#FF0000", "#FF0000", "#FF0000", "#FF0000", "#FF0000", "#FF0000", "#FF0000", "#FF0000", "#FF0000", "#FF0000", "#FF0000", "#FF0000", "#FF0000", # Hearts
311
+ "#FF0000", "#FF0000", "#FF0000", "#FF0000", "#FF0000", "#FF0000", "#FF0000", "#FF0000", "#FF0000", "#FF0000", "#FF0000", "#FF0000", "#FF0000", # Diamonds
312
+ "#000000", "#000000", "#000000", "#000000", "#000000", "#000000", "#000000", "#000000", "#000000", "#000000", "#000000", "#000000", "#000000", # Clubs
313
+ "#000000", "#000000", "#000000", "#000000", "#000000", "#000000", "#000000", "#000000", "#000000", "#000000", "#000000", "#000000", "#000000" # Spades
314
+ ]
315
+ card_colors_alternating = [
316
+ "#FF0000", "#FF0000", "#FF0000", "#FF0000", "#FF0000", "#FF0000", "#FF0000", "#FF0000", "#FF0000", "#FF0000", "#FF0000", "#FF0000", "#FF0000", # Hearts
317
+ "#000000", "#000000", "#000000", "#000000", "#000000", "#000000", "#000000", "#000000", "#000000", "#000000", "#000000", "#000000", "#000000", # Clubs
318
+ "#FF0000", "#FF0000", "#FF0000", "#FF0000", "#FF0000", "#FF0000", "#FF0000", "#FF0000", "#FF0000", "#FF0000", "#FF0000", "#FF0000", "#FF0000", # Diamonds
319
+ "#000000", "#000000", "#000000", "#000000", "#000000", "#000000", "#000000", "#000000", "#000000", "#000000", "#000000", "#000000", "#000000" # Spades
320
+ ]
utils/depth_estimation.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils/depth_estimation.py
2
+
3
+ import torch
4
+ import numpy as np
5
+ from PIL import Image
6
+ import open3d as o3d
7
+ from transformers import DPTImageProcessor, DPTForDepthEstimation
8
+ from pathlib import Path
9
+ import logging
10
+ logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
11
+ from utils.image_utils import (
12
+ change_color,
13
+ open_image,
14
+ build_prerendered_images,
15
+ upscale_image,
16
+ crop_and_resize_image,
17
+ resize_image_with_aspect_ratio,
18
+ show_lut,
19
+ apply_lut_to_image_path
20
+ )
21
+
22
+ # Load models once during module import
23
+ image_processor = DPTImageProcessor.from_pretrained("Intel/dpt-large")
24
+ depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large", ignore_mismatched_sizes=True)
25
+
26
+ def estimate_depth(image):
27
+ # Ensure image is in RGB mode
28
+ if image.mode != "RGB":
29
+ image = image.convert("RGB")
30
+
31
+ # Resize the image for the model
32
+ image_resized = image.resize(
33
+ (image.width, image.height),
34
+ Image.Resampling.LANCZOS
35
+ )
36
+
37
+ # Prepare image for the model
38
+ encoding = image_processor(image_resized, return_tensors="pt")
39
+
40
+ # Forward pass
41
+ with torch.no_grad():
42
+ outputs = depth_model(**encoding)
43
+ predicted_depth = outputs.predicted_depth
44
+
45
+ # Interpolate to original size
46
+ prediction = torch.nn.functional.interpolate(
47
+ predicted_depth.unsqueeze(1),
48
+ size=(image.height, image.width),
49
+ mode="bicubic",
50
+ align_corners=False,
51
+ ).squeeze()
52
+
53
+ # Convert to depth image
54
+ output = prediction.cpu().numpy()
55
+ depth_min = output.min()
56
+ depth_max = output.max()
57
+ max_val = (2**8) - 1
58
+
59
+ # Normalize and convert to 8-bit image
60
+ depth_image = max_val * (output - depth_min) / (depth_max - depth_min)
61
+ depth_image = depth_image.astype("uint8")
62
+
63
+ depth_pil = Image.fromarray(depth_image)
64
+
65
+ return depth_pil, output
66
+
67
+ def create_3d_model(rgb_image, depth_array, voxel_size_factor=0.01):
68
+ depth_o3d = o3d.geometry.Image(depth_array.astype(np.float32))
69
+ rgb_o3d = o3d.geometry.Image(np.array(rgb_image))
70
+
71
+ rgbd_image = o3d.geometry.RGBDImage.create_from_color_and_depth(
72
+ rgb_o3d,
73
+ depth_o3d,
74
+ convert_rgb_to_intensity=False
75
+ )
76
+
77
+ # Create a point cloud from the RGBD image
78
+ camera_intrinsic = o3d.camera.PinholeCameraIntrinsic(
79
+ rgb_image.width,
80
+ rgb_image.height,
81
+ fx=1.0,
82
+ fy=1.0,
83
+ cx=rgb_image.width / 2.0,
84
+ cy=rgb_image.height / 2.0,
85
+ )
86
+
87
+ pcd = o3d.geometry.PointCloud.create_from_rgbd_image(
88
+ rgbd_image,
89
+ camera_intrinsic
90
+ )
91
+
92
+ # Voxel downsample
93
+ voxel_size = max(pcd.get_max_bound() - pcd.get_min_bound()) * voxel_size_factor
94
+ voxel_grid = o3d.geometry.VoxelGrid.create_from_point_cloud(pcd, voxel_size=voxel_size)
95
+
96
+ # Save the 3D model to a temporary file
97
+ temp_dir = Path.cwd() / "temp_models"
98
+ temp_dir.mkdir(exist_ok=True)
99
+ model_path = temp_dir / "model.ply"
100
+ o3d.io.write_voxel_grid(str(model_path), voxel_grid)
101
+
102
+ return str(model_path)
103
+
104
+ def generate_depth_and_3d(input_image_path, voxel_size_factor):
105
+ image = Image.open(input_image_path).convert("RGB")
106
+ resized_image = resize_image_with_aspect_ratio(image, 2688, 1680)
107
+ depth_image, depth_array = estimate_depth(resized_image)
108
+ model_path = create_3d_model(resized_image, depth_array, voxel_size_factor=voxel_size_factor)
109
+ return depth_image, model_path
110
+
111
+ def generate_depth_button_click(depth_image_source, voxel_size_factor, input_image, output_image, overlay_image, bordered_image_output):
112
+ if depth_image_source == "Input Image":
113
+ image_path = input_image
114
+ elif depth_image_source == "Output Image":
115
+ image_path = output_image
116
+ elif depth_image_source == "Image with Margins":
117
+ image_path = bordered_image_output
118
+ else:
119
+ image_path = overlay_image
120
+
121
+ return generate_depth_and_3d(image_path, voxel_size_factor)
utils/excluded_colors.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils/excluded_colors.py
2
+ import gradio as gr
3
+
4
+ from utils.color_utils import (
5
+ hex_to_rgb,
6
+ )
7
+ from utils.image_utils import (
8
+ convert_str_to_int_or_zero,
9
+ )
10
+
11
+ excluded_color_list = gr.State([(0,0,0,0),(255,255,255,0)])
12
+
13
+ def add_color(color, excluded_colors_var):
14
+ excluded_colors = excluded_colors_var.value
15
+ # Convert the color from hex to RGBA
16
+ color = hex_to_rgb(color) + (255,)
17
+ if color not in [tuple(lst) for lst in excluded_colors]:
18
+ excluded_colors.append(color)
19
+ excluded_color_lst = [tuple(lst) for lst in excluded_colors]
20
+ else:
21
+ excluded_color_lst = [tuple(lst) for lst in excluded_colors]
22
+ return excluded_color_lst, excluded_color_lst
23
+
24
+ def delete_color(row, excluded_colors_var):
25
+ global excluded_color_list
26
+ excluded_colors = list(excluded_colors_var)
27
+ row_index = convert_str_to_int_or_zero(row)
28
+ print(f"Delete Excluded Color {row_index} of {len(excluded_colors) - 1}")
29
+ if row_index <= len(excluded_colors) - 1:
30
+ del excluded_colors[row_index]
31
+ excluded_color_lst = [tuple(lst) for lst in excluded_colors]
32
+ excluded_color_list = excluded_color_lst
33
+ return excluded_color_lst
34
+ else:
35
+ excluded_color_lst = [tuple(lst) for lst in excluded_color_list]
36
+ print(f"Row index {row_index} not found in the list:{excluded_color_lst}")
37
+ excluded_color_list = excluded_color_lst
38
+ return excluded_color_lst
39
+
40
+ def build_dataframe(excluded_colors_var):
41
+ excluded_colors = [tuple(lst) for lst in excluded_colors_var.value]
42
+ #print(f"input: {excluded_colors}")
43
+ return excluded_colors
44
+
45
+ def on_input(excluded_colors):
46
+ print(f"input: {excluded_colors}")
47
+ excluded_color_lst = [tuple(lst) for lst in excluded_colors]
48
+ print(f"output: {excluded_color_lst}")
49
+ return excluded_color_lst, excluded_color_lst
50
+
51
+ # Event listener for when the user selects a row
52
+ def on_color_display_select(selected_rows, event: gr.SelectData):
53
+ # Get the selected row
54
+ selected_index = event.index[0]
55
+ print(f"Selected row index:{selected_rows[selected_index]}, index: {selected_index}")
56
+ return selected_index
utils/file_utils.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # file_utils
2
+ import os
3
+ import utils.constants as constants
4
+
5
+ def cleanup_temp_files():
6
+ for file_path in constants.temp_files:
7
+ try:
8
+ os.remove(file_path)
9
+ except Exception as e:
10
+ print(f"Failed to delete temp file {file_path}: {e}")
utils/hex_grid.py CHANGED
@@ -193,9 +193,7 @@ def generate_hexagon_grid_with_text(hex_size, border_size, input_image=None, ima
193
  # Prepare the text and color lists
194
  text_list = []
195
  color_list = []
196
- if add_hex_text_option == "Row-Column Coordinates":
197
- pass # Coordinates will be generated dynamically
198
- elif add_hex_text_option == "Playing Cards Sequential":
199
  text_list = constants.cards
200
  color_list = constants.card_colors
201
  elif add_hex_text_option == "Playing Cards Alternate Red and Black":
@@ -204,13 +202,13 @@ def generate_hexagon_grid_with_text(hex_size, border_size, input_image=None, ima
204
  elif add_hex_text_option == "Custom List":
205
  if custom_text_list:
206
  #text_list = [text.strip() for text in custom_text_list.split(",")]
207
- text_list = ast.literal_eval(custom_text_list) if custom_text_list else None
208
  if custom_text_color_list:
209
  #color_list = [color.strip() for color in custom_text_color_list.split(",")]
210
  color_list = ast.literal_eval(custom_text_color_list) if custom_text_color_list else None
211
  else:
212
- text_list = []
213
- color_list = []
214
  hex_index = -1 # Initialize hex index
215
  def draw_hexagon(x, y, color="#FFFFFFFF", rotation=0, outline_color="#12165380", outline_width=0, sides=6):
216
  side_length = (hex_size * 2) / math.sqrt(3)
@@ -277,10 +275,12 @@ def generate_hexagon_grid_with_text(hex_size, border_size, input_image=None, ima
277
  # Determine the text to draw
278
  if add_hex_text_option == "Row-Column Coordinates":
279
  text = f"{col},{row}"
 
 
280
  elif text_list:
281
  text = text_list[hex_index % len(text_list)]
282
  else:
283
- text = ""
284
  # Determine the text color
285
  if color_list:
286
  # Extract the opacity from the border color and add to the color_list
@@ -296,7 +296,7 @@ def generate_hexagon_grid_with_text(hex_size, border_size, input_image=None, ima
296
  text_color = border_color
297
  #text_color = "#{:02x}{:02x}{:02x}{:02x}".format(*text_color)
298
  # Skip if text is empty
299
- if text != "":
300
  print(f"Drawing Text: {text} color: {text_color} size: {font_size}")
301
  # Calculate text size using Pango
302
  # Create a temporary surface to calculate text size
 
193
  # Prepare the text and color lists
194
  text_list = []
195
  color_list = []
196
+ if add_hex_text_option == "Playing Cards Sequential":
 
 
197
  text_list = constants.cards
198
  color_list = constants.card_colors
199
  elif add_hex_text_option == "Playing Cards Alternate Red and Black":
 
202
  elif add_hex_text_option == "Custom List":
203
  if custom_text_list:
204
  #text_list = [text.strip() for text in custom_text_list.split(",")]
205
+ text_list = ast.literal_eval(custom_text_list) if custom_text_list else None
206
  if custom_text_color_list:
207
  #color_list = [color.strip() for color in custom_text_color_list.split(",")]
208
  color_list = ast.literal_eval(custom_text_color_list) if custom_text_color_list else None
209
  else:
210
+ # Coordinates will be generated dynamically
211
+ pass
212
  hex_index = -1 # Initialize hex index
213
  def draw_hexagon(x, y, color="#FFFFFFFF", rotation=0, outline_color="#12165380", outline_width=0, sides=6):
214
  side_length = (hex_size * 2) / math.sqrt(3)
 
275
  # Determine the text to draw
276
  if add_hex_text_option == "Row-Column Coordinates":
277
  text = f"{col},{row}"
278
+ elif add_hex_text_option == "Sequential Numbers":
279
+ text = f"{hex_index}"
280
  elif text_list:
281
  text = text_list[hex_index % len(text_list)]
282
  else:
283
+ text = None
284
  # Determine the text color
285
  if color_list:
286
  # Extract the opacity from the border color and add to the color_list
 
296
  text_color = border_color
297
  #text_color = "#{:02x}{:02x}{:02x}{:02x}".format(*text_color)
298
  # Skip if text is empty
299
+ if text != None:
300
  print(f"Drawing Text: {text} color: {text_color} size: {font_size}")
301
  # Calculate text size using Pango
302
  # Create a temporary surface to calculate text size
utils/lora_details.py CHANGED
@@ -21,7 +21,7 @@ def upd_prompt_notes(model_textbox_value):
21
  notes = item['notes']
22
  break
23
  else:
24
- notes = "Enter Prompt description of your image"
25
  return gr.update(value=notes)
26
 
27
  def get_trigger_words(model_textbox_value):
@@ -57,3 +57,48 @@ def upd_trigger_words(model_textbox_value):
57
  """
58
  trigger_words = get_trigger_words(model_textbox_value)
59
  return gr.update(value=trigger_words)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  notes = item['notes']
22
  break
23
  else:
24
+ notes = "Enter Prompt description of your image, \nusing models without LoRa may take a 30 minutes."
25
  return gr.update(value=notes)
26
 
27
  def get_trigger_words(model_textbox_value):
 
57
  """
58
  trigger_words = get_trigger_words(model_textbox_value)
59
  return gr.update(value=trigger_words)
60
+
61
+ def approximate_token_count(prompt):
62
+ """
63
+ Approximates the number of tokens in a prompt based on word count.
64
+
65
+ Parameters:
66
+ prompt (str): The text prompt.
67
+
68
+ Returns:
69
+ int: The approximate number of tokens.
70
+ """
71
+ words = prompt.split()
72
+ # Average tokens per word (can vary based on language and model)
73
+ tokens_per_word = 1.3
74
+ return int(len(words) * tokens_per_word)
75
+
76
+ def split_prompt_by_tokens(prompt, token_number):
77
+ words = prompt.split()
78
+ # Average tokens per word (can vary based on language and model)
79
+ tokens_per_word = 1.3
80
+ return ' '.join(words[:int(tokens_per_word * token_number)]), ' '.join(words[int(tokens_per_word * token_number):])
81
+
82
+ # Split prompt precisely by token count
83
+ import tiktoken
84
+
85
+ def split_prompt_precisely(prompt, max_tokens=77, model="gpt-3.5-turbo"):
86
+ try:
87
+ encoding = tiktoken.encoding_for_model(model)
88
+ except KeyError:
89
+ encoding = tiktoken.get_encoding("cl100k_base")
90
+
91
+ tokens = encoding.encode(prompt)
92
+
93
+ if len(tokens) <= max_tokens:
94
+ return prompt, ""
95
+
96
+ # Find the split point
97
+ split_point = max_tokens
98
+ split_tokens = tokens[:split_point]
99
+ remaining_tokens = tokens[split_point:]
100
+
101
+ split_prompt = encoding.decode(split_tokens)
102
+ remaining_prompt = encoding.decode(remaining_tokens)
103
+
104
+ return split_prompt, remaining_prompt
utils/version_info.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils/version_info.py
2
+
3
+ import subprocess
4
+ import os
5
+ import torch
6
+ import sys
7
+ import gradio as gr
8
+
9
+ git = os.environ.get('GIT', "git")
10
+
11
+ def commit_hash():
12
+ try:
13
+ return subprocess.check_output([git, "rev-parse", "HEAD"], shell=False, encoding='utf8').strip()
14
+ except Exception:
15
+ return "<none>"
16
+
17
+ def get_xformers_version():
18
+ try:
19
+ import xformers
20
+ return xformers.__version__
21
+ except Exception:
22
+ return "<none>"
23
+ def get_transformers_version():
24
+ try:
25
+ import transformers
26
+ return transformers.__version__
27
+ except Exception:
28
+ return "<none>"
29
+
30
+ def get_accelerate_version():
31
+ try:
32
+ import accelerate
33
+ return accelerate.__version__
34
+ except Exception:
35
+ return "<none>"
36
+ def get_safetensors_version():
37
+ try:
38
+ import safetensors
39
+ return safetensors.__version__
40
+ except Exception:
41
+ return "<none>"
42
+ def get_diffusers_version():
43
+ try:
44
+ import diffusers
45
+ return diffusers.__version__
46
+ except Exception:
47
+ return "<none>"
48
+
49
+ def get_torch_info():
50
+ try:
51
+ return [torch.__version__, f"CUDA Version:{torch.version.cuda}", f"Available:{torch.cuda.is_available()}", f"flash attention enabled: {torch.backends.cuda.flash_sdp_enabled()}", f"Capabilities: {torch.cuda.get_device_capability(0)}", f"Device Name: {torch.cuda.get_device_name(0)}"]
52
+ except Exception:
53
+ return "<none>"
54
+
55
+ def versions_html():
56
+ python_version = ".".join([str(x) for x in sys.version_info[0:3]])
57
+ commit = commit_hash()
58
+
59
+ # Define the Toggle Dark Mode link with JavaScript
60
+ toggle_dark_link = '''
61
+ <a href="#" onclick="document.body.classList.toggle('dark'); return false;" style="cursor: pointer; text-decoration: underline; color: #1a0dab;">
62
+ Toggle Dark Mode
63
+ </a>
64
+ '''
65
+
66
+ # version: <a href="https://github.com/Oncorporation/audiocraft/commit/{"" if commit == "<none>" else commit}" target="_blank">{"click" if commit == "<none>" else commit}</a>
67
+ return f"""
68
+ version: <a href="https://github.com/Oncorporation/audiocraft/commit/{"" if commit == "<none>" else commit}" target="_blank">{"click" if commit == "<none>" else commit}</a>
69
+ &#x2000;•&#x2000;
70
+ python: <span title="{sys.version}">{python_version}</span>
71
+ &#x2000;•&#x2000;
72
+ torch: {getattr(torch, '__long_version__',torch.__version__)}
73
+ &#x2000;•&#x2000;
74
+ diffusers: {get_diffusers_version()}
75
+ &#x2000;•&#x2000;
76
+ transformers: {get_transformers_version()}
77
+ &#x2000;•&#x2000;
78
+ gradio: {gr.__version__}
79
+ &#x2000;•&#x2000;
80
+ {toggle_dark_link}
81
+ """