fia
CPU Settings
64d8b6a unverified
raw
history blame contribute delete
No virus
5.76 kB
import gradio as gr
import torch
from torch import autocast
from diffusers import StableDiffusionPipeline
from datasets import load_dataset
from PIL import Image
import re
import streamlit as st
model_id = "CompVis/stable-diffusion-v1-4"
device = "cpu"
#If you are running this code locally, you need to either do a 'huggingface-cli login` or paste your User Access Token from here https://huggingface.co./settings/tokens into the use_auth_token field below.
pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=st.secrets["AUTH_KEY"], torch_dtype=torch.float32)
def dummy(images, **kwargs): return images, False
pipe.safety_checker = dummy
def infer(prompt, width, height, steps, scale, seed):
if seed == -1:
images_list = pipe(
[prompt],
height=height,
width=width,
num_inference_steps=steps,
guidance_scale=scale,
generator=torch.Generator(device=device).manual_seed(seed))
else:
images_list = pipe(
[prompt],
height=height,
width=width,
num_inference_steps=steps,
guidance_scale=scale)
return images_list["sample"]
css = """
.gradio-container {
font-family: 'IBM Plex Sans', sans-serif;
}
.gr-button {
color: white;
border-color: black;
background: black;
}
input[type='range'] {
accent-color: black;
}
.dark input[type='range'] {
accent-color: #dfdfdf;
}
.container {
max-width: 730px;
margin: auto;
padding-top: 1.5rem;
}
#gallery {
min-height: 22rem;
margin-bottom: 15px;
margin-left: auto;
margin-right: auto;
border-bottom-right-radius: .5rem !important;
border-bottom-left-radius: .5rem !important;
}
#gallery>div>.h-full {
min-height: 20rem;
}
.details:hover {
text-decoration: underline;
}
.gr-button {
white-space: nowrap;
}
.gr-button:focus {
border-color: rgb(147 197 253 / var(--tw-border-opacity));
outline: none;
box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000);
--tw-border-opacity: 1;
--tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color);
--tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(3px var(--tw-ring-offset-width)) var(--tw-ring-color);
--tw-ring-color: rgb(191 219 254 / var(--tw-ring-opacity));
--tw-ring-opacity: .5;
}
.footer {
margin-bottom: 45px;
margin-top: 35px;
text-align: center;
border-bottom: 1px solid #e5e5e5;
}
.footer>p {
font-size: .8rem;
display: inline-block;
padding: 0 10px;
transform: translateY(10px);
background: white;
}
.dark .footer {
border-color: #303030;
}
.dark .footer>p {
background: #0b0f19;
}
.acknowledgments h4{
margin: 1.25em 0 .25em 0;
font-weight: bold;
font-size: 115%;
}
"""
block = gr.Blocks(css=css)
with block:
gr.HTML(
"""
<div style="text-align: center; max-width: 650px; margin: 0 auto;">
<div
style="
display: inline-flex;
align-items: center;
gap: 0.8rem;
font-size: 1.75rem;
"
>
<h1 style="font-weight: 900; margin-bottom: 7px;">
Stable Diffusion CPU
</h1>
</div>
</div>
"""
)
with gr.Group():
with gr.Box():
with gr.Row().style(mobile_collapse=False, equal_height=True):
text = gr.Textbox(
label="Enter your prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
).style(
border=(True, False, True, True),
rounded=(True, False, False, True),
container=False,
)
btn = gr.Button("Generate image").style(
margin=False,
rounded=(False, True, True, False),
)
gallery = gr.Gallery(
label="Generated images", show_label=False, elem_id="gallery"
).style(grid=[2], height="auto")
with gr.Row().style(mobile_collapse=False, equal_height=True):
width = gr.Slider(label="Width", minimum=32, maximum=1024, value=512, step=8)
height = gr.Slider(label="Height", minimum=32, maximum=1024, value=512, step=8)
with gr.Row():
steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=30, step=1)
scale = gr.Slider(
label="Guidance Scale", minimum=0, maximum=50, value=7.5, step=0.1
)
seed = gr.Slider(
label="Seed",
minimum=-1,
maximum=2147483647,
step=1,
value=-1,
)
text.submit(infer, inputs=[text, width, height, steps, scale, seed], outputs=gallery)
btn.click(infer, inputs=[text, width, height, steps, scale, seed], outputs=gallery)
block.queue(max_size=10).launch()