import gradio as gr
from PIL import Image
from io import BytesIO
import torch
import os
#os.system("pip install git+https://github.com/fffiloni/diffusers")
from diffusers import DiffusionPipeline, DDIMScheduler
from imagic import ImagicStableDiffusionPipeline
has_cuda = torch.cuda.is_available()
device = "cuda"
pipe = ImagicStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
safety_checker=None,
#custom_pipeline=ImagicStableDiffusionPipeline,
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
).to(device)
generator = torch.Generator("cuda").manual_seed(0)
def infer(prompt, init_image):
init_image = Image.open(init_image).convert("RGB")
init_image = init_image.resize((256, 256))
res = pipe.train(
prompt,
init_image,
guidance_scale=7.5,
num_inference_steps=50,
generator=generator,
text_embedding_optimization_steps=500,
model_fine_tuning_optimization_steps=600)
with torch.no_grad():
torch.cuda.empty_cache()
res = pipe(alpha=1)
return res.images[0]
#return 'trained success'
title = """
Imagic Stable Diffusion • Community Pipeline
Text-Based Real Image Editing with Diffusion Models
This pipeline aims to implement this paper to Stable Diffusion, allowing for real-world image editing.
You can skip the queue by duplicating this space:
"""
article = """
"""
css = '''
#col-container {max-width: 700px; margin-left: auto; margin-right: auto;}
a {text-decoration-line: underline; font-weight: 600;}
.footer {
margin-bottom: 45px;
margin-top: 35px;
text-align: center;
border-bottom: 1px solid #e5e5e5;
}
.footer>p {
font-size: .8rem;
display: inline-block;
padding: 0 10px;
transform: translateY(10px);
background: white;
}
.dark .footer {
border-color: #303030;
}
.dark .footer>p {
background: #0b0f19;
}
'''
with gr.Blocks(css=css) as block:
with gr.Column(elem_id="col-container"):
gr.HTML(title)
prompt_input = gr.Textbox(label="Target text", placeholder="Describe the image with what you want to change about the subject")
image_init = gr.Image(source="upload", type="filepath",label="Input Image")
submit_btn = gr.Button("Train")
image_output = gr.Image(label="Edited image")
examples=[['a sitting dog','imagic-dog.png'], ['a photo of a bird spreading wings','imagic-bird.png']]
ex = gr.Examples(examples=examples, fn=infer, inputs=[prompt_input,image_init], outputs=[image_output], cache_examples=False, run_on_click=True)
ex.dataset.headers = [""]
gr.HTML(article)
submit_btn.click(fn=infer, inputs=[prompt_input,image_init], outputs=[image_output])
block.queue(max_size=12).launch(show_api=False)