Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import numpy as np | |
import random | |
import spaces | |
import torch | |
from diffusers import DiffusionPipeline, FluxPipeline, AutoencoderTiny, AutoencoderKL | |
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast | |
from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images | |
from download import download_all_models, models, download_vaes | |
# Call the download function at the start of the app | |
download_all_models() | |
download_vaes() | |
dtype = torch.bfloat16 | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Load VAEs - these can be reused across models | |
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device) | |
good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=dtype).to(device) | |
def load_model(model_name): | |
model_info = models[model_name] | |
pipeline_class = model_info["pipeline_class"] | |
model_id = model_info["model_id"] | |
config = model_info["config"] | |
pipeline = pipeline_class.from_pretrained(model_id, **config, vae=taef1).to(device) | |
# Assign the custom function for live preview if it's a FluxPipeline or DiffusionPipeline | |
if pipeline_class in (FluxPipeline, DiffusionPipeline): | |
pipeline.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipeline) | |
return pipeline, model_info.get("description", "No description available.") | |
# Initialize with default model | |
current_model_name = "FLUX.1-dev" | |
pipe, model_description = load_model(current_model_name) | |
torch.cuda.empty_cache() | |
MAX_SEED = np.iinfo(np.int32).max | |
MAX_IMAGE_SIZE = 2048 | |
def infer(model_name, prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=3.5, num_inference_steps=28, progress=gr.Progress(track_tqdm=True)): | |
global pipe, current_model_name # Access the global pipe and model name | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
generator = torch.Generator().manual_seed(seed) | |
# Only reload the model if a different one is selected | |
if model_name != current_model_name: | |
pipe, _ = load_model(model_name) | |
current_model_name = model_name | |
for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images( | |
prompt=prompt, | |
guidance_scale=guidance_scale, | |
num_inference_steps=num_inference_steps, | |
width=width, | |
height=height, | |
generator=generator, | |
output_type="pil", | |
good_vae=good_vae, | |
): | |
yield img, seed | |
examples = [ | |
["FLUX.1-dev", "a tiny astronaut hatching from an egg on the moon"], | |
["FLUX.1-schnell", "a cat holding a sign that says hello world"], | |
["Flux.1-lite-8B-alpha", "an anime illustration of a wiener schnitzel"], | |
] | |
css=""" | |
#col-container { | |
margin: 0 auto; | |
max-width: 520px; | |
} | |
""" | |
with gr.Blocks(css=css) as demo: | |
with gr.Column(elem_id="col-container"): | |
gr.Markdown(f"""# FLUX.1 Model Selector | |
Select a model to generate images using the FLUX pipeline. | |
""") | |
model_selector = gr.Dropdown( | |
choices=list(models.keys()), | |
value=current_model_name, | |
label="Select Model", | |
) | |
model_description_box = gr.Markdown(model_description) | |
with gr.Row(): | |
prompt = gr.Text( | |
label="Prompt", | |
show_label=False, | |
max_lines=1, | |
placeholder="Enter your prompt", | |
container=False, | |
) | |
run_button = gr.Button("Run", scale=0) | |
result = gr.Image(label="Result", show_label=False) | |
with gr.Accordion("Advanced Settings", open=False): | |
seed = gr.Slider( | |
label="Seed", | |
minimum=0, | |
maximum=MAX_SEED, | |
step=1, | |
value=0, | |
) | |
randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
with gr.Row(): | |
width = gr.Slider( | |
label="Width", | |
minimum=256, | |
maximum=MAX_IMAGE_SIZE, | |
step=32, | |
value=1024, | |
) | |
height = gr.Slider( | |
label="Height", | |
minimum=256, | |
maximum=MAX_IMAGE_SIZE, | |
step=32, | |
value=1024, | |
) | |
with gr.Row(): | |
guidance_scale = gr.Slider( | |
label="Guidance Scale", | |
minimum=1, | |
maximum=15, | |
step=0.1, | |
value=3.5, | |
) | |
num_inference_steps = gr.Slider( | |
label="Number of inference steps", | |
minimum=1, | |
maximum=50, | |
step=1, | |
value=28, | |
) | |
def update_description(selected_model): | |
return models[selected_model]["description"] | |
model_selector.change( | |
fn=update_description, | |
inputs=[model_selector], | |
outputs=[model_description_box], | |
) | |
gr.Examples( | |
examples=examples, | |
fn=infer, | |
inputs=[model_selector, prompt], # Correct order of inputs | |
outputs=[result, seed], | |
cache_examples=False, | |
) | |
gr.on( | |
triggers=[run_button.click, prompt.submit], | |
fn=infer, | |
inputs=[model_selector, prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps], | |
outputs=[result, seed], | |
) | |
demo.launch() | |
#working1 | |
'''' | |
import gradio as gr | |
import numpy as np | |
import random | |
import spaces | |
import torch | |
from diffusers import DiffusionPipeline, FluxPipeline, AutoencoderTiny, AutoencoderKL | |
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast | |
from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images | |
# Define models and their configurations | |
models = { | |
"FLUX.1-schnell": { | |
"pipeline_class": FluxPipeline, | |
"model_id": "black-forest-labs/FLUX.1-schnell", | |
"config": {"torch_dtype": torch.bfloat16}, | |
"description": "**FLUX.1-schnell** is a fast and efficient model designed for quick image generation. It excels at producing high-quality images rapidly, making it ideal for applications where speed is crucial. However, its rapid generation may slightly compromise on the level of detail compared to slower, more meticulous models.", | |
}, | |
"FLUX.1-dev": { | |
"pipeline_class": DiffusionPipeline, | |
"model_id": "black-forest-labs/FLUX.1-dev", | |
"config": {"torch_dtype": torch.bfloat16}, | |
"description": "**FLUX.1-dev** is a development model that focuses on delivering highly detailed and artistically rich images.", | |
}, | |
"Flux.1-lite-8B-alpha": { | |
"pipeline_class": FluxPipeline, | |
"model_id": "Freepik/flux.1-lite-8B-alpha", | |
"config": {"torch_dtype": torch.bfloat16}, | |
"description": "**Flux.1-lite-8B-alpha** is a lightweight model optimized for efficiency and ease of use.", | |
}, | |
} | |
dtype = torch.bfloat16 | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Load VAEs - these can be reused across models | |
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device) | |
good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=dtype).to(device) | |
def load_model(model_name): | |
model_info = models[model_name] | |
pipeline_class = model_info["pipeline_class"] | |
model_id = model_info["model_id"] | |
config = model_info["config"] | |
pipeline = pipeline_class.from_pretrained(model_id, **config, vae=taef1).to(device) | |
# Assign the custom function for live preview if it's a FluxPipeline or DiffusionPipeline | |
if pipeline_class in (FluxPipeline, DiffusionPipeline): | |
pipeline.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipeline) | |
return pipeline, model_info.get("description", "No description available.") | |
# Initialize with default model | |
current_model_name = "FLUX.1-dev" | |
pipe, model_description = load_model(current_model_name) | |
torch.cuda.empty_cache() | |
MAX_SEED = np.iinfo(np.int32).max | |
MAX_IMAGE_SIZE = 2048 | |
@spaces.GPU(duration=75) | |
def infer(model_name, prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=3.5, num_inference_steps=28, progress=gr.Progress(track_tqdm=True)): | |
global pipe, current_model_name # Access the global pipe and model name | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
generator = torch.Generator().manual_seed(seed) | |
# Only reload the model if a different one is selected | |
if model_name != current_model_name: | |
pipe, _ = load_model(model_name) | |
current_model_name = model_name | |
for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images( | |
prompt=prompt, | |
guidance_scale=guidance_scale, | |
num_inference_steps=num_inference_steps, | |
width=width, | |
height=height, | |
generator=generator, | |
output_type="pil", | |
good_vae=good_vae, | |
): | |
yield img, seed | |
examples = [ | |
["FLUX.1-dev", "a tiny astronaut hatching from an egg on the moon"], | |
["FLUX.1-schnell", "a cat holding a sign that says hello world"], | |
["Flux.1-lite-8B-alpha", "an anime illustration of a wiener schnitzel"], | |
] | |
css=""" | |
#col-container { | |
margin: 0 auto; | |
max-width: 520px; | |
} | |
""" | |
with gr.Blocks(css=css) as demo: | |
with gr.Column(elem_id="col-container"): | |
gr.Markdown(f"""# FLUX.1 Model Selector | |
Select a model to generate images using the FLUX pipeline. | |
""") | |
model_selector = gr.Dropdown( | |
choices=list(models.keys()), | |
value=current_model_name, | |
label="Select Model", | |
) | |
model_description_box = gr.Markdown(model_description) | |
with gr.Row(): | |
prompt = gr.Text( | |
label="Prompt", | |
show_label=False, | |
max_lines=1, | |
placeholder="Enter your prompt", | |
container=False, | |
) | |
run_button = gr.Button("Run", scale=0) | |
result = gr.Image(label="Result", show_label=False) | |
with gr.Accordion("Advanced Settings", open=False): | |
seed = gr.Slider( | |
label="Seed", | |
minimum=0, | |
maximum=MAX_SEED, | |
step=1, | |
value=0, | |
) | |
randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
with gr.Row(): | |
width = gr.Slider( | |
label="Width", | |
minimum=256, | |
maximum=MAX_IMAGE_SIZE, | |
step=32, | |
value=1024, | |
) | |
height = gr.Slider( | |
label="Height", | |
minimum=256, | |
maximum=MAX_IMAGE_SIZE, | |
step=32, | |
value=1024, | |
) | |
with gr.Row(): | |
guidance_scale = gr.Slider( | |
label="Guidance Scale", | |
minimum=1, | |
maximum=15, | |
step=0.1, | |
value=3.5, | |
) | |
num_inference_steps = gr.Slider( | |
label="Number of inference steps", | |
minimum=1, | |
maximum=50, | |
step=1, | |
value=28, | |
) | |
def update_description(selected_model): | |
return models[selected_model]["description"] | |
model_selector.change( | |
fn=update_description, | |
inputs=[model_selector], | |
outputs=[model_description_box], | |
) | |
gr.Examples( | |
examples=examples, | |
fn=infer, | |
inputs=[model_selector, prompt], # Correct order of inputs | |
outputs=[result, seed], | |
cache_examples=False, | |
) | |
gr.on( | |
triggers=[run_button.click, prompt.submit], | |
fn=infer, | |
inputs=[model_selector, prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps], | |
outputs=[result, seed], | |
) | |
demo.launch() | |
''' | |
#orginal | |
''' | |
import gradio as gr | |
import numpy as np | |
import random | |
import spaces | |
import torch | |
from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler, AutoencoderTiny, AutoencoderKL | |
from transformers import CLIPTextModel, CLIPTokenizer,T5EncoderModel, T5TokenizerFast | |
from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images | |
dtype = torch.bfloat16 | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device) | |
good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=dtype).to(device) | |
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=dtype, vae=taef1).to(device) | |
torch.cuda.empty_cache() | |
MAX_SEED = np.iinfo(np.int32).max | |
MAX_IMAGE_SIZE = 2048 | |
pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe) | |
@spaces.GPU(duration=75) | |
def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=3.5, num_inference_steps=28, progress=gr.Progress(track_tqdm=True)): | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
generator = torch.Generator().manual_seed(seed) | |
for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images( | |
prompt=prompt, | |
guidance_scale=guidance_scale, | |
num_inference_steps=num_inference_steps, | |
width=width, | |
height=height, | |
generator=generator, | |
output_type="pil", | |
good_vae=good_vae, | |
): | |
yield img, seed | |
examples = [ | |
"a tiny astronaut hatching from an egg on the moon", | |
"a cat holding a sign that says hello world", | |
"an anime illustration of a wiener schnitzel", | |
] | |
css=""" | |
#col-container { | |
margin: 0 auto; | |
max-width: 520px; | |
} | |
""" | |
with gr.Blocks(css=css) as demo: | |
with gr.Column(elem_id="col-container"): | |
gr.Markdown(f"""# FLUX.1 [dev] | |
12B param rectified flow transformer guidance-distilled from [FLUX.1 [pro]](https://blackforestlabs.ai/) | |
[[non-commercial license](https://huggingface.co./black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md)] [[blog](https://blackforestlabs.ai/announcing-black-forest-labs/)] [[model](https://huggingface.co./black-forest-labs/FLUX.1-dev)] | |
""") | |
with gr.Row(): | |
prompt = gr.Text( | |
label="Prompt", | |
show_label=False, | |
max_lines=1, | |
placeholder="Enter your prompt", | |
container=False, | |
) | |
run_button = gr.Button("Run", scale=0) | |
result = gr.Image(label="Result", show_label=False) | |
with gr.Accordion("Advanced Settings", open=False): | |
seed = gr.Slider( | |
label="Seed", | |
minimum=0, | |
maximum=MAX_SEED, | |
step=1, | |
value=0, | |
) | |
randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
with gr.Row(): | |
width = gr.Slider( | |
label="Width", | |
minimum=256, | |
maximum=MAX_IMAGE_SIZE, | |
step=32, | |
value=1024, | |
) | |
height = gr.Slider( | |
label="Height", | |
minimum=256, | |
maximum=MAX_IMAGE_SIZE, | |
step=32, | |
value=1024, | |
) | |
with gr.Row(): | |
guidance_scale = gr.Slider( | |
label="Guidance Scale", | |
minimum=1, | |
maximum=15, | |
step=0.1, | |
value=3.5, | |
) | |
num_inference_steps = gr.Slider( | |
label="Number of inference steps", | |
minimum=1, | |
maximum=50, | |
step=1, | |
value=28, | |
) | |
gr.Examples( | |
examples = examples, | |
fn = infer, | |
inputs = [prompt], | |
outputs = [result, seed], | |
cache_examples="lazy" | |
) | |
gr.on( | |
triggers=[run_button.click, prompt.submit], | |
fn = infer, | |
inputs = [prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps], | |
outputs = [result, seed] | |
) | |
demo.launch() | |
''' | |
''' | |
import os | |
import torch | |
import gradio as gr | |
from diffusers import FluxPipeline, DiffusionPipeline | |
import spaces | |
# Helper function to get the Hugging Face token securely | |
def get_hf_token(): | |
try: | |
from google.colab import userdata | |
hf_token = userdata.get('HF_TOKEN') | |
if hf_token: | |
return hf_token | |
else: | |
raise RuntimeError("HF_TOKEN not found in Colab secrets.") | |
except ImportError: | |
return os.getenv("HF_TOKEN", None) | |
# Securely get the token | |
_HF_TOKEN = get_hf_token() | |
if not _HF_TOKEN: | |
raise ValueError("HF_TOKEN is not available. Please set it in Colab secrets or environment variables.") | |
# Define models and their configurations | |
models = { | |
"FLUX.1-schnell": { | |
"pipeline_class": FluxPipeline, | |
"model_id": "black-forest-labs/FLUX.1-schnell", | |
"config": {"torch_dtype": torch.bfloat16}, | |
"description": "**FLUX.1-schnell** is a fast and efficient model designed for quick image generation. It excels at producing high-quality images rapidly, making it ideal for applications where speed is crucial. However, its rapid generation may slightly compromise on the level of detail compared to slower, more meticulous models.", | |
}, | |
"FLUX.1-dev": { | |
"pipeline_class": DiffusionPipeline, | |
"model_id": "black-forest-labs/FLUX.1-dev", | |
"lora": { | |
"repo": "strangerzonehf/Flux-Enrich-Art-LoRA", | |
"trigger_word": "enrich art", | |
}, | |
"config": {"torch_dtype": torch.bfloat16}, | |
"description": "**FLUX.1-dev** is a development model that focuses on delivering highly detailed and artistically rich images.", | |
}, | |
"Flux.1-lite-8B-alpha": { | |
"pipeline_class": FluxPipeline, | |
"model_id": "Freepik/flux.1-lite-8B-alpha", | |
"config": {"torch_dtype": torch.bfloat16}, | |
"description": "**Flux.1-lite-8B-alpha** is a lightweight model optimized for efficiency and ease of use.", | |
}, | |
} | |
models = { | |
"FLUX.1-schnell": { | |
"pipeline_class": FluxPipeline, | |
"model_id": "black-forest-labs/FLUX.1-schnell", | |
"config": {"torch_dtype": torch.bfloat16}, | |
"description": "**FLUX.1-schnell** is a fast and efficient model designed for quick image generation. It excels at producing high-quality images rapidly, making it ideal for applications where speed is crucial. However, its rapid generation may slightly compromise on the level of detail compared to slower, more meticulous models.", | |
}, | |
} | |
# Function to pre-download models | |
def download_all_models(): | |
print("Downloading all models...") | |
for model_key, config in models.items(): | |
try: | |
pipeline_class = config["pipeline_class"] | |
model_id = config["model_id"] | |
# Attempt to download the pipeline without loading it into memory | |
pipeline_class.from_pretrained(model_id, token=_HF_TOKEN, **config.get("config", {})) | |
if "lora" in config: | |
pipeline_class.download_lora_weights(config["lora"]["repo"],token=_HF_TOKEN,) | |
print(f"Model '{model_key}' downloaded successfully.") | |
except Exception as e: | |
print(f"Error downloading model '{model_key}': {e}") | |
print("Model download process complete.") | |
loaded_models = {} | |
model_load_status = {} # Dictionary to track model load status | |
def clear_gpu_memory(): | |
"""Clears GPU memory. Keeps model status information.""" | |
global loaded_models | |
try: | |
for model_key in list(loaded_models.keys()): # Iterate over a copy to allow deletion | |
del loaded_models[model_key] | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
torch.cuda.ipc_collect() | |
print("GPU memory cleared.") | |
return "GPU memory cleared." | |
except Exception as e: | |
print(f"Error clearing GPU memory: {e}") | |
return f"Error clearing GPU memory: {e}" | |
def load_model(model_key): | |
"""Loads a model, clearing GPU memory first if a different model was loaded.""" | |
global model_load_status | |
if model_key not in models: | |
model_load_status[model_key] = "Model not found." | |
return f"Model '{model_key}' not found in the available models." | |
# Clear GPU memory only if a different model is already loaded | |
if loaded_models and list(loaded_models.keys())[0] != model_key: | |
clear_gpu_memory() | |
try: | |
config = models[model_key] | |
pipeline_class = config["pipeline_class"] | |
model_id = config["model_id"] | |
pipe = pipeline_class.from_pretrained(model_id, token=_HF_TOKEN, **config.get("config", {})) | |
if "lora" in config: | |
lora_config = config["lora"] | |
pipe.load_lora_weights(lora_config["repo"], token=_HF_TOKEN) | |
if torch.cuda.is_available(): | |
pipe.to("cuda") | |
loaded_models[model_key] = pipe | |
model_load_status[model_key] = "Loaded" # Update load status | |
return f"Model '{model_key}' loaded successfully." | |
except Exception as e: | |
model_load_status[model_key] = "Failed" # Update load status on error | |
return f"Error loading model '{model_key}': {e}" | |
@spaces.GPU() | |
def generate_image(model, prompt, seed=-1): | |
generator = torch.Generator(device="cuda" if torch.cuda.is_available() else "cpu") | |
if seed != -1: | |
generator = generator.manual_seed(seed) | |
with torch.no_grad(): | |
image = model(prompt=prompt, generator=generator).images[0] | |
return image | |
def gradio_generate(selected_model, prompt, seed): | |
if selected_model not in loaded_models: | |
if selected_model in model_load_status and model_load_status[selected_model] == "Loaded": | |
# Model should be loaded but isn't in loaded_models, clear it from the status | |
del model_load_status[selected_model] | |
if selected_model not in model_load_status or model_load_status[selected_model] != "Loaded": | |
# Attempt to load the model if not already attempted or failed | |
load_model(selected_model) | |
if selected_model not in loaded_models: | |
# If still not loaded after attempt, return an error | |
return f"Model not loaded. Load status: {model_load_status.get(selected_model, 'Not attempted')}.", None | |
model = loaded_models[selected_model] | |
image = generate_image(model, prompt, seed) | |
runtime_info = f"Model: {selected_model}\nSeed: {seed}" | |
output_path = "generated_image.png" | |
image.save(output_path) | |
return output_path, runtime_info | |
def gradio_load_model(selected_model): | |
if not selected_model: | |
return "No model selected. Please select a model to load." | |
return load_model(selected_model) | |
import gradio as gr | |
with gr.Blocks( | |
css=""" | |
.container { | |
max-width: 800px; | |
margin: auto; | |
padding: 20px; | |
background-color: #ffffff; /* Pure white for a clean background */ | |
border-radius: 10px; /* Smooth rounded corners */ | |
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.05); /* Subtle shadow for a lighter feel */ | |
color: #333333; /* Dark gray text for good contrast and readability */ | |
} | |
""" | |
) as interface: | |
with gr.Tab("Image Generator"): | |
with gr.Column(): | |
gr.Markdown("# Text-to-Image Generator") | |
model_dropdown = gr.Dropdown(choices=list(models.keys()), label="Select Model") | |
prompt_textbox = gr.Textbox(label="Enter Text Prompt") | |
seed_slider = gr.Slider(minimum=-1, maximum=1000, step=1, value=-1, label="Random Seed (-1 for random)") | |
generate_button = gr.Button("Generate Image") | |
output_image = gr.Image(label="Generated Image") | |
runtime_info_textbox = gr.Textbox(label="Runtime Information", lines=2, interactive=False) | |
# Add example prompts at the bottom | |
gr.Markdown("### Example Prompts") | |
examples = gr.Examples( | |
examples=[ | |
[list(models.keys())[0], "Sexy Woman", "sample2.png"], | |
# [list(models.keys())[2], "Sexy girl", "sample3.png"], | |
# [list(models.keys())[1], "Future City", "sample1.png"] | |
], | |
inputs=[model_dropdown, prompt_textbox, output_image], | |
) | |
generate_button.click(gradio_generate, inputs=[model_dropdown, prompt_textbox, seed_slider], outputs=[output_image, runtime_info_textbox]) | |
with gr.Tab("Model Information"): | |
for model_key, model_info in models.items(): | |
gr.Markdown(f"## {model_key}") | |
gr.Markdown(model_info["description"]) | |
gr.Markdown("""--- | |
**Credits**: Created by Ruslan Magana Vsevolodovna. For more information, visit [https://ruslanmv.com/](https://ruslanmv.com/).""") | |
if __name__ == "__main__": | |
# Pre-download all models at startup | |
download_all_models() | |
interface.launch(debug=True) | |
''' | |
''' | |
import os | |
import torch | |
import gradio as gr | |
from diffusers import FluxPipeline, DiffusionPipeline | |
import spaces | |
# Helper function to get the Hugging Face token securely | |
def get_hf_token(): | |
try: | |
from google.colab import userdata | |
hf_token = userdata.get('HF_TOKEN') | |
if hf_token: | |
return hf_token | |
else: | |
raise RuntimeError("HF_TOKEN not found in Colab secrets.") | |
except ImportError: | |
return os.getenv("HF_TOKEN", None) | |
# Securely get the token | |
_HF_TOKEN = get_hf_token() | |
if not _HF_TOKEN: | |
raise ValueError("HF_TOKEN is not available. Please set it in Colab secrets or environment variables.") | |
# Define models and their configurations | |
models = { | |
"FLUX.1-schnell": { | |
"pipeline_class": FluxPipeline, | |
"model_id": "black-forest-labs/FLUX.1-schnell", | |
"config": {"torch_dtype": torch.bfloat16}, | |
"description": "**FLUX.1-schnell** is a fast and efficient model designed for quick image generation. It excels at producing high-quality images rapidly, making it ideal for applications where speed is crucial. However, its rapid generation may slightly compromise on the level of detail compared to slower, more meticulous models.", | |
}, | |
"FLUX.1-dev": { | |
"pipeline_class": DiffusionPipeline, | |
"model_id": "black-forest-labs/FLUX.1-dev", | |
"lora": { | |
"repo": "strangerzonehf/Flux-Enrich-Art-LoRA", | |
"trigger_word": "enrich art", | |
}, | |
"config": {"torch_dtype": torch.bfloat16}, | |
"description": "**FLUX.1-dev** is a development model that focuses on delivering highly detailed and artistically rich images.", | |
}, | |
"Flux.1-lite-8B-alpha": { | |
"pipeline_class": FluxPipeline, | |
"model_id": "Freepik/flux.1-lite-8B-alpha", | |
"config": {"torch_dtype": torch.bfloat16}, | |
"description": "**Flux.1-lite-8B-alpha** is a lightweight model optimized for efficiency and ease of use.", | |
}, | |
} | |
# Function to pre-download models | |
def download_all_models(): | |
print("Downloading all models...") | |
for model_key, config in models.items(): | |
try: | |
pipeline_class = config["pipeline_class"] | |
model_id = config["model_id"] | |
# Attempt to download the pipeline without loading it into memory | |
pipeline_class.from_pretrained(model_id, token=_HF_TOKEN, **config.get("config", {})) | |
if "lora" in config: | |
pipeline_class.download_lora_weights(config["lora"]["repo"],token=_HF_TOKEN,) | |
print(f"Model '{model_key}' downloaded successfully.") | |
except Exception as e: | |
print(f"Error downloading model '{model_key}': {e}") | |
print("Model download process complete.") | |
loaded_models = {} | |
model_load_status = {} # Dictionary to track model load status | |
def clear_gpu_memory(): | |
"""Clears GPU memory. Keeps model status information.""" | |
global loaded_models | |
try: | |
for model_key in list(loaded_models.keys()): # Iterate over a copy to allow deletion | |
del loaded_models[model_key] | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
torch.cuda.ipc_collect() | |
print("GPU memory cleared.") | |
return "GPU memory cleared." | |
except Exception as e: | |
print(f"Error clearing GPU memory: {e}") | |
return f"Error clearing GPU memory: {e}" | |
def load_model(model_key): | |
"""Loads a model, clearing GPU memory first if a different model was loaded.""" | |
global model_load_status | |
if model_key not in models: | |
model_load_status[model_key] = "Model not found." | |
return f"Model '{model_key}' not found in the available models." | |
# Clear GPU memory only if a different model is already loaded | |
if loaded_models and list(loaded_models.keys())[0] != model_key: | |
clear_gpu_memory() | |
try: | |
config = models[model_key] | |
pipeline_class = config["pipeline_class"] | |
model_id = config["model_id"] | |
pipe = pipeline_class.from_pretrained(model_id, token=_HF_TOKEN, **config.get("config", {})) | |
if "lora" in config: | |
lora_config = config["lora"] | |
pipe.load_lora_weights(lora_config["repo"], token=_HF_TOKEN) | |
if torch.cuda.is_available(): | |
pipe.to("cuda") | |
loaded_models[model_key] = pipe | |
model_load_status[model_key] = "Loaded" # Update load status | |
return f"Model '{model_key}' loaded successfully." | |
except Exception as e: | |
model_load_status[model_key] = "Failed" # Update load status on error | |
return f"Error loading model '{model_key}': {e}" | |
@spaces.GPU() | |
def generate_image(model, prompt, seed=-1): | |
generator = torch.Generator(device="cuda" if torch.cuda.is_available() else "cpu") | |
if seed != -1: | |
generator = generator.manual_seed(seed) | |
with torch.no_grad(): | |
image = model(prompt=prompt, generator=generator).images[0] | |
return image | |
def gradio_generate(selected_model, prompt, seed): | |
if selected_model not in loaded_models: | |
if selected_model in model_load_status and model_load_status[selected_model] == "Loaded": | |
# Model should be loaded but isn't in loaded_models, clear it from the status | |
del model_load_status[selected_model] | |
if selected_model not in model_load_status or model_load_status[selected_model] != "Loaded": | |
# Attempt to load the model if not already attempted or failed | |
load_model(selected_model) | |
if selected_model not in loaded_models: | |
# If still not loaded after attempt, return an error | |
return f"Model not loaded. Load status: {model_load_status.get(selected_model, 'Not attempted')}.", None | |
model = loaded_models[selected_model] | |
image = generate_image(model, prompt, seed) | |
runtime_info = f"Model: {selected_model}\nSeed: {seed}" | |
output_path = "generated_image.png" | |
image.save(output_path) | |
return output_path, runtime_info | |
def gradio_load_model(selected_model): | |
if not selected_model: | |
return "No model selected. Please select a model to load." | |
return load_model(selected_model) | |
with gr.Blocks( | |
css=""" | |
.container { | |
max-width: 800px; | |
margin: auto; | |
padding: 20px; | |
background-color: #f8f8f8; | |
border-radius: 10px; | |
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); | |
} | |
""" | |
) as interface: | |
with gr.Tab("Image Generator"): | |
with gr.Column(): | |
gr.Markdown("# Text-to-Image Generator") | |
model_dropdown = gr.Dropdown(choices=list(models.keys()), label="Select Model") | |
prompt_textbox = gr.Textbox(label="Enter Text Prompt") | |
seed_slider = gr.Slider(minimum=-1, maximum=1000, step=1, value=-1, label="Random Seed (-1 for random)") | |
generate_button = gr.Button("Generate Image") | |
output_image = gr.Image(label="Generated Image") | |
runtime_info_textbox = gr.Textbox(label="Runtime Information", lines=2, interactive=False) | |
# Add example prompts at the bottom | |
gr.Markdown("### Example Prompts") | |
examples = gr.Examples( | |
examples=[ | |
[list(models.keys())[0], "Sexy Woman", "sample2.png"], | |
[list(models.keys())[2], "Sexy girl", "sample3.png"], | |
[list(models.keys())[1], "Future City", "sample1.png"] | |
], | |
inputs=[model_dropdown, prompt_textbox, output_image], | |
) | |
generate_button.click(gradio_generate, inputs=[model_dropdown, prompt_textbox, seed_slider], outputs=[output_image, runtime_info_textbox]) | |
with gr.Tab("Model Information"): | |
for model_key, model_info in models.items(): | |
gr.Markdown(f"## {model_key}") | |
gr.Markdown(model_info["description"]) | |
gr.Markdown("""--- | |
**Credits**: Created by Ruslan Magana Vsevolodovna. For more information, visit [https://ruslanmv.com/](https://ruslanmv.com/).""") | |
if __name__ == "__main__": | |
# Pre-download all models at startup | |
download_all_models() | |
# Load only the first model | |
load_model(list(models.keys())[0]) | |
interface.launch(debug=True) | |
''' | |
''' | |
with gr.Blocks( | |
css=""" | |
.container { | |
max-width: 800px; | |
margin: auto; | |
padding: 20px; | |
background-color: #f8f8f8; | |
border-radius: 10px; | |
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); | |
} | |
""" | |
) as interface: | |
with gr.Tab("Image Generator"): | |
with gr.Column(): | |
gr.Markdown("# Text-to-Image Generator") | |
model_dropdown = gr.Dropdown(choices=list(models.keys()), label="Select Model") | |
prompt_textbox = gr.Textbox(label="Enter Text Prompt") | |
seed_slider = gr.Slider(minimum=-1, maximum=1000, step=1, value=-1, label="Random Seed (-1 for random)") | |
generate_button = gr.Button("Generate Image") | |
output_image = gr.Image(label="Generated Image") | |
runtime_info_textbox = gr.Textbox(label="Runtime Information", lines=2, interactive=False) | |
# Add example prompts at the bottom | |
gr.Markdown("### Example Prompts") | |
examples = gr.Examples( | |
examples=[ | |
[list(models.keys())[0], "Sexy girl"], | |
[list(models.keys())[0], "Beautiful Woman"], | |
[list(models.keys())[0], "Future City"] | |
], | |
inputs=[model_dropdown, prompt_textbox], | |
) | |
generate_button.click(gradio_generate, inputs=[model_dropdown, prompt_textbox, seed_slider], outputs=[output_image, runtime_info_textbox]) | |
with gr.Tab("Model Information"): | |
for model_key, model_info in models.items(): | |
gr.Markdown(f"## {model_key}") | |
gr.Markdown(model_info["description"]) | |
gr.Markdown("""--- | |
**Credits**: Created by Ruslan Magana Vsevolodovna. For more information, visit [https://ruslanmv.com/](https://ruslanmv.com/).""") | |
if __name__ == "__main__": | |
interface.launch(debug=True) | |
''' | |
''' | |
with gr.Blocks( | |
css=""" | |
.container { | |
max-width: 800px; | |
margin: auto; | |
padding: 20px; | |
background-color: #f8f8f8; | |
border-radius: 10px; | |
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); | |
} | |
""" | |
) as interface: | |
with gr.Tab("Image Generator"): | |
with gr.Column(): | |
gr.Markdown("# Text-to-Image Generator") | |
model_dropdown = gr.Dropdown(choices=list(models.keys()), label="Select Model") | |
#with gr.Row(): | |
#load_button = gr.Button("Load Model") | |
# clear_button = gr.Button("Clear GPU Memory") # Removed clear button | |
#load_status = gr.Textbox(label="Model Load Status", interactive=False) # Removed load button | |
prompt_textbox = gr.Textbox(label="Enter Text Prompt") | |
seed_slider = gr.Slider(minimum=-1, maximum=1000, step=1, value=-1, label="Random Seed (-1 for random)") | |
generate_button = gr.Button("Generate Image") | |
output_image = gr.Image(label="Generated Image") | |
runtime_info_textbox = gr.Textbox(label="Runtime Information", lines=2, interactive=False) | |
#load_button.click(gradio_load_model, inputs=[model_dropdown], outputs=[load_status]) # Removed load button click action | |
# clear_button.click(clear_gpu_memory, outputs=[load_status]) # Removed clear button click action | |
generate_button.click(gradio_generate, inputs=[model_dropdown, prompt_textbox, seed_slider], outputs=[output_image, runtime_info_textbox]) | |
with gr.Tab("Model Information"): | |
for model_key, model_info in models.items(): | |
gr.Markdown(f"## {model_key}") | |
gr.Markdown(model_info["description"]) | |
gr.Markdown("""--- | |
**Credits**: Created by Ruslan Magana Vsevolodovna. For more information, visit [https://ruslanmv.com/](https://ruslanmv.com/).""") | |
if __name__ == "__main__": | |
interface.launch(debug=True) | |
''' |