PuLiD-Flux / pulidflux.py
sanjayxm's picture
Upload pulidflux.py
757ccb1 verified
raw
history blame
16.5 kB
import torch
from torch import nn, Tensor
from torchvision import transforms
from torchvision.transforms import functional
import os
import logging
import folder_paths
import comfy.utils
from comfy.ldm.flux.layers import timestep_embedding
from insightface.app import FaceAnalysis
from facexlib.parsing import init_parsing_model
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from .eva_clip.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .encoders_flux import IDFormer, PerceiverAttentionCA
INSIGHTFACE_DIR = os.path.join(folder_paths.models_dir, "insightface")
MODELS_DIR = os.path.join(folder_paths.models_dir, "pulid")
if "pulid" not in folder_paths.folder_names_and_paths:
current_paths = [MODELS_DIR]
else:
current_paths, _ = folder_paths.folder_names_and_paths["pulid"]
folder_paths.folder_names_and_paths["pulid"] = (current_paths, folder_paths.supported_pt_extensions)
class PulidFluxModel(nn.Module):
def __init__(self):
super().__init__()
self.double_interval = 2
self.single_interval = 4
# Init encoder
self.pulid_encoder = IDFormer()
# Init attention
num_ca = 19 // self.double_interval + 38 // self.single_interval
if 19 % self.double_interval != 0:
num_ca += 1
if 38 % self.single_interval != 0:
num_ca += 1
self.pulid_ca = nn.ModuleList([
PerceiverAttentionCA() for _ in range(num_ca)
])
def from_pretrained(self, path: str):
state_dict = comfy.utils.load_torch_file(path, safe_load=True)
state_dict_dict = {}
for k, v in state_dict.items():
module = k.split('.')[0]
state_dict_dict.setdefault(module, {})
new_k = k[len(module) + 1:]
state_dict_dict[module][new_k] = v
for module in state_dict_dict:
getattr(self, module).load_state_dict(state_dict_dict[module], strict=True)
del state_dict
del state_dict_dict
def get_embeds(self, face_embed, clip_embeds):
return self.pulid_encoder(face_embed, clip_embeds)
def forward_orig(
self,
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
timesteps: Tensor,
y: Tensor,
guidance: Tensor = None,
control=None,
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
# running on sequences img
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype))
if self.params.guidance_embed:
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
vec = vec + self.vector_in(y)
txt = self.txt_in(txt)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
ca_idx = 0
for i, block in enumerate(self.double_blocks):
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
if control is not None: # Controlnet
control_i = control.get("input")
if i < len(control_i):
add = control_i[i]
if add is not None:
img += add
# PuLID attention
if self.pulid_data:
if i % self.pulid_double_interval == 0:
# Will calculate influence of all pulid nodes at once
for _, node_data in self.pulid_data.items():
if torch.any((node_data['sigma_start'] >= timesteps) & (timesteps >= node_data['sigma_end'])):
img = img + node_data['weight'] * self.pulid_ca[ca_idx](node_data['embedding'], img)
ca_idx += 1
img = torch.cat((txt, img), 1)
for i, block in enumerate(self.single_blocks):
img = block(img, vec=vec, pe=pe)
if control is not None: # Controlnet
control_o = control.get("output")
if i < len(control_o):
add = control_o[i]
if add is not None:
img[:, txt.shape[1] :, ...] += add
# PuLID attention
if self.pulid_data:
real_img, txt = img[:, txt.shape[1]:, ...], img[:, :txt.shape[1], ...]
if i % self.pulid_single_interval == 0:
# Will calculate influence of all nodes at once
for _, node_data in self.pulid_data.items():
if torch.any((node_data['sigma_start'] >= timesteps) & (timesteps >= node_data['sigma_end'])):
real_img = real_img + node_data['weight'] * self.pulid_ca[ca_idx](node_data['embedding'], real_img)
ca_idx += 1
img = torch.cat((txt, real_img), 1)
img = img[:, txt.shape[1] :, ...]
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
return img
def tensor_to_image(tensor):
image = tensor.mul(255).clamp(0, 255).byte().cpu()
image = image[..., [2, 1, 0]].numpy()
return image
def image_to_tensor(image):
tensor = torch.clamp(torch.from_numpy(image).float() / 255., 0, 1)
tensor = tensor[..., [2, 1, 0]]
return tensor
def to_gray(img):
x = 0.299 * img[:, 0:1] + 0.587 * img[:, 1:2] + 0.114 * img[:, 2:3]
x = x.repeat(1, 3, 1, 1)
return x
"""
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Nodes
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
"""
class PulidFluxModelLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": {"pulid_file": (folder_paths.get_filename_list("pulid"), )}}
RETURN_TYPES = ("PULIDFLUX",)
FUNCTION = "load_model"
CATEGORY = "pulid"
def load_model(self, pulid_file):
model_path = folder_paths.get_full_path("pulid", pulid_file)
# Also initialize the model, takes longer to load but then it doesn't have to be done every time you change parameters in the apply node
model = PulidFluxModel()
logging.info("Loading PuLID-Flux model.")
model.from_pretrained(path=model_path)
return (model,)
class PulidFluxInsightFaceLoader:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"provider": (["CPU", "CUDA", "ROCM"], ),
},
}
RETURN_TYPES = ("FACEANALYSIS",)
FUNCTION = "load_insightface"
CATEGORY = "pulid"
def load_insightface(self, provider):
model = FaceAnalysis(name="antelopev2", root=INSIGHTFACE_DIR, providers=[provider + 'ExecutionProvider',]) # alternative to buffalo_l
model.prepare(ctx_id=0, det_size=(640, 640))
return (model,)
class PulidFluxEvaClipLoader:
@classmethod
def INPUT_TYPES(s):
return {
"required": {},
}
RETURN_TYPES = ("EVA_CLIP",)
FUNCTION = "load_eva_clip"
CATEGORY = "pulid"
def load_eva_clip(self):
from .eva_clip.factory import create_model_and_transforms
model, _, _ = create_model_and_transforms('EVA02-CLIP-L-14-336', 'eva_clip', force_custom_clip=True)
model = model.visual
eva_transform_mean = getattr(model, 'image_mean', OPENAI_DATASET_MEAN)
eva_transform_std = getattr(model, 'image_std', OPENAI_DATASET_STD)
if not isinstance(eva_transform_mean, (list, tuple)):
model["image_mean"] = (eva_transform_mean,) * 3
if not isinstance(eva_transform_std, (list, tuple)):
model["image_std"] = (eva_transform_std,) * 3
return (model,)
class ApplyPulidFlux:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL", ),
"pulid_flux": ("PULIDFLUX", ),
"eva_clip": ("EVA_CLIP", ),
"face_analysis": ("FACEANALYSIS", ),
"image": ("IMAGE", ),
"weight": ("FLOAT", {"default": 1.0, "min": -1.0, "max": 5.0, "step": 0.05 }),
"start_at": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001 }),
"end_at": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001 }),
},
"optional": {
"attn_mask": ("MASK", ),
},
"hidden": {
"unique_id": "UNIQUE_ID"
},
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "apply_pulid_flux"
CATEGORY = "pulid"
def __init__(self):
self.pulid_data_dict = None
def apply_pulid_flux(self, model, pulid_flux, eva_clip, face_analysis, image, weight, start_at, end_at, attn_mask=None, unique_id=None):
device = comfy.model_management.get_torch_device()
# Why should I care what args say, when the unet model has a different dtype?!
# Am I missing something?!
#dtype = comfy.model_management.unet_dtype()
dtype = model.model.diffusion_model.dtype
# Because of 8bit models we must check what cast type does the unet uses
# ZLUDA (Intel, AMD) & GPUs with compute capability < 8.0 don't support bfloat16 etc.
# Issue: https://github.com/balazik/ComfyUI-PuLID-Flux/issues/6
if model.model.manual_cast_dtype is not None:
dtype = model.model.manual_cast_dtype
eva_clip.to(device, dtype=dtype)
pulid_flux.to(device, dtype=dtype)
# TODO: Add masking support!
if attn_mask is not None:
if attn_mask.dim() > 3:
attn_mask = attn_mask.squeeze(-1)
elif attn_mask.dim() < 3:
attn_mask = attn_mask.unsqueeze(0)
attn_mask = attn_mask.to(device, dtype=dtype)
image = tensor_to_image(image)
face_helper = FaceRestoreHelper(
upscale_factor=1,
face_size=512,
crop_ratio=(1, 1),
det_model='retinaface_resnet50',
save_ext='png',
device=device,
)
face_helper.face_parse = None
face_helper.face_parse = init_parsing_model(model_name='bisenet', device=device)
bg_label = [0, 16, 18, 7, 8, 9, 14, 15]
cond = []
# Analyse multiple images at multiple sizes and combine largest area embeddings
for i in range(image.shape[0]):
# get insightface embeddings
iface_embeds = None
for size in [(size, size) for size in range(640, 256, -64)]:
face_analysis.det_model.input_size = size
face_info = face_analysis.get(image[i])
if face_info:
# Only use the maximum face
# Removed the reverse=True from original code because we need the largest area not the smallest one!
# Sorts the list in ascending order (smallest to largest),
# then selects the last element, which is the largest face
face_info = sorted(face_info, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]))[-1]
iface_embeds = torch.from_numpy(face_info.embedding).unsqueeze(0).to(device, dtype=dtype)
break
else:
# No face detected, skip this image
logging.warning(f'Warning: No face detected in image {str(i)}')
continue
# get eva_clip embeddings
face_helper.clean_all()
face_helper.read_image(image[i])
face_helper.get_face_landmarks_5(only_center_face=True)
face_helper.align_warp_face()
if len(face_helper.cropped_faces) == 0:
# No face detected, skip this image
continue
# Get aligned face image
align_face = face_helper.cropped_faces[0]
# Convert bgr face image to tensor
align_face = image_to_tensor(align_face).unsqueeze(0).permute(0, 3, 1, 2).to(device)
parsing_out = face_helper.face_parse(functional.normalize(align_face, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))[0]
parsing_out = parsing_out.argmax(dim=1, keepdim=True)
bg = sum(parsing_out == i for i in bg_label).bool()
white_image = torch.ones_like(align_face)
# Only keep the face features
face_features_image = torch.where(bg, white_image, to_gray(align_face))
# Transform img before sending to eva_clip
# Apparently MPS only supports NEAREST interpolation?
face_features_image = functional.resize(face_features_image, eva_clip.image_size, transforms.InterpolationMode.BICUBIC if 'cuda' in device.type else transforms.InterpolationMode.NEAREST).to(device, dtype=dtype)
face_features_image = functional.normalize(face_features_image, eva_clip.image_mean, eva_clip.image_std)
# eva_clip
id_cond_vit, id_vit_hidden = eva_clip(face_features_image, return_all_features=False, return_hidden=True, shuffle=False)
id_cond_vit = id_cond_vit.to(device, dtype=dtype)
for idx in range(len(id_vit_hidden)):
id_vit_hidden[idx] = id_vit_hidden[idx].to(device, dtype=dtype)
id_cond_vit = torch.div(id_cond_vit, torch.norm(id_cond_vit, 2, 1, True))
# Combine embeddings
id_cond = torch.cat([iface_embeds, id_cond_vit], dim=-1)
# Pulid_encoder
cond.append(pulid_flux.get_embeds(id_cond, id_vit_hidden))
if not cond:
# No faces detected, return the original model
logging.warning("PuLID warning: No faces detected in any of the given images, returning unmodified model.")
return (model,)
# average embeddings
cond = torch.cat(cond).to(device, dtype=dtype)
if cond.shape[0] > 1:
cond = torch.mean(cond, dim=0, keepdim=True)
sigma_start = model.get_model_object("model_sampling").percent_to_sigma(start_at)
sigma_end = model.get_model_object("model_sampling").percent_to_sigma(end_at)
# Patch the Flux model (original diffusion_model)
# Nah, I don't care for the official ModelPatcher because it's undocumented!
# I want the end result now, and I don’t mind if I break other custom nodes in the process. 😄
flux_model = model.model.diffusion_model
# Let's see if we already patched the underlying flux model, if not apply patch
if not hasattr(flux_model, "pulid_ca"):
# Add perceiver attention, variables and current node data (weight, embedding, sigma_start, sigma_end)
# The pulid_data is stored in Dict by unique node index,
# so we can chain multiple ApplyPulidFlux nodes!
flux_model.pulid_ca = pulid_flux.pulid_ca
flux_model.pulid_double_interval = pulid_flux.double_interval
flux_model.pulid_single_interval = pulid_flux.single_interval
flux_model.pulid_data = {}
# Replace model forward_orig with our own
new_method = forward_orig.__get__(flux_model, flux_model.__class__)
setattr(flux_model, 'forward_orig', new_method)
# Patch is already in place, add data (weight, embedding, sigma_start, sigma_end) under unique node index
flux_model.pulid_data[unique_id] = {
'weight': weight,
'embedding': cond,
'sigma_start': sigma_start,
'sigma_end': sigma_end,
}
# Keep a reference for destructor (if node is deleted the data will be deleted as well)
self.pulid_data_dict = {'data': flux_model.pulid_data, 'unique_id': unique_id}
return (model,)
def __del__(self):
# Destroy the data for this node
if self.pulid_data_dict:
del self.pulid_data_dict['data'][self.pulid_data_dict['unique_id']]
del self.pulid_data_dict
NODE_CLASS_MAPPINGS = {
"PulidFluxModelLoader": PulidFluxModelLoader,
"PulidFluxInsightFaceLoader": PulidFluxInsightFaceLoader,
"PulidFluxEvaClipLoader": PulidFluxEvaClipLoader,
"ApplyPulidFlux": ApplyPulidFlux,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"PulidFluxModelLoader": "Load PuLID Flux Model",
"PulidFluxInsightFaceLoader": "Load InsightFace (PuLID Flux)",
"PulidFluxEvaClipLoader": "Load Eva Clip (PuLID Flux)",
"ApplyPulidFlux": "Apply PuLID Flux",
}