multimodalart HF staff commited on
Commit
0e0ee20
1 Parent(s): 1846c84

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -0
app.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import json
3
+ import logging
4
+ import torch
5
+ from PIL import Image
6
+ from diffusers import DiffusionPipeline
7
+ import spaces
8
+
9
+ # Load LoRAs from JSON file
10
+ with open('loras.json', 'r') as f:
11
+ loras = json.load(f)
12
+
13
+ # Initialize the base model
14
+ base_model = "black-forest-labs/FLUX.1-dev"
15
+ pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
16
+ pipe.to("cuda")
17
+
18
+ def update_selection(evt: gr.SelectData):
19
+ selected_lora = loras[evt.index]
20
+ new_placeholder = f"Type a prompt for {selected_lora['title']}"
21
+ lora_repo = selected_lora["repo"]
22
+ updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo}) ✨"
23
+ return (
24
+ gr.update(placeholder=new_placeholder),
25
+ updated_text,
26
+ evt.index
27
+ )
28
+
29
+ @spaces.GPU
30
+ def run_lora(prompt, negative_prompt, cfg_scale, steps, selected_index, seed, width, height, lora_scale):
31
+ if selected_index is None:
32
+ raise gr.Error("You must select a LoRA before proceeding.")
33
+
34
+ selected_lora = loras[selected_index]
35
+ lora_path = selected_lora["repo"]
36
+ trigger_word = selected_lora["trigger_word"]
37
+
38
+ # Load LoRA weights
39
+ pipe.load_lora_weights(lora_path)
40
+
41
+ # Set random seed for reproducibility
42
+ generator = torch.Generator(device="cuda").manual_seed(seed)
43
+
44
+ # Generate image
45
+ image = pipe(
46
+ prompt=f"{prompt} {trigger_word}",
47
+ negative_prompt=negative_prompt,
48
+ num_inference_steps=steps,
49
+ guidance_scale=cfg_scale,
50
+ width=width,
51
+ height=height,
52
+ generator=generator,
53
+ cross_attention_kwargs={"scale": lora_scale},
54
+ ).images[0]
55
+
56
+ # Unload LoRA weights
57
+ pipe.unload_lora_weights()
58
+
59
+ return image
60
+
61
+ with gr.Blocks(theme=gr.themes.Soft()) as app:
62
+ gr.Markdown("# FLUX.1 LoRA the Explorer")
63
+ selected_index = gr.State(None)
64
+
65
+ with gr.Row():
66
+ with gr.Column(scale=2):
67
+ result = gr.Image(label="Generated Image", height=768)
68
+ generate_button = gr.Button("Generate", variant="primary")
69
+
70
+ with gr.Column(scale=1):
71
+ gallery = gr.Gallery(
72
+ [(item["image"], item["title"]) for item in loras],
73
+ label="LoRA Gallery",
74
+ allow_preview=False,
75
+ columns=2
76
+ )
77
+
78
+ with gr.Row():
79
+ with gr.Column():
80
+ prompt_title = gr.Markdown("### Click on a LoRA in the gallery to select it")
81
+ selected_info = gr.Markdown("")
82
+ prompt = gr.Textbox(label="Prompt", lines=3, placeholder="Type a prompt after selecting a LoRA")
83
+ negative_prompt = gr.Textbox(label="Negative Prompt", lines=2, value="low quality, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry")
84
+
85
+ with gr.Column():
86
+ with gr.Row():
87
+ cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=7.5)
88
+ steps = gr.Slider(label="Steps", minimum=1, maximum=100, step=1, value=30)
89
+
90
+ with gr.Row():
91
+ width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024)
92
+ height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024)
93
+
94
+ with gr.Row():
95
+ seed = gr.Slider(label="Seed", minimum=0, maximum=2**32-1, step=1, value=0, randomize=True)
96
+ lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=1, step=0.01, value=1)
97
+
98
+ gallery.select(update_selection, outputs=[prompt, selected_info, selected_index])
99
+
100
+ generate_button.click(
101
+ fn=run_lora,
102
+ inputs=[prompt, negative_prompt, cfg_scale, steps, selected_index, seed, width, height, lora_scale],
103
+ outputs=[result]
104
+ )
105
+
106
+ app.queue()
107
+ app.launch()