Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
import os
|
2 |
import torch
|
3 |
import gradio as gr
|
|
|
4 |
import spaces
|
5 |
-
from diffusers import FluxPipeline, DiffusionPipeline, AutoPipelineForText2Image
|
6 |
|
7 |
# Helper function to get the Hugging Face token securely
|
8 |
def get_hf_token():
|
@@ -26,23 +26,23 @@ models = {
|
|
26 |
"FLUX.1-schnell": {
|
27 |
"pipeline_class": FluxPipeline,
|
28 |
"model_id": "black-forest-labs/FLUX.1-schnell",
|
29 |
-
"config": {"torch_dtype": torch.bfloat16
|
30 |
"description": "**FLUX.1-schnell** is a fast and efficient model designed for quick image generation. It excels at producing high-quality images rapidly, making it ideal for applications where speed is crucial. However, its rapid generation may slightly compromise on the level of detail compared to slower, more meticulous models.",
|
31 |
},
|
32 |
"FLUX.1-dev": {
|
33 |
-
"pipeline_class":
|
34 |
"model_id": "black-forest-labs/FLUX.1-dev",
|
35 |
"lora": {
|
36 |
"repo": "strangerzonehf/Flux-Enrich-Art-LoRA",
|
37 |
"trigger_word": "enrich art",
|
38 |
},
|
39 |
-
"config": {"torch_dtype": torch.bfloat16
|
40 |
"description": "**FLUX.1-dev** is a development model that focuses on delivering highly detailed and artistically rich images.",
|
41 |
},
|
42 |
"Flux.1-lite-8B-alpha": {
|
43 |
"pipeline_class": FluxPipeline,
|
44 |
"model_id": "Freepik/flux.1-lite-8B-alpha",
|
45 |
-
"config": {"torch_dtype": torch.bfloat16
|
46 |
"description": "**Flux.1-lite-8B-alpha** is a lightweight model optimized for efficiency and ease of use.",
|
47 |
},
|
48 |
}
|
@@ -65,7 +65,7 @@ def clear_gpu_memory():
|
|
65 |
print(f"Error clearing GPU memory: {e}")
|
66 |
return f"Error clearing GPU memory: {e}"
|
67 |
|
68 |
-
|
69 |
def load_model(model_key):
|
70 |
"""Loads a model, clearing GPU memory first if a different model was loaded."""
|
71 |
global model_load_status
|
@@ -82,13 +82,12 @@ def load_model(model_key):
|
|
82 |
config = models[model_key]
|
83 |
pipeline_class = config["pipeline_class"]
|
84 |
model_id = config["model_id"]
|
85 |
-
|
86 |
-
pipe = pipeline_class.from_pretrained(model_id, token=_HF_TOKEN, **config.get("config", {})
|
87 |
|
88 |
if "lora" in config:
|
89 |
lora_config = config["lora"]
|
90 |
-
pipe.load_lora_weights(lora_config["repo"], token=_HF_TOKEN
|
91 |
-
pipe.fuse_lora()
|
92 |
|
93 |
if torch.cuda.is_available():
|
94 |
pipe.to("cuda")
|
@@ -100,18 +99,13 @@ def load_model(model_key):
|
|
100 |
model_load_status[model_key] = "Failed" # Update load status on error
|
101 |
return f"Error loading model '{model_key}': {e}"
|
102 |
|
103 |
-
@spaces.GPU(duration=
|
104 |
-
def generate_image(
|
105 |
-
|
106 |
-
model = loaded_models[model_key]
|
107 |
-
|
108 |
generator = torch.Generator(device="cuda" if torch.cuda.is_available() else "cpu")
|
109 |
if seed != -1:
|
110 |
generator = generator.manual_seed(seed)
|
111 |
|
112 |
-
if "lora" in models.get(model_key, {}):
|
113 |
-
prompt = f"{models[model_key]['lora']['trigger_word']}, {prompt}"
|
114 |
-
|
115 |
with torch.no_grad():
|
116 |
image = model(prompt=prompt, generator=generator).images[0]
|
117 |
return image
|
@@ -130,7 +124,9 @@ def gradio_generate(selected_model, prompt, seed):
|
|
130 |
# If still not loaded after attempt, return an error
|
131 |
return f"Model not loaded. Load status: {model_load_status.get(selected_model, 'Not attempted')}.", None
|
132 |
|
133 |
-
|
|
|
|
|
134 |
runtime_info = f"Model: {selected_model}\nSeed: {seed}"
|
135 |
output_path = "generated_image.png"
|
136 |
image.save(output_path)
|
@@ -157,12 +153,18 @@ with gr.Blocks(
|
|
157 |
with gr.Column():
|
158 |
gr.Markdown("# Text-to-Image Generator")
|
159 |
model_dropdown = gr.Dropdown(choices=list(models.keys()), label="Select Model")
|
|
|
|
|
|
|
|
|
160 |
prompt_textbox = gr.Textbox(label="Enter Text Prompt")
|
161 |
seed_slider = gr.Slider(minimum=-1, maximum=1000, step=1, value=-1, label="Random Seed (-1 for random)")
|
162 |
generate_button = gr.Button("Generate Image")
|
163 |
output_image = gr.Image(label="Generated Image")
|
164 |
runtime_info_textbox = gr.Textbox(label="Runtime Information", lines=2, interactive=False)
|
165 |
|
|
|
|
|
166 |
generate_button.click(gradio_generate, inputs=[model_dropdown, prompt_textbox, seed_slider], outputs=[output_image, runtime_info_textbox])
|
167 |
|
168 |
with gr.Tab("Model Information"):
|
@@ -174,4 +176,4 @@ with gr.Blocks(
|
|
174 |
**Credits**: Created by Ruslan Magana Vsevolodovna. For more information, visit [https://ruslanmv.com/](https://ruslanmv.com/).""")
|
175 |
|
176 |
if __name__ == "__main__":
|
177 |
-
interface.launch(debug=True
|
|
|
1 |
import os
|
2 |
import torch
|
3 |
import gradio as gr
|
4 |
+
from diffusers import FluxPipeline, DiffusionPipeline
|
5 |
import spaces
|
|
|
6 |
|
7 |
# Helper function to get the Hugging Face token securely
|
8 |
def get_hf_token():
|
|
|
26 |
"FLUX.1-schnell": {
|
27 |
"pipeline_class": FluxPipeline,
|
28 |
"model_id": "black-forest-labs/FLUX.1-schnell",
|
29 |
+
"config": {"torch_dtype": torch.bfloat16},
|
30 |
"description": "**FLUX.1-schnell** is a fast and efficient model designed for quick image generation. It excels at producing high-quality images rapidly, making it ideal for applications where speed is crucial. However, its rapid generation may slightly compromise on the level of detail compared to slower, more meticulous models.",
|
31 |
},
|
32 |
"FLUX.1-dev": {
|
33 |
+
"pipeline_class": DiffusionPipeline,
|
34 |
"model_id": "black-forest-labs/FLUX.1-dev",
|
35 |
"lora": {
|
36 |
"repo": "strangerzonehf/Flux-Enrich-Art-LoRA",
|
37 |
"trigger_word": "enrich art",
|
38 |
},
|
39 |
+
"config": {"torch_dtype": torch.bfloat16},
|
40 |
"description": "**FLUX.1-dev** is a development model that focuses on delivering highly detailed and artistically rich images.",
|
41 |
},
|
42 |
"Flux.1-lite-8B-alpha": {
|
43 |
"pipeline_class": FluxPipeline,
|
44 |
"model_id": "Freepik/flux.1-lite-8B-alpha",
|
45 |
+
"config": {"torch_dtype": torch.bfloat16},
|
46 |
"description": "**Flux.1-lite-8B-alpha** is a lightweight model optimized for efficiency and ease of use.",
|
47 |
},
|
48 |
}
|
|
|
65 |
print(f"Error clearing GPU memory: {e}")
|
66 |
return f"Error clearing GPU memory: {e}"
|
67 |
|
68 |
+
|
69 |
def load_model(model_key):
|
70 |
"""Loads a model, clearing GPU memory first if a different model was loaded."""
|
71 |
global model_load_status
|
|
|
82 |
config = models[model_key]
|
83 |
pipeline_class = config["pipeline_class"]
|
84 |
model_id = config["model_id"]
|
85 |
+
|
86 |
+
pipe = pipeline_class.from_pretrained(model_id, token=_HF_TOKEN, **config.get("config", {}))
|
87 |
|
88 |
if "lora" in config:
|
89 |
lora_config = config["lora"]
|
90 |
+
pipe.load_lora_weights(lora_config["repo"], token=_HF_TOKEN)
|
|
|
91 |
|
92 |
if torch.cuda.is_available():
|
93 |
pipe.to("cuda")
|
|
|
99 |
model_load_status[model_key] = "Failed" # Update load status on error
|
100 |
return f"Error loading model '{model_key}': {e}"
|
101 |
|
102 |
+
@spaces.GPU(duration=120)
|
103 |
+
def generate_image(model, prompt, seed=-1):
|
104 |
+
|
|
|
|
|
105 |
generator = torch.Generator(device="cuda" if torch.cuda.is_available() else "cpu")
|
106 |
if seed != -1:
|
107 |
generator = generator.manual_seed(seed)
|
108 |
|
|
|
|
|
|
|
109 |
with torch.no_grad():
|
110 |
image = model(prompt=prompt, generator=generator).images[0]
|
111 |
return image
|
|
|
124 |
# If still not loaded after attempt, return an error
|
125 |
return f"Model not loaded. Load status: {model_load_status.get(selected_model, 'Not attempted')}.", None
|
126 |
|
127 |
+
model = loaded_models[selected_model]
|
128 |
+
image = generate_image(model, prompt, seed)
|
129 |
+
|
130 |
runtime_info = f"Model: {selected_model}\nSeed: {seed}"
|
131 |
output_path = "generated_image.png"
|
132 |
image.save(output_path)
|
|
|
153 |
with gr.Column():
|
154 |
gr.Markdown("# Text-to-Image Generator")
|
155 |
model_dropdown = gr.Dropdown(choices=list(models.keys()), label="Select Model")
|
156 |
+
#with gr.Row():
|
157 |
+
#load_button = gr.Button("Load Model")
|
158 |
+
# clear_button = gr.Button("Clear GPU Memory") # Removed clear button
|
159 |
+
#load_status = gr.Textbox(label="Model Load Status", interactive=False) # Removed load button
|
160 |
prompt_textbox = gr.Textbox(label="Enter Text Prompt")
|
161 |
seed_slider = gr.Slider(minimum=-1, maximum=1000, step=1, value=-1, label="Random Seed (-1 for random)")
|
162 |
generate_button = gr.Button("Generate Image")
|
163 |
output_image = gr.Image(label="Generated Image")
|
164 |
runtime_info_textbox = gr.Textbox(label="Runtime Information", lines=2, interactive=False)
|
165 |
|
166 |
+
#load_button.click(gradio_load_model, inputs=[model_dropdown], outputs=[load_status]) # Removed load button click action
|
167 |
+
# clear_button.click(clear_gpu_memory, outputs=[load_status]) # Removed clear button click action
|
168 |
generate_button.click(gradio_generate, inputs=[model_dropdown, prompt_textbox, seed_slider], outputs=[output_image, runtime_info_textbox])
|
169 |
|
170 |
with gr.Tab("Model Information"):
|
|
|
176 |
**Credits**: Created by Ruslan Magana Vsevolodovna. For more information, visit [https://ruslanmv.com/](https://ruslanmv.com/).""")
|
177 |
|
178 |
if __name__ == "__main__":
|
179 |
+
interface.launch(debug=True)
|