|
import gradio as gr |
|
import pillow_heif |
|
import spaces |
|
import torch |
|
from huggingface_hub import ( |
|
hf_hub_download, |
|
) |
|
from PIL import Image |
|
from refiners.fluxion.utils import manual_seed, no_grad |
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1 import StableDiffusion_1 |
|
|
|
pillow_heif.register_heif_opener() |
|
pillow_heif.register_avif_opener() |
|
|
|
TITLE = """ |
|
# SD1.5 with Refiners |
|
""" |
|
|
|
|
|
DEVICE_CPU = torch.device("cpu") |
|
DEVICE_GPU = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 |
|
|
|
model = StableDiffusion_1(device=DEVICE_CPU, dtype=DTYPE) |
|
model.unet.load_from_safetensors( |
|
tensors_path=hf_hub_download( |
|
repo_id="refiners/sd15.unet", |
|
filename="model.safetensors", |
|
revision="6b01fc610c7465fa79e44c52c4d2eb0ea56821c9", |
|
) |
|
) |
|
model.lda.load_from_safetensors( |
|
tensors_path=hf_hub_download( |
|
repo_id="refiners/sd15.autoencoder", |
|
filename="model.safetensors", |
|
revision="7565efe4812d8e14072111ab326b15eea4c908a5", |
|
) |
|
) |
|
model.clip_text_encoder.load_from_safetensors( |
|
tensors_path=hf_hub_download( |
|
repo_id="refiners/sd15.text_encoder", |
|
filename="model.safetensors", |
|
revision="1b5023ecf0d646b7403f4ad182b6f0ab6b251fef", |
|
) |
|
) |
|
|
|
|
|
model.to(device=DEVICE_GPU, dtype=DTYPE) |
|
model.unet.to(device=DEVICE_GPU, dtype=DTYPE) |
|
model.lda.to(device=DEVICE_GPU, dtype=DTYPE) |
|
model.clip_text_encoder.to(device=DEVICE_GPU, dtype=DTYPE) |
|
model.solver.to(device=DEVICE_GPU, dtype=DTYPE) |
|
model.device = DEVICE_GPU |
|
model.dtype = DTYPE |
|
|
|
|
|
@spaces.GPU |
|
@no_grad() |
|
def process( |
|
prompt: str, |
|
negative_prompt: str, |
|
condition_scale: float, |
|
num_inference_steps: int, |
|
seed: int, |
|
) -> Image.Image: |
|
assert condition_scale >= 0 |
|
assert num_inference_steps > 0 |
|
assert seed >= 0 |
|
|
|
|
|
manual_seed(seed) |
|
|
|
|
|
clip_text_embedding = model.compute_clip_text_embedding( |
|
text=prompt, |
|
negative_text=negative_prompt, |
|
) |
|
|
|
|
|
x = model.init_latents(size=(512, 512)) |
|
|
|
|
|
for step in model.steps: |
|
x = model( |
|
x, |
|
step=step, |
|
clip_text_embedding=clip_text_embedding, |
|
condition_scale=condition_scale, |
|
) |
|
|
|
|
|
image = model.lda.latents_to_image(x) |
|
|
|
return image |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown(TITLE) |
|
|
|
with gr.Column(): |
|
with gr.Row(): |
|
prompt = gr.Text( |
|
label="Prompt", |
|
show_label=False, |
|
max_lines=1, |
|
placeholder="Enter your prompt", |
|
container=False, |
|
) |
|
run_button = gr.Button( |
|
value="Run", |
|
scale=0, |
|
) |
|
|
|
output_image = gr.Image( |
|
label="Output Image", |
|
image_mode="RGB", |
|
type="pil", |
|
) |
|
|
|
with gr.Accordion("Advanced Settings", open=True): |
|
negative_prompt = gr.Textbox( |
|
label="Negative Prompt", |
|
placeholder="Enter your (optional) negative prompt", |
|
) |
|
seed = gr.Slider( |
|
label="Seed", |
|
minimum=0, |
|
maximum=100_000, |
|
value=2, |
|
step=1, |
|
) |
|
condition_scale = gr.Slider( |
|
label="Condition scale", |
|
minimum=0, |
|
maximum=20, |
|
value=7.5, |
|
step=0.05, |
|
) |
|
num_inference_steps = gr.Slider( |
|
label="Number of inference steps", |
|
minimum=1, |
|
maximum=50, |
|
value=30, |
|
step=1, |
|
) |
|
|
|
run_button.click( |
|
fn=process, |
|
inputs=[ |
|
prompt, |
|
negative_prompt, |
|
condition_scale, |
|
num_inference_steps, |
|
seed, |
|
], |
|
outputs=output_image, |
|
) |
|
|
|
gr.Examples( |
|
examples=[ |
|
[ |
|
"a cute cat, detailed high-quality professional image", |
|
"lowres, bad anatomy, bad hands, cropped, worst quality", |
|
7.5, |
|
30, |
|
2, |
|
], |
|
[ |
|
"a cute dog, detailed high-quality professional image", |
|
"lowres, bad anatomy, bad hands, cropped, worst quality", |
|
7.5, |
|
30, |
|
2, |
|
], |
|
], |
|
inputs=[ |
|
prompt, |
|
negative_prompt, |
|
condition_scale, |
|
num_inference_steps, |
|
seed, |
|
], |
|
outputs=output_image, |
|
fn=process, |
|
cache_examples=True, |
|
cache_mode="lazy", |
|
run_on_click=False, |
|
) |
|
|
|
demo.launch() |
|
|