Spaces:
Running
on
Zero
Running
on
Zero
# Copyright 2024 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import contextlib | |
import inspect | |
from typing import Any, Callable, Dict, List, Optional, Union | |
import numpy as np | |
import PIL.Image | |
import torch | |
from packaging import version | |
from transformers import CLIPTextModel, CLIPTokenizer, DPTFeatureExtractor, DPTForDepthEstimation | |
from ...configuration_utils import FrozenDict | |
from ...image_processor import PipelineImageInput, VaeImageProcessor | |
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin | |
from ...models import AutoencoderKL, UNet2DConditionModel | |
from ...models.lora import adjust_lora_scale_text_encoder | |
from ...schedulers import KarrasDiffusionSchedulers | |
from ...utils import PIL_INTERPOLATION, USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers | |
from ...utils.torch_utils import randn_tensor | |
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput | |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents | |
def retrieve_latents( | |
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" | |
): | |
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": | |
return encoder_output.latent_dist.sample(generator) | |
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": | |
return encoder_output.latent_dist.mode() | |
elif hasattr(encoder_output, "latents"): | |
return encoder_output.latents | |
else: | |
raise AttributeError("Could not access latents of provided encoder_output") | |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess | |
def preprocess(image): | |
deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead" | |
deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False) | |
if isinstance(image, torch.Tensor): | |
return image | |
elif isinstance(image, PIL.Image.Image): | |
image = [image] | |
if isinstance(image[0], PIL.Image.Image): | |
w, h = image[0].size | |
w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 | |
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] | |
image = np.concatenate(image, axis=0) | |
image = np.array(image).astype(np.float32) / 255.0 | |
image = image.transpose(0, 3, 1, 2) | |
image = 2.0 * image - 1.0 | |
image = torch.from_numpy(image) | |
elif isinstance(image[0], torch.Tensor): | |
image = torch.cat(image, dim=0) | |
return image | |
class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): | |
r""" | |
Pipeline for text-guided depth-based image-to-image generation using Stable Diffusion. | |
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods | |
implemented for all pipelines (downloading, saving, running on a particular device, etc.). | |
The pipeline also inherits the following loading methods: | |
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings | |
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights | |
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights | |
Args: | |
vae ([`AutoencoderKL`]): | |
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. | |
text_encoder ([`~transformers.CLIPTextModel`]): | |
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co./openai/clip-vit-large-patch14)). | |
tokenizer ([`~transformers.CLIPTokenizer`]): | |
A `CLIPTokenizer` to tokenize text. | |
unet ([`UNet2DConditionModel`]): | |
A `UNet2DConditionModel` to denoise the encoded image latents. | |
scheduler ([`SchedulerMixin`]): | |
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of | |
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. | |
""" | |
model_cpu_offload_seq = "text_encoder->unet->vae" | |
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "depth_mask"] | |
def __init__( | |
self, | |
vae: AutoencoderKL, | |
text_encoder: CLIPTextModel, | |
tokenizer: CLIPTokenizer, | |
unet: UNet2DConditionModel, | |
scheduler: KarrasDiffusionSchedulers, | |
depth_estimator: DPTForDepthEstimation, | |
feature_extractor: DPTFeatureExtractor, | |
): | |
super().__init__() | |
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( | |
version.parse(unet.config._diffusers_version).base_version | |
) < version.parse("0.9.0.dev0") | |
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 | |
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: | |
deprecation_message = ( | |
"The configuration file of the unet has set the default `sample_size` to smaller than" | |
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" | |
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" | |
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" | |
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" | |
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" | |
" in the config might lead to incorrect results in future versions. If you have downloaded this" | |
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" | |
" the `unet/config.json` file" | |
) | |
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) | |
new_config = dict(unet.config) | |
new_config["sample_size"] = 64 | |
unet._internal_dict = FrozenDict(new_config) | |
self.register_modules( | |
vae=vae, | |
text_encoder=text_encoder, | |
tokenizer=tokenizer, | |
unet=unet, | |
scheduler=scheduler, | |
depth_estimator=depth_estimator, | |
feature_extractor=feature_extractor, | |
) | |
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) | |
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) | |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt | |
def _encode_prompt( | |
self, | |
prompt, | |
device, | |
num_images_per_prompt, | |
do_classifier_free_guidance, | |
negative_prompt=None, | |
prompt_embeds: Optional[torch.FloatTensor] = None, | |
negative_prompt_embeds: Optional[torch.FloatTensor] = None, | |
lora_scale: Optional[float] = None, | |
**kwargs, | |
): | |
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." | |
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) | |
prompt_embeds_tuple = self.encode_prompt( | |
prompt=prompt, | |
device=device, | |
num_images_per_prompt=num_images_per_prompt, | |
do_classifier_free_guidance=do_classifier_free_guidance, | |
negative_prompt=negative_prompt, | |
prompt_embeds=prompt_embeds, | |
negative_prompt_embeds=negative_prompt_embeds, | |
lora_scale=lora_scale, | |
**kwargs, | |
) | |
# concatenate for backwards comp | |
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) | |
return prompt_embeds | |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt | |
def encode_prompt( | |
self, | |
prompt, | |
device, | |
num_images_per_prompt, | |
do_classifier_free_guidance, | |
negative_prompt=None, | |
prompt_embeds: Optional[torch.FloatTensor] = None, | |
negative_prompt_embeds: Optional[torch.FloatTensor] = None, | |
lora_scale: Optional[float] = None, | |
clip_skip: Optional[int] = None, | |
): | |
r""" | |
Encodes the prompt into text encoder hidden states. | |
Args: | |
prompt (`str` or `List[str]`, *optional*): | |
prompt to be encoded | |
device: (`torch.device`): | |
torch device | |
num_images_per_prompt (`int`): | |
number of images that should be generated per prompt | |
do_classifier_free_guidance (`bool`): | |
whether to use classifier free guidance or not | |
negative_prompt (`str` or `List[str]`, *optional*): | |
The prompt or prompts not to guide the image generation. If not defined, one has to pass | |
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is | |
less than `1`). | |
prompt_embeds (`torch.FloatTensor`, *optional*): | |
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not | |
provided, text embeddings will be generated from `prompt` input argument. | |
negative_prompt_embeds (`torch.FloatTensor`, *optional*): | |
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt | |
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input | |
argument. | |
lora_scale (`float`, *optional*): | |
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. | |
clip_skip (`int`, *optional*): | |
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that | |
the output of the pre-final layer will be used for computing the prompt embeddings. | |
""" | |
# set lora scale so that monkey patched LoRA | |
# function of text encoder can correctly access it | |
if lora_scale is not None and isinstance(self, LoraLoaderMixin): | |
self._lora_scale = lora_scale | |
# dynamically adjust the LoRA scale | |
if not USE_PEFT_BACKEND: | |
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) | |
else: | |
scale_lora_layers(self.text_encoder, lora_scale) | |
if prompt is not None and isinstance(prompt, str): | |
batch_size = 1 | |
elif prompt is not None and isinstance(prompt, list): | |
batch_size = len(prompt) | |
else: | |
batch_size = prompt_embeds.shape[0] | |
if prompt_embeds is None: | |
# textual inversion: process multi-vector tokens if necessary | |
if isinstance(self, TextualInversionLoaderMixin): | |
prompt = self.maybe_convert_prompt(prompt, self.tokenizer) | |
text_inputs = self.tokenizer( | |
prompt, | |
padding="max_length", | |
max_length=self.tokenizer.model_max_length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
text_input_ids = text_inputs.input_ids | |
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids | |
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( | |
text_input_ids, untruncated_ids | |
): | |
removed_text = self.tokenizer.batch_decode( | |
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] | |
) | |
logger.warning( | |
"The following part of your input was truncated because CLIP can only handle sequences up to" | |
f" {self.tokenizer.model_max_length} tokens: {removed_text}" | |
) | |
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: | |
attention_mask = text_inputs.attention_mask.to(device) | |
else: | |
attention_mask = None | |
if clip_skip is None: | |
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) | |
prompt_embeds = prompt_embeds[0] | |
else: | |
prompt_embeds = self.text_encoder( | |
text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True | |
) | |
# Access the `hidden_states` first, that contains a tuple of | |
# all the hidden states from the encoder layers. Then index into | |
# the tuple to access the hidden states from the desired layer. | |
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] | |
# We also need to apply the final LayerNorm here to not mess with the | |
# representations. The `last_hidden_states` that we typically use for | |
# obtaining the final prompt representations passes through the LayerNorm | |
# layer. | |
prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) | |
if self.text_encoder is not None: | |
prompt_embeds_dtype = self.text_encoder.dtype | |
elif self.unet is not None: | |
prompt_embeds_dtype = self.unet.dtype | |
else: | |
prompt_embeds_dtype = prompt_embeds.dtype | |
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) | |
bs_embed, seq_len, _ = prompt_embeds.shape | |
# duplicate text embeddings for each generation per prompt, using mps friendly method | |
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) | |
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) | |
# get unconditional embeddings for classifier free guidance | |
if do_classifier_free_guidance and negative_prompt_embeds is None: | |
uncond_tokens: List[str] | |
if negative_prompt is None: | |
uncond_tokens = [""] * batch_size | |
elif prompt is not None and type(prompt) is not type(negative_prompt): | |
raise TypeError( | |
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" | |
f" {type(prompt)}." | |
) | |
elif isinstance(negative_prompt, str): | |
uncond_tokens = [negative_prompt] | |
elif batch_size != len(negative_prompt): | |
raise ValueError( | |
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" | |
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" | |
" the batch size of `prompt`." | |
) | |
else: | |
uncond_tokens = negative_prompt | |
# textual inversion: process multi-vector tokens if necessary | |
if isinstance(self, TextualInversionLoaderMixin): | |
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) | |
max_length = prompt_embeds.shape[1] | |
uncond_input = self.tokenizer( | |
uncond_tokens, | |
padding="max_length", | |
max_length=max_length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: | |
attention_mask = uncond_input.attention_mask.to(device) | |
else: | |
attention_mask = None | |
negative_prompt_embeds = self.text_encoder( | |
uncond_input.input_ids.to(device), | |
attention_mask=attention_mask, | |
) | |
negative_prompt_embeds = negative_prompt_embeds[0] | |
if do_classifier_free_guidance: | |
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method | |
seq_len = negative_prompt_embeds.shape[1] | |
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) | |
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) | |
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) | |
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: | |
# Retrieve the original scale by scaling back the LoRA layers | |
unscale_lora_layers(self.text_encoder, lora_scale) | |
return prompt_embeds, negative_prompt_embeds | |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker | |
def run_safety_checker(self, image, device, dtype): | |
if self.safety_checker is None: | |
has_nsfw_concept = None | |
else: | |
if torch.is_tensor(image): | |
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") | |
else: | |
feature_extractor_input = self.image_processor.numpy_to_pil(image) | |
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) | |
image, has_nsfw_concept = self.safety_checker( | |
images=image, clip_input=safety_checker_input.pixel_values.to(dtype) | |
) | |
return image, has_nsfw_concept | |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents | |
def decode_latents(self, latents): | |
deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" | |
deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) | |
latents = 1 / self.vae.config.scaling_factor * latents | |
image = self.vae.decode(latents, return_dict=False)[0] | |
image = (image / 2 + 0.5).clamp(0, 1) | |
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 | |
image = image.cpu().permute(0, 2, 3, 1).float().numpy() | |
return image | |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs | |
def prepare_extra_step_kwargs(self, generator, eta): | |
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature | |
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. | |
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 | |
# and should be between [0, 1] | |
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) | |
extra_step_kwargs = {} | |
if accepts_eta: | |
extra_step_kwargs["eta"] = eta | |
# check if the scheduler accepts generator | |
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) | |
if accepts_generator: | |
extra_step_kwargs["generator"] = generator | |
return extra_step_kwargs | |
def check_inputs( | |
self, | |
prompt, | |
strength, | |
callback_steps, | |
negative_prompt=None, | |
prompt_embeds=None, | |
negative_prompt_embeds=None, | |
callback_on_step_end_tensor_inputs=None, | |
): | |
if strength < 0 or strength > 1: | |
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") | |
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): | |
raise ValueError( | |
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" | |
f" {type(callback_steps)}." | |
) | |
if callback_on_step_end_tensor_inputs is not None and not all( | |
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs | |
): | |
raise ValueError( | |
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" | |
) | |
if prompt is not None and prompt_embeds is not None: | |
raise ValueError( | |
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" | |
" only forward one of the two." | |
) | |
elif prompt is None and prompt_embeds is None: | |
raise ValueError( | |
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." | |
) | |
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): | |
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") | |
if negative_prompt is not None and negative_prompt_embeds is not None: | |
raise ValueError( | |
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" | |
f" {negative_prompt_embeds}. Please make sure to only forward one of the two." | |
) | |
if prompt_embeds is not None and negative_prompt_embeds is not None: | |
if prompt_embeds.shape != negative_prompt_embeds.shape: | |
raise ValueError( | |
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" | |
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" | |
f" {negative_prompt_embeds.shape}." | |
) | |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps | |
def get_timesteps(self, num_inference_steps, strength, device): | |
# get the original timestep using init_timestep | |
init_timestep = min(int(num_inference_steps * strength), num_inference_steps) | |
t_start = max(num_inference_steps - init_timestep, 0) | |
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] | |
if hasattr(self.scheduler, "set_begin_index"): | |
self.scheduler.set_begin_index(t_start * self.scheduler.order) | |
return timesteps, num_inference_steps - t_start | |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.prepare_latents | |
def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): | |
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): | |
raise ValueError( | |
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" | |
) | |
image = image.to(device=device, dtype=dtype) | |
batch_size = batch_size * num_images_per_prompt | |
if image.shape[1] == 4: | |
init_latents = image | |
else: | |
if isinstance(generator, list) and len(generator) != batch_size: | |
raise ValueError( | |
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" | |
f" size of {batch_size}. Make sure the batch size matches the length of the generators." | |
) | |
elif isinstance(generator, list): | |
init_latents = [ | |
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) | |
for i in range(batch_size) | |
] | |
init_latents = torch.cat(init_latents, dim=0) | |
else: | |
init_latents = retrieve_latents(self.vae.encode(image), generator=generator) | |
init_latents = self.vae.config.scaling_factor * init_latents | |
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: | |
# expand init_latents for batch_size | |
deprecation_message = ( | |
f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" | |
" images (`image`). Initial images are now duplicating to match the number of text prompts. Note" | |
" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" | |
" your script to pass as many initial images as text prompts to suppress this warning." | |
) | |
deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) | |
additional_image_per_prompt = batch_size // init_latents.shape[0] | |
init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) | |
elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: | |
raise ValueError( | |
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." | |
) | |
else: | |
init_latents = torch.cat([init_latents], dim=0) | |
shape = init_latents.shape | |
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) | |
# get latents | |
init_latents = self.scheduler.add_noise(init_latents, noise, timestep) | |
latents = init_latents | |
return latents | |
def prepare_depth_map(self, image, depth_map, batch_size, do_classifier_free_guidance, dtype, device): | |
if isinstance(image, PIL.Image.Image): | |
image = [image] | |
else: | |
image = list(image) | |
if isinstance(image[0], PIL.Image.Image): | |
width, height = image[0].size | |
elif isinstance(image[0], np.ndarray): | |
width, height = image[0].shape[:-1] | |
else: | |
height, width = image[0].shape[-2:] | |
if depth_map is None: | |
pixel_values = self.feature_extractor(images=image, return_tensors="pt").pixel_values | |
pixel_values = pixel_values.to(device=device) | |
# The DPT-Hybrid model uses batch-norm layers which are not compatible with fp16. | |
# So we use `torch.autocast` here for half precision inference. | |
if torch.backends.mps.is_available(): | |
autocast_ctx = contextlib.nullcontext() | |
logger.warning( | |
"The DPT-Hybrid model uses batch-norm layers which are not compatible with fp16, but autocast is not yet supported on MPS." | |
) | |
else: | |
autocast_ctx = torch.autocast(device.type, dtype=dtype) | |
with autocast_ctx: | |
depth_map = self.depth_estimator(pixel_values).predicted_depth | |
else: | |
depth_map = depth_map.to(device=device, dtype=dtype) | |
depth_map = torch.nn.functional.interpolate( | |
depth_map.unsqueeze(1), | |
size=(height // self.vae_scale_factor, width // self.vae_scale_factor), | |
mode="bicubic", | |
align_corners=False, | |
) | |
depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True) | |
depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True) | |
depth_map = 2.0 * (depth_map - depth_min) / (depth_max - depth_min) - 1.0 | |
depth_map = depth_map.to(dtype) | |
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method | |
if depth_map.shape[0] < batch_size: | |
repeat_by = batch_size // depth_map.shape[0] | |
depth_map = depth_map.repeat(repeat_by, 1, 1, 1) | |
depth_map = torch.cat([depth_map] * 2) if do_classifier_free_guidance else depth_map | |
return depth_map | |
def guidance_scale(self): | |
return self._guidance_scale | |
def clip_skip(self): | |
return self._clip_skip | |
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) | |
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` | |
# corresponds to doing no classifier free guidance. | |
def do_classifier_free_guidance(self): | |
return self._guidance_scale > 1 | |
def cross_attention_kwargs(self): | |
return self._cross_attention_kwargs | |
def num_timesteps(self): | |
return self._num_timesteps | |
def __call__( | |
self, | |
prompt: Union[str, List[str]] = None, | |
image: PipelineImageInput = None, | |
depth_map: Optional[torch.FloatTensor] = None, | |
strength: float = 0.8, | |
num_inference_steps: Optional[int] = 50, | |
guidance_scale: Optional[float] = 7.5, | |
negative_prompt: Optional[Union[str, List[str]]] = None, | |
num_images_per_prompt: Optional[int] = 1, | |
eta: Optional[float] = 0.0, | |
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | |
prompt_embeds: Optional[torch.FloatTensor] = None, | |
negative_prompt_embeds: Optional[torch.FloatTensor] = None, | |
output_type: Optional[str] = "pil", | |
return_dict: bool = True, | |
cross_attention_kwargs: Optional[Dict[str, Any]] = None, | |
clip_skip: Optional[int] = None, | |
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, | |
callback_on_step_end_tensor_inputs: List[str] = ["latents"], | |
**kwargs, | |
): | |
r""" | |
The call function to the pipeline for generation. | |
Args: | |
prompt (`str` or `List[str]`, *optional*): | |
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. | |
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): | |
`Image` or tensor representing an image batch to be used as the starting point. Can accept image | |
latents as `image` only if `depth_map` is not `None`. | |
depth_map (`torch.FloatTensor`, *optional*): | |
Depth prediction to be used as additional conditioning for the image generation process. If not | |
defined, it automatically predicts the depth with `self.depth_estimator`. | |
strength (`float`, *optional*, defaults to 0.8): | |
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a | |
starting point and more noise is added the higher the `strength`. The number of denoising steps depends | |
on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising | |
process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 | |
essentially ignores `image`. | |
num_inference_steps (`int`, *optional*, defaults to 50): | |
The number of denoising steps. More denoising steps usually lead to a higher quality image at the | |
expense of slower inference. This parameter is modulated by `strength`. | |
guidance_scale (`float`, *optional*, defaults to 7.5): | |
A higher guidance scale value encourages the model to generate images closely linked to the text | |
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. | |
negative_prompt (`str` or `List[str]`, *optional*): | |
The prompt or prompts to guide what to not include in image generation. If not defined, you need to | |
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). | |
num_images_per_prompt (`int`, *optional*, defaults to 1): | |
The number of images to generate per prompt. | |
eta (`float`, *optional*, defaults to 0.0): | |
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies | |
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. | |
generator (`torch.Generator` or `List[torch.Generator]`, *optional*): | |
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make | |
generation deterministic. | |
prompt_embeds (`torch.FloatTensor`, *optional*): | |
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not | |
provided, text embeddings are generated from the `prompt` input argument. | |
negative_prompt_embeds (`torch.FloatTensor`, *optional*): | |
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If | |
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. | |
output_type (`str`, *optional*, defaults to `"pil"`): | |
The output format of the generated image. Choose between `PIL.Image` or `np.array`. | |
return_dict (`bool`, *optional*, defaults to `True`): | |
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a | |
plain tuple. | |
cross_attention_kwargs (`dict`, *optional*): | |
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in | |
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). | |
clip_skip (`int`, *optional*): | |
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that | |
the output of the pre-final layer will be used for computing the prompt embeddings. | |
callback_on_step_end (`Callable`, *optional*): | |
A function that calls at the end of each denoising steps during the inference. The function is called | |
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, | |
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by | |
`callback_on_step_end_tensor_inputs`. | |
callback_on_step_end_tensor_inputs (`List`, *optional*): | |
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list | |
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the | |
`._callback_tensor_inputs` attribute of your pipeline class. | |
Examples: | |
```py | |
>>> import torch | |
>>> import requests | |
>>> from PIL import Image | |
>>> from diffusers import StableDiffusionDepth2ImgPipeline | |
>>> pipe = StableDiffusionDepth2ImgPipeline.from_pretrained( | |
... "stabilityai/stable-diffusion-2-depth", | |
... torch_dtype=torch.float16, | |
... ) | |
>>> pipe.to("cuda") | |
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" | |
>>> init_image = Image.open(requests.get(url, stream=True).raw) | |
>>> prompt = "two tigers" | |
>>> n_prompt = "bad, deformed, ugly, bad anotomy" | |
>>> image = pipe(prompt=prompt, image=init_image, negative_prompt=n_prompt, strength=0.7).images[0] | |
``` | |
Returns: | |
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: | |
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, | |
otherwise a `tuple` is returned where the first element is a list with the generated images. | |
""" | |
callback = kwargs.pop("callback", None) | |
callback_steps = kwargs.pop("callback_steps", None) | |
if callback is not None: | |
deprecate( | |
"callback", | |
"1.0.0", | |
"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", | |
) | |
if callback_steps is not None: | |
deprecate( | |
"callback_steps", | |
"1.0.0", | |
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", | |
) | |
# 1. Check inputs | |
self.check_inputs( | |
prompt, | |
strength, | |
callback_steps, | |
negative_prompt=negative_prompt, | |
prompt_embeds=prompt_embeds, | |
negative_prompt_embeds=negative_prompt_embeds, | |
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, | |
) | |
self._guidance_scale = guidance_scale | |
self._clip_skip = clip_skip | |
self._cross_attention_kwargs = cross_attention_kwargs | |
if image is None: | |
raise ValueError("`image` input cannot be undefined.") | |
# 2. Define call parameters | |
if prompt is not None and isinstance(prompt, str): | |
batch_size = 1 | |
elif prompt is not None and isinstance(prompt, list): | |
batch_size = len(prompt) | |
else: | |
batch_size = prompt_embeds.shape[0] | |
device = self._execution_device | |
# 3. Encode input prompt | |
text_encoder_lora_scale = ( | |
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None | |
) | |
prompt_embeds, negative_prompt_embeds = self.encode_prompt( | |
prompt, | |
device, | |
num_images_per_prompt, | |
self.do_classifier_free_guidance, | |
negative_prompt, | |
prompt_embeds=prompt_embeds, | |
negative_prompt_embeds=negative_prompt_embeds, | |
lora_scale=text_encoder_lora_scale, | |
clip_skip=self.clip_skip, | |
) | |
# For classifier free guidance, we need to do two forward passes. | |
# Here we concatenate the unconditional and text embeddings into a single batch | |
# to avoid doing two forward passes | |
if self.do_classifier_free_guidance: | |
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) | |
# 4. Prepare depth mask | |
depth_mask = self.prepare_depth_map( | |
image, | |
depth_map, | |
batch_size * num_images_per_prompt, | |
self.do_classifier_free_guidance, | |
prompt_embeds.dtype, | |
device, | |
) | |
# 5. Preprocess image | |
image = self.image_processor.preprocess(image) | |
# 6. Set timesteps | |
self.scheduler.set_timesteps(num_inference_steps, device=device) | |
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) | |
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) | |
# 7. Prepare latent variables | |
latents = self.prepare_latents( | |
image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator | |
) | |
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline | |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) | |
# 9. Denoising loop | |
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order | |
self._num_timesteps = len(timesteps) | |
with self.progress_bar(total=num_inference_steps) as progress_bar: | |
for i, t in enumerate(timesteps): | |
# expand the latents if we are doing classifier free guidance | |
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents | |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | |
latent_model_input = torch.cat([latent_model_input, depth_mask], dim=1) | |
# predict the noise residual | |
noise_pred = self.unet( | |
latent_model_input, | |
t, | |
encoder_hidden_states=prompt_embeds, | |
cross_attention_kwargs=self.cross_attention_kwargs, | |
return_dict=False, | |
)[0] | |
# perform guidance | |
if self.do_classifier_free_guidance: | |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) | |
# compute the previous noisy sample x_t -> x_t-1 | |
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] | |
if callback_on_step_end is not None: | |
callback_kwargs = {} | |
for k in callback_on_step_end_tensor_inputs: | |
callback_kwargs[k] = locals()[k] | |
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) | |
latents = callback_outputs.pop("latents", latents) | |
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) | |
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) | |
depth_mask = callback_outputs.pop("depth_mask", depth_mask) | |
# call the callback, if provided | |
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): | |
progress_bar.update() | |
if callback is not None and i % callback_steps == 0: | |
step_idx = i // getattr(self.scheduler, "order", 1) | |
callback(step_idx, t, latents) | |
if not output_type == "latent": | |
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] | |
else: | |
image = latents | |
image = self.image_processor.postprocess(image, output_type=output_type) | |
self.maybe_free_model_hooks() | |
if not return_dict: | |
return (image,) | |
return ImagePipelineOutput(images=image) | |