d-edit / utils.py
niulx's picture
Update utils.py
a126919 verified
from transformers import PretrainedConfig
from PIL import Image
import torch
import numpy as np
import PIL
import os
from tqdm.auto import tqdm
from diffusers.models.attention_processor import (
AttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
def myroll2d(a, delta_x, delta_y):
h, w = a.shape[0], a.shape[1]
delta_x = -delta_x
delta_y = -delta_y
if isinstance(a, np.ndarray):
b = np.zeros ([h,w]).astype(np.uint8)
elif isinstance(a, torch.Tensor):
b = torch.zeros([h,w]).to(torch.uint8)
if delta_x > 0:
left_a = delta_x
right_a = w
left_b = 0
right_b = w - delta_x
else:
left_a = 0
right_a = w + delta_x
left_b = -delta_x
right_b = w
if delta_y > 0:
top_a = delta_y
bot_a = h
top_b = 0
bot_b = h-delta_y
else:
top_a = 0
bot_a = h + delta_y
top_b = -delta_y
bot_b = h
b[left_b: right_b, top_b: bot_b] = a[left_a: right_a, top_a: bot_a]
return b
def import_model_class_from_model_name_or_path(
pretrained_model_name_or_path: str, revision = None, subfolder: str = "text_encoder"
):
text_encoder_config = PretrainedConfig.from_pretrained(
pretrained_model_name_or_path, subfolder=subfolder, revision=revision
)
model_class = text_encoder_config.architectures[0]
if model_class == "CLIPTextModel":
from transformers import CLIPTextModel
return CLIPTextModel
elif model_class == "CLIPTextModelWithProjection":
from transformers import CLIPTextModelWithProjection
return CLIPTextModelWithProjection
else:
raise ValueError(f"{model_class} is not supported.")
@torch.no_grad()
def image2latent(image, vae = None, dtype=None):
with torch.no_grad():
if type(image) is Image or type(image) is PIL.PngImagePlugin.PngImageFile or type(image) is PIL.JpegImagePlugin.JpegImageFile:
image = np.array(image)
if type(image) is torch.Tensor and image.dim() == 4:
latents = image
else:
image = torch.from_numpy(image).float() / 127.5 - 1
image = image.permute(2, 0, 1).unsqueeze(0).to(device, dtype= dtype)
latents = vae.encode(image).latent_dist.sample()
latents = latents * vae.config.scaling_factor
return latents
@torch.no_grad()
def latent2image(latents, return_type = 'np', vae = None):
# needs_upcasting = vae.dtype == torch.float16 and vae.config.force_upcast
needs_upcasting = True
if needs_upcasting:
upcast_vae(vae)
latents = latents.to(next(iter(vae.post_quant_conv.parameters())).dtype)
image = vae.decode(latents /vae.config.scaling_factor, return_dict=False)[0]
if return_type == 'np':
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()#[0]
image = (image * 255).astype(np.uint8)
if needs_upcasting:
vae.to(dtype=torch.float16)
return image
def upcast_vae(vae):
dtype = vae.dtype
vae.to(dtype=torch.float32)
use_torch_2_0_or_xformers = isinstance(
vae.decoder.mid_block.attentions[0].processor,
(
AttnProcessor2_0,
XFormersAttnProcessor,
LoRAXFormersAttnProcessor,
LoRAAttnProcessor2_0,
),
)
# if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory
if use_torch_2_0_or_xformers:
vae.post_quant_conv.to(dtype)
vae.decoder.conv_in.to(dtype)
vae.decoder.mid_block.to(dtype)
def prompt_to_emb_length_sdxl(prompt, tokenizer, text_encoder, length = None):
text_input = tokenizer(
[prompt],
padding="max_length",
max_length=length,
truncation=True,
return_tensors="pt",
)
prompt_embeds = text_encoder(text_input.input_ids.to(device),output_hidden_states=True)
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]
bs_embed, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
return {"prompt_embeds": prompt_embeds, "pooled_prompt_embeds": pooled_prompt_embeds}
def prompt_to_emb_length_sd(prompt, tokenizer, text_encoder, length = None):
text_input = tokenizer(
[prompt],
padding="max_length",
max_length=length,
truncation=True,
return_tensors="pt",
)
emb = text_encoder(text_input.input_ids.to(device))[0]
return emb
def sdxl_prepare_input_decom(
set_string_list,
tokenizer,
tokenizer_2,
text_encoder_1,
text_encoder_2,
length = 20,
bsz = 1,
weight_dtype = torch.float32,
resolution = 1024,
normal_token_id_list = []
):
encoder_hidden_states_list = []
pooled_prompt_embeds = 0
for m_idx in range(len(set_string_list)):
prompt_embeds_list = []
if ("#" in set_string_list[m_idx] or "$" in set_string_list[m_idx]) and m_idx not in normal_token_id_list : ###
out = prompt_to_emb_length_sdxl(
set_string_list[m_idx], tokenizer, text_encoder_1, length = length
)
else:
out = prompt_to_emb_length_sdxl(
set_string_list[m_idx], tokenizer, text_encoder_1, length = 77
)
print(m_idx, set_string_list[m_idx])
prompt_embeds, _ = out["prompt_embeds"].to(dtype=weight_dtype), out["pooled_prompt_embeds"].to(dtype=weight_dtype)
prompt_embeds = prompt_embeds.repeat(bsz, 1, 1)
prompt_embeds_list.append(prompt_embeds)
if ("#" in set_string_list[m_idx] or "$" in set_string_list[m_idx]) and m_idx not in normal_token_id_list:
out = prompt_to_emb_length_sdxl(
set_string_list[m_idx], tokenizer_2, text_encoder_2, length = length
)
else:
out = prompt_to_emb_length_sdxl(
set_string_list[m_idx], tokenizer_2, text_encoder_2, length = 77
)
print(m_idx, set_string_list[m_idx])
prompt_embeds = out["prompt_embeds"].to(dtype=weight_dtype)
pooled_prompt_embeds += out["pooled_prompt_embeds"].to(dtype=weight_dtype)
prompt_embeds = prompt_embeds.repeat(bsz, 1, 1)
prompt_embeds_list.append(prompt_embeds)
encoder_hidden_states_list.append(torch.concat(prompt_embeds_list, dim=-1))
add_text_embeds = pooled_prompt_embeds /len(set_string_list)
target_size, original_size,crops_coords_top_left = (resolution,resolution),(resolution,resolution),(0,0)
add_time_ids = list(original_size + crops_coords_top_left + target_size)
add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype,device = pooled_prompt_embeds.device) #[B,6]
return encoder_hidden_states_list, add_text_embeds, add_time_ids
def sd_prepare_input_decom(
set_string_list,
tokenizer,
text_encoder_1,
length = 20,
bsz = 1,
weight_dtype = torch.float32,
normal_token_id_list = []
):
encoder_hidden_states_list = []
for m_idx in range(len(set_string_list)):
if ("#" in set_string_list[m_idx] or "$" in set_string_list[m_idx]) and m_idx not in normal_token_id_list : ###
encoder_hidden_states = prompt_to_emb_length_sd(
set_string_list[m_idx], tokenizer, text_encoder_1, length = length
)
else:
encoder_hidden_states = prompt_to_emb_length_sd(
set_string_list[m_idx], tokenizer, text_encoder_1, length = 77
)
print(m_idx, set_string_list[m_idx])
encoder_hidden_states = encoder_hidden_states.repeat(bsz, 1, 1)
encoder_hidden_states_list.append(encoder_hidden_states.to(dtype=weight_dtype))
return encoder_hidden_states_list
def load_mask (input_folder):
np_mask_dtype = 'uint8'
mask_np_list = []
mask_label_list = []
files = [
file_name for file_name in os.listdir(input_folder) \
if "mask" in file_name and ".npy" in file_name \
and "_" in file_name and "Edited" not in file_name
]
files = sorted(files, key = lambda x: int(x.split("_")[0][4:]))
for idx, file_name in enumerate(files):
if "mask" in file_name and ".npy" in file_name and "_" in file_name \
and "Edited" not in file_name:
mask_np = np.load(os.path.join(input_folder, file_name)).astype(np_mask_dtype)
mask_np_list.append(mask_np)
mask_label = file_name.split("_")[1][:-4]
mask_label_list.append(mask_label)
mask_list = []
for mask_np in mask_np_list:
mask = torch.from_numpy(mask_np)
mask_list.append(mask)
try:
assert torch.all(sum(mask_list)==1)
except:
print("please check mask")
# plt.imsave( "out_mask.png", mask_list_edit[0])
return mask_list, mask_label_list
def load_image(image_path, left=0, right=0, top=0, bottom=0, size = 512):
if type(image_path) is str:
image = np.array(Image.open(image_path))[:, :, :3]
else:
image = image_path
h, w, c = image.shape
left = min(left, w-1)
right = min(right, w - left - 1)
top = min(top, h - left - 1)
bottom = min(bottom, h - top - 1)
image = image[top:h-bottom, left:w-right]
h, w, c = image.shape
if h < w:
offset = (w - h) // 2
image = image[:, offset:offset + h]
elif w < h:
offset = (h - w) // 2
image = image[offset:offset + w]
image = np.array(Image.fromarray(image).resize((size, size)))
return image
def mask_union_torch(*masks):
masks = [m.to(torch.float) for m in masks]
res = sum(masks)>0
return res
def load_mask_edit(input_folder):
np_mask_dtype = 'uint8'
mask_np_list = []
mask_label_list = []
files = [file_name for file_name in os.listdir(input_folder) if "mask" in file_name and ".npy" in file_name and "_" in file_name and "Edited" in file_name and "-1" not in file_name]
files = sorted(files, key = lambda x: int(x.split("_")[0][10:]))
for idx, file_name in enumerate(files):
if "mask" in file_name and ".npy" in file_name and "_" in file_name and "Edited" in file_name and "-1" not in file_name:
mask_np = np.load(os.path.join(input_folder, file_name)).astype(np_mask_dtype)
mask_np_list.append(mask_np)
mask_label = file_name.split("_")[1][:-4]
# mask_label = mask_label.split("-")[0]
mask_label_list.append(mask_label)
mask_list = []
for mask_np in mask_np_list:
mask = torch.from_numpy(mask_np)
mask_list.append(mask)
try:
assert torch.all(sum(mask_list)==1)
except:
print("Make sure maskEdited is in the folder, if not, generate using the UI")
import pdb; pdb.set_trace()
return mask_list, mask_label_list
def save_images(images,filename, num_rows=1, offset_ratio=0.02):
if type(images) is list:
num_empty = len(images) % num_rows
elif images.ndim == 4:
num_empty = images.shape[0] % num_rows
else:
images = [images]
num_empty = 0
empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255
images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty
num_items = len(images)
folder = os.path.dirname(filename)
for i, image in enumerate(images):
pil_img = Image.fromarray(image)
name = filename.split("/")[-1]
name = name.split(".")[-2]+"_{}".format(i) +"."+filename.split(".")[-1]
pil_img.save(os.path.join(folder, name))
print("saved to ", os.path.join(folder, name))