File size: 3,583 Bytes
0925cf1
40822a4
c63d488
0925cf1
c92c1fc
0925cf1
0b23474
c92c1fc
 
fd34bb6
c92c1fc
c2b6d1e
c92c1fc
 
 
c576f11
c92c1fc
 
 
0925cf1
a17e285
c6747cf
e66a721
 
 
 
 
c6747cf
 
874cb7c
65dc494
c6747cf
e66a721
8f724dc
 
 
c8f91a3
8f724dc
 
db07984
dfe65d8
 
 
 
 
 
 
 
 
 
e66a721
1cdbaa3
b282552
c28f29b
 
 
edf126d
92ec9db
daa6766
a031477
761d42b
 
 
 
 
 
 
 
 
 
 
 
 
 
8dae870
fc8e9a3
82d2444
761d42b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
from diffusers import DiffusionPipeline
import gradio as gr
import os
import spaces

# Load the models outside of the generate_images function
model_list = [model.strip() for model in os.environ.get("MODELS").split(",")]
lora_list = [model.strip() for model in os.environ.get("LORAS").split(",")]

print(f"Detected {len(model_list)} on models and {len(lora_list)} LoRAs.")

models = {}
for model_name in model_list:
    try:
        print(f"\n\nLoading {model_name}...")
        models[model_name] = DiffusionPipeline.from_pretrained(model_name, torch_dtype=torch.float16).to("cuda")
    except Exception as e:
        print(f"Error loading model {model_name}: {e}")

@spaces.GPU
def generate_images(
    model_name,
    prompt,
    negative_prompt,
    num_inference_steps,
    guidance_scale,
    height,
    width,
    num_images=4,
    progress=gr.Progress(track_tqdm=True)
):
    if prompt is not None and prompt.strip() != "":
        pipe = models.get(model_name)
        if pipe is None:
            return []

        print(f"Prompt is: [ {prompt} ]")
        outputs = []

        for _ in range(num_images):
            output = pipe(
                prompt,
                negative_prompt=negative_prompt,
                num_inference_steps=num_inference_steps,
                guidance_scale=guidance_scale,
                height=height,
                width=width
            )["images"][0]
            outputs.append(output)

        

        return outputs
    else:
        gr.Warning("Prompt empty!")

# Create the Gradio blocks
with gr.Blocks(css="style.css", theme='derekzen/stardust') as demo:
    with gr.Row(equal_height=False):
        with gr.Column(elem_id="input_column"):
            with gr.Group(elem_id="input_group"):
                model_dropdown = gr.Dropdown(choices=list(models.keys()), value=model_list[0] if model_list else None, label="Model", elem_id="model_dropdown")
                prompt = gr.Textbox(label="Prompt", elem_id="prompt_textbox")
                generate_btn = gr.Button("Generate Image", elem_id="generate_button")
            with gr.Accordion("Advanced", open=False, elem_id="advanced_accordion"):
                negative_prompt = gr.Textbox(label="Negative Prompt", value="lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]", elem_id="negative_prompt_textbox")
                num_inference_steps = gr.Slider(minimum=10, maximum=50, step=1, value=25, label="Number of Inference Steps", elem_id="num_inference_steps_slider")
                guidance_scale = gr.Slider(minimum=1, maximum=20, step=0.5, value=7.5, label="Guidance Scale", elem_id="guidance_scale_slider")
                height = gr.Slider(minimum=1024, maximum=2048, step=256, value=1024, label="Height", elem_id="height_slider")
                width = gr.Slider(minimum=1024, maximum=2048, step=256, value=1024, label="Width", elem_id="width_slider")
                num_images = gr.Slider(minimum=1, maximum=4, step=1, value=4, label="Number of Images", elem_id="num_images_slider")
        with gr.Column(elem_id="output_column"):
            output_gallery = gr.Gallery(label="Generated Images", height=480, scale=1, elem_id="output_gallery")
        
        generate_btn.click(generate_images, inputs=[model_dropdown, prompt, negative_prompt, num_inference_steps, guidance_scale, height, width, num_images], outputs=output_gallery)

demo.launch()