File size: 2,915 Bytes
0925cf1
40822a4
c63d488
0925cf1
8ef1d5d
0925cf1
0b23474
9da8ea8
 
 
a17e285
92ec9db
 
6bb6b9b
92ec9db
 
0925cf1
a17e285
c6747cf
 
 
 
 
 
 
 
874cb7c
c6747cf
 
0b23474
 
c8f91a3
 
b282552
 
c6747cf
 
 
 
 
 
 
 
b282552
 
c8f91a3
edf126d
92ec9db
0925cf1
a031477
 
 
82d2444
 
8caeb79
9627d6f
44549ae
 
18c183f
 
874cb7c
82d2444
c8f91a3
 
874cb7c
82d2444
0b23474
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
from diffusers import DiffusionPipeline
import gradio as gr
import os
import spaces

# Load the models outside of the generate_images function
model_list = [model.strip() for model in os.environ.get("MODELS").split(",")]
lora_list = [model.strip() for model in os.environ.get("LORAS").split(",")]

models = {}
for model_name in model_list:
    try:
        models[model_name] = DiffusionPipeline.from_pretrained(model_name, torch_dtype=torch.float16).to("cuda")
    except Exception as e:
        print(f"Error loading model {model_name}: {e}")

@spaces.GPU
def generate_images(
    model_name, 
    prompt, 
    negative_prompt, 
    num_inference_steps, 
    guidance_scale, 
    height,
    width,
    num_images=4,
    progress=gr.Progress(track_tqdm=True)
):
    pipe = models.get(model_name)
    if pipe is None:
        return []

    outputs = []
    for _ in range(num_images):
        output = pipe(
            prompt, 
            negative_prompt=negative_prompt, 
            num_inference_steps=num_inference_steps, 
            guidance_scale=guidance_scale,
            height=height,
            width=width
        )["images"][0]
        outputs.append(output)

    return outputs

# Create the Gradio blocks
with gr.Blocks() as demo:
    with gr.Row(equal_height=False):
        with gr.Group():
            with gr.Column():
                model_dropdown = gr.Dropdown(choices=list(models.keys()), value=model_list[0] if model_list else None, label="Model")
                prompt = gr.Textbox(label="Prompt")
                with gr.Accordion("Advanced", open=False):
                    negative_prompt = gr.Textbox(label="Negative Prompt", value="lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract], kid, kid looking, child, childish look")
                    num_inference_steps = gr.Slider(minimum=10, maximum=50, step=1, value=25, label="Number of Inference Steps")
                    guidance_scale = gr.Slider(minimum=1, maximum=20, step=0.5, value=7.5, label="Guidance Scale")
                    height = gr.Slider(minimum=1024, maximum=2048, step=256, value=1024, label="Height")
                    width = gr.Slider(minimum=1024, maximum=2048, step=256, value=1024, label="Width")
                    num_images = gr.Slider(minimum=1, maximum=4, step=1, value=4, label="Number of Images")
                generate_btn = gr.Button("Generate Image")
            with gr.Column():
                output_gallery = gr.Gallery(label="Generated Images", height=480, scale=1)
            generate_btn.click(generate_images, inputs=[model_dropdown, prompt, negative_prompt, num_inference_steps, guidance_scale, height, width, num_images], outputs=output_gallery)

demo.launch()