Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -2,7 +2,7 @@ import os
|
|
2 |
import torch
|
3 |
import gradio as gr
|
4 |
import spaces
|
5 |
-
from diffusers import FluxPipeline, DiffusionPipeline
|
6 |
|
7 |
# Helper function to get the Hugging Face token securely
|
8 |
def get_hf_token():
|
@@ -30,13 +30,13 @@ models = {
|
|
30 |
"description": "**FLUX.1-schnell** is a fast and efficient model designed for quick image generation.",
|
31 |
},
|
32 |
"FLUX.1-dev": {
|
33 |
-
"pipeline_class":
|
34 |
"model_id": "black-forest-labs/FLUX.1-dev",
|
35 |
"lora": {
|
36 |
"repo": "strangerzonehf/Flux-Enrich-Art-LoRA",
|
37 |
"trigger_word": "enrich art",
|
38 |
},
|
39 |
-
"config": {"torch_dtype": torch.bfloat16},
|
40 |
"description": "**FLUX.1-dev** is a development model for highly detailed and artistic images.",
|
41 |
},
|
42 |
}
|
@@ -52,32 +52,33 @@ def download_all_models():
|
|
52 |
model_id = config["model_id"]
|
53 |
pipeline_class.from_pretrained(model_id, token=_HF_TOKEN, **config.get("config", {}))
|
54 |
if "lora" in config:
|
55 |
-
|
|
|
56 |
print(f"Model '{model_key}' downloaded successfully.")
|
57 |
except Exception as e:
|
58 |
print(f"Error downloading model '{model_key}': {e}")
|
59 |
print("Model download process complete.")
|
60 |
|
61 |
# Function to load a model when requested
|
62 |
-
|
63 |
def load_model(model_key):
|
64 |
if model_key in loaded_models:
|
65 |
return f"Model '{model_key}' is already loaded."
|
66 |
-
|
67 |
try:
|
68 |
config = models[model_key]
|
69 |
pipeline_class = config["pipeline_class"]
|
70 |
model_id = config["model_id"]
|
71 |
|
72 |
-
pipe = pipeline_class.from_pretrained(
|
73 |
-
|
|
|
74 |
if "lora" in config:
|
75 |
lora_config = config["lora"]
|
76 |
-
pipe.load_lora_weights(lora_config["repo"], token=_HF_TOKEN)
|
|
|
77 |
|
78 |
if torch.cuda.is_available():
|
79 |
pipe.to("cuda")
|
80 |
-
|
81 |
loaded_models[model_key] = pipe
|
82 |
return f"Model '{model_key}' loaded successfully."
|
83 |
except Exception as e:
|
@@ -86,12 +87,15 @@ def load_model(model_key):
|
|
86 |
return f"Error loading model '{model_key}': {e}"
|
87 |
|
88 |
# Function to generate an image from text
|
89 |
-
|
90 |
def generate_image(model, prompt, seed=-1):
|
91 |
generator = torch.Generator(device="cuda" if torch.cuda.is_available() else "cpu")
|
92 |
if seed != -1:
|
93 |
generator = generator.manual_seed(seed)
|
94 |
-
|
|
|
|
|
|
|
95 |
with torch.no_grad():
|
96 |
image = model(prompt=prompt, generator=generator).images[0]
|
97 |
return image
|
@@ -100,10 +104,8 @@ def generate_image(model, prompt, seed=-1):
|
|
100 |
def gradio_generate(selected_model, prompt, seed):
|
101 |
if selected_model not in loaded_models:
|
102 |
return "Model not loaded. Please load the model first.", None
|
103 |
-
|
104 |
model = loaded_models[selected_model]
|
105 |
image = generate_image(model, prompt, seed)
|
106 |
-
|
107 |
runtime_info = f"Model: {selected_model}\nSeed: {seed}"
|
108 |
output_path = "generated_image.png"
|
109 |
image.save(output_path)
|
@@ -128,14 +130,12 @@ with gr.Blocks() as interface:
|
|
128 |
generate_button = gr.Button("Generate Image")
|
129 |
output_image = gr.Image(label="Generated Image")
|
130 |
runtime_info_textbox = gr.Textbox(label="Runtime Information", lines=2)
|
131 |
-
|
132 |
load_button.click(gradio_load_model, inputs=[model_dropdown], outputs=[load_status])
|
133 |
generate_button.click(
|
134 |
fn=gradio_generate,
|
135 |
inputs=[model_dropdown, prompt_textbox, seed_slider],
|
136 |
outputs=[output_image, runtime_info_textbox],
|
137 |
)
|
138 |
-
|
139 |
with gr.Tab("Model Information"):
|
140 |
with gr.Column():
|
141 |
for model_key, model_info in models.items():
|
@@ -144,4 +144,8 @@ with gr.Blocks() as interface:
|
|
144 |
|
145 |
if __name__ == "__main__":
|
146 |
download_all_models()
|
147 |
-
|
|
|
|
|
|
|
|
|
|
2 |
import torch
|
3 |
import gradio as gr
|
4 |
import spaces
|
5 |
+
from diffusers import FluxPipeline, DiffusionPipeline, AutoPipelineForText2Image
|
6 |
|
7 |
# Helper function to get the Hugging Face token securely
|
8 |
def get_hf_token():
|
|
|
30 |
"description": "**FLUX.1-schnell** is a fast and efficient model designed for quick image generation.",
|
31 |
},
|
32 |
"FLUX.1-dev": {
|
33 |
+
"pipeline_class": AutoPipelineForText2Image, # Changed to AutoPipelineForText2Image
|
34 |
"model_id": "black-forest-labs/FLUX.1-dev",
|
35 |
"lora": {
|
36 |
"repo": "strangerzonehf/Flux-Enrich-Art-LoRA",
|
37 |
"trigger_word": "enrich art",
|
38 |
},
|
39 |
+
"config": {"torch_dtype": torch.bfloat16, "variant": "fp16"},
|
40 |
"description": "**FLUX.1-dev** is a development model for highly detailed and artistic images.",
|
41 |
},
|
42 |
}
|
|
|
52 |
model_id = config["model_id"]
|
53 |
pipeline_class.from_pretrained(model_id, token=_HF_TOKEN, **config.get("config", {}))
|
54 |
if "lora" in config:
|
55 |
+
# Download LoRA weights separately if needed, but not essential for pre-download
|
56 |
+
pass
|
57 |
print(f"Model '{model_key}' downloaded successfully.")
|
58 |
except Exception as e:
|
59 |
print(f"Error downloading model '{model_key}': {e}")
|
60 |
print("Model download process complete.")
|
61 |
|
62 |
# Function to load a model when requested
|
63 |
+
#@spaces.GPU # Removed as it's not a standard decorator
|
64 |
def load_model(model_key):
|
65 |
if model_key in loaded_models:
|
66 |
return f"Model '{model_key}' is already loaded."
|
|
|
67 |
try:
|
68 |
config = models[model_key]
|
69 |
pipeline_class = config["pipeline_class"]
|
70 |
model_id = config["model_id"]
|
71 |
|
72 |
+
pipe = pipeline_class.from_pretrained(
|
73 |
+
model_id, token=_HF_TOKEN, **config.get("config", {}), low_cpu_mem_usage=True # Add low_cpu_mem_usage
|
74 |
+
)
|
75 |
if "lora" in config:
|
76 |
lora_config = config["lora"]
|
77 |
+
pipe.load_lora_weights(lora_config["repo"], token=_HF_TOKEN, weight_name="FluxEnrichArtLoRA.safetensors") # Added specific weight file name
|
78 |
+
pipe.fuse_lora()
|
79 |
|
80 |
if torch.cuda.is_available():
|
81 |
pipe.to("cuda")
|
|
|
82 |
loaded_models[model_key] = pipe
|
83 |
return f"Model '{model_key}' loaded successfully."
|
84 |
except Exception as e:
|
|
|
87 |
return f"Error loading model '{model_key}': {e}"
|
88 |
|
89 |
# Function to generate an image from text
|
90 |
+
#@spaces.GPU # Removed as it's not a standard decorator
|
91 |
def generate_image(model, prompt, seed=-1):
|
92 |
generator = torch.Generator(device="cuda" if torch.cuda.is_available() else "cpu")
|
93 |
if seed != -1:
|
94 |
generator = generator.manual_seed(seed)
|
95 |
+
|
96 |
+
if "lora" in models.get(list(loaded_models.keys())[list(loaded_models.values()).index(model)], {}):
|
97 |
+
prompt = f"{models[list(loaded_models.keys())[list(loaded_models.values()).index(model)]]['lora']['trigger_word']}, {prompt}"
|
98 |
+
|
99 |
with torch.no_grad():
|
100 |
image = model(prompt=prompt, generator=generator).images[0]
|
101 |
return image
|
|
|
104 |
def gradio_generate(selected_model, prompt, seed):
|
105 |
if selected_model not in loaded_models:
|
106 |
return "Model not loaded. Please load the model first.", None
|
|
|
107 |
model = loaded_models[selected_model]
|
108 |
image = generate_image(model, prompt, seed)
|
|
|
109 |
runtime_info = f"Model: {selected_model}\nSeed: {seed}"
|
110 |
output_path = "generated_image.png"
|
111 |
image.save(output_path)
|
|
|
130 |
generate_button = gr.Button("Generate Image")
|
131 |
output_image = gr.Image(label="Generated Image")
|
132 |
runtime_info_textbox = gr.Textbox(label="Runtime Information", lines=2)
|
|
|
133 |
load_button.click(gradio_load_model, inputs=[model_dropdown], outputs=[load_status])
|
134 |
generate_button.click(
|
135 |
fn=gradio_generate,
|
136 |
inputs=[model_dropdown, prompt_textbox, seed_slider],
|
137 |
outputs=[output_image, runtime_info_textbox],
|
138 |
)
|
|
|
139 |
with gr.Tab("Model Information"):
|
140 |
with gr.Column():
|
141 |
for model_key, model_info in models.items():
|
|
|
144 |
|
145 |
if __name__ == "__main__":
|
146 |
download_all_models()
|
147 |
+
try:
|
148 |
+
interface.launch(debug=True, share=True) # Enable sharing for public link
|
149 |
+
except gr.Error as e:
|
150 |
+
print(f"Error launching Gradio interface: {e}")
|
151 |
+
print("Please make sure you have the necessary permissions and try again.")
|