Spaces:
Paused
Paused
##### | |
# Modified from https://github.com/huggingface/diffusers/blob/v0.29.1/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py | |
# PhotoMaker v2 @ TencentARC and MCG-NKU | |
# Author: Zhen Li | |
##### | |
# Copyright 2024 TencentARC and The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import inspect | |
from typing import Any, Callable, Dict, List, Optional, Tuple, Union | |
import numpy as np | |
import PIL.Image | |
import torch | |
from transformers import ( | |
CLIPImageProcessor, | |
CLIPTextModel, | |
CLIPTextModelWithProjection, | |
CLIPTokenizer, | |
CLIPVisionModelWithProjection, | |
) | |
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor | |
from diffusers.loaders import ( | |
FromSingleFileMixin, | |
IPAdapterMixin, | |
StableDiffusionXLLoraLoaderMixin, | |
TextualInversionLoaderMixin, | |
) | |
from diffusers.models import AutoencoderKL, ImageProjection, MultiAdapter, T2IAdapter, UNet2DConditionModel | |
from diffusers.models.attention_processor import ( | |
AttnProcessor2_0, | |
LoRAAttnProcessor2_0, | |
LoRAXFormersAttnProcessor, | |
XFormersAttnProcessor, | |
) | |
from diffusers.models.lora import adjust_lora_scale_text_encoder | |
from diffusers.schedulers import KarrasDiffusionSchedulers | |
from diffusers.utils import ( | |
PIL_INTERPOLATION, | |
USE_PEFT_BACKEND, | |
logging, | |
replace_example_docstring, | |
scale_lora_layers, | |
unscale_lora_layers, | |
) | |
from diffusers.utils.torch_utils import randn_tensor | |
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin | |
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput | |
from diffusers.pipelines import StableDiffusionXLAdapterPipeline | |
from diffusers.utils import _get_model_file | |
from safetensors import safe_open | |
from huggingface_hub.utils import validate_hf_hub_args | |
from model_v2 import PhotoMakerIDEncoder_CLIPInsightfaceExtendtoken | |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg | |
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): | |
""" | |
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and | |
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 | |
""" | |
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) | |
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) | |
# rescale the results from guidance (fixes overexposure) | |
noise_pred_rescaled = noise_cfg * (std_text / std_cfg) | |
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images | |
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg | |
return noise_cfg | |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps | |
def retrieve_timesteps( | |
scheduler, | |
num_inference_steps: Optional[int] = None, | |
device: Optional[Union[str, torch.device]] = None, | |
timesteps: Optional[List[int]] = None, | |
sigmas: Optional[List[float]] = None, | |
**kwargs, | |
): | |
""" | |
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles | |
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. | |
Args: | |
scheduler (`SchedulerMixin`): | |
The scheduler to get timesteps from. | |
num_inference_steps (`int`): | |
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` | |
must be `None`. | |
device (`str` or `torch.device`, *optional*): | |
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. | |
timesteps (`List[int]`, *optional*): | |
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, | |
`num_inference_steps` and `sigmas` must be `None`. | |
sigmas (`List[float]`, *optional*): | |
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, | |
`num_inference_steps` and `timesteps` must be `None`. | |
Returns: | |
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the | |
second element is the number of inference steps. | |
""" | |
if timesteps is not None and sigmas is not None: | |
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") | |
if timesteps is not None: | |
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) | |
if not accepts_timesteps: | |
raise ValueError( | |
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" | |
f" timestep schedules. Please check whether you are using the correct scheduler." | |
) | |
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) | |
timesteps = scheduler.timesteps | |
num_inference_steps = len(timesteps) | |
elif sigmas is not None: | |
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) | |
if not accept_sigmas: | |
raise ValueError( | |
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" | |
f" sigmas schedules. Please check whether you are using the correct scheduler." | |
) | |
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) | |
timesteps = scheduler.timesteps | |
num_inference_steps = len(timesteps) | |
else: | |
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) | |
timesteps = scheduler.timesteps | |
return timesteps, num_inference_steps | |
def _preprocess_adapter_image(image, height, width): | |
if isinstance(image, torch.Tensor): | |
return image | |
elif isinstance(image, PIL.Image.Image): | |
image = [image] | |
if isinstance(image[0], PIL.Image.Image): | |
image = [np.array(i.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])) for i in image] | |
image = [ | |
i[None, ..., None] if i.ndim == 2 else i[None, ...] for i in image | |
] # expand [h, w] or [h, w, c] to [b, h, w, c] | |
image = np.concatenate(image, axis=0) | |
image = np.array(image).astype(np.float32) / 255.0 | |
image = image.transpose(0, 3, 1, 2) | |
image = torch.from_numpy(image) | |
elif isinstance(image[0], torch.Tensor): | |
if image[0].ndim == 3: | |
image = torch.stack(image, dim=0) | |
elif image[0].ndim == 4: | |
image = torch.cat(image, dim=0) | |
else: | |
raise ValueError( | |
f"Invalid image tensor! Expecting image tensor with 3 or 4 dimension, but recive: {image[0].ndim}" | |
) | |
return image | |
class PhotoMakerStableDiffusionXLAdapterPipeline(StableDiffusionXLAdapterPipeline): | |
def load_photomaker_adapter( | |
self, | |
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], | |
weight_name: str, | |
subfolder: str = '', | |
trigger_word: str = 'img', | |
pm_version: str = 'v2', | |
**kwargs, | |
): | |
""" | |
Parameters: | |
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): | |
Can be either: | |
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on | |
the Hub. | |
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved | |
with [`ModelMixin.save_pretrained`]. | |
- A [torch state | |
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). | |
weight_name (`str`): | |
The weight name NOT the path to the weight. | |
subfolder (`str`, defaults to `""`): | |
The subfolder location of a model file within a larger model repository on the Hub or locally. | |
trigger_word (`str`, *optional*, defaults to `"img"`): | |
The trigger word is used to identify the position of class word in the text prompt, | |
and it is recommended not to set it as a common word. | |
This trigger word must be placed after the class word when used, otherwise, it will affect the performance of the personalized generation. | |
""" | |
# Load the main state dict first. | |
cache_dir = kwargs.pop("cache_dir", None) | |
force_download = kwargs.pop("force_download", False) | |
resume_download = kwargs.pop("resume_download", False) | |
proxies = kwargs.pop("proxies", None) | |
local_files_only = kwargs.pop("local_files_only", None) | |
token = kwargs.pop("token", None) | |
revision = kwargs.pop("revision", None) | |
user_agent = { | |
"file_type": "attn_procs_weights", | |
"framework": "pytorch", | |
} | |
if not isinstance(pretrained_model_name_or_path_or_dict, dict): | |
model_file = _get_model_file( | |
pretrained_model_name_or_path_or_dict, | |
weights_name=weight_name, | |
cache_dir=cache_dir, | |
force_download=force_download, | |
resume_download=resume_download, | |
proxies=proxies, | |
local_files_only=local_files_only, | |
token=token, | |
revision=revision, | |
subfolder=subfolder, | |
user_agent=user_agent, | |
) | |
if weight_name.endswith(".safetensors"): | |
state_dict = {"id_encoder": {}, "lora_weights": {}} | |
with safe_open(model_file, framework="pt", device="cpu") as f: | |
for key in f.keys(): | |
if key.startswith("id_encoder."): | |
state_dict["id_encoder"][key.replace("id_encoder.", "")] = f.get_tensor(key) | |
elif key.startswith("lora_weights."): | |
state_dict["lora_weights"][key.replace("lora_weights.", "")] = f.get_tensor(key) | |
else: | |
state_dict = torch.load(model_file, map_location="cpu") | |
else: | |
state_dict = pretrained_model_name_or_path_or_dict | |
keys = list(state_dict.keys()) | |
if keys != ["id_encoder", "lora_weights"]: | |
raise ValueError("Required keys are (`id_encoder` and `lora_weights`) missing from the state dict.") | |
self.trigger_word = trigger_word | |
# load finetuned CLIP image encoder and fuse module here if it has not been registered to the pipeline yet | |
print(f"Loading PhotoMaker {pm_version} components [1] id_encoder from [{pretrained_model_name_or_path_or_dict}]...") | |
self.id_image_processor = CLIPImageProcessor() | |
if pm_version == "v1": # PhotoMaker v1 | |
id_encoder = PhotoMakerIDEncoder() | |
elif pm_version == "v2": # PhotoMaker v2 | |
id_encoder = PhotoMakerIDEncoder_CLIPInsightfaceExtendtoken() | |
else: | |
raise NotImplementedError(f"The PhotoMaker version [{pm_version}] does not support") | |
id_encoder.load_state_dict(state_dict["id_encoder"], strict=True) | |
id_encoder = id_encoder.to(self.device, dtype=self.unet.dtype) | |
self.id_encoder = id_encoder | |
# load lora into models | |
print(f"Loading PhotoMaker {pm_version} components [2] lora_weights from [{pretrained_model_name_or_path_or_dict}]") | |
self.load_lora_weights(state_dict["lora_weights"], adapter_name="photomaker") | |
# Add trigger word token | |
if self.tokenizer is not None: | |
self.tokenizer.add_tokens([self.trigger_word], special_tokens=True) | |
self.tokenizer_2.add_tokens([self.trigger_word], special_tokens=True) | |
def encode_prompt_with_trigger_word( | |
self, | |
prompt: str, | |
prompt_2: Optional[str] = None, | |
device: Optional[torch.device] = None, | |
num_images_per_prompt: int = 1, | |
do_classifier_free_guidance: bool = True, | |
negative_prompt: Optional[str] = None, | |
negative_prompt_2: Optional[str] = None, | |
prompt_embeds: Optional[torch.Tensor] = None, | |
negative_prompt_embeds: Optional[torch.Tensor] = None, | |
pooled_prompt_embeds: Optional[torch.Tensor] = None, | |
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, | |
lora_scale: Optional[float] = None, | |
clip_skip: Optional[int] = None, | |
### Added args | |
num_id_images: int = 1, | |
class_tokens_mask: Optional[torch.LongTensor] = None, | |
): | |
device = device or self._execution_device | |
# set lora scale so that monkey patched LoRA | |
# function of text encoder can correctly access it | |
if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): | |
self._lora_scale = lora_scale | |
# dynamically adjust the LoRA scale | |
if self.text_encoder is not None: | |
if not USE_PEFT_BACKEND: | |
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) | |
else: | |
scale_lora_layers(self.text_encoder, lora_scale) | |
if self.text_encoder_2 is not None: | |
if not USE_PEFT_BACKEND: | |
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) | |
else: | |
scale_lora_layers(self.text_encoder_2, lora_scale) | |
prompt = [prompt] if isinstance(prompt, str) else prompt | |
if prompt is not None: | |
batch_size = len(prompt) | |
else: | |
batch_size = prompt_embeds.shape[0] | |
# Find the token id of the trigger word | |
image_token_id = self.tokenizer_2.convert_tokens_to_ids(self.trigger_word) | |
# Define tokenizers and text encoders | |
tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] | |
text_encoders = ( | |
[self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] | |
) | |
if prompt_embeds is None: | |
prompt_2 = prompt_2 or prompt | |
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 | |
# textual inversion: process multi-vector tokens if necessary | |
prompt_embeds_list = [] | |
prompts = [prompt, prompt_2] | |
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): | |
if isinstance(self, TextualInversionLoaderMixin): | |
prompt = self.maybe_convert_prompt(prompt, tokenizer) | |
text_inputs = tokenizer( | |
prompt, | |
padding="max_length", | |
max_length=tokenizer.model_max_length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
text_input_ids = text_inputs.input_ids | |
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids | |
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( | |
text_input_ids, untruncated_ids | |
): | |
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) | |
print( | |
"The following part of your input was truncated because CLIP can only handle sequences up to" | |
f" {tokenizer.model_max_length} tokens: {removed_text}" | |
) | |
clean_index = 0 | |
clean_input_ids = [] | |
class_token_index = [] | |
# Find out the corresponding class word token based on the newly added trigger word token | |
for i, token_id in enumerate(text_input_ids.tolist()[0]): | |
if token_id == image_token_id: | |
class_token_index.append(clean_index - 1) | |
else: | |
clean_input_ids.append(token_id) | |
clean_index += 1 | |
if len(class_token_index) != 1: | |
raise ValueError( | |
f"PhotoMaker currently does not support multiple trigger words in a single prompt.\ | |
Trigger word: {self.trigger_word}, Prompt: {prompt}." | |
) | |
class_token_index = class_token_index[0] | |
# Expand the class word token and corresponding mask | |
class_token = clean_input_ids[class_token_index] | |
clean_input_ids = clean_input_ids[:class_token_index] + [class_token] * num_id_images * self.num_tokens + \ | |
clean_input_ids[class_token_index+1:] | |
# Truncation or padding | |
max_len = tokenizer.model_max_length | |
if len(clean_input_ids) > max_len: | |
clean_input_ids = clean_input_ids[:max_len] | |
else: | |
clean_input_ids = clean_input_ids + [tokenizer.pad_token_id] * ( | |
max_len - len(clean_input_ids) | |
) | |
class_tokens_mask = [True if class_token_index <= i < class_token_index+(num_id_images * self.num_tokens) else False \ | |
for i in range(len(clean_input_ids))] | |
clean_input_ids = torch.tensor(clean_input_ids, dtype=torch.long).unsqueeze(0) | |
class_tokens_mask = torch.tensor(class_tokens_mask, dtype=torch.bool).unsqueeze(0) | |
prompt_embeds = text_encoder(clean_input_ids.to(device), output_hidden_states=True) | |
# We are only ALWAYS interested in the pooled output of the final text encoder | |
pooled_prompt_embeds = prompt_embeds[0] | |
if clip_skip is None: | |
prompt_embeds = prompt_embeds.hidden_states[-2] | |
else: | |
# "2" because SDXL always indexes from the penultimate layer. | |
prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] | |
prompt_embeds_list.append(prompt_embeds) | |
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) | |
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) | |
class_tokens_mask = class_tokens_mask.to(device=device) # TODO: ignoring two-prompt case | |
# get unconditional embeddings for classifier free guidance | |
zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt | |
if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: | |
negative_prompt_embeds = torch.zeros_like(prompt_embeds) | |
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) | |
elif do_classifier_free_guidance and negative_prompt_embeds is None: | |
negative_prompt = negative_prompt or "" | |
negative_prompt_2 = negative_prompt_2 or negative_prompt | |
# normalize str to list | |
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt | |
negative_prompt_2 = ( | |
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 | |
) | |
uncond_tokens: List[str] | |
if prompt is not None and type(prompt) is not type(negative_prompt): | |
raise TypeError( | |
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" | |
f" {type(prompt)}." | |
) | |
elif batch_size != len(negative_prompt): | |
raise ValueError( | |
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" | |
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" | |
" the batch size of `prompt`." | |
) | |
else: | |
uncond_tokens = [negative_prompt, negative_prompt_2] | |
negative_prompt_embeds_list = [] | |
for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): | |
if isinstance(self, TextualInversionLoaderMixin): | |
negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) | |
max_length = prompt_embeds.shape[1] | |
uncond_input = tokenizer( | |
negative_prompt, | |
padding="max_length", | |
max_length=max_length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
negative_prompt_embeds = text_encoder( | |
uncond_input.input_ids.to(device), | |
output_hidden_states=True, | |
) | |
# We are only ALWAYS interested in the pooled output of the final text encoder | |
negative_pooled_prompt_embeds = negative_prompt_embeds[0] | |
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] | |
negative_prompt_embeds_list.append(negative_prompt_embeds) | |
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) | |
if self.text_encoder_2 is not None: | |
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) | |
else: | |
prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) | |
bs_embed, seq_len, _ = prompt_embeds.shape | |
if do_classifier_free_guidance: | |
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method | |
seq_len = negative_prompt_embeds.shape[1] | |
if self.text_encoder_2 is not None: | |
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) | |
else: | |
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) | |
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) | |
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) | |
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( | |
bs_embed * num_images_per_prompt, -1 | |
) | |
if do_classifier_free_guidance: | |
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( | |
bs_embed * num_images_per_prompt, -1 | |
) | |
if self.text_encoder is not None: | |
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: | |
# Retrieve the original scale by scaling back the LoRA layers | |
unscale_lora_layers(self.text_encoder, lora_scale) | |
if self.text_encoder_2 is not None: | |
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: | |
# Retrieve the original scale by scaling back the LoRA layers | |
unscale_lora_layers(self.text_encoder_2, lora_scale) | |
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, class_tokens_mask | |
def interrupt(self): | |
return self._interrupt | |
def __call__( | |
self, | |
prompt: Union[str, List[str]] = None, | |
prompt_2: Optional[Union[str, List[str]]] = None, | |
image: PipelineImageInput = None, | |
height: Optional[int] = None, | |
width: Optional[int] = None, | |
num_inference_steps: int = 50, | |
timesteps: List[int] = None, | |
sigmas: List[float] = None, | |
denoising_end: Optional[float] = None, | |
guidance_scale: float = 5.0, | |
negative_prompt: Optional[Union[str, List[str]]] = None, | |
negative_prompt_2: Optional[Union[str, List[str]]] = None, | |
num_images_per_prompt: Optional[int] = 1, | |
eta: float = 0.0, | |
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | |
latents: Optional[torch.Tensor] = None, | |
prompt_embeds: Optional[torch.Tensor] = None, | |
negative_prompt_embeds: Optional[torch.Tensor] = None, | |
pooled_prompt_embeds: Optional[torch.Tensor] = None, | |
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, | |
ip_adapter_image: Optional[PipelineImageInput] = None, | |
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, | |
output_type: Optional[str] = "pil", | |
return_dict: bool = True, | |
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, | |
callback_steps: int = 1, | |
cross_attention_kwargs: Optional[Dict[str, Any]] = None, | |
guidance_rescale: float = 0.0, | |
original_size: Optional[Tuple[int, int]] = None, | |
crops_coords_top_left: Tuple[int, int] = (0, 0), | |
target_size: Optional[Tuple[int, int]] = None, | |
negative_original_size: Optional[Tuple[int, int]] = None, | |
negative_crops_coords_top_left: Tuple[int, int] = (0, 0), | |
negative_target_size: Optional[Tuple[int, int]] = None, | |
adapter_conditioning_scale: Union[float, List[float]] = 1.0, | |
adapter_conditioning_factor: float = 1.0, | |
clip_skip: Optional[int] = None, | |
# Added parameters (for PhotoMaker) | |
input_id_images: PipelineImageInput = None, | |
start_merge_step: int = 10, # TODO: change to `style_strength_ratio` in the future | |
class_tokens_mask: Optional[torch.LongTensor] = None, | |
id_embeds: Optional[torch.FloatTensor] = None, | |
prompt_embeds_text_only: Optional[torch.FloatTensor] = None, | |
pooled_prompt_embeds_text_only: Optional[torch.FloatTensor] = None, | |
**kwargs, | |
): | |
r""" | |
Function invoked when calling the pipeline for generation. | |
Only the parameters introduced by PhotoMaker are discussed here. | |
For explanations of the previous parameters in StableDiffusionXLControlNetPipeline, please refer to https://github.com/huggingface/diffusers/blob/v0.25.0/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py | |
Args: | |
input_id_images (`PipelineImageInput`, *optional*): | |
Input ID Image to work with PhotoMaker. | |
class_tokens_mask (`torch.LongTensor`, *optional*): | |
Pre-generated class token. When the `prompt_embeds` parameter is provided in advance, it is necessary to prepare the `class_tokens_mask` beforehand for marking out the position of class word. | |
prompt_embeds_text_only (`torch.FloatTensor`, *optional*): | |
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not | |
provided, text embeddings will be generated from `prompt` input argument. | |
pooled_prompt_embeds_text_only (`torch.FloatTensor`, *optional*): | |
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. | |
If not provided, pooled text embeddings will be generated from `prompt` input argument. | |
Returns: | |
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: | |
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a | |
`tuple`. When returning a tuple, the first element is a list with the generated images. | |
""" | |
height, width = self._default_height_width(height, width, image) | |
device = self._execution_device | |
use_adapter = True if image is not None else False | |
print(f"Use adapter: {use_adapter} | output size: {(height, width)}") | |
if use_adapter: | |
if isinstance(self.adapter, MultiAdapter): | |
adapter_input = [] | |
for one_image in image: | |
one_image = _preprocess_adapter_image(one_image, height, width) | |
one_image = one_image.to(device=device, dtype=self.adapter.dtype) | |
adapter_input.append(one_image) | |
else: | |
adapter_input = _preprocess_adapter_image(image, height, width) | |
adapter_input = adapter_input.to(device=device, dtype=self.adapter.dtype) | |
original_size = original_size or (height, width) | |
target_size = target_size or (height, width) | |
# 1. Check inputs. Raise error if not correct | |
self.check_inputs( | |
prompt, | |
prompt_2, | |
height, | |
width, | |
callback_steps, | |
negative_prompt, | |
negative_prompt_2, | |
prompt_embeds, | |
negative_prompt_embeds, | |
pooled_prompt_embeds, | |
negative_pooled_prompt_embeds, | |
ip_adapter_image, | |
ip_adapter_image_embeds, | |
) | |
self._guidance_scale = guidance_scale | |
self._clip_skip = clip_skip | |
# | |
if prompt_embeds is not None and class_tokens_mask is None: | |
raise ValueError( | |
"If `prompt_embeds` are provided, `class_tokens_mask` also have to be passed. Make sure to generate `class_tokens_mask` from the same tokenizer that was used to generate `prompt_embeds`." | |
) | |
# check the input id images | |
if input_id_images is None: | |
raise ValueError( | |
"Provide `input_id_images`. Cannot leave `input_id_images` undefined for PhotoMaker pipeline." | |
) | |
if not isinstance(input_id_images, list): | |
input_id_images = [input_id_images] | |
# 2. Define call parameters | |
if prompt is not None and isinstance(prompt, str): | |
batch_size = 1 | |
elif prompt is not None and isinstance(prompt, list): | |
batch_size = len(prompt) | |
else: | |
batch_size = prompt_embeds.shape[0] | |
device = self._execution_device | |
# 3. Encode input prompt | |
lora_scale = ( | |
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None | |
) | |
num_id_images = len(input_id_images) | |
( | |
prompt_embeds, | |
_, | |
pooled_prompt_embeds, | |
_, | |
class_tokens_mask, | |
) = self.encode_prompt_with_trigger_word( | |
prompt=prompt, | |
prompt_2=prompt_2, | |
device=device, | |
num_id_images=num_id_images, | |
class_tokens_mask=class_tokens_mask, | |
num_images_per_prompt=num_images_per_prompt, | |
do_classifier_free_guidance=self.do_classifier_free_guidance, | |
negative_prompt=negative_prompt, | |
negative_prompt_2=negative_prompt_2, | |
prompt_embeds=prompt_embeds, | |
negative_prompt_embeds=negative_prompt_embeds, | |
pooled_prompt_embeds=pooled_prompt_embeds, | |
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, | |
lora_scale=lora_scale, | |
clip_skip=self._clip_skip, | |
) | |
# 4. Encode input prompt without the trigger word for delayed conditioning | |
# encode, remove trigger word token, then decode | |
tokens_text_only = self.tokenizer.encode(prompt, add_special_tokens=False) | |
trigger_word_token = self.tokenizer.convert_tokens_to_ids(self.trigger_word) | |
tokens_text_only.remove(trigger_word_token) | |
prompt_text_only = self.tokenizer.decode(tokens_text_only, add_special_tokens=False) | |
( | |
prompt_embeds_text_only, | |
negative_prompt_embeds, | |
pooled_prompt_embeds_text_only, # TODO: replace the pooled_prompt_embeds with text only prompt | |
negative_pooled_prompt_embeds, | |
) = self.encode_prompt( | |
prompt=prompt_text_only, | |
prompt_2=prompt_2, | |
device=device, | |
num_images_per_prompt=num_images_per_prompt, | |
do_classifier_free_guidance=self.do_classifier_free_guidance, | |
negative_prompt=negative_prompt, | |
negative_prompt_2=negative_prompt_2, | |
prompt_embeds=prompt_embeds_text_only, | |
negative_prompt_embeds=negative_prompt_embeds, | |
pooled_prompt_embeds=pooled_prompt_embeds_text_only, | |
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, | |
lora_scale=lora_scale, | |
clip_skip=self._clip_skip, | |
) | |
# 5. Prepare the input ID images | |
dtype = next(self.id_encoder.parameters()).dtype | |
if not isinstance(input_id_images[0], torch.Tensor): | |
id_pixel_values = self.id_image_processor(input_id_images, return_tensors="pt").pixel_values | |
id_pixel_values = id_pixel_values.unsqueeze(0).to(device=device, dtype=dtype) # TODO: multiple prompts | |
# 6. Get the update text embedding with the stacked ID embedding | |
if id_embeds is not None: | |
id_embeds = id_embeds.unsqueeze(0).to(device=device, dtype=dtype) | |
prompt_embeds = self.id_encoder(id_pixel_values, prompt_embeds, class_tokens_mask, id_embeds) | |
else: | |
prompt_embeds = self.id_encoder(id_pixel_values, prompt_embeds, class_tokens_mask) | |
bs_embed, seq_len, _ = prompt_embeds.shape | |
# duplicate text embeddings for each generation per prompt, using mps friendly method | |
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) | |
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) | |
# 6.1 Get the ip adapter embedding | |
if ip_adapter_image is not None or ip_adapter_image_embeds is not None: | |
image_embeds = self.prepare_ip_adapter_image_embeds( | |
ip_adapter_image, | |
ip_adapter_image_embeds, | |
device, | |
batch_size * num_images_per_prompt, | |
self.do_classifier_free_guidance, | |
) | |
# 7. Prepare timesteps | |
timesteps, num_inference_steps = retrieve_timesteps( | |
self.scheduler, num_inference_steps, device, timesteps, sigmas | |
) | |
# 8. Prepare latent variables | |
num_channels_latents = self.unet.config.in_channels | |
latents = self.prepare_latents( | |
batch_size * num_images_per_prompt, | |
num_channels_latents, | |
height, | |
width, | |
prompt_embeds.dtype, | |
device, | |
generator, | |
latents, | |
) | |
# 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline | |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) | |
# 8.5 Optionally get Guidance Scale Embedding | |
timestep_cond = None | |
if self.unet.config.time_cond_proj_dim is not None: | |
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) | |
timestep_cond = self.get_guidance_scale_embedding( | |
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim | |
).to(device=device, dtype=latents.dtype) | |
# 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline | |
if use_adapter: | |
if isinstance(self.adapter, MultiAdapter): | |
adapter_state = self.adapter(adapter_input, adapter_conditioning_scale) | |
for k, v in enumerate(adapter_state): | |
adapter_state[k] = v | |
else: | |
adapter_state = self.adapter(adapter_input) | |
for k, v in enumerate(adapter_state): | |
adapter_state[k] = v * adapter_conditioning_scale | |
if num_images_per_prompt > 1: | |
for k, v in enumerate(adapter_state): | |
adapter_state[k] = v.repeat(num_images_per_prompt, 1, 1, 1) | |
if self.do_classifier_free_guidance: | |
for k, v in enumerate(adapter_state): | |
adapter_state[k] = torch.cat([v] * 2, dim=0) | |
add_text_embeds = pooled_prompt_embeds | |
if self.text_encoder_2 is None: | |
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) | |
else: | |
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim | |
add_time_ids = self._get_add_time_ids( | |
original_size, | |
crops_coords_top_left, | |
target_size, | |
dtype=prompt_embeds.dtype, | |
text_encoder_projection_dim=text_encoder_projection_dim, | |
) | |
if negative_original_size is not None and negative_target_size is not None: | |
negative_add_time_ids = self._get_add_time_ids( | |
negative_original_size, | |
negative_crops_coords_top_left, | |
negative_target_size, | |
dtype=prompt_embeds.dtype, | |
text_encoder_projection_dim=text_encoder_projection_dim, | |
) | |
else: | |
negative_add_time_ids = add_time_ids | |
if self.do_classifier_free_guidance: | |
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) | |
prompt_embeds = prompt_embeds.to(device) | |
add_text_embeds = add_text_embeds.to(device) | |
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) | |
# 11. Denoising loop | |
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) | |
# Apply denoising_end | |
if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1: | |
discrete_timestep_cutoff = int( | |
round( | |
self.scheduler.config.num_train_timesteps | |
- (denoising_end * self.scheduler.config.num_train_timesteps) | |
) | |
) | |
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) | |
timesteps = timesteps[:num_inference_steps] | |
with self.progress_bar(total=num_inference_steps) as progress_bar: | |
for i, t in enumerate(timesteps): | |
# expand the latents if we are doing classifier free guidance | |
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents | |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | |
if i <= start_merge_step: | |
current_prompt_embeds = torch.cat( | |
[negative_prompt_embeds, prompt_embeds_text_only], dim=0 | |
) if self.do_classifier_free_guidance else prompt_embeds_text_only | |
add_text_embeds = torch.cat( | |
[negative_pooled_prompt_embeds, pooled_prompt_embeds_text_only], dim=0 | |
) if self.do_classifier_free_guidance else pooled_prompt_embeds_text_only | |
else: | |
current_prompt_embeds = torch.cat( | |
[negative_prompt_embeds, prompt_embeds], dim=0 | |
) if self.do_classifier_free_guidance else prompt_embeds | |
add_text_embeds = torch.cat( | |
[negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0 | |
) if self.do_classifier_free_guidance else pooled_prompt_embeds | |
if i < int(num_inference_steps * adapter_conditioning_factor) and (use_adapter): | |
down_intrablock_additional_residuals = [state.clone() for state in adapter_state] | |
else: | |
down_intrablock_additional_residuals = None | |
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} | |
if ip_adapter_image is not None or ip_adapter_image_embeds is not None: | |
added_cond_kwargs["image_embeds"] = image_embeds | |
# predict the noise residual | |
noise_pred = self.unet( | |
latent_model_input, | |
t, | |
encoder_hidden_states=current_prompt_embeds, | |
timestep_cond=timestep_cond, | |
cross_attention_kwargs=cross_attention_kwargs, | |
down_intrablock_additional_residuals=down_intrablock_additional_residuals, | |
added_cond_kwargs=added_cond_kwargs, | |
return_dict=False, | |
)[0] | |
# perform guidance | |
if self.do_classifier_free_guidance: | |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | |
if self.do_classifier_free_guidance and guidance_rescale > 0.0: | |
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf | |
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) | |
# compute the previous noisy sample x_t -> x_t-1 | |
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] | |
# call the callback, if provided | |
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): | |
progress_bar.update() | |
if callback is not None and i % callback_steps == 0: | |
step_idx = i // getattr(self.scheduler, "order", 1) | |
callback(step_idx, t, latents) | |
if not output_type == "latent": | |
# make sure the VAE is in float32 mode, as it overflows in float16 | |
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast | |
if needs_upcasting: | |
self.upcast_vae() | |
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) | |
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] | |
# cast back to fp16 if needed | |
if needs_upcasting: | |
self.vae.to(dtype=torch.float16) | |
else: | |
image = latents | |
return StableDiffusionXLPipelineOutput(images=image) | |
image = self.image_processor.postprocess(image, output_type=output_type) | |
# Offload all models | |
self.maybe_free_model_hooks() | |
if not return_dict: | |
return (image,) | |
return StableDiffusionXLPipelineOutput(images=image) |