ruslanmv commited on
Commit
42bd6fc
·
verified ·
1 Parent(s): 02baaa7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -17
app.py CHANGED
@@ -2,7 +2,7 @@ import os
2
  import torch
3
  import gradio as gr
4
  import spaces
5
- from diffusers import FluxPipeline, DiffusionPipeline
6
 
7
  # Helper function to get the Hugging Face token securely
8
  def get_hf_token():
@@ -30,13 +30,13 @@ models = {
30
  "description": "**FLUX.1-schnell** is a fast and efficient model designed for quick image generation.",
31
  },
32
  "FLUX.1-dev": {
33
- "pipeline_class": DiffusionPipeline,
34
  "model_id": "black-forest-labs/FLUX.1-dev",
35
  "lora": {
36
  "repo": "strangerzonehf/Flux-Enrich-Art-LoRA",
37
  "trigger_word": "enrich art",
38
  },
39
- "config": {"torch_dtype": torch.bfloat16},
40
  "description": "**FLUX.1-dev** is a development model for highly detailed and artistic images.",
41
  },
42
  }
@@ -52,32 +52,33 @@ def download_all_models():
52
  model_id = config["model_id"]
53
  pipeline_class.from_pretrained(model_id, token=_HF_TOKEN, **config.get("config", {}))
54
  if "lora" in config:
55
- pipeline_class.download_lora_weights(config["lora"]["repo"], token=_HF_TOKEN)
 
56
  print(f"Model '{model_key}' downloaded successfully.")
57
  except Exception as e:
58
  print(f"Error downloading model '{model_key}': {e}")
59
  print("Model download process complete.")
60
 
61
  # Function to load a model when requested
62
- @spaces.GPU
63
  def load_model(model_key):
64
  if model_key in loaded_models:
65
  return f"Model '{model_key}' is already loaded."
66
-
67
  try:
68
  config = models[model_key]
69
  pipeline_class = config["pipeline_class"]
70
  model_id = config["model_id"]
71
 
72
- pipe = pipeline_class.from_pretrained(model_id, token=_HF_TOKEN, **config.get("config", {}))
73
-
 
74
  if "lora" in config:
75
  lora_config = config["lora"]
76
- pipe.load_lora_weights(lora_config["repo"], token=_HF_TOKEN)
 
77
 
78
  if torch.cuda.is_available():
79
  pipe.to("cuda")
80
-
81
  loaded_models[model_key] = pipe
82
  return f"Model '{model_key}' loaded successfully."
83
  except Exception as e:
@@ -86,12 +87,15 @@ def load_model(model_key):
86
  return f"Error loading model '{model_key}': {e}"
87
 
88
  # Function to generate an image from text
89
- @spaces.GPU
90
  def generate_image(model, prompt, seed=-1):
91
  generator = torch.Generator(device="cuda" if torch.cuda.is_available() else "cpu")
92
  if seed != -1:
93
  generator = generator.manual_seed(seed)
94
-
 
 
 
95
  with torch.no_grad():
96
  image = model(prompt=prompt, generator=generator).images[0]
97
  return image
@@ -100,10 +104,8 @@ def generate_image(model, prompt, seed=-1):
100
  def gradio_generate(selected_model, prompt, seed):
101
  if selected_model not in loaded_models:
102
  return "Model not loaded. Please load the model first.", None
103
-
104
  model = loaded_models[selected_model]
105
  image = generate_image(model, prompt, seed)
106
-
107
  runtime_info = f"Model: {selected_model}\nSeed: {seed}"
108
  output_path = "generated_image.png"
109
  image.save(output_path)
@@ -128,14 +130,12 @@ with gr.Blocks() as interface:
128
  generate_button = gr.Button("Generate Image")
129
  output_image = gr.Image(label="Generated Image")
130
  runtime_info_textbox = gr.Textbox(label="Runtime Information", lines=2)
131
-
132
  load_button.click(gradio_load_model, inputs=[model_dropdown], outputs=[load_status])
133
  generate_button.click(
134
  fn=gradio_generate,
135
  inputs=[model_dropdown, prompt_textbox, seed_slider],
136
  outputs=[output_image, runtime_info_textbox],
137
  )
138
-
139
  with gr.Tab("Model Information"):
140
  with gr.Column():
141
  for model_key, model_info in models.items():
@@ -144,4 +144,8 @@ with gr.Blocks() as interface:
144
 
145
  if __name__ == "__main__":
146
  download_all_models()
147
- interface.launch(debug=True)
 
 
 
 
 
2
  import torch
3
  import gradio as gr
4
  import spaces
5
+ from diffusers import FluxPipeline, DiffusionPipeline, AutoPipelineForText2Image
6
 
7
  # Helper function to get the Hugging Face token securely
8
  def get_hf_token():
 
30
  "description": "**FLUX.1-schnell** is a fast and efficient model designed for quick image generation.",
31
  },
32
  "FLUX.1-dev": {
33
+ "pipeline_class": AutoPipelineForText2Image, # Changed to AutoPipelineForText2Image
34
  "model_id": "black-forest-labs/FLUX.1-dev",
35
  "lora": {
36
  "repo": "strangerzonehf/Flux-Enrich-Art-LoRA",
37
  "trigger_word": "enrich art",
38
  },
39
+ "config": {"torch_dtype": torch.bfloat16, "variant": "fp16"},
40
  "description": "**FLUX.1-dev** is a development model for highly detailed and artistic images.",
41
  },
42
  }
 
52
  model_id = config["model_id"]
53
  pipeline_class.from_pretrained(model_id, token=_HF_TOKEN, **config.get("config", {}))
54
  if "lora" in config:
55
+ # Download LoRA weights separately if needed, but not essential for pre-download
56
+ pass
57
  print(f"Model '{model_key}' downloaded successfully.")
58
  except Exception as e:
59
  print(f"Error downloading model '{model_key}': {e}")
60
  print("Model download process complete.")
61
 
62
  # Function to load a model when requested
63
+ #@spaces.GPU # Removed as it's not a standard decorator
64
  def load_model(model_key):
65
  if model_key in loaded_models:
66
  return f"Model '{model_key}' is already loaded."
 
67
  try:
68
  config = models[model_key]
69
  pipeline_class = config["pipeline_class"]
70
  model_id = config["model_id"]
71
 
72
+ pipe = pipeline_class.from_pretrained(
73
+ model_id, token=_HF_TOKEN, **config.get("config", {}), low_cpu_mem_usage=True # Add low_cpu_mem_usage
74
+ )
75
  if "lora" in config:
76
  lora_config = config["lora"]
77
+ pipe.load_lora_weights(lora_config["repo"], token=_HF_TOKEN, weight_name="FluxEnrichArtLoRA.safetensors") # Added specific weight file name
78
+ pipe.fuse_lora()
79
 
80
  if torch.cuda.is_available():
81
  pipe.to("cuda")
 
82
  loaded_models[model_key] = pipe
83
  return f"Model '{model_key}' loaded successfully."
84
  except Exception as e:
 
87
  return f"Error loading model '{model_key}': {e}"
88
 
89
  # Function to generate an image from text
90
+ #@spaces.GPU # Removed as it's not a standard decorator
91
  def generate_image(model, prompt, seed=-1):
92
  generator = torch.Generator(device="cuda" if torch.cuda.is_available() else "cpu")
93
  if seed != -1:
94
  generator = generator.manual_seed(seed)
95
+
96
+ if "lora" in models.get(list(loaded_models.keys())[list(loaded_models.values()).index(model)], {}):
97
+ prompt = f"{models[list(loaded_models.keys())[list(loaded_models.values()).index(model)]]['lora']['trigger_word']}, {prompt}"
98
+
99
  with torch.no_grad():
100
  image = model(prompt=prompt, generator=generator).images[0]
101
  return image
 
104
  def gradio_generate(selected_model, prompt, seed):
105
  if selected_model not in loaded_models:
106
  return "Model not loaded. Please load the model first.", None
 
107
  model = loaded_models[selected_model]
108
  image = generate_image(model, prompt, seed)
 
109
  runtime_info = f"Model: {selected_model}\nSeed: {seed}"
110
  output_path = "generated_image.png"
111
  image.save(output_path)
 
130
  generate_button = gr.Button("Generate Image")
131
  output_image = gr.Image(label="Generated Image")
132
  runtime_info_textbox = gr.Textbox(label="Runtime Information", lines=2)
 
133
  load_button.click(gradio_load_model, inputs=[model_dropdown], outputs=[load_status])
134
  generate_button.click(
135
  fn=gradio_generate,
136
  inputs=[model_dropdown, prompt_textbox, seed_slider],
137
  outputs=[output_image, runtime_info_textbox],
138
  )
 
139
  with gr.Tab("Model Information"):
140
  with gr.Column():
141
  for model_key, model_info in models.items():
 
144
 
145
  if __name__ == "__main__":
146
  download_all_models()
147
+ try:
148
+ interface.launch(debug=True, share=True) # Enable sharing for public link
149
+ except gr.Error as e:
150
+ print(f"Error launching Gradio interface: {e}")
151
+ print("Please make sure you have the necessary permissions and try again.")