File size: 3,049 Bytes
dfaea05 d44229e dfaea05 8772541 d44229e dfaea05 d44229e dfaea05 d44229e dfaea05 d44229e dfaea05 8110171 8772541 dfaea05 d44229e 47fe7d6 dfaea05 8772541 dfaea05 d44229e dfaea05 d44229e dfaea05 d44229e 8772541 d44229e 8772541 dfaea05 d44229e dfaea05 d44229e dfaea05 8772541 d44229e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
from diffusers import DiffusionPipeline
import gradio as gr
import torch
import time
import psutil
start_time = time.time()
device = "GPU 🔥" if torch.cuda.is_available() else "CPU 🥶"
def error_str(error, title="Error"):
return (
f"""#### {title}
{error}"""
if error
else ""
)
def inference(
repo_id,
discuss_nr,
prompt,
):
print(psutil.virtual_memory()) # print memory usage
seed = 0
torch_device = "cuda" if "GPU" in device else "cpu"
generator = torch.Generator(torch_device).manual_seed(seed)
dtype = torch.float16 if torch_device == "cuda" else torch.float32
try:
revision = f"refs/pr/{discuss_nr}" if (discuss_nr != "" or discuss_nr is None) else None
pipe = DiffusionPipeline.from_pretrained(repo_id, revision=revision, torch_dtype=dtype)
pipe.to(torch_device)
return pipe(prompt, generator=generator, num_inference_steps=25).images, f"Done. Seed: {seed}"
except Exception as e:
url = f"https://huggingface.co./{repo_id}/discussions/{discuss_nr}"
message = f"There is a problem with your diffusers weights of the PR: {url}. Error message: \n"
return None, error_str(message + e)
with gr.Blocks(css="style.css") as demo:
gr.HTML(
f"""
<div class="diffusion">
<p>
Space to test whether `diffusers` PRs work.
</p>
<p>
Running on <b>{device}</b>
</p>
</div>
"""
)
with gr.Row():
with gr.Column(scale=55):
with gr.Group():
repo_id = gr.Textbox(
label="Repo id on Hub",
placeholder="Path to model, e.g. CompVis/stable-diffusion-v1-4 for https://huggingface.co./CompVis/stable-diffusion-v1-4",
)
discuss_nr = gr.Textbox(
label="Discussion number",
placeholder="Number of the discussion that should be checked, e.g. 171 for https://huggingface.co./CompVis/stable-diffusion-v1-4/discussions/171",
)
prompt = gr.Textbox(
label="Prompt",
default="An astronaut riding a horse on Mars.",
placeholder="Enter prompt.",
)
gallery = gr.Gallery(
label="Generated images", show_label=False, elem_id="gallery"
).style(grid=[2], height="auto")
error_output = gr.Markdown()
generate = gr.Button(value="Generate").style(
rounded=(False, True, True, False)
)
inputs = [
repo_id,
discuss_nr,
prompt,
]
outputs = [gallery, error_output]
prompt.submit(inference, inputs=inputs, outputs=outputs)
generate.click(inference, inputs=inputs, outputs=outputs)
print(f"Space built in {time.time() - start_time:.2f} seconds")
demo.queue(concurrency_count=1)
demo.launch()
|