from email.policy import default
from this import d
import gradio as gr
import numpy as np
import torch
import gc
from huggingface_hub import hf_hub_download
import requests
import random
import os
import sys
import pickle
from PIL import Image
from tqdm.auto import tqdm
from datetime import datetime
from utils.gradio_utils import is_torch2_available
if is_torch2_available():
from utils.gradio_utils import \
AttnProcessor2_0 as AttnProcessor
else:
from utils.gradio_utils import AttnProcessor
import diffusers
from diffusers import StableDiffusionXLPipeline
from utils import PhotoMakerStableDiffusionXLPipeline
from diffusers import DDIMScheduler
import torch.nn.functional as F
from utils.gradio_utils import cal_attn_mask_xl
import copy
import os
from diffusers.utils import load_image
from utils.utils import get_comic
from utils.style_template import styles
import torch.nn.functional as F
image_encoder_path = "./data/models/ip_adapter/sdxl_models/image_encoder"
ip_ckpt = "./data/models/ip_adapter/sdxl_models/ip-adapter_sdxl_vit-h.bin"
os.environ["no_proxy"] = "localhost,127.0.0.1,::1"
STYLE_NAMES = list(styles.keys())
DEFAULT_STYLE_NAME = "Japanese Anime"
global models_dict
use_va = False
models_dict = {
# "Juggernaut": "RunDiffusion/Juggernaut-XL-v8",
"RealVision": "SG161222/RealVisXL_V4.0" ,
"SDXL": "stabilityai/stable-diffusion-xl-base-1.0" ,
"Unstable": "stablediffusionapi/sdxl-unstable-diffusers-y"
}
photomaker_path = hf_hub_download(repo_id="TencentARC/PhotoMaker", filename="photomaker-v1.bin", repo_type="model")
MAX_SEED = np.iinfo(np.int32).max
def setup_seed(seed):
torch.manual_seed(seed)
#torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
def set_text_unfinished():
return gr.update(visible=True, value="
(Not Finished) Generating ··· The intermediate results will be shown.
")
def set_text_finished():
return gr.update(visible=True, value="Generation Finished
")
#################################################
def get_image_path_list(folder_name):
image_basename_list = os.listdir(folder_name)
image_path_list = sorted([os.path.join(folder_name, basename) for basename in image_basename_list])
return image_path_list
#################################################
class SpatialAttnProcessor2_0(torch.nn.Module):
r"""
Attention processor for IP-Adapater for PyTorch 2.0.
Args:
hidden_size (`int`):
The hidden size of the attention layer.
cross_attention_dim (`int`):
The number of channels in the `encoder_hidden_states`.
text_context_len (`int`, defaults to 77):
The context length of the text features.
scale (`float`, defaults to 1.0):
the weight scale of image prompt.
"""
def __init__(self, hidden_size=None, cross_attention_dim=None, id_length=4, device="mps", dtype=torch.float32):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
self.device = device
self.dtype = dtype
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.total_length = id_length + 1
self.id_length = id_length
self.id_bank = {}
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None):
# un_cond_hidden_states, cond_hidden_states = hidden_states.chunk(2)
# un_cond_hidden_states = self.__call2__(attn, un_cond_hidden_states,encoder_hidden_states,attention_mask,temb)
# 生成一个0到1之间的随机数
global total_count,attn_count,cur_step,mask1024,mask4096
global sa32, sa64
global write
global height,width
if write:
# print(f"white:{cur_step}")
self.id_bank[cur_step] = [hidden_states[:self.id_length].clone(), hidden_states[self.id_length:].clone()]
else:
encoder_hidden_states = torch.cat((self.id_bank[cur_step][0].to(self.device),hidden_states[:1],self.id_bank[cur_step][1].to(self.device),hidden_states[1:]))
# 判断随机数是否大于0.5
if cur_step <1:
hidden_states = self.__call2__(attn, hidden_states,None,attention_mask,temb)
else: # 256 1024 4096
random_number = random.random()
if cur_step <20:
rand_num = 0.3
else:
rand_num = 0.1
# print(f"hidden state shape {hidden_states.shape[1]}")
if random_number > rand_num:
# print("mask shape",mask1024.shape,mask4096.shape)
if not write:
if hidden_states.shape[1] == (height//32) * (width//32):
attention_mask = mask1024[mask1024.shape[0] // self.total_length * self.id_length:]
else:
attention_mask = mask4096[mask4096.shape[0] // self.total_length * self.id_length:]
else:
# print(self.total_length,self.id_length,hidden_states.shape,(height//32) * (width//32))
if hidden_states.shape[1] == (height//32) * (width//32):
attention_mask = mask1024[:mask1024.shape[0] // self.total_length * self.id_length,:mask1024.shape[0] // self.total_length * self.id_length]
else:
attention_mask = mask4096[:mask4096.shape[0] // self.total_length * self.id_length,:mask4096.shape[0] // self.total_length * self.id_length]
# print(attention_mask.shape)
# print("before attention",hidden_states.shape,attention_mask.shape,encoder_hidden_states.shape if encoder_hidden_states is not None else "None")
hidden_states = self.__call1__(attn, hidden_states,encoder_hidden_states,attention_mask,temb)
else:
hidden_states = self.__call2__(attn, hidden_states,None,attention_mask,temb)
attn_count +=1
if attn_count == total_count:
attn_count = 0
cur_step += 1
mask1024,mask4096 = cal_attn_mask_xl(self.total_length,self.id_length,sa32,sa64,height,width, device=self.device, dtype= self.dtype)
return hidden_states
def __call1__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
):
# print("hidden state shape",hidden_states.shape,self.id_length)
residual = hidden_states
# if encoder_hidden_states is not None:
# raise Exception("not implement")
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
total_batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(total_batch_size, channel, height * width).transpose(1, 2)
total_batch_size,nums_token,channel = hidden_states.shape
img_nums = total_batch_size//2
hidden_states = hidden_states.view(-1,img_nums,nums_token,channel).reshape(-1,img_nums * nums_token,channel)
batch_size, sequence_length, _ = hidden_states.shape
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states # B, N, C
else:
encoder_hidden_states = encoder_hidden_states.view(-1,self.id_length+1,nums_token,channel).reshape(-1,(self.id_length+1) * nums_token,channel)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# print(key.shape,value.shape,query.shape,attention_mask.shape)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
#print(query.shape,key.shape,value.shape,attention_mask.shape)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(total_batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
# if input_ndim == 4:
# tile_hidden_states = tile_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
# if attn.residual_connection:
# tile_hidden_states = tile_hidden_states + residual
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(total_batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
# print(hidden_states.shape)
return hidden_states
def __call2__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, channel = (
hidden_states.shape
)
# print(hidden_states.shape)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states # B, N, C
else:
encoder_hidden_states = encoder_hidden_states.view(-1,self.id_length+1,sequence_length,channel).reshape(-1,(self.id_length+1) * sequence_length,channel)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
def set_attention_processor(unet,id_length,is_ipadapter = False):
global attn_procs
attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
if cross_attention_dim is None:
if name.startswith("up_blocks") :
attn_procs[name] = SpatialAttnProcessor2_0(id_length = id_length)
else:
attn_procs[name] = AttnProcessor()
else:
if is_ipadapter:
attn_procs[name] = IPAttnProcessor2_0(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
scale=1,
num_tokens=4,
).to(unet.device, dtype=torch.float16)
else:
attn_procs[name] = AttnProcessor()
unet.set_attn_processor(copy.deepcopy(attn_procs))
#################################################
#################################################
canvas_html = ""
load_js = """
async () => {
const url = "https://huggingface.co./datasets/radames/gradio-components/raw/main/sketch-canvas.js"
fetch(url)
.then(res => res.text())
.then(text => {
const script = document.createElement('script');
script.type = "module"
script.src = URL.createObjectURL(new Blob([text], { type: 'application/javascript' }));
document.head.appendChild(script);
});
}
"""
get_js_colors = """
async (canvasData) => {
const canvasEl = document.getElementById("canvas-root");
return [canvasEl._data]
}
"""
css = '''
#color-bg{display:flex;justify-content: center;align-items: center;}
.color-bg-item{width: 100%; height: 32px}
#main_button{width:100%}