RNRI / app.py
barakmeiri's picture
Update app.py
bff370f verified
raw
history blame
4.39 kB
import gradio as gr
import numpy as np
import random
from diffusers import DiffusionPipeline
import torch
from src.euler_scheduler import MyEulerAncestralDiscreteScheduler
from diffusers.pipelines.auto_pipeline import AutoPipelineForImage2Image
from src.sdxl_inversion_pipeline import SDXLDDIMPipeline
from src.config import RunConfig
device = "cuda" if torch.cuda.is_available() else "cpu"
scheduler_class = MyEulerAncestralDiscreteScheduler
pipe_inversion = SDXLDDIMPipeline.from_pretrained("stabilityai/sdxl-turbo", use_safetensors=True).to(device)
pipe_inference = AutoPipelineForImage2Image.from_pretrained("stabilityai/sdxl-turbo", use_safetensors=True).to(device)
pipe_inference.scheduler = scheduler_class.from_config(pipe_inference.scheduler.config)
pipe_inversion.scheduler = scheduler_class.from_config(pipe_inversion.scheduler.config)
pipe_inversion.scheduler_inference = scheduler_class.from_config(pipe_inference.scheduler.config)
# if torch.cuda.is_available():
# torch.cuda.max_memory_allocated(device=device)
# pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
# pipe.enable_xformers_memory_efficient_attention()
# pipe = pipe.to(device)
# else:
# pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", use_safetensors=True)
# pipe = pipe.to(device)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
def infer(input_image, description_prompt, target_prompt, guidance_scale, num_inference_steps=4, num_inversion_steps=4, inversion_max_step=0.6):
config = RunConfig(num_inference_steps=num_inference_steps,
num_inversion_steps=num_inversion_steps,
guidance_scale=guidance_scale,
inversion_max_step=inversion_max_step)
editor = ImageEditorDemo(pipe_inversion, pipe_inference, input_image, description_prompt, config)
editor.edit(target_prompt)
return image
examples = [
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
"An astronaut riding a green horse",
"A delicious ceviche cheesecake slice",
]
css="""
#col-container {
margin: 0 auto;
max-width: 520px;
}
"""
if torch.cuda.is_available():
power_device = "GPU"
else:
power_device = "CPU"
with gr.Blocks(css=css) as demo:
gr.Markdown(f"""
# RNRI briel and links on device: {power_device}.
""")
with gr.Column(elem_id="col-container"):
with gr.Row():
input_image = gr.Image(label="Input image", sources=['upload', 'webcam', 'clipboard'], type="pil")
with gr.Row():
description_prompt = gr.Text(
label="Image description",
show_label=False,
max_lines=1,
placeholder="Enter your image description",
container=False,
)
with gr.Row():
target_prompt = gr.Text(
label="Edit prompt",
show_label=False,
max_lines=1,
placeholder="Enter your edit prompt",
container=False,
)
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=10.0,
step=0.1,
value=0.0,
)
num_inference_steps = gr.Slider(
label="Number of RNRI iterations",
minimum=1,
maximum=12,
step=1,
value=2,
)
with gr.Row():
run_button = gr.Button("Edit", scale=0)
with gr.Column(elem_id="col-container"):
result = gr.Image(label="Result", show_label=False)
# gr.Examples(
# examples = examples,
# inputs = [prompt]
# )
run_button.click(
fn = infer,
inputs = [input_image, description_prompt, target_prompt, guidance_scale, num_inference_steps, num_inference_steps],
outputs = [result]
)
demo.queue().launch()