imaginpaint / app.py
aiqtech's picture
Update app.py
e8b8f38 verified
import gradio as gr
import spaces
import torch
from diffusers import AutoencoderKL, TCDScheduler
from diffusers.models.model_loading_utils import load_state_dict
from gradio_imageslider import ImageSlider
from huggingface_hub import hf_hub_download
from transformers import pipeline
from controlnet_union import ControlNetModel_Union
from pipeline_fill_sd_xl import StableDiffusionXLFillPipeline
MODELS = {
"RealVisXL V5.0 Lightning": "SG161222/RealVisXL_V5.0_Lightning",
}
# λ²ˆμ—­ λͺ¨λΈ λ‘œλ“œ
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
config_file = hf_hub_download(
"xinsir/controlnet-union-sdxl-1.0",
filename="config_promax.json",
)
config = ControlNetModel_Union.load_config(config_file)
controlnet_model = ControlNetModel_Union.from_config(config)
model_file = hf_hub_download(
"xinsir/controlnet-union-sdxl-1.0",
filename="diffusion_pytorch_model_promax.safetensors",
)
state_dict = load_state_dict(model_file)
model, _, _, _, _ = ControlNetModel_Union._load_pretrained_model(
controlnet_model, state_dict, model_file, "xinsir/controlnet-union-sdxl-1.0"
)
model.to(device="cuda", dtype=torch.float16)
vae = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
).to("cuda")
pipe = StableDiffusionXLFillPipeline.from_pretrained(
"SG161222/RealVisXL_V5.0_Lightning",
torch_dtype=torch.float16,
vae=vae,
controlnet=model,
variant="fp16",
).to("cuda")
pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
def translate_if_korean(text):
# μž…λ ₯된 ν…μŠ€νŠΈκ°€ ν•œκΈ€μ„ ν¬ν•¨ν•˜κ³  μžˆλŠ”μ§€ 확인
if any('\u3131' <= char <= '\u318E' or '\uAC00' <= char <= '\uD7A3' for char in text):
# ν•œκΈ€μ΄ ν¬ν•¨λ˜μ–΄ μžˆλ‹€λ©΄ λ²ˆμ—­
translated = translator(text)[0]['translation_text']
print(f"Translated prompt: {translated}") # 디버깅을 μœ„ν•œ 좜λ ₯
return translated
return text
@spaces.GPU
def fill_image(prompt, image, model_selection):
# ν”„λ‘¬ν”„νŠΈ λ²ˆμ—­
translated_prompt = translate_if_korean(prompt)
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = pipe.encode_prompt(translated_prompt, "cuda", True)
source = image["background"]
mask = image["layers"][0]
alpha_channel = mask.split()[3]
binary_mask = alpha_channel.point(lambda p: p > 0 and 255)
cnet_image = source.copy()
cnet_image.paste(0, (0, 0), binary_mask)
for image in pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
image=cnet_image,
):
yield image, cnet_image
image = image.convert("RGBA")
cnet_image.paste(image, (0, 0), binary_mask)
yield source, cnet_image
def clear_result():
return gr.update(value=None)
css = """
footer {
visibility: hidden;
}
.sample-image {
display: flex;
justify-content: center;
margin-top: 20px;
}
"""
with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
with gr.Row():
with gr.Column():
prompt = gr.Textbox(
label="ν”„λ‘¬ν”„νŠΈ",
info="λ§ˆμŠ€ν¬μ— μ±„μ›Œλ„£μ„ λ‚΄μš©μ„ μ„€λͺ…ν•˜μ„Έμš” (ν•œκΈ€ λ˜λŠ” μ˜μ–΄)",
lines=3,
)
with gr.Column():
model_selection = gr.Dropdown(
choices=list(MODELS.keys()),
value="RealVisXL V5.0 Lightning",
label="λͺ¨λΈ",
)
run_button = gr.Button("생성")
with gr.Row():
input_image = gr.ImageMask(
type="pil",
label="μž…λ ₯ 이미지",
crop_size=(1024, 1024),
layers=False
)
result = ImageSlider(
interactive=False,
label="μƒμ„±λœ 이미지",
)
use_as_input_button = gr.Button("μž…λ ₯ μ΄λ―Έμ§€λ‘œ μ‚¬μš©", visible=False)
# μƒ˜ν”Œ 이미지 μΆ”κ°€
with gr.Row(elem_classes="sample-image"):
sample_image = gr.Image("sample.png", label="μƒ˜ν”Œ 이미지", height=256, width=256)
def use_output_as_input(output_image):
return gr.update(value=output_image[1])
use_as_input_button.click(
fn=use_output_as_input,
inputs=[result],
outputs=[input_image]
)
run_button.click(
fn=clear_result,
inputs=None,
outputs=result,
).then(
fn=lambda: gr.update(visible=False),
inputs=None,
outputs=use_as_input_button,
).then(
fn=fill_image,
inputs=[prompt, input_image, model_selection],
outputs=result,
).then(
fn=lambda: gr.update(visible=True),
inputs=None,
outputs=use_as_input_button,
)
prompt.submit(
fn=clear_result,
inputs=None,
outputs=result,
).then(
fn=lambda: gr.update(visible=False),
inputs=None,
outputs=use_as_input_button,
).then(
fn=fill_image,
inputs=[prompt, input_image, model_selection],
outputs=result,
).then(
fn=lambda: gr.update(visible=True),
inputs=None,
outputs=use_as_input_button,
)
demo.launch(share=False)