ruslanmv commited on
Commit
5e7319a
·
verified ·
1 Parent(s): 807fd2d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -20
app.py CHANGED
@@ -1,8 +1,8 @@
1
  import os
2
  import torch
3
  import gradio as gr
 
4
  import spaces
5
- from diffusers import FluxPipeline, DiffusionPipeline, AutoPipelineForText2Image
6
 
7
  # Helper function to get the Hugging Face token securely
8
  def get_hf_token():
@@ -26,23 +26,23 @@ models = {
26
  "FLUX.1-schnell": {
27
  "pipeline_class": FluxPipeline,
28
  "model_id": "black-forest-labs/FLUX.1-schnell",
29
- "config": {"torch_dtype": torch.bfloat16, "variant": "fp16"},
30
  "description": "**FLUX.1-schnell** is a fast and efficient model designed for quick image generation. It excels at producing high-quality images rapidly, making it ideal for applications where speed is crucial. However, its rapid generation may slightly compromise on the level of detail compared to slower, more meticulous models.",
31
  },
32
  "FLUX.1-dev": {
33
- "pipeline_class": AutoPipelineForText2Image,
34
  "model_id": "black-forest-labs/FLUX.1-dev",
35
  "lora": {
36
  "repo": "strangerzonehf/Flux-Enrich-Art-LoRA",
37
  "trigger_word": "enrich art",
38
  },
39
- "config": {"torch_dtype": torch.bfloat16, "variant": "fp16"},
40
  "description": "**FLUX.1-dev** is a development model that focuses on delivering highly detailed and artistically rich images.",
41
  },
42
  "Flux.1-lite-8B-alpha": {
43
  "pipeline_class": FluxPipeline,
44
  "model_id": "Freepik/flux.1-lite-8B-alpha",
45
- "config": {"torch_dtype": torch.bfloat16, "variant": "fp16"},
46
  "description": "**Flux.1-lite-8B-alpha** is a lightweight model optimized for efficiency and ease of use.",
47
  },
48
  }
@@ -65,7 +65,7 @@ def clear_gpu_memory():
65
  print(f"Error clearing GPU memory: {e}")
66
  return f"Error clearing GPU memory: {e}"
67
 
68
- @spaces.GPU(duration=80)
69
  def load_model(model_key):
70
  """Loads a model, clearing GPU memory first if a different model was loaded."""
71
  global model_load_status
@@ -82,13 +82,12 @@ def load_model(model_key):
82
  config = models[model_key]
83
  pipeline_class = config["pipeline_class"]
84
  model_id = config["model_id"]
85
-
86
- pipe = pipeline_class.from_pretrained(model_id, token=_HF_TOKEN, **config.get("config", {}), low_cpu_mem_usage=True)
87
 
88
  if "lora" in config:
89
  lora_config = config["lora"]
90
- pipe.load_lora_weights(lora_config["repo"], token=_HF_TOKEN, weight_name="FluxEnrichArtLoRA.safetensors")
91
- pipe.fuse_lora()
92
 
93
  if torch.cuda.is_available():
94
  pipe.to("cuda")
@@ -100,18 +99,13 @@ def load_model(model_key):
100
  model_load_status[model_key] = "Failed" # Update load status on error
101
  return f"Error loading model '{model_key}': {e}"
102
 
103
- @spaces.GPU(duration=60)
104
- def generate_image(model_key, prompt, seed=-1):
105
-
106
- model = loaded_models[model_key]
107
-
108
  generator = torch.Generator(device="cuda" if torch.cuda.is_available() else "cpu")
109
  if seed != -1:
110
  generator = generator.manual_seed(seed)
111
 
112
- if "lora" in models.get(model_key, {}):
113
- prompt = f"{models[model_key]['lora']['trigger_word']}, {prompt}"
114
-
115
  with torch.no_grad():
116
  image = model(prompt=prompt, generator=generator).images[0]
117
  return image
@@ -130,7 +124,9 @@ def gradio_generate(selected_model, prompt, seed):
130
  # If still not loaded after attempt, return an error
131
  return f"Model not loaded. Load status: {model_load_status.get(selected_model, 'Not attempted')}.", None
132
 
133
- image = generate_image(selected_model, prompt, seed)
 
 
134
  runtime_info = f"Model: {selected_model}\nSeed: {seed}"
135
  output_path = "generated_image.png"
136
  image.save(output_path)
@@ -157,12 +153,18 @@ with gr.Blocks(
157
  with gr.Column():
158
  gr.Markdown("# Text-to-Image Generator")
159
  model_dropdown = gr.Dropdown(choices=list(models.keys()), label="Select Model")
 
 
 
 
160
  prompt_textbox = gr.Textbox(label="Enter Text Prompt")
161
  seed_slider = gr.Slider(minimum=-1, maximum=1000, step=1, value=-1, label="Random Seed (-1 for random)")
162
  generate_button = gr.Button("Generate Image")
163
  output_image = gr.Image(label="Generated Image")
164
  runtime_info_textbox = gr.Textbox(label="Runtime Information", lines=2, interactive=False)
165
 
 
 
166
  generate_button.click(gradio_generate, inputs=[model_dropdown, prompt_textbox, seed_slider], outputs=[output_image, runtime_info_textbox])
167
 
168
  with gr.Tab("Model Information"):
@@ -174,4 +176,4 @@ with gr.Blocks(
174
  **Credits**: Created by Ruslan Magana Vsevolodovna. For more information, visit [https://ruslanmv.com/](https://ruslanmv.com/).""")
175
 
176
  if __name__ == "__main__":
177
- interface.launch(debug=True, share=True)
 
1
  import os
2
  import torch
3
  import gradio as gr
4
+ from diffusers import FluxPipeline, DiffusionPipeline
5
  import spaces
 
6
 
7
  # Helper function to get the Hugging Face token securely
8
  def get_hf_token():
 
26
  "FLUX.1-schnell": {
27
  "pipeline_class": FluxPipeline,
28
  "model_id": "black-forest-labs/FLUX.1-schnell",
29
+ "config": {"torch_dtype": torch.bfloat16},
30
  "description": "**FLUX.1-schnell** is a fast and efficient model designed for quick image generation. It excels at producing high-quality images rapidly, making it ideal for applications where speed is crucial. However, its rapid generation may slightly compromise on the level of detail compared to slower, more meticulous models.",
31
  },
32
  "FLUX.1-dev": {
33
+ "pipeline_class": DiffusionPipeline,
34
  "model_id": "black-forest-labs/FLUX.1-dev",
35
  "lora": {
36
  "repo": "strangerzonehf/Flux-Enrich-Art-LoRA",
37
  "trigger_word": "enrich art",
38
  },
39
+ "config": {"torch_dtype": torch.bfloat16},
40
  "description": "**FLUX.1-dev** is a development model that focuses on delivering highly detailed and artistically rich images.",
41
  },
42
  "Flux.1-lite-8B-alpha": {
43
  "pipeline_class": FluxPipeline,
44
  "model_id": "Freepik/flux.1-lite-8B-alpha",
45
+ "config": {"torch_dtype": torch.bfloat16},
46
  "description": "**Flux.1-lite-8B-alpha** is a lightweight model optimized for efficiency and ease of use.",
47
  },
48
  }
 
65
  print(f"Error clearing GPU memory: {e}")
66
  return f"Error clearing GPU memory: {e}"
67
 
68
+
69
  def load_model(model_key):
70
  """Loads a model, clearing GPU memory first if a different model was loaded."""
71
  global model_load_status
 
82
  config = models[model_key]
83
  pipeline_class = config["pipeline_class"]
84
  model_id = config["model_id"]
85
+
86
+ pipe = pipeline_class.from_pretrained(model_id, token=_HF_TOKEN, **config.get("config", {}))
87
 
88
  if "lora" in config:
89
  lora_config = config["lora"]
90
+ pipe.load_lora_weights(lora_config["repo"], token=_HF_TOKEN)
 
91
 
92
  if torch.cuda.is_available():
93
  pipe.to("cuda")
 
99
  model_load_status[model_key] = "Failed" # Update load status on error
100
  return f"Error loading model '{model_key}': {e}"
101
 
102
+ @spaces.GPU(duration=120)
103
+ def generate_image(model, prompt, seed=-1):
104
+
 
 
105
  generator = torch.Generator(device="cuda" if torch.cuda.is_available() else "cpu")
106
  if seed != -1:
107
  generator = generator.manual_seed(seed)
108
 
 
 
 
109
  with torch.no_grad():
110
  image = model(prompt=prompt, generator=generator).images[0]
111
  return image
 
124
  # If still not loaded after attempt, return an error
125
  return f"Model not loaded. Load status: {model_load_status.get(selected_model, 'Not attempted')}.", None
126
 
127
+ model = loaded_models[selected_model]
128
+ image = generate_image(model, prompt, seed)
129
+
130
  runtime_info = f"Model: {selected_model}\nSeed: {seed}"
131
  output_path = "generated_image.png"
132
  image.save(output_path)
 
153
  with gr.Column():
154
  gr.Markdown("# Text-to-Image Generator")
155
  model_dropdown = gr.Dropdown(choices=list(models.keys()), label="Select Model")
156
+ #with gr.Row():
157
+ #load_button = gr.Button("Load Model")
158
+ # clear_button = gr.Button("Clear GPU Memory") # Removed clear button
159
+ #load_status = gr.Textbox(label="Model Load Status", interactive=False) # Removed load button
160
  prompt_textbox = gr.Textbox(label="Enter Text Prompt")
161
  seed_slider = gr.Slider(minimum=-1, maximum=1000, step=1, value=-1, label="Random Seed (-1 for random)")
162
  generate_button = gr.Button("Generate Image")
163
  output_image = gr.Image(label="Generated Image")
164
  runtime_info_textbox = gr.Textbox(label="Runtime Information", lines=2, interactive=False)
165
 
166
+ #load_button.click(gradio_load_model, inputs=[model_dropdown], outputs=[load_status]) # Removed load button click action
167
+ # clear_button.click(clear_gpu_memory, outputs=[load_status]) # Removed clear button click action
168
  generate_button.click(gradio_generate, inputs=[model_dropdown, prompt_textbox, seed_slider], outputs=[output_image, runtime_info_textbox])
169
 
170
  with gr.Tab("Model Information"):
 
176
  **Credits**: Created by Ruslan Magana Vsevolodovna. For more information, visit [https://ruslanmv.com/](https://ruslanmv.com/).""")
177
 
178
  if __name__ == "__main__":
179
+ interface.launch(debug=True)