Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,915 Bytes
0925cf1 40822a4 c63d488 0925cf1 8ef1d5d 0925cf1 0b23474 9da8ea8 a17e285 92ec9db 6bb6b9b 92ec9db 0925cf1 a17e285 c6747cf 874cb7c c6747cf 0b23474 c8f91a3 b282552 c6747cf b282552 c8f91a3 edf126d 92ec9db 0925cf1 a031477 82d2444 8caeb79 9627d6f 44549ae 18c183f 874cb7c 82d2444 c8f91a3 874cb7c 82d2444 0b23474 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import torch
from diffusers import DiffusionPipeline
import gradio as gr
import os
import spaces
# Load the models outside of the generate_images function
model_list = [model.strip() for model in os.environ.get("MODELS").split(",")]
lora_list = [model.strip() for model in os.environ.get("LORAS").split(",")]
models = {}
for model_name in model_list:
try:
models[model_name] = DiffusionPipeline.from_pretrained(model_name, torch_dtype=torch.float16).to("cuda")
except Exception as e:
print(f"Error loading model {model_name}: {e}")
@spaces.GPU
def generate_images(
model_name,
prompt,
negative_prompt,
num_inference_steps,
guidance_scale,
height,
width,
num_images=4,
progress=gr.Progress(track_tqdm=True)
):
pipe = models.get(model_name)
if pipe is None:
return []
outputs = []
for _ in range(num_images):
output = pipe(
prompt,
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
height=height,
width=width
)["images"][0]
outputs.append(output)
return outputs
# Create the Gradio blocks
with gr.Blocks() as demo:
with gr.Row(equal_height=False):
with gr.Group():
with gr.Column():
model_dropdown = gr.Dropdown(choices=list(models.keys()), value=model_list[0] if model_list else None, label="Model")
prompt = gr.Textbox(label="Prompt")
with gr.Accordion("Advanced", open=False):
negative_prompt = gr.Textbox(label="Negative Prompt", value="lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract], kid, kid looking, child, childish look")
num_inference_steps = gr.Slider(minimum=10, maximum=50, step=1, value=25, label="Number of Inference Steps")
guidance_scale = gr.Slider(minimum=1, maximum=20, step=0.5, value=7.5, label="Guidance Scale")
height = gr.Slider(minimum=1024, maximum=2048, step=256, value=1024, label="Height")
width = gr.Slider(minimum=1024, maximum=2048, step=256, value=1024, label="Width")
num_images = gr.Slider(minimum=1, maximum=4, step=1, value=4, label="Number of Images")
generate_btn = gr.Button("Generate Image")
with gr.Column():
output_gallery = gr.Gallery(label="Generated Images", height=480, scale=1)
generate_btn.click(generate_images, inputs=[model_dropdown, prompt, negative_prompt, num_inference_steps, guidance_scale, height, width, num_images], outputs=output_gallery)
demo.launch()
|