File size: 3,049 Bytes
dfaea05
d44229e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dfaea05
8772541
d44229e
 
 
 
 
dfaea05
 
d44229e
dfaea05
d44229e
dfaea05
d44229e
dfaea05
8110171
8772541
dfaea05
d44229e
47fe7d6
dfaea05
8772541
dfaea05
 
d44229e
 
 
dfaea05
 
 
 
 
 
 
 
 
 
 
 
d44229e
 
 
 
dfaea05
d44229e
8772541
d44229e
8772541
 
 
dfaea05
 
 
 
 
 
d44229e
 
 
 
 
 
dfaea05
 
 
d44229e
 
dfaea05
8772541
d44229e
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from diffusers import DiffusionPipeline
import gradio as gr
import torch
import time
import psutil


start_time = time.time()

device = "GPU 🔥" if torch.cuda.is_available() else "CPU 🥶"


def error_str(error, title="Error"):
    return (
        f"""#### {title}
            {error}"""
        if error
        else ""
    )


def inference(
    repo_id,
    discuss_nr,
    prompt,
):

    print(psutil.virtual_memory())  # print memory usage

    seed = 0
    torch_device = "cuda" if "GPU" in device else "cpu"

    generator = torch.Generator(torch_device).manual_seed(seed)

    dtype = torch.float16 if torch_device == "cuda" else torch.float32

    try:
        revision = f"refs/pr/{discuss_nr}" if (discuss_nr != "" or discuss_nr is None) else None
        pipe = DiffusionPipeline.from_pretrained(repo_id, revision=revision, torch_dtype=dtype)
        pipe.to(torch_device)

        return pipe(prompt, generator=generator, num_inference_steps=25).images, f"Done. Seed: {seed}"
    except Exception as e:
        url = f"https://huggingface.co./{repo_id}/discussions/{discuss_nr}"
        message = f"There is a problem with your diffusers weights of the PR: {url}. Error message: \n"
        return None, error_str(message + e)


with gr.Blocks(css="style.css") as demo:
    gr.HTML(
        f"""
            <div class="diffusion">
              <p>
               Space to test whether `diffusers` PRs work.
              </p>
              <p>
               Running on <b>{device}</b>
              </p>
            </div>
        """
    )
    with gr.Row():

        with gr.Column(scale=55):
            with gr.Group():
                repo_id = gr.Textbox(
                    label="Repo id on Hub",
                    placeholder="Path to model, e.g. CompVis/stable-diffusion-v1-4 for https://huggingface.co./CompVis/stable-diffusion-v1-4",
                )
                discuss_nr = gr.Textbox(
                    label="Discussion number",
                    placeholder="Number of the discussion that should be checked, e.g. 171 for https://huggingface.co./CompVis/stable-diffusion-v1-4/discussions/171",
                )
                prompt = gr.Textbox(
                    label="Prompt",
                    default="An astronaut riding a horse on Mars.",
                    placeholder="Enter prompt.",
                )
                gallery = gr.Gallery(
                    label="Generated images", show_label=False, elem_id="gallery"
                ).style(grid=[2], height="auto")

            error_output = gr.Markdown()

            generate = gr.Button(value="Generate").style(
                rounded=(False, True, True, False)
            )

    inputs = [
        repo_id,
        discuss_nr,
        prompt,
    ]
    outputs = [gallery, error_output]
    prompt.submit(inference, inputs=inputs, outputs=outputs)
    generate.click(inference, inputs=inputs, outputs=outputs)

print(f"Space built in {time.time() - start_time:.2f} seconds")

demo.queue(concurrency_count=1)
demo.launch()