Spaces:
Paused
Paused
from __future__ import annotations | |
import gradio as gr | |
# import spaces | |
from PIL import Image | |
import torch | |
from my_run import run as run_model | |
# @spaces.GPU | |
def main_pipeline( | |
input_image: str, | |
src_prompt: str, | |
tgt_prompt: str, | |
seed: int, | |
w1: float, | |
# w2: float, | |
): | |
w2 = 1.0 | |
res_image = run_model(input_image, src_prompt, tgt_prompt, seed, w1, w2) | |
return res_image | |
with gr.Blocks(css="app/style.css", theme="Nymbo/Nymbo_Theme") as demo: | |
gr.HTML("<center><h1>Turbo Edit</h1></center>") | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image( | |
label="Input image", type="filepath", height=512, width=512 | |
) | |
src_prompt = gr.Text( | |
label="Source Prompt", | |
max_lines=1, | |
placeholder="Source Prompt", | |
) | |
tgt_prompt = gr.Text( | |
label="Target Prompt", | |
max_lines=1, | |
placeholder="Target Prompt", | |
) | |
with gr.Accordion("Advanced Options", open=False): | |
seed = gr.Slider( | |
label="seed", minimum=0, maximum=16 * 1024, value=7865, step=1 | |
) | |
w1 = gr.Slider( | |
label="w", minimum=1.0, maximum=3.0, value=1.5, step=0.05 | |
) | |
# w2 = gr.Slider( | |
# label='w2', | |
# minimum=1.0, | |
# maximum=3.0, | |
# value=1.0, | |
# step=0.05 | |
# ) | |
run_button = gr.Button("Edit") | |
with gr.Column(): | |
# result = gr.Gallery(label='Result') | |
result = gr.Image(label="Result", type="pil", height=512, width=512) | |
examples = [ | |
[ | |
"examples_demo/1.jpeg", # input_image | |
"a dreamy cat sleeping on a floating leaf", # src_prompt | |
"a dreamy bear sleeping on a floating leaf", # tgt_prompt | |
7, # seed | |
1.3, # w1 | |
], | |
[ | |
"examples_demo/2.jpeg", # input_image | |
"A painting of a cat and a bunny surrounded by flowers", # src_prompt | |
"a polygonal illustration of a cat and a bunny", # tgt_prompt | |
2, # seed | |
1.5, # w1 | |
], | |
[ | |
"examples_demo/3.jpg", # input_image | |
"a chess pawn wearing a crown", # src_prompt | |
"a chess pawn wearing a hat", # tgt_prompt | |
2, # seed | |
1.3, # w1 | |
], | |
] | |
gr.Examples( | |
examples=examples, | |
inputs=[ | |
input_image, | |
src_prompt, | |
tgt_prompt, | |
seed, | |
w1, | |
], | |
outputs=[result], | |
fn=main_pipeline, | |
cache_examples=True, | |
) | |
inputs = [ | |
input_image, | |
src_prompt, | |
tgt_prompt, | |
seed, | |
w1, | |
# w2, | |
] | |
outputs = [result] | |
run_button.click(fn=main_pipeline, inputs=inputs, outputs=outputs) | |
demo.queue(max_size=50).launch(share=False) |