rafaaa2105's picture
Update app.py
fc8e9a3 verified
raw
history blame
No virus
3.63 kB
import torch
from diffusers import DiffusionPipeline
import gradio as gr
import os
import spaces
# Load the models outside of the generate_images function
model_list = [model.strip() for model in os.environ.get("MODELS").split(",")]
lora_list = [model.strip() for model in os.environ.get("LORAS").split(",")]
models = {}
for model_name in model_list:
try:
models[model_name] = DiffusionPipeline.from_pretrained(model_name, torch_dtype=torch.float16).to("cuda")
except Exception as e:
print(f"Error loading model {model_name}: {e}")
@spaces.GPU
def generate_images(
model_name,
prompt,
negative_prompt,
num_inference_steps,
guidance_scale,
height,
width,
num_images=4,
progress=gr.Progress(track_tqdm=True)
):
pipe = models.get(model_name)
if pipe is None:
return []
outputs = []
for _ in range(num_images):
output = pipe(
prompt,
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
height=height,
width=width
)["images"][0]
outputs.append(output)
return outputs
# Create the Gradio blocks
with gr.Blocks() as demo:
with gr.Row(equal_height=False):
with gr.Column():
with gr.Group():
model_dropdown = gr.Dropdown(choices=list(models.keys()), value=model_list[0] if model_list else None, label="Model")
prompt = gr.Textbox(label="Prompt")
with gr.Accordion("Advanced", open=False):
negative_prompt = gr.Textbox(label="Negative Prompt", value="verybadimagenegative_v1.3, ng_deepnegative_v1_75t, (ugly face:0.8),cross-eyed,sketches, (worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, ((monochrome)), ((grayscale)), skin spots, acnes, skin blemishes, bad anatomy, DeepNegative, facing away, tilted head, {Multiple people}, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worstquality, low quality, normal quality, jpegartifacts, signature, watermark, username, blurry, bad feet, cropped, poorly drawn hands, poorly drawn face, mutation, deformed, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, extra fingers, fewer digits, extra limbs, extra arms,extra legs, malformed limbs, fused fingers, too many fingers, long neck, cross-eyed,mutated hands, polar lowres, bad body, bad proportions, gross proportions, text, error, missing fingers, missing arms, missing legs, extra digit, extra arms, extra leg, extra foot, ((repeating hair))")
num_inference_steps = gr.Slider(minimum=10, maximum=50, step=1, value=25, label="Number of Inference Steps")
guidance_scale = gr.Slider(minimum=1, maximum=20, step=0.5, value=7.5, label="Guidance Scale")
height = gr.Slider(minimum=1024, maximum=2048, step=256, value=1024, label="Height")
width = gr.Slider(minimum=1024, maximum=2048, step=256, value=1024, label="Width")
num_images = gr.Slider(minimum=1, maximum=4, step=1, value=4, label="Number of Images")
generate_btn = gr.Button("Generate Image")
with gr.Column():
output_gallery = gr.Gallery(label="Generated Images", height=480, scale=1)
generate_btn.click(generate_images, inputs=[model_dropdown, prompt, negative_prompt, num_inference_steps, guidance_scale, height, width, num_images], outputs=output_gallery)
demo.launch()