StyleShot / app.py
nuwandaa's picture
Update app.py
34b01d1 verified
import spaces
from types import MethodType
import os
import gradio as gr
import torch
import cv2
from annotator.util import resize_image
from annotator.hed import SOFT_HEDdetector
from annotator.lineart import LineartDetector
from annotator.lineart import LineartDetector
from annotator.canny import CannyDetector
from diffusers import UNet2DConditionModel, ControlNetModel
from transformers import CLIPVisionModelWithProjection
from huggingface_hub import snapshot_download
from PIL import Image
from ip_adapter import StyleShot, StyleContentStableDiffusionControlNetPipeline
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
contour_detector = SOFT_HEDdetector()
lineart_detector = LineartDetector()
canny_detector = CannyDetector()
base_model_path = "SG161222/Realistic_Vision_V6.0_B1_noVAE"
transformer_block_path = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
styleshot_model_path = "Gaojunyao/StyleShot"
styleshot_lineart_model_path = "Gaojunyao/StyleShot_lineart"
if not os.path.isdir(base_model_path):
base_model_path = snapshot_download(base_model_path, local_dir=base_model_path)
print(f"Downloaded model to {base_model_path}")
if not os.path.isdir(transformer_block_path):
transformer_block_path = snapshot_download(transformer_block_path, local_dir=transformer_block_path)
print(f"Downloaded model to {transformer_block_path}")
if not os.path.isdir(styleshot_model_path):
styleshot_model_path = snapshot_download(styleshot_model_path, local_dir=styleshot_model_path)
print(f"Downloaded model to {styleshot_model_path}")
if not os.path.isdir(styleshot_lineart_model_path):
styleshot_lineart_model_path = snapshot_download(styleshot_lineart_model_path, local_dir=styleshot_lineart_model_path)
print(f"Downloaded model to {styleshot_lineart_model_path}")
# weights for ip-adapter and our content-fusion encoder
contour_ip_ckpt = os.path.join(styleshot_model_path, "pretrained_weight/ip.bin")
contour_style_aware_encoder_path = os.path.join(styleshot_model_path, "pretrained_weight/style_aware_encoder.bin")
contour_transformer_block_path = transformer_block_path
contour_unet = UNet2DConditionModel.from_pretrained(base_model_path, subfolder="unet")
contour_content_fusion_encoder = ControlNetModel.from_unet(contour_unet)
contour_pipe = StyleContentStableDiffusionControlNetPipeline.from_pretrained(base_model_path, controlnet=contour_content_fusion_encoder)
contour_styleshot = StyleShot(device, contour_pipe, contour_ip_ckpt, contour_style_aware_encoder_path, contour_transformer_block_path)
# weights for ip-adapter and our content-fusion encoder
canny_ip_ckpt = os.path.join(styleshot_model_path, "pretrained_weight/ip.bin")
canny_style_aware_encoder_path = os.path.join(styleshot_model_path, "pretrained_weight/style_aware_encoder.bin")
canny_transformer_block_path = transformer_block_path
canny_unet = UNet2DConditionModel.from_pretrained(base_model_path, subfolder="unet")
canny_content_fusion_encoder = ControlNetModel.from_unet(canny_unet)
canny_pipe = StyleContentStableDiffusionControlNetPipeline.from_pretrained(base_model_path, controlnet=canny_content_fusion_encoder)
canny_styleshot = StyleShot(device, canny_pipe, canny_ip_ckpt, canny_style_aware_encoder_path, canny_transformer_block_path)
lineart_ip_ckpt = os.path.join(styleshot_lineart_model_path, "pretrained_weight/ip.bin")
lineart_style_aware_encoder_path = os.path.join(styleshot_lineart_model_path, "pretrained_weight/style_aware_encoder.bin")
lineart_transformer_block_path = transformer_block_path
lineart_unet = UNet2DConditionModel.from_pretrained(base_model_path, subfolder="unet")
lineart_content_fusion_encoder = ControlNetModel.from_unet(lineart_unet)
lineart_pipe = StyleContentStableDiffusionControlNetPipeline.from_pretrained(base_model_path, controlnet=lineart_content_fusion_encoder)
lineart_styleshot = StyleShot(device, lineart_pipe, lineart_ip_ckpt, lineart_style_aware_encoder_path, lineart_transformer_block_path)
@spaces.GPU(duration=150)
def process(style_image, content_image, prompt, num_samples, image_resolution, condition_scale, style_scale,ddim_steps, guidance_scale, seed, a_prompt, n_prompt, btn1, Contour_Threshold=200):
weight_dtype = torch.float32
style_shots = []
btns = []
contour_content_images = []
contour_results = []
canny_content_images = []
canny_results = []
lineart_content_images = []
lineart_results = []
type1 = 'Contour'
type2 = 'Lineart'
type3 = 'Canny'
if btn1 == type1 or content_image is None:
style_shots = [contour_styleshot]
btns = [type1]
elif btn1 == type2:
style_shots = [lineart_styleshot]
btns = [type2]
elif btn1 == type3:
style_shots = [canny_styleshot]
btns = [type3]
elif btn1 == "Both":
style_shots = [contour_styleshot, lineart_styleshot, canny_styleshot]
btns = [type1, type2, type3]
ori_style_image = style_image.copy()
if content_image is not None:
ori_content_image = content_image.copy()
else:
ori_content_image = None
for styleshot, btn in zip(style_shots, btns):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
prompts = [prompt+" "+a_prompt]
style_image = Image.fromarray(ori_style_image)
if ori_content_image is not None:
if btn == type1:
content_image = resize_image(ori_content_image, image_resolution)
content_image = contour_detector(content_image, threshold=Contour_Threshold)
elif btn == type2:
content_image = resize_image(ori_content_image, image_resolution)
content_image = lineart_detector(content_image, coarse=False)
elif btn == type3:
content_image = resize_image(ori_content_image, image_resolution)
content_image = canny_detector(content_image)
content_image = Image.fromarray(content_image)
else:
content_image = cv2.resize(ori_style_image, (image_resolution, image_resolution))
content_image = Image.fromarray(content_image)
condition_scale = 0.0
g_images = styleshot.generate(style_image=style_image,
prompt=[[prompt]],
negative_prompt=n_prompt,
scale=style_scale,
num_samples = num_samples,
seed = seed,
num_inference_steps=ddim_steps,
guidance_scale=guidance_scale,
content_image=content_image,
controlnet_conditioning_scale= float(condition_scale))
if btn == type1:
contour_content_images = [content_image]
contour_results = g_images[0]
elif btn == type2:
lineart_content_images = [content_image]
lineart_results = g_images[0]
elif btn == type3:
canny_content_images = [content_image]
canny_results = g_images[0]
if ori_content_image is None:
contour_content_images = []
lineart_results = []
lineart_content_images = []
canny_results = []
canny_content_images = []
return [contour_results, contour_content_images, lineart_results, lineart_content_images, canny_results, canny_content_images]
block = gr.Blocks().queue()
with block:
with gr.Row():
gr.Markdown("## Styleshot Demo")
with gr.Row():
with gr.Column():
style_image = gr.Image(sources=['upload'], type="numpy", label='Style Image')
with gr.Column():
with gr.Blocks():
with gr.Column():
content_image = gr.Image(sources=['upload'], type="numpy", label='Content Image (optional)')
btn1 = gr.Radio(
choices=["Contour", "Lineart", "Canny", "All"],
interactive=True,
label="Preprocessor",
value="All",
)
gr.Markdown("We recommend using 'Contour' for sparse control and 'Lineart' for detailed control. If you choose 'Both', we will provide results for two types of control. If you choose 'Contour', you can adjust the 'Contour Threshold' under the 'Advanced options' for the level of detail in control. ")
with gr.Row():
prompt = gr.Textbox(label="Prompt")
with gr.Row():
run_button = gr.Button(value="Run")
with gr.Row():
with gr.Column():
with gr.Accordion("Advanced options", open=False):
num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=4, step=1)
image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
condition_scale = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
Contour_Threshold = gr.Slider(label="Contour Threshold", minimum=0, maximum=255, value=200, step=1)
style_scale = gr.Slider(label="Style Strength", minimum=0, maximum=2, value=1.0, step=0.01)
ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=50, step=1)
guidance_scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, value=42, step=1)
a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
n_prompt = gr.Textbox(label="Negative Prompt",
value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
with gr.Row():
gr.Markdown("### Results for Contour")
with gr.Row():
with gr.Blocks():
with gr.Row():
with gr.Column(scale = 1):
contour_gallery = gr.Gallery(label='Contour Output', show_label=True, elem_id="gallery", columns=[1], rows=[1], height='auto')
with gr.Column(scale = 4):
image_gallery = gr.Gallery(label='Result for Contour', show_label=True, elem_id="gallery", columns=[4], rows=[1], height='auto')
with gr.Row():
gr.Markdown("### Results for Lineart")
with gr.Row():
with gr.Blocks():
with gr.Row():
with gr.Column(scale = 1):
line_gallery = gr.Gallery(label='Lineart Output', show_label=True, elem_id="gallery", columns=[1], rows=[1], height='auto')
with gr.Column(scale = 4):
line_image_gallery = gr.Gallery(label='Result for Lineart', show_label=True, elem_id="gallery", columns=[4], rows=[1], height='auto')
with gr.Row():
gr.Markdown("### Results for Canny")
with gr.Row():
with gr.Blocks():
with gr.Row():
with gr.Column(scale = 1):
canny_gallery = gr.Gallery(label='Canny Output', show_label=True, elem_id="gallery", columns=[1], rows=[1], height='auto')
with gr.Column(scale = 4):
canny_image_gallery = gr.Gallery(label='Result for Canny', show_label=True, elem_id="gallery", columns=[4], rows=[1], height='auto')
ips = [style_image, content_image, prompt, num_samples, image_resolution, condition_scale, style_scale, ddim_steps, guidance_scale, seed, a_prompt, n_prompt, btn1, Contour_Threshold]
run_button.click(fn=process, inputs=ips, outputs=[image_gallery, contour_gallery, line_image_gallery, line_gallery, canny_image_gallery, canny_gallery])
block.launch(server_name='0.0.0.0')