taki0112's picture
fix code
cb415f3
raw
history blame
18 kB
import torch
from pipelines.inverted_ve_pipeline import STYLE_DESCRIPTION_DICT, create_image_grid
import gradio as gr
import os, json
import numpy as np
from PIL import Image
from pipelines.pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline
from pipelines.pipeline_stable_diffusion_xl import StableDiffusionXLPipeline
from diffusers import ControlNetModel, AutoencoderKL
from transformers import DPTFeatureExtractor, DPTForDepthEstimation
from random import randint
from utils import init_latent
from transformers import Blip2Processor, Blip2ForConditionalGeneration
from diffusers import DDIMScheduler
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if device == 'cpu':
torch_dtype = torch.float32
else:
torch_dtype = torch.float16
def memory_efficient(model):
try:
model.to(device)
except Exception as e:
print("Error moving model to device:", e)
try:
model.enable_model_cpu_offload()
except AttributeError:
print("enable_model_cpu_offload is not supported.")
try:
model.enable_vae_slicing()
except AttributeError:
print("enable_vae_slicing is not supported.")
if device == 'cuda':
try:
model.enable_xformers_memory_efficient_attention()
except AttributeError:
print("enable_xformers_memory_efficient_attention is not supported.")
controlnet = ControlNetModel.from_pretrained("diffusers/controlnet-depth-sdxl-1.0", torch_dtype=torch_dtype)
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch_dtype)
model_controlnet = StableDiffusionXLControlNetPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, vae=vae, torch_dtype=torch_dtype
)
model = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch_dtype)
print("vae")
memory_efficient(vae)
print("control")
memory_efficient(controlnet)
print("ControlNet-SDXL")
memory_efficient(model_controlnet)
print("SDXL")
memory_efficient(model)
depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to(device)
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas")
blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
blip_model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch_dtype).to(device)
# controlnet_scale, canny thres 1, 2 (2 > 1, 2:1, 3:1)
def parse_config(config):
with open(config, 'r') as f:
config = json.load(f)
return config
def get_depth_map(image):
image = feature_extractor(images=image, return_tensors="pt").pixel_values.to(device)
with torch.no_grad(), torch.autocast(device):
depth_map = depth_estimator(image).predicted_depth
depth_map = torch.nn.functional.interpolate(
depth_map.unsqueeze(1),
size=(1024, 1024),
mode="bicubic",
align_corners=False,
)
depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
depth_map = (depth_map - depth_min) / (depth_max - depth_min)
image = torch.cat([depth_map] * 3, dim=1)
image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
return image
def get_depth_edge_array(depth_img_path):
depth_image_tmp = Image.fromarray(depth_img_path)
# get depth map
depth_map = get_depth_map(depth_image_tmp)
return depth_map
def blip_inf_prompt(image):
inputs = blip_processor(images=image, return_tensors="pt").to(device, torch.float16)
generated_ids = blip_model.generate(**inputs)
generated_text = blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
return generated_text
def load_example_controlnet():
folder_path = 'assets/ref'
examples = []
for filename in os.listdir(folder_path):
if filename.endswith((".png")):
image_path = os.path.join(folder_path, filename)
image_name = os.path.basename(image_path)
style_name = image_name.split('_')[1]
config_path = './config/{}.json'.format(style_name)
config = parse_config(config_path)
inf_object_name = config["inference_info"]["inf_object_list"][0]
canny_path = './assets/depth_dir/gundam.png'
image_info = [image_path, canny_path, style_name, inf_object_name, 1, 0.5, 50]
examples.append(image_info)
return examples
def load_example_style():
folder_path = 'assets/ref'
examples = []
for filename in os.listdir(folder_path):
if filename.endswith((".png")):
image_path = os.path.join(folder_path, filename)
image_name = os.path.basename(image_path)
style_name = image_name.split('_')[1]
config_path = './config/{}.json'.format(style_name)
config = parse_config(config_path)
inf_object_name = config["inference_info"]["inf_object_list"][0]
image_info = [image_path, style_name, inf_object_name, 1, 50]
examples.append(image_info)
return examples
def style_fn(image_path, style_name, content_text, output_number, diffusion_step=50):
user_image_flag = not style_name.strip() # empty
if not user_image_flag:
real_img = None
config_path = './config/{}.json'.format(style_name)
config = parse_config(config_path)
inf_object = content_text
inf_seeds = [randint(0, 10**10) for _ in range(int(output_number))]
activate_layer_indices_list = config['inference_info']['activate_layer_indices_list']
activate_step_indices_list = config['inference_info']['activate_step_indices_list']
ref_seed = config['reference_info']['ref_seeds'][0]
attn_map_save_steps = config['inference_info']['attn_map_save_steps']
guidance_scale = config['guidance_scale']
use_inf_negative_prompt = config['inference_info']['use_negative_prompt']
ref_object = config["reference_info"]["ref_object_list"][0]
ref_with_style_description = config['reference_info']['with_style_description']
inf_with_style_description = config['inference_info']['with_style_description']
use_shared_attention = config['inference_info']['use_shared_attention']
adain_queries = config['inference_info']['adain_queries']
adain_keys = config['inference_info']['adain_keys']
adain_values = config['inference_info']['adain_values']
use_advanced_sampling = config['inference_info']['use_advanced_sampling']
use_prompt_as_null = False
style_name = config["style_name_list"][0]
style_description_pos, style_description_neg = STYLE_DESCRIPTION_DICT[style_name][0], \
STYLE_DESCRIPTION_DICT[style_name][1]
if ref_with_style_description:
ref_prompt = style_description_pos.replace("{object}", ref_object)
else:
ref_prompt = ref_object
if inf_with_style_description:
inf_prompt = style_description_pos.replace("{object}", inf_object)
else:
inf_prompt = inf_object
else:
model.scheduler = DDIMScheduler.from_config(model.scheduler.config)
origin_real_img = Image.open(image_path).resize((1024, 1024), resample=Image.BICUBIC)
real_img = np.array(origin_real_img).astype(np.float32) / 255.0
style_name = 'default'
config_path = './config/{}.json'.format(style_name)
config = parse_config(config_path)
inf_object = content_text
inf_seeds = [randint(0, 10**10) for _ in range(int(output_number))]
activate_layer_indices_list = config['inference_info']['activate_layer_indices_list']
activate_step_indices_list = config['inference_info']['activate_step_indices_list']
ref_seed = 0
attn_map_save_steps = config['inference_info']['attn_map_save_steps']
guidance_scale = config['guidance_scale']
use_inf_negative_prompt = False
use_shared_attention = config['inference_info']['use_shared_attention']
adain_queries = config['inference_info']['adain_queries']
adain_keys = config['inference_info']['adain_keys']
adain_values = config['inference_info']['adain_values']
use_advanced_sampling = False
use_prompt_as_null = True
ref_prompt = blip_inf_prompt(origin_real_img)
inf_prompt = inf_object
style_description_neg = None
# Inference
with torch.inference_mode():
grid = None
for activate_layer_indices in activate_layer_indices_list:
for activate_step_indices in activate_step_indices_list:
str_activate_layer, str_activate_step = model.activate_layer(
activate_layer_indices=activate_layer_indices,
attn_map_save_steps=attn_map_save_steps,
activate_step_indices=activate_step_indices, use_shared_attention=use_shared_attention,
adain_queries=adain_queries,
adain_keys=adain_keys,
adain_values=adain_values,
)
ref_latent = init_latent(model, device_name=device, dtype=torch_dtype, seed=ref_seed)
latents = [ref_latent]
num_images_per_prompt = len(inf_seeds) + 1
for inf_seed in inf_seeds:
# latents.append(model.get_init_latent(inf_seed, precomputed_path=None))
inf_latent = init_latent(model, device_name=device, dtype=torch_dtype, seed=inf_seed)
latents.append(inf_latent)
latents = torch.cat(latents, dim=0)
latents.to(device)
images = model(
prompt=ref_prompt,
negative_prompt=style_description_neg,
guidance_scale=guidance_scale,
num_inference_steps=diffusion_step,
latents=latents,
num_images_per_prompt=num_images_per_prompt,
target_prompt=inf_prompt,
use_inf_negative_prompt=use_inf_negative_prompt,
use_advanced_sampling=use_advanced_sampling,
use_prompt_as_null=use_prompt_as_null,
image=real_img
)[0][1:]
n_row = 1
n_col = len(inf_seeds) + 1 # μ›λ³ΈμΆ”κ°€ν•˜λ €λ©΄ + 1
# make grid
grid = create_image_grid(images, n_row, n_col, padding=10)
return grid
def controlnet_fn(image_path, depth_image_path, style_name, content_text, output_number, controlnet_scale=0.5, diffusion_step=50):
config_path = './config/{}.json'.format(style_name)
config = parse_config(config_path)
inf_object = content_text
inf_seeds = [randint(0, 10**10) for _ in range(int(output_number))]
# inf_seeds = [i for i in range(int(output_number))]
activate_layer_indices_list = config['inference_info']['activate_layer_indices_list']
activate_step_indices_list = config['inference_info']['activate_step_indices_list']
ref_seed = config['reference_info']['ref_seeds'][0]
attn_map_save_steps = config['inference_info']['attn_map_save_steps']
guidance_scale = config['guidance_scale']
use_inf_negative_prompt = config['inference_info']['use_negative_prompt']
style_name = config["style_name_list"][0]
ref_object = config["reference_info"]["ref_object_list"][0]
ref_with_style_description = config['reference_info']['with_style_description']
inf_with_style_description = config['inference_info']['with_style_description']
use_shared_attention = config['inference_info']['use_shared_attention']
adain_queries = config['inference_info']['adain_queries']
adain_keys = config['inference_info']['adain_keys']
adain_values = config['inference_info']['adain_values']
use_advanced_sampling = config['inference_info']['use_advanced_sampling']
#get canny edge array
depth_image = get_depth_edge_array(depth_image_path)
style_description_pos, style_description_neg = STYLE_DESCRIPTION_DICT[style_name][0], \
STYLE_DESCRIPTION_DICT[style_name][1]
# Inference
with torch.inference_mode():
grid = None
if ref_with_style_description:
ref_prompt = style_description_pos.replace("{object}", ref_object)
else:
ref_prompt = ref_object
if inf_with_style_description:
inf_prompt = style_description_pos.replace("{object}", inf_object)
else:
inf_prompt = inf_object
for activate_layer_indices in activate_layer_indices_list:
for activate_step_indices in activate_step_indices_list:
str_activate_layer, str_activate_step = model_controlnet.activate_layer(
activate_layer_indices=activate_layer_indices,
attn_map_save_steps=attn_map_save_steps,
activate_step_indices=activate_step_indices,
use_shared_attention=use_shared_attention,
adain_queries=adain_queries,
adain_keys=adain_keys,
adain_values=adain_values,
)
# ref_latent = model_controlnet.get_init_latent(ref_seed, precomputed_path=None)
ref_latent = init_latent(model_controlnet, device_name=device, dtype=torch_dtype, seed=ref_seed)
latents = [ref_latent]
for inf_seed in inf_seeds:
# latents.append(model_controlnet.get_init_latent(inf_seed, precomputed_path=None))
inf_latent = init_latent(model_controlnet, device_name=device, dtype=torch_dtype, seed=inf_seed)
latents.append(inf_latent)
latents = torch.cat(latents, dim=0)
latents.to(device)
images = model_controlnet.generated_ve_inference(
prompt=ref_prompt,
negative_prompt=style_description_neg,
guidance_scale=guidance_scale,
num_inference_steps=diffusion_step,
controlnet_conditioning_scale=controlnet_scale,
latents=latents,
num_images_per_prompt=len(inf_seeds) + 1,
target_prompt=inf_prompt,
image=depth_image,
use_inf_negative_prompt=use_inf_negative_prompt,
use_advanced_sampling=use_advanced_sampling
)[0][1:]
n_row = 1
n_col = len(inf_seeds) # μ›λ³ΈμΆ”κ°€ν•˜λ €λ©΄ + 1
# make grid
grid = create_image_grid(images, n_row, n_col)
return grid
description_md = """
### We introduce `Visual Style Prompting`, which reflects the style of a reference image to the images generated by a pretrained text-to-image diffusion model without finetuning or optimization (e.g., Figure N).
### πŸ“– [[Paper](https://arxiv.org/abs/2402.12974)] | ✨ [[Project page](https://curryjung.github.io/VisualStylePrompt)] | ✨ [[Code](https://github.com/naver-ai/Visual-Style-Prompting)]
---
### πŸ”₯ To try out our vanilla demo,
1. Choose a `style reference` from the collection of images below.
2. Enter the `text prompt`.
3. Choose the `number of outputs`.
---
### ✨ Visual Style Prompting also works on `ControlNet` which specifies the shape of the results by depthmap or keypoints.
### ‼️ w/ ControlNet ver does not support user style images.
### πŸ”₯ To try out our demo with ControlNet,
1. Upload an `image for depth control`. An off-the-shelf model will produce the depthmap from it.
2. Choose `ControlNet scale` which determines the alignment to the depthmap.
3. Choose a `style reference` from the collection of images below.
4. Enter the `text prompt`. (`Empty text` is okay, but a depthmap description helps.)
5. Choose the `number of outputs`.
### πŸ‘‰ To achieve faster results, we recommend lowering the diffusion steps to 30.
### Enjoy ! πŸ˜„
"""
iface_style = gr.Interface(
fn=style_fn,
inputs=[
gr.components.Image(label="Style Image", type="filepath"),
gr.components.Textbox(label='Style name', visible=False),
gr.components.Textbox(label="Text prompt", placeholder="Enter Text prompt"),
gr.components.Textbox(label="Number of outputs", placeholder="Enter Number of outputs"),
gr.components.Slider(minimum=10, maximum=50, step=10, value=50, label="Diffusion steps")
],
outputs=gr.components.Image(label="Generated Image"),
title="🎨 Visual Style Prompting (default)",
description=description_md,
examples=load_example_style(),
)
iface_controlnet = gr.Interface(
fn=controlnet_fn,
inputs=[
gr.components.Image(label="Style image"),
gr.components.Image(label="Depth image"),
gr.components.Textbox(label='Style name', visible=False),
gr.components.Textbox(label="Text prompt", placeholder="Enter Text prompt"),
gr.components.Textbox(label="Number of outputs", placeholder="Enter Number of outputs"),
gr.components.Slider(minimum=0.5, maximum=10, step=0.5, value=0.5, label="Controlnet scale"),
gr.components.Slider(minimum=10, maximum=50, step=10, value=50, label="Diffusion steps")
],
outputs=gr.components.Image(label="Generated Image"),
title="🎨 Visual Style Prompting (w/ ControlNet)",
description=description_md,
examples=load_example_controlnet(),
)
iface = gr.TabbedInterface([iface_style, iface_controlnet], ["Vanilla", "w/ ControlNet"])
iface.launch(debug=True)