Spaces:
Running
on
Zero
Running
on
Zero
# Copyright 2024 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from typing import Any, Callable, Dict, List, Optional, Tuple, Union | |
import torch | |
from transformers import CLIPTextModelWithProjection, CLIPTokenizer | |
from ...image_processor import PipelineImageInput, VaeImageProcessor | |
from ...models import UVit2DModel, VQModel | |
from ...schedulers import AmusedScheduler | |
from ...utils import replace_example_docstring | |
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput | |
EXAMPLE_DOC_STRING = """ | |
Examples: | |
```py | |
>>> import torch | |
>>> from diffusers import AmusedInpaintPipeline | |
>>> from diffusers.utils import load_image | |
>>> pipe = AmusedInpaintPipeline.from_pretrained( | |
... "amused/amused-512", variant="fp16", torch_dtype=torch.float16 | |
... ) | |
>>> pipe = pipe.to("cuda") | |
>>> prompt = "fall mountains" | |
>>> input_image = ( | |
... load_image( | |
... "https://huggingface.co./datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1.jpg" | |
... ) | |
... .resize((512, 512)) | |
... .convert("RGB") | |
... ) | |
>>> mask = ( | |
... load_image( | |
... "https://huggingface.co./datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1_mask.png" | |
... ) | |
... .resize((512, 512)) | |
... .convert("L") | |
... ) | |
>>> pipe(prompt, input_image, mask).images[0].save("out.png") | |
``` | |
""" | |
class AmusedInpaintPipeline(DiffusionPipeline): | |
image_processor: VaeImageProcessor | |
vqvae: VQModel | |
tokenizer: CLIPTokenizer | |
text_encoder: CLIPTextModelWithProjection | |
transformer: UVit2DModel | |
scheduler: AmusedScheduler | |
model_cpu_offload_seq = "text_encoder->transformer->vqvae" | |
# TODO - when calling self.vqvae.quantize, it uses self.vqvae.quantize.embedding.weight before | |
# the forward method of self.vqvae.quantize, so the hook doesn't get called to move the parameter | |
# off the meta device. There should be a way to fix this instead of just not offloading it | |
_exclude_from_cpu_offload = ["vqvae"] | |
def __init__( | |
self, | |
vqvae: VQModel, | |
tokenizer: CLIPTokenizer, | |
text_encoder: CLIPTextModelWithProjection, | |
transformer: UVit2DModel, | |
scheduler: AmusedScheduler, | |
): | |
super().__init__() | |
self.register_modules( | |
vqvae=vqvae, | |
tokenizer=tokenizer, | |
text_encoder=text_encoder, | |
transformer=transformer, | |
scheduler=scheduler, | |
) | |
self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1) | |
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False) | |
self.mask_processor = VaeImageProcessor( | |
vae_scale_factor=self.vae_scale_factor, | |
do_normalize=False, | |
do_binarize=True, | |
do_convert_grayscale=True, | |
do_resize=True, | |
) | |
self.scheduler.register_to_config(masking_schedule="linear") | |
def __call__( | |
self, | |
prompt: Optional[Union[List[str], str]] = None, | |
image: PipelineImageInput = None, | |
mask_image: PipelineImageInput = None, | |
strength: float = 1.0, | |
num_inference_steps: int = 12, | |
guidance_scale: float = 10.0, | |
negative_prompt: Optional[Union[str, List[str]]] = None, | |
num_images_per_prompt: Optional[int] = 1, | |
generator: Optional[torch.Generator] = None, | |
prompt_embeds: Optional[torch.Tensor] = None, | |
encoder_hidden_states: Optional[torch.Tensor] = None, | |
negative_prompt_embeds: Optional[torch.Tensor] = None, | |
negative_encoder_hidden_states: Optional[torch.Tensor] = None, | |
output_type="pil", | |
return_dict: bool = True, | |
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, | |
callback_steps: int = 1, | |
cross_attention_kwargs: Optional[Dict[str, Any]] = None, | |
micro_conditioning_aesthetic_score: int = 6, | |
micro_conditioning_crop_coord: Tuple[int, int] = (0, 0), | |
temperature: Union[int, Tuple[int, int], List[int]] = (2, 0), | |
): | |
""" | |
The call function to the pipeline for generation. | |
Args: | |
prompt (`str` or `List[str]`, *optional*): | |
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. | |
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): | |
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both | |
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list | |
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a | |
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image | |
latents as `image`, but if passing latents directly it is not encoded again. | |
mask_image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): | |
`Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask | |
are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a | |
single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one | |
color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B, | |
H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W, | |
1)`, or `(H, W)`. | |
strength (`float`, *optional*, defaults to 1.0): | |
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a | |
starting point and more noise is added the higher the `strength`. The number of denoising steps depends | |
on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising | |
process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 | |
essentially ignores `image`. | |
num_inference_steps (`int`, *optional*, defaults to 16): | |
The number of denoising steps. More denoising steps usually lead to a higher quality image at the | |
expense of slower inference. | |
guidance_scale (`float`, *optional*, defaults to 10.0): | |
A higher guidance scale value encourages the model to generate images closely linked to the text | |
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. | |
negative_prompt (`str` or `List[str]`, *optional*): | |
The prompt or prompts to guide what to not include in image generation. If not defined, you need to | |
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). | |
num_images_per_prompt (`int`, *optional*, defaults to 1): | |
The number of images to generate per prompt. | |
generator (`torch.Generator`, *optional*): | |
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make | |
generation deterministic. | |
prompt_embeds (`torch.FloatTensor`, *optional*): | |
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not | |
provided, text embeddings are generated from the `prompt` input argument. A single vector from the | |
pooled and projected final hidden states. | |
encoder_hidden_states (`torch.FloatTensor`, *optional*): | |
Pre-generated penultimate hidden states from the text encoder providing additional text conditioning. | |
negative_prompt_embeds (`torch.FloatTensor`, *optional*): | |
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If | |
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. | |
negative_encoder_hidden_states (`torch.FloatTensor`, *optional*): | |
Analogous to `encoder_hidden_states` for the positive prompt. | |
output_type (`str`, *optional*, defaults to `"pil"`): | |
The output format of the generated image. Choose between `PIL.Image` or `np.array`. | |
return_dict (`bool`, *optional*, defaults to `True`): | |
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a | |
plain tuple. | |
callback (`Callable`, *optional*): | |
A function that calls every `callback_steps` steps during inference. The function is called with the | |
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. | |
callback_steps (`int`, *optional*, defaults to 1): | |
The frequency at which the `callback` function is called. If not specified, the callback is called at | |
every step. | |
cross_attention_kwargs (`dict`, *optional*): | |
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in | |
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). | |
micro_conditioning_aesthetic_score (`int`, *optional*, defaults to 6): | |
The targeted aesthetic score according to the laion aesthetic classifier. See | |
https://laion.ai/blog/laion-aesthetics/ and the micro-conditioning section of | |
https://arxiv.org/abs/2307.01952. | |
micro_conditioning_crop_coord (`Tuple[int]`, *optional*, defaults to (0, 0)): | |
The targeted height, width crop coordinates. See the micro-conditioning section of | |
https://arxiv.org/abs/2307.01952. | |
temperature (`Union[int, Tuple[int, int], List[int]]`, *optional*, defaults to (2, 0)): | |
Configures the temperature scheduler on `self.scheduler` see `AmusedScheduler#set_timesteps`. | |
Examples: | |
Returns: | |
[`~pipelines.pipeline_utils.ImagePipelineOutput`] or `tuple`: | |
If `return_dict` is `True`, [`~pipelines.pipeline_utils.ImagePipelineOutput`] is returned, otherwise a | |
`tuple` is returned where the first element is a list with the generated images. | |
""" | |
if (prompt_embeds is not None and encoder_hidden_states is None) or ( | |
prompt_embeds is None and encoder_hidden_states is not None | |
): | |
raise ValueError("pass either both `prompt_embeds` and `encoder_hidden_states` or neither") | |
if (negative_prompt_embeds is not None and negative_encoder_hidden_states is None) or ( | |
negative_prompt_embeds is None and negative_encoder_hidden_states is not None | |
): | |
raise ValueError( | |
"pass either both `negatve_prompt_embeds` and `negative_encoder_hidden_states` or neither" | |
) | |
if (prompt is None and prompt_embeds is None) or (prompt is not None and prompt_embeds is not None): | |
raise ValueError("pass only one of `prompt` or `prompt_embeds`") | |
if isinstance(prompt, str): | |
prompt = [prompt] | |
if prompt is not None: | |
batch_size = len(prompt) | |
else: | |
batch_size = prompt_embeds.shape[0] | |
batch_size = batch_size * num_images_per_prompt | |
if prompt_embeds is None: | |
input_ids = self.tokenizer( | |
prompt, | |
return_tensors="pt", | |
padding="max_length", | |
truncation=True, | |
max_length=self.tokenizer.model_max_length, | |
).input_ids.to(self._execution_device) | |
outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True) | |
prompt_embeds = outputs.text_embeds | |
encoder_hidden_states = outputs.hidden_states[-2] | |
prompt_embeds = prompt_embeds.repeat(num_images_per_prompt, 1) | |
encoder_hidden_states = encoder_hidden_states.repeat(num_images_per_prompt, 1, 1) | |
if guidance_scale > 1.0: | |
if negative_prompt_embeds is None: | |
if negative_prompt is None: | |
negative_prompt = [""] * len(prompt) | |
if isinstance(negative_prompt, str): | |
negative_prompt = [negative_prompt] | |
input_ids = self.tokenizer( | |
negative_prompt, | |
return_tensors="pt", | |
padding="max_length", | |
truncation=True, | |
max_length=self.tokenizer.model_max_length, | |
).input_ids.to(self._execution_device) | |
outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True) | |
negative_prompt_embeds = outputs.text_embeds | |
negative_encoder_hidden_states = outputs.hidden_states[-2] | |
negative_prompt_embeds = negative_prompt_embeds.repeat(num_images_per_prompt, 1) | |
negative_encoder_hidden_states = negative_encoder_hidden_states.repeat(num_images_per_prompt, 1, 1) | |
prompt_embeds = torch.concat([negative_prompt_embeds, prompt_embeds]) | |
encoder_hidden_states = torch.concat([negative_encoder_hidden_states, encoder_hidden_states]) | |
image = self.image_processor.preprocess(image) | |
height, width = image.shape[-2:] | |
# Note that the micro conditionings _do_ flip the order of width, height for the original size | |
# and the crop coordinates. This is how it was done in the original code base | |
micro_conds = torch.tensor( | |
[ | |
width, | |
height, | |
micro_conditioning_crop_coord[0], | |
micro_conditioning_crop_coord[1], | |
micro_conditioning_aesthetic_score, | |
], | |
device=self._execution_device, | |
dtype=encoder_hidden_states.dtype, | |
) | |
micro_conds = micro_conds.unsqueeze(0) | |
micro_conds = micro_conds.expand(2 * batch_size if guidance_scale > 1.0 else batch_size, -1) | |
self.scheduler.set_timesteps(num_inference_steps, temperature, self._execution_device) | |
num_inference_steps = int(len(self.scheduler.timesteps) * strength) | |
start_timestep_idx = len(self.scheduler.timesteps) - num_inference_steps | |
needs_upcasting = self.vqvae.dtype == torch.float16 and self.vqvae.config.force_upcast | |
if needs_upcasting: | |
self.vqvae.float() | |
latents = self.vqvae.encode(image.to(dtype=self.vqvae.dtype, device=self._execution_device)).latents | |
latents_bsz, channels, latents_height, latents_width = latents.shape | |
latents = self.vqvae.quantize(latents)[2][2].reshape(latents_bsz, latents_height, latents_width) | |
mask = self.mask_processor.preprocess( | |
mask_image, height // self.vae_scale_factor, width // self.vae_scale_factor | |
) | |
mask = mask.reshape(mask.shape[0], latents_height, latents_width).bool().to(latents.device) | |
latents[mask] = self.scheduler.config.mask_token_id | |
starting_mask_ratio = mask.sum() / latents.numel() | |
latents = latents.repeat(num_images_per_prompt, 1, 1) | |
with self.progress_bar(total=num_inference_steps) as progress_bar: | |
for i in range(start_timestep_idx, len(self.scheduler.timesteps)): | |
timestep = self.scheduler.timesteps[i] | |
if guidance_scale > 1.0: | |
model_input = torch.cat([latents] * 2) | |
else: | |
model_input = latents | |
model_output = self.transformer( | |
model_input, | |
micro_conds=micro_conds, | |
pooled_text_emb=prompt_embeds, | |
encoder_hidden_states=encoder_hidden_states, | |
cross_attention_kwargs=cross_attention_kwargs, | |
) | |
if guidance_scale > 1.0: | |
uncond_logits, cond_logits = model_output.chunk(2) | |
model_output = uncond_logits + guidance_scale * (cond_logits - uncond_logits) | |
latents = self.scheduler.step( | |
model_output=model_output, | |
timestep=timestep, | |
sample=latents, | |
generator=generator, | |
starting_mask_ratio=starting_mask_ratio, | |
).prev_sample | |
if i == len(self.scheduler.timesteps) - 1 or ((i + 1) % self.scheduler.order == 0): | |
progress_bar.update() | |
if callback is not None and i % callback_steps == 0: | |
step_idx = i // getattr(self.scheduler, "order", 1) | |
callback(step_idx, timestep, latents) | |
if output_type == "latent": | |
output = latents | |
else: | |
output = self.vqvae.decode( | |
latents, | |
force_not_quantize=True, | |
shape=( | |
batch_size, | |
height // self.vae_scale_factor, | |
width // self.vae_scale_factor, | |
self.vqvae.config.latent_channels, | |
), | |
).sample.clip(0, 1) | |
output = self.image_processor.postprocess(output, output_type) | |
if needs_upcasting: | |
self.vqvae.half() | |
self.maybe_free_model_hooks() | |
if not return_dict: | |
return (output,) | |
return ImagePipelineOutput(output) | |