Text-to-Images / app.py
ruslanmv's picture
Update app.py
3de03d4 verified
raw
history blame
9.26 kB
import os
import torch
import gradio as gr
from diffusers import FluxPipeline, DiffusionPipeline
import spaces
# Helper function to get the Hugging Face token securely
def get_hf_token():
try:
from google.colab import userdata
hf_token = userdata.get('HF_TOKEN')
if hf_token:
return hf_token
else:
raise RuntimeError("HF_TOKEN not found in Colab secrets.")
except ImportError:
return os.getenv("HF_TOKEN", None)
# Securely get the token
_HF_TOKEN = get_hf_token()
if not _HF_TOKEN:
raise ValueError("HF_TOKEN is not available. Please set it in Colab secrets or environment variables.")
# Define models and their configurations
models = {
"FLUX.1-schnell": {
"pipeline_class": FluxPipeline,
"model_id": "black-forest-labs/FLUX.1-schnell",
"config": {"torch_dtype": torch.bfloat16},
"description": "**FLUX.1-schnell** is a fast and efficient model designed for quick image generation. It excels at producing high-quality images rapidly, making it ideal for applications where speed is crucial. However, its rapid generation may slightly compromise on the level of detail compared to slower, more meticulous models.",
},
"FLUX.1-dev": {
"pipeline_class": DiffusionPipeline,
"model_id": "black-forest-labs/FLUX.1-dev",
"lora": {
"repo": "strangerzonehf/Flux-Enrich-Art-LoRA",
"trigger_word": "enrich art",
},
"config": {"torch_dtype": torch.bfloat16},
"description": "**FLUX.1-dev** is a development model that focuses on delivering highly detailed and artistically rich images.",
},
"Flux.1-lite-8B-alpha": {
"pipeline_class": FluxPipeline,
"model_id": "Freepik/flux.1-lite-8B-alpha",
"config": {"torch_dtype": torch.bfloat16},
"description": "**Flux.1-lite-8B-alpha** is a lightweight model optimized for efficiency and ease of use.",
},
}
loaded_models = {}
model_load_status = {} # Dictionary to track model load status
def clear_gpu_memory():
"""Clears GPU memory. Keeps model status information."""
global loaded_models
try:
for model_key in list(loaded_models.keys()): # Iterate over a copy to allow deletion
del loaded_models[model_key]
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
print("GPU memory cleared.")
return "GPU memory cleared."
except Exception as e:
print(f"Error clearing GPU memory: {e}")
return f"Error clearing GPU memory: {e}"
def load_model(model_key):
"""Loads a model, clearing GPU memory first if a different model was loaded."""
global model_load_status
if model_key not in models:
model_load_status[model_key] = "Model not found."
return f"Model '{model_key}' not found in the available models."
# Clear GPU memory only if a different model is already loaded
if loaded_models and list(loaded_models.keys())[0] != model_key:
clear_gpu_memory()
try:
config = models[model_key]
pipeline_class = config["pipeline_class"]
model_id = config["model_id"]
pipe = pipeline_class.from_pretrained(model_id, token=_HF_TOKEN, **config.get("config", {}))
if "lora" in config:
lora_config = config["lora"]
pipe.load_lora_weights(lora_config["repo"], token=_HF_TOKEN)
if torch.cuda.is_available():
pipe.to("cuda")
loaded_models[model_key] = pipe
model_load_status[model_key] = "Loaded" # Update load status
return f"Model '{model_key}' loaded successfully."
except Exception as e:
model_load_status[model_key] = "Failed" # Update load status on error
return f"Error loading model '{model_key}': {e}"
@spaces.GPU(duration=120)
def generate_image(model, prompt, seed=-1):
generator = torch.Generator(device="cuda" if torch.cuda.is_available() else "cpu")
if seed != -1:
generator = generator.manual_seed(seed)
with torch.no_grad():
image = model(prompt=prompt, generator=generator).images[0]
return image
def gradio_generate(selected_model, prompt, seed):
if selected_model not in loaded_models:
if selected_model in model_load_status and model_load_status[selected_model] == "Loaded":
# Model should be loaded but isn't in loaded_models, clear it from the status
del model_load_status[selected_model]
if selected_model not in model_load_status or model_load_status[selected_model] != "Loaded":
# Attempt to load the model if not already attempted or failed
load_model(selected_model)
if selected_model not in loaded_models:
# If still not loaded after attempt, return an error
return f"Model not loaded. Load status: {model_load_status.get(selected_model, 'Not attempted')}.", None
model = loaded_models[selected_model]
image = generate_image(model, prompt, seed)
runtime_info = f"Model: {selected_model}\nSeed: {seed}"
output_path = "generated_image.png"
image.save(output_path)
return output_path, runtime_info
def gradio_load_model(selected_model):
if not selected_model:
return "No model selected. Please select a model to load."
return load_model(selected_model)
with gr.Blocks(
css="""
.container {
max-width: 800px;
margin: auto;
padding: 20px;
background-color: #f8f8f8;
border-radius: 10px;
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
}
"""
) as interface:
with gr.Tab("Image Generator"):
with gr.Column():
gr.Markdown("# Text-to-Image Generator")
model_dropdown = gr.Dropdown(choices=list(models.keys()), label="Select Model")
prompt_textbox = gr.Textbox(label="Enter Text Prompt")
seed_slider = gr.Slider(minimum=-1, maximum=1000, step=1, value=-1, label="Random Seed (-1 for random)")
generate_button = gr.Button("Generate Image")
output_image = gr.Image(label="Generated Image")
runtime_info_textbox = gr.Textbox(label="Runtime Information", lines=2, interactive=False)
# Add example prompts at the bottom
gr.Markdown("### Example Prompts")
examples = gr.Examples(
examples=[
[list(models.keys())[0], "Sexy girl"],
[list(models.keys())[0], "Beautiful Woman"],
[list(models.keys())[0], "Future City"]
],
inputs=[model_dropdown, prompt_textbox],
)
generate_button.click(gradio_generate, inputs=[model_dropdown, prompt_textbox, seed_slider], outputs=[output_image, runtime_info_textbox])
with gr.Tab("Model Information"):
for model_key, model_info in models.items():
gr.Markdown(f"## {model_key}")
gr.Markdown(model_info["description"])
gr.Markdown("""---
**Credits**: Created by Ruslan Magana Vsevolodovna. For more information, visit [https://ruslanmv.com/](https://ruslanmv.com/).""")
if __name__ == "__main__":
interface.launch(debug=True)
'''
with gr.Blocks(
css="""
.container {
max-width: 800px;
margin: auto;
padding: 20px;
background-color: #f8f8f8;
border-radius: 10px;
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
}
"""
) as interface:
with gr.Tab("Image Generator"):
with gr.Column():
gr.Markdown("# Text-to-Image Generator")
model_dropdown = gr.Dropdown(choices=list(models.keys()), label="Select Model")
#with gr.Row():
#load_button = gr.Button("Load Model")
# clear_button = gr.Button("Clear GPU Memory") # Removed clear button
#load_status = gr.Textbox(label="Model Load Status", interactive=False) # Removed load button
prompt_textbox = gr.Textbox(label="Enter Text Prompt")
seed_slider = gr.Slider(minimum=-1, maximum=1000, step=1, value=-1, label="Random Seed (-1 for random)")
generate_button = gr.Button("Generate Image")
output_image = gr.Image(label="Generated Image")
runtime_info_textbox = gr.Textbox(label="Runtime Information", lines=2, interactive=False)
#load_button.click(gradio_load_model, inputs=[model_dropdown], outputs=[load_status]) # Removed load button click action
# clear_button.click(clear_gpu_memory, outputs=[load_status]) # Removed clear button click action
generate_button.click(gradio_generate, inputs=[model_dropdown, prompt_textbox, seed_slider], outputs=[output_image, runtime_info_textbox])
with gr.Tab("Model Information"):
for model_key, model_info in models.items():
gr.Markdown(f"## {model_key}")
gr.Markdown(model_info["description"])
gr.Markdown("""---
**Credits**: Created by Ruslan Magana Vsevolodovna. For more information, visit [https://ruslanmv.com/](https://ruslanmv.com/).""")
if __name__ == "__main__":
interface.launch(debug=True)
'''