from __future__ import annotations
import gradio as gr
# import spaces
from PIL import Image
import torch
from my_run import run as run_model
# @spaces.GPU
def main_pipeline(
input_image: str,
src_prompt: str,
tgt_prompt: str,
seed: int,
w1: float,
# w2: float,
):
w2 = 1.0
res_image = run_model(input_image, src_prompt, tgt_prompt, seed, w1, w2)
return res_image
with gr.Blocks(css="app/style.css", theme="Nymbo/Nymbo_Theme") as demo:
gr.HTML("
Turbo Edit
")
with gr.Row():
with gr.Column():
input_image = gr.Image(
label="Input image", type="filepath", height=512, width=512
)
src_prompt = gr.Text(
label="Source Prompt",
max_lines=1,
placeholder="Source Prompt",
)
tgt_prompt = gr.Text(
label="Target Prompt",
max_lines=1,
placeholder="Target Prompt",
)
with gr.Accordion("Advanced Options", open=False):
seed = gr.Slider(
label="seed", minimum=0, maximum=16 * 1024, value=7865, step=1
)
w1 = gr.Slider(
label="w", minimum=1.0, maximum=3.0, value=1.5, step=0.05
)
# w2 = gr.Slider(
# label='w2',
# minimum=1.0,
# maximum=3.0,
# value=1.0,
# step=0.05
# )
run_button = gr.Button("Edit")
with gr.Column():
# result = gr.Gallery(label='Result')
result = gr.Image(label="Result", type="pil", height=512, width=512)
examples = [
[
"examples_demo/1.jpeg", # input_image
"a dreamy cat sleeping on a floating leaf", # src_prompt
"a dreamy bear sleeping on a floating leaf", # tgt_prompt
7, # seed
1.3, # w1
],
[
"examples_demo/2.jpeg", # input_image
"A painting of a cat and a bunny surrounded by flowers", # src_prompt
"a polygonal illustration of a cat and a bunny", # tgt_prompt
2, # seed
1.5, # w1
],
[
"examples_demo/3.jpg", # input_image
"a chess pawn wearing a crown", # src_prompt
"a chess pawn wearing a hat", # tgt_prompt
2, # seed
1.3, # w1
],
]
gr.Examples(
examples=examples,
inputs=[
input_image,
src_prompt,
tgt_prompt,
seed,
w1,
],
outputs=[result],
fn=main_pipeline,
cache_examples=True,
)
inputs = [
input_image,
src_prompt,
tgt_prompt,
seed,
w1,
# w2,
]
outputs = [result]
run_button.click(fn=main_pipeline, inputs=inputs, outputs=outputs)
demo.queue(max_size=50).launch(share=False)