BLIP-Diffusion / app_zero_shot.py
hysts's picture
hysts HF staff
Update
8db855c
raw
history blame
3.86 kB
#!/usr/bin/env python
import gradio as gr
import PIL.Image
import spaces
import torch
from diffusers.pipelines import BlipDiffusionPipeline
from settings import DEFAULT_NEGATIVE_PROMPT, MAX_INFERENCE_STEPS
from utils import MAX_SEED, randomize_seed_fn
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
pipe = BlipDiffusionPipeline.from_pretrained("Salesforce/blipdiffusion", torch_dtype=torch.float16).to(device)
@spaces.GPU
def run(
condition_image: PIL.Image.Image,
condition_subject: str,
target_subject: str,
prompt: str,
negative_prompt: str = DEFAULT_NEGATIVE_PROMPT,
seed: int = 0,
guidance_scale: float = 7.5,
num_inference_steps: int = 25,
) -> PIL.Image.Image:
if num_inference_steps > MAX_INFERENCE_STEPS:
error_message = f"Number of inference steps must be less than {MAX_INFERENCE_STEPS}"
raise gr.Error(error_message)
return pipe(
prompt,
condition_image,
condition_subject,
target_subject,
generator=torch.Generator(device=device).manual_seed(seed),
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
neg_prompt=negative_prompt,
height=512,
width=512,
).images[0]
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
condition_image = gr.Image(label="Condition Image")
condition_subject = gr.Textbox(label="Condition Subject")
target_subject = gr.Textbox(label="Target Subject")
prompt = gr.Textbox(label="Prompt")
run_button = gr.Button()
with gr.Accordion(label="Advanced options", open=False):
negative_prompt = gr.Textbox(label="Negative Prompt", value=DEFAULT_NEGATIVE_PROMPT)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
guidance_scale = gr.Slider(
label="Guidance Scale",
minimum=0,
maximum=10,
step=0.1,
value=7.5,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=MAX_INFERENCE_STEPS,
step=1,
value=25,
)
with gr.Column():
result = gr.Image(label="Result")
gr.Examples(
examples=[
[
"images/dog.png",
"dog",
"dog",
"swimming underwater",
],
],
inputs=[
condition_image,
condition_subject,
target_subject,
prompt,
],
outputs=result,
fn=run,
)
inputs = [
condition_image,
condition_subject,
target_subject,
prompt,
negative_prompt,
seed,
guidance_scale,
num_inference_steps,
]
gr.on(
triggers=[
condition_subject.submit,
target_subject.submit,
prompt.submit,
negative_prompt.submit,
run_button.click,
],
fn=randomize_seed_fn,
inputs=[seed, randomize_seed],
outputs=seed,
api_name=False,
concurrency_limit=None,
).then(
fn=run,
inputs=inputs,
outputs=result,
api_name="run-zero-shot",
concurrency_id="gpu",
concurrency_limit=1,
)
if __name__ == "__main__":
demo.queue(max_size=20).launch()