multimodalart HF staff commited on
Commit
11166a4
1 Parent(s): f56a518

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +134 -85
app.py CHANGED
@@ -55,11 +55,30 @@ class calculateDuration:
55
  else:
56
  print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
57
 
58
- def update_selection(evt: gr.SelectData, width, height):
59
  selected_lora = loras[evt.index]
60
  new_placeholder = f"Type a prompt for {selected_lora['title']}"
61
- lora_repo = selected_lora["repo"]
62
- updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo}) ✨"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  if "aspect" in selected_lora:
64
  if selected_lora["aspect"] == "portrait":
65
  width = 768
@@ -70,16 +89,21 @@ def update_selection(evt: gr.SelectData, width, height):
70
  else:
71
  width = 1024
72
  height = 1024
73
- return (
74
- gr.update(placeholder=new_placeholder),
75
- updated_text,
76
- evt.index,
77
- width,
78
- height,
79
- )
 
 
 
 
 
80
 
81
  @spaces.GPU(duration=70)
82
- def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, progress):
83
  pipe.to("cuda")
84
  generator = torch.Generator(device="cuda").manual_seed(seed)
85
  with calculateDuration("Generating image"):
@@ -91,14 +115,13 @@ def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scal
91
  width=width,
92
  height=height,
93
  generator=generator,
94
- joint_attention_kwargs={"scale": lora_scale},
95
  output_type="pil",
96
  good_vae=good_vae,
97
  ):
98
  yield img
99
 
100
  @spaces.GPU(duration=70)
101
- def generate_image_to_image(prompt_mash, image_input_path, image_strength, steps, cfg_scale, width, height, lora_scale, seed):
102
  generator = torch.Generator(device="cuda").manual_seed(seed)
103
  pipe_i2i.to("cuda")
104
  image_input = load_image(image_input_path)
@@ -111,93 +134,99 @@ def generate_image_to_image(prompt_mash, image_input_path, image_strength, steps
111
  width=width,
112
  height=height,
113
  generator=generator,
114
- joint_attention_kwargs={"scale": lora_scale},
115
  output_type="pil",
116
  ).images[0]
117
  return final_image
 
 
 
 
118
 
119
- def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
120
- if selected_index is None:
121
- raise gr.Error("You must select a LoRA before proceeding.")
122
- selected_lora = loras[selected_index]
123
- lora_path = selected_lora["repo"]
124
- trigger_word = selected_lora["trigger_word"]
125
- if(trigger_word):
126
- if "trigger_position" in selected_lora:
127
- if selected_lora["trigger_position"] == "prepend":
128
- prompt_mash = f"{trigger_word} {prompt}"
129
  else:
130
- prompt_mash = f"{prompt} {trigger_word}"
131
- else:
132
- prompt_mash = f"{trigger_word} {prompt}"
133
- else:
134
- prompt_mash = prompt
 
 
 
 
 
 
135
 
136
- with calculateDuration("Unloading LoRA"):
137
  pipe.unload_lora_weights()
138
  pipe_i2i.unload_lora_weights()
139
 
140
- # Load LoRA weights
141
- with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
142
- if(image_input is not None):
143
- if "weights" in selected_lora:
144
- pipe_i2i.load_lora_weights(lora_path, weight_name=selected_lora["weights"])
145
- else:
146
- pipe_i2i.load_lora_weights(lora_path)
147
  else:
148
- if "weights" in selected_lora:
149
- pipe.load_lora_weights(lora_path, weight_name=selected_lora["weights"])
150
- else:
151
- pipe.load_lora_weights(lora_path)
152
 
153
  # Set random seed for reproducibility
154
  with calculateDuration("Randomizing seed"):
155
  if randomize_seed:
156
  seed = random.randint(0, MAX_SEED)
157
 
158
- if(image_input is not None):
159
-
160
- final_image = generate_image_to_image(prompt_mash, image_input, image_strength, steps, cfg_scale, width, height, lora_scale, seed)
161
  yield final_image, seed, gr.update(visible=False)
162
  else:
163
- image_generator = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, progress)
164
-
165
  # Consume the generator to get the final image
166
  final_image = None
167
  step_counter = 0
168
  for image in image_generator:
169
- step_counter+=1
170
  final_image = image
171
  progress_bar = f'<div class="progress-container"><div class="progress-bar" style="--current: {step_counter}; --total: {steps};"></div></div>'
172
  yield image, seed, gr.update(value=progress_bar, visible=True)
173
-
174
  yield final_image, seed, gr.update(value=progress_bar, visible=False)
175
 
176
  def get_huggingface_safetensors(link):
177
- split_link = link.split("/")
178
- if(len(split_link) == 2):
179
- model_card = ModelCard.load(link)
180
- base_model = model_card.data.get("base_model")
181
- print(base_model)
182
- if((base_model != "black-forest-labs/FLUX.1-dev") and (base_model != "black-forest-labs/FLUX.1-schnell")):
183
- raise Exception("Not a FLUX LoRA!")
184
- image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
185
- trigger_word = model_card.data.get("instance_prompt", "")
186
- image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
187
- fs = HfFileSystem()
188
- try:
189
- list_of_files = fs.ls(link, detail=False)
190
- for file in list_of_files:
191
- if(file.endswith(".safetensors")):
192
- safetensors_name = file.split("/")[-1]
193
- if (not image_url and file.lower().endswith((".jpg", ".jpeg", ".png", ".webp"))):
194
- image_elements = file.split("/")
195
- image_url = f"https://huggingface.co/{link}/resolve/main/{image_elements[-1]}"
196
- except Exception as e:
197
- print(e)
198
- gr.Warning(f"You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
199
- raise Exception(f"You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
200
- return split_link[1], link, safetensors_name, trigger_word, image_url
201
 
202
  def check_custom_model(link):
203
  if(link.startswith("https://")):
@@ -257,8 +286,8 @@ css = '''
257
  #title img{width: 100px; margin-right: 0.5em}
258
  #gallery .grid-wrap{height: 10vh}
259
  #lora_list{background: var(--block-background-fill);padding: 0 1em .3em; font-size: 90%}
260
- .card_internal{display: flex;height: 100px;margin-top: .5em}
261
- .card_internal img{margin-right: 1em}
262
  .styler{--form-gap-width: 0px !important}
263
  #progress{height:30px}
264
  #progress .generating{display:none}
@@ -267,18 +296,18 @@ css = '''
267
  '''
268
  with gr.Blocks(theme=gr.themes.Soft(), css=css, delete_cache=(60, 3600)) as app:
269
  title = gr.HTML(
270
- """<h1><img src="https://huggingface.co/spaces/multimodalart/flux-lora-the-explorer/resolve/main/flux_lora.png" alt="LoRA"> FLUX LoRA the Explorer</h1>""",
271
  elem_id="title",
272
  )
273
- selected_index = gr.State(None)
 
274
  with gr.Row():
275
  with gr.Column(scale=3):
276
- prompt = gr.Textbox(label="Prompt", lines=1, placeholder="Type a prompt after selecting a LoRA")
277
  with gr.Column(scale=1, elem_id="gen_column"):
278
  generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
279
  with gr.Row():
280
  with gr.Column():
281
- selected_info = gr.Markdown("")
282
  gallery = gr.Gallery(
283
  [(item["image"], item["title"]) for item in loras],
284
  label="LoRA Gallery",
@@ -291,6 +320,17 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css, delete_cache=(60, 3600)) as app:
291
  gr.Markdown("[Check the list of FLUX LoRas](https://huggingface.co/models?other=base_model:adapter:black-forest-labs/FLUX.1-dev)", elem_id="lora_list")
292
  custom_lora_info = gr.HTML(visible=False)
293
  custom_lora_button = gr.Button("Remove custom LoRA", visible=False)
 
 
 
 
 
 
 
 
 
 
 
294
  with gr.Column():
295
  progress_bar = gr.Markdown(elem_id="progress",visible=False)
296
  result = gr.Image(label="Generated Image")
@@ -312,28 +352,37 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css, delete_cache=(60, 3600)) as app:
312
  with gr.Row():
313
  randomize_seed = gr.Checkbox(True, label="Randomize seed")
314
  seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
315
- lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=3, step=0.01, value=0.95)
316
 
317
  gallery.select(
318
  update_selection,
319
- inputs=[width, height],
320
- outputs=[prompt, selected_info, selected_index, width, height]
 
 
 
 
 
 
 
 
 
 
321
  )
322
  custom_lora.input(
323
  add_custom_lora,
324
  inputs=[custom_lora],
325
- outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, prompt]
326
  )
327
  custom_lora_button.click(
328
  remove_custom_lora,
329
- outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, custom_lora]
330
  )
331
  gr.on(
332
  triggers=[generate_button.click, prompt.submit],
333
  fn=run_lora,
334
- inputs=[prompt, input_image, image_strength, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale],
335
  outputs=[result, seed, progress_bar]
336
  )
337
 
338
  app.queue()
339
- app.launch()
 
55
  else:
56
  print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
57
 
58
+ def update_selection(evt: gr.SelectData, width, height, selected_lora1, selected_lora2):
59
  selected_lora = loras[evt.index]
60
  new_placeholder = f"Type a prompt for {selected_lora['title']}"
61
+
62
+ # Initialize outputs
63
+ outputs = []
64
+
65
+ if selected_lora1 is None:
66
+ selected_lora1 = selected_lora
67
+ selected_lora1_info = f"### LoRA 1 Selected: [{selected_lora1['title']}](https://huggingface.co/{selected_lora1['repo']}) ✨"
68
+ lora_scale1_visible = True
69
+ remove_lora1_visible = True
70
+ elif selected_lora2 is None:
71
+ selected_lora2 = selected_lora
72
+ selected_lora2_info = f"### LoRA 2 Selected: [{selected_lora2['title']}](https://huggingface.co/{selected_lora2['repo']}) ✨"
73
+ lora_scale2_visible = True
74
+ remove_lora2_visible = True
75
+ else:
76
+ raise gr.Error("You can only select up to two LoRAs. Please remove one before selecting another.")
77
+
78
+ # Update placeholder
79
+ placeholder_update = gr.update(placeholder=new_placeholder)
80
+
81
+ # For width and height adjustment
82
  if "aspect" in selected_lora:
83
  if selected_lora["aspect"] == "portrait":
84
  width = 768
 
89
  else:
90
  width = 1024
91
  height = 1024
92
+
93
+ return placeholder_update, selected_lora1, selected_lora2, selected_lora1_info, selected_lora2_info, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), width, height
94
+
95
+ def remove_selected_lora1(selected_lora1, selected_lora1_info):
96
+ selected_lora1 = None
97
+ selected_lora1_info = ""
98
+ return selected_lora1, selected_lora1_info, gr.update(visible=False), gr.update(visible=False)
99
+
100
+ def remove_selected_lora2(selected_lora2, selected_lora2_info):
101
+ selected_lora2 = None
102
+ selected_lora2_info = ""
103
+ return selected_lora2, selected_lora2_info, gr.update(visible=False), gr.update(visible=False)
104
 
105
  @spaces.GPU(duration=70)
106
+ def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, progress):
107
  pipe.to("cuda")
108
  generator = torch.Generator(device="cuda").manual_seed(seed)
109
  with calculateDuration("Generating image"):
 
115
  width=width,
116
  height=height,
117
  generator=generator,
 
118
  output_type="pil",
119
  good_vae=good_vae,
120
  ):
121
  yield img
122
 
123
  @spaces.GPU(duration=70)
124
+ def generate_image_to_image(prompt_mash, image_input_path, image_strength, steps, cfg_scale, width, height, seed):
125
  generator = torch.Generator(device="cuda").manual_seed(seed)
126
  pipe_i2i.to("cuda")
127
  image_input = load_image(image_input_path)
 
134
  width=width,
135
  height=height,
136
  generator=generator,
 
137
  output_type="pil",
138
  ).images[0]
139
  return final_image
140
+
141
+ def run_lora(prompt, image_input, image_strength, cfg_scale, steps, randomize_seed, seed, width, height, selected_lora1, selected_lora2, lora_scale1, lora_scale2, progress=gr.Progress(track_tqdm=True)):
142
+ if selected_lora1 is None and selected_lora2 is None:
143
+ raise gr.Error("You must select at least one LoRA before proceeding.")
144
 
145
+ # Build the prompt mash
146
+ prompt_mash = prompt
147
+
148
+ # Handle trigger words and positions
149
+ trigger_words = []
150
+ if selected_lora1 is not None:
151
+ trigger_word1 = selected_lora1.get("trigger_word", "")
152
+ if trigger_word1:
153
+ if selected_lora1.get("trigger_position") == "prepend":
154
+ trigger_words.insert(0, trigger_word1)
155
  else:
156
+ trigger_words.append(trigger_word1)
157
+ if selected_lora2 is not None:
158
+ trigger_word2 = selected_lora2.get("trigger_word", "")
159
+ if trigger_word2:
160
+ if selected_lora2.get("trigger_position") == "prepend":
161
+ trigger_words.insert(0, trigger_word2)
162
+ else:
163
+ trigger_words.append(trigger_word2)
164
+ # Combine trigger words with the prompt
165
+ if trigger_words:
166
+ prompt_mash = f"{' '.join(trigger_words)} {prompt}"
167
 
168
+ with calculateDuration("Unloading LoRAs"):
169
  pipe.unload_lora_weights()
170
  pipe_i2i.unload_lora_weights()
171
 
172
+ # Load LoRA weights with respective scales
173
+ with calculateDuration("Loading LoRA weights"):
174
+ if image_input is not None:
175
+ if selected_lora1 is not None:
176
+ pipe_i2i.load_lora_weights(selected_lora1['repo'], weight_name=selected_lora1.get('weights'), scale=lora_scale1)
177
+ if selected_lora2 is not None:
178
+ pipe_i2i.load_lora_weights(selected_lora2['repo'], weight_name=selected_lora2.get('weights'), scale=lora_scale2)
179
  else:
180
+ if selected_lora1 is not None:
181
+ pipe.load_lora_weights(selected_lora1['repo'], weight_name=selected_lora1.get('weights'), scale=lora_scale1)
182
+ if selected_lora2 is not None:
183
+ pipe.load_lora_weights(selected_lora2['repo'], weight_name=selected_lora2.get('weights'), scale=lora_scale2)
184
 
185
  # Set random seed for reproducibility
186
  with calculateDuration("Randomizing seed"):
187
  if randomize_seed:
188
  seed = random.randint(0, MAX_SEED)
189
 
190
+ if image_input is not None:
191
+ final_image = generate_image_to_image(prompt_mash, image_input, image_strength, steps, cfg_scale, width, height, seed)
 
192
  yield final_image, seed, gr.update(visible=False)
193
  else:
194
+ image_generator = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, progress)
 
195
  # Consume the generator to get the final image
196
  final_image = None
197
  step_counter = 0
198
  for image in image_generator:
199
+ step_counter += 1
200
  final_image = image
201
  progress_bar = f'<div class="progress-container"><div class="progress-bar" style="--current: {step_counter}; --total: {steps};"></div></div>'
202
  yield image, seed, gr.update(value=progress_bar, visible=True)
 
203
  yield final_image, seed, gr.update(value=progress_bar, visible=False)
204
 
205
  def get_huggingface_safetensors(link):
206
+ split_link = link.split("/")
207
+ if(len(split_link) == 2):
208
+ model_card = ModelCard.load(link)
209
+ base_model = model_card.data.get("base_model")
210
+ print(base_model)
211
+ if((base_model != "black-forest-labs/FLUX.1-dev") and (base_model != "black-forest-labs/FLUX.1-schnell")):
212
+ raise Exception("Not a FLUX LoRA!")
213
+ image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
214
+ trigger_word = model_card.data.get("instance_prompt", "")
215
+ image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
216
+ fs = HfFileSystem()
217
+ try:
218
+ list_of_files = fs.ls(link, detail=False)
219
+ for file in list_of_files:
220
+ if(file.endswith(".safetensors")):
221
+ safetensors_name = file.split("/")[-1]
222
+ if (not image_url and file.lower().endswith((".jpg", ".jpeg", ".png", ".webp"))):
223
+ image_elements = file.split("/")
224
+ image_url = f"https://huggingface.co/{link}/resolve/main/{image_elements[-1]}"
225
+ except Exception as e:
226
+ print(e)
227
+ gr.Warning(f"You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
228
+ raise Exception(f"You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
229
+ return split_link[1], link, safetensors_name, trigger_word, image_url
230
 
231
  def check_custom_model(link):
232
  if(link.startswith("https://")):
 
286
  #title img{width: 100px; margin-right: 0.5em}
287
  #gallery .grid-wrap{height: 10vh}
288
  #lora_list{background: var(--block-background-fill);padding: 0 1em .3em; font-size: 90%}
289
+ .custom_lora_card .card_internal{display: flex;height: 100px;margin-top: .5em}
290
+ .custom_lora_card .card_internal img{margin-right: 1em}
291
  .styler{--form-gap-width: 0px !important}
292
  #progress{height:30px}
293
  #progress .generating{display:none}
 
296
  '''
297
  with gr.Blocks(theme=gr.themes.Soft(), css=css, delete_cache=(60, 3600)) as app:
298
  title = gr.HTML(
299
+ """<h1><img src="https://huggingface.co/spaces/multimodalart/flux-lora-the-explorer/resolve/main/flux_lora.png" alt="LoRA"> LoRA Lab</h1>""",
300
  elem_id="title",
301
  )
302
+ selected_lora1 = gr.State(None)
303
+ selected_lora2 = gr.State(None)
304
  with gr.Row():
305
  with gr.Column(scale=3):
306
+ prompt = gr.Textbox(label="Prompt", lines=1, placeholder="Type a prompt after selecting LoRAs")
307
  with gr.Column(scale=1, elem_id="gen_column"):
308
  generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
309
  with gr.Row():
310
  with gr.Column():
 
311
  gallery = gr.Gallery(
312
  [(item["image"], item["title"]) for item in loras],
313
  label="LoRA Gallery",
 
320
  gr.Markdown("[Check the list of FLUX LoRas](https://huggingface.co/models?other=base_model:adapter:black-forest-labs/FLUX.1-dev)", elem_id="lora_list")
321
  custom_lora_info = gr.HTML(visible=False)
322
  custom_lora_button = gr.Button("Remove custom LoRA", visible=False)
323
+ # Selected LoRAs section
324
+ gr.Markdown("### Selected LoRAs")
325
+ with gr.Row():
326
+ with gr.Column():
327
+ selected_lora1_info = gr.Markdown("", visible=False)
328
+ lora_scale1 = gr.Slider(label="LoRA 1 Scale", minimum=0, maximum=3, step=0.01, value=0.95, visible=False)
329
+ remove_lora1_button = gr.Button("Remove LoRA 1", visible=False)
330
+ with gr.Column():
331
+ selected_lora2_info = gr.Markdown("", visible=False)
332
+ lora_scale2 = gr.Slider(label="LoRA 2 Scale", minimum=0, maximum=3, step=0.01, value=0.95, visible=False)
333
+ remove_lora2_button = gr.Button("Remove LoRA 2", visible=False)
334
  with gr.Column():
335
  progress_bar = gr.Markdown(elem_id="progress",visible=False)
336
  result = gr.Image(label="Generated Image")
 
352
  with gr.Row():
353
  randomize_seed = gr.Checkbox(True, label="Randomize seed")
354
  seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
 
355
 
356
  gallery.select(
357
  update_selection,
358
+ inputs=[width, height, selected_lora1, selected_lora2],
359
+ outputs=[prompt, selected_lora1, selected_lora2, selected_lora1_info, selected_lora2_info, lora_scale1, remove_lora1_button, lora_scale2, remove_lora2_button, width, height]
360
+ )
361
+ remove_lora1_button.click(
362
+ remove_selected_lora1,
363
+ inputs=[selected_lora1, selected_lora1_info],
364
+ outputs=[selected_lora1, selected_lora1_info, lora_scale1, remove_lora1_button]
365
+ )
366
+ remove_lora2_button.click(
367
+ remove_selected_lora2,
368
+ inputs=[selected_lora2, selected_lora2_info],
369
+ outputs=[selected_lora2, selected_lora2_info, lora_scale2, remove_lora2_button]
370
  )
371
  custom_lora.input(
372
  add_custom_lora,
373
  inputs=[custom_lora],
374
+ outputs=[custom_lora_info, custom_lora_button, gallery, selected_lora1_info, selected_lora2_info, prompt]
375
  )
376
  custom_lora_button.click(
377
  remove_custom_lora,
378
+ outputs=[custom_lora_info, custom_lora_button, gallery, selected_lora1_info, selected_lora2_info, custom_lora]
379
  )
380
  gr.on(
381
  triggers=[generate_button.click, prompt.submit],
382
  fn=run_lora,
383
+ inputs=[prompt, input_image, image_strength, cfg_scale, steps, randomize_seed, seed, width, height, selected_lora1, selected_lora2, lora_scale1, lora_scale2],
384
  outputs=[result, seed, progress_bar]
385
  )
386
 
387
  app.queue()
388
+ app.launch()