Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import gradio as gr | |
import torch | |
from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline, DDIMScheduler | |
from compel import Compel, ReturnedEmbeddingsType | |
from PIL import Image | |
import os | |
import time | |
from utils.utils import load_cn_model, load_cn_config, load_tagger_model, load_lora_model, resize_image_aspect_ratio, base_generation | |
from utils.prompt_analysis import PromptAnalysis | |
class Img2Img: | |
def __init__(self): | |
self.setup_paths() | |
self.setup_models() | |
self.compel = self.setup_compel() | |
self.demo = self.layout() | |
def setup_paths(self): | |
self.path = os.getcwd() | |
self.cn_dir = f"{self.path}/controlnet" | |
self.tagger_dir = f"{self.path}/tagger" | |
self.lora_dir = f"{self.path}/lora" | |
os.makedirs(self.cn_dir, exist_ok=True) | |
os.makedirs(self.tagger_dir, exist_ok=True) | |
os.makedirs(self.lora_dir, exist_ok=True) | |
def setup_models(self): | |
load_cn_model(self.cn_dir) | |
load_cn_config(self.cn_dir) | |
load_tagger_model(self.tagger_dir) | |
load_lora_model(self.lora_dir) | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.dtype = torch.float16 | |
self.model = "cagliostrolab/animagine-xl-3.1" | |
self.scheduler = DDIMScheduler.from_pretrained(self.model, subfolder="scheduler") | |
self.controlnet = ControlNetModel.from_pretrained(self.cn_dir, torch_dtype=self.dtype, use_safetensors=True) | |
self.pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained( | |
self.model, | |
controlnet=self.controlnet, | |
torch_dtype=self.dtype, | |
use_safetensors=True, | |
scheduler=self.scheduler, | |
) | |
self.pipe.load_lora_weights(self.lora_dir, weight_name="sdxl_BWLine.safetensors") | |
self.pipe = self.pipe.to(self.device) | |
def setup_compel(self): | |
return Compel( | |
tokenizer=[self.pipe.tokenizer, self.pipe.tokenizer_2], | |
text_encoder=[self.pipe.text_encoder, self.pipe.text_encoder_2], | |
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, | |
requires_pooled=[False, True], | |
) | |
def layout(self): | |
css = """ | |
#intro{ | |
max-width: 32rem; | |
text-align: center; | |
margin: 0 auto; | |
} | |
""" | |
with gr.Blocks(css=css) as demo: | |
with gr.Row(): | |
with gr.Column(): | |
self.input_image_path = gr.Image(label="入力画像", type='filepath') | |
self.prompt_analysis = PromptAnalysis(self.tagger_dir) | |
self.prompt, self.negative_prompt = self.prompt_analysis.layout(self.input_image_path) | |
self.controlnet_scale = gr.Slider(minimum=0.5, maximum=1.25, value=1.0, step=0.01, label="線画忠実度") | |
generate_button = gr.Button("生成") | |
with gr.Column(): | |
self.output_image = gr.Image(type="pil", label="生成画像") | |
generate_button.click( | |
fn=self.predict, | |
inputs=[self.input_image_path, self.prompt, self.negative_prompt, self.controlnet_scale], | |
outputs=self.output_image | |
) | |
return demo | |
def predict(self, input_image_path, prompt, negative_prompt, controlnet_scale): | |
input_image_pil = Image.open(input_image_path) | |
base_size = input_image_pil.size | |
resize_image = resize_image_aspect_ratio(input_image_pil) | |
resize_image_size = resize_image.size | |
width, height = resize_image_size | |
white_base_pil = base_generation(resize_image.size, (255, 255, 255, 255)).convert("RGB") | |
conditioning, pooled = self.compel([prompt, negative_prompt]) | |
generator = torch.manual_seed(0) | |
last_time = time.time() | |
output_image = self.pipe( | |
image=white_base_pil, | |
control_image=resize_image, | |
strength=1.0, | |
prompt_embeds=conditioning[0:1], | |
pooled_prompt_embeds=pooled[0:1], | |
negative_prompt_embeds=conditioning[1:2], | |
negative_pooled_prompt_embeds=pooled[1:2], | |
width=width, | |
height=height, | |
controlnet_conditioning_scale=float(controlnet_scale), | |
controlnet_start=0.0, | |
controlnet_end=1.0, | |
generator=generator, | |
num_inference_steps=30, | |
guidance_scale=8.5, | |
eta=1.0, | |
) | |
print(f"Time taken: {time.time() - last_time}") | |
output_image = output_image.resize(base_size, Image.LANCZOS) | |
return output_image | |
img2img = Img2Img() | |
img2img.demo.launch() | |