|
|
|
|
|
|
|
import inspect |
|
from typing import Callable, List, Optional, Union |
|
|
|
import numpy as np |
|
import PIL |
|
import torch |
|
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer |
|
|
|
try: |
|
from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE |
|
except ImportError: |
|
ORT_TO_NP_TYPE = { |
|
"tensor(bool)": np.bool_, |
|
"tensor(int8)": np.int8, |
|
"tensor(uint8)": np.uint8, |
|
"tensor(int16)": np.int16, |
|
"tensor(uint16)": np.uint16, |
|
"tensor(int32)": np.int32, |
|
"tensor(uint32)": np.uint32, |
|
"tensor(int64)": np.int64, |
|
"tensor(uint64)": np.uint64, |
|
"tensor(float16)": np.float16, |
|
"tensor(float)": np.float32, |
|
"tensor(double)": np.float64, |
|
} |
|
|
|
from diffusers import OnnxRuntimeModel, OnnxStableDiffusionPipeline, SchedulerMixin |
|
from diffusers.models import AutoencoderKL, UNet2DConditionModel |
|
from diffusers.schedulers import KarrasDiffusionSchedulers, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler |
|
from diffusers.utils import ( |
|
PIL_INTERPOLATION, |
|
deprecate, |
|
logging, |
|
randn_tensor, |
|
) |
|
from diffusers.pipeline_utils import DiffusionPipeline |
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
def preprocess(image): |
|
if isinstance(image, np.ndarray): |
|
return image |
|
|
|
w, h = image.size |
|
w, h = map(lambda x: x - x % 8, (w, h)) |
|
image = np.array(image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] |
|
image = np.array(image).astype(np.float32) / 255.0 |
|
image = image.transpose(0, 3, 1, 2) |
|
image = 2.0 * image - 1.0 |
|
return image |
|
|
|
|
|
class OnnxStableDiffusionInstructPix2PixPipeline(DiffusionPipeline): |
|
r""" |
|
Pipeline for pixel-level image editing by following text instructions. Based on Stable Diffusion. |
|
|
|
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the |
|
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) |
|
|
|
Args: |
|
vae ([`AutoencoderKL`]): |
|
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. |
|
text_encoder ([`CLIPTextModel`]): |
|
Frozen text-encoder. Stable Diffusion uses the text portion of |
|
[CLIP](https://huggingface.co./docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically |
|
the [clip-vit-large-patch14](https://huggingface.co./openai/clip-vit-large-patch14) variant. |
|
tokenizer (`CLIPTokenizer`): |
|
Tokenizer of class |
|
[CLIPTokenizer](https://huggingface.co./docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). |
|
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. |
|
scheduler ([`SchedulerMixin`]): |
|
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of |
|
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. |
|
safety_checker ([`StableDiffusionSafetyChecker`]): |
|
Classification module that estimates whether generated images could be considered offensive or harmful. |
|
Please, refer to the [model card](https://huggingface.co./runwayml/stable-diffusion-v1-5) for details. |
|
feature_extractor ([`CLIPFeatureExtractor`]): |
|
Model that extracts features from generated images to be used as inputs for the `safety_checker`. |
|
""" |
|
vae_encoder: OnnxRuntimeModel |
|
vae_decoder: OnnxRuntimeModel |
|
text_encoder: OnnxRuntimeModel |
|
tokenizer: CLIPTokenizer |
|
unet: OnnxRuntimeModel |
|
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] |
|
safety_checker: OnnxRuntimeModel |
|
feature_extractor: CLIPFeatureExtractor |
|
_optional_components = ["safety_checker", "feature_extractor"] |
|
|
|
def __init__( |
|
self, |
|
vae_encoder: OnnxRuntimeModel, |
|
vae_decoder: OnnxRuntimeModel, |
|
text_encoder: OnnxRuntimeModel, |
|
tokenizer: CLIPTokenizer, |
|
unet: OnnxRuntimeModel, |
|
scheduler: KarrasDiffusionSchedulers, |
|
safety_checker: OnnxRuntimeModel, |
|
feature_extractor: CLIPFeatureExtractor, |
|
requires_safety_checker: bool = True, |
|
): |
|
super().__init__() |
|
self.unet_in_channels = 8 |
|
self.vae_scale_factor = 8 |
|
|
|
if safety_checker is None and requires_safety_checker: |
|
logger.warning( |
|
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" |
|
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" |
|
" results in services or applications open to the public. Both the diffusers team and Hugging Face" |
|
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" |
|
" it only for use-cases that involve analyzing network behavior or auditing its results. For more" |
|
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." |
|
) |
|
|
|
if safety_checker is not None and feature_extractor is None: |
|
raise ValueError( |
|
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" |
|
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." |
|
) |
|
|
|
self.register_modules( |
|
vae_encoder=vae_encoder, |
|
vae_decoder=vae_decoder, |
|
text_encoder=text_encoder, |
|
tokenizer=tokenizer, |
|
unet=unet, |
|
scheduler=scheduler, |
|
safety_checker=safety_checker, |
|
feature_extractor=feature_extractor, |
|
) |
|
|
|
self.register_to_config(requires_safety_checker=requires_safety_checker) |
|
|
|
@torch.no_grad() |
|
def __call__( |
|
self, |
|
prompt: Union[str, List[str]] = None, |
|
image: Union[np.ndarray, PIL.Image.Image] = None, |
|
num_inference_steps: int = 100, |
|
guidance_scale: float = 7.5, |
|
image_guidance_scale: float = 1.5, |
|
negative_prompt: Optional[Union[str, List[str]]] = None, |
|
num_images_per_prompt: Optional[int] = 1, |
|
eta: float = 0.0, |
|
generator: Optional[np.random.RandomState] = None, |
|
latents: Optional[np.ndarray] = None, |
|
output_type: Optional[str] = "pil", |
|
return_dict: bool = True, |
|
callback: Optional[Callable[[int, int, np.ndarray], None]] = None, |
|
callback_steps: int = 1, |
|
): |
|
r""" |
|
Function invoked when calling the pipeline for generation. |
|
|
|
Args: |
|
prompt (`str` or `List[str]`, *optional*): |
|
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. |
|
instead. |
|
image (`PIL.Image.Image`): |
|
`Image`, or tensor representing an image batch which will be repainted according to `prompt`. |
|
num_inference_steps (`int`, *optional*, defaults to 100): |
|
The number of denoising steps. More denoising steps usually lead to a higher quality image at the |
|
expense of slower inference. |
|
guidance_scale (`float`, *optional*, defaults to 7.5): |
|
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). |
|
`guidance_scale` is defined as `w` of equation 2. of [Imagen |
|
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > |
|
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, |
|
usually at the expense of lower image quality. This pipeline requires a value of at least `1`. |
|
image_guidance_scale (`float`, *optional*, defaults to 1.5): |
|
Image guidance scale is to push the generated image towards the inital image `image`. Image guidance |
|
scale is enabled by setting `image_guidance_scale > 1`. Higher image guidance scale encourages to |
|
generate images that are closely linked to the source image `image`, usually at the expense of lower |
|
image quality. This pipeline requires a value of at least `1`. |
|
negative_prompt (`str` or `List[str]`, *optional*): |
|
The prompt or prompts not to guide the image generation. If not defined, one has to pass |
|
`negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` |
|
is less than `1`). |
|
num_images_per_prompt (`int`, *optional*, defaults to 1): |
|
The number of images to generate per prompt. |
|
eta (`float`, *optional*, defaults to 0.0): |
|
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to |
|
[`schedulers.DDIMScheduler`], will be ignored for others. |
|
generator (`torch.Generator`, *optional*): |
|
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) |
|
to make generation deterministic. |
|
latents (`torch.FloatTensor`, *optional*): |
|
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image |
|
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents |
|
tensor will ge generated by sampling using the supplied random `generator`. |
|
prompt_embeds (`torch.FloatTensor`, *optional*): |
|
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not |
|
provided, text embeddings will be generated from `prompt` input argument. |
|
negative_prompt_embeds (`torch.FloatTensor`, *optional*): |
|
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt |
|
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input |
|
argument. |
|
output_type (`str`, *optional*, defaults to `"pil"`): |
|
The output format of the generate image. Choose between |
|
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. |
|
return_dict (`bool`, *optional*, defaults to `True`): |
|
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a |
|
plain tuple. |
|
callback (`Callable`, *optional*): |
|
A function that will be called every `callback_steps` steps during inference. The function will be |
|
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. |
|
callback_steps (`int`, *optional*, defaults to 1): |
|
The frequency at which the `callback` function will be called. If not specified, the callback will be |
|
called at every step. |
|
|
|
Examples: |
|
|
|
```py |
|
>>> import PIL |
|
>>> import requests |
|
>>> import torch |
|
>>> from io import BytesIO |
|
|
|
>>> from diffusers import StableDiffusionInstructPix2PixPipeline |
|
|
|
|
|
>>> def download_image(url): |
|
... response = requests.get(url) |
|
... return PIL.Image.open(BytesIO(response.content)).convert("RGB") |
|
|
|
|
|
>>> img_url = "https://huggingface.co./datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png" |
|
|
|
>>> image = download_image(img_url).resize((512, 512)) |
|
|
|
>>> pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained( |
|
... "timbrooks/instruct-pix2pix", torch_dtype=torch.float16 |
|
... ) |
|
>>> pipe = pipe.to("cuda") |
|
|
|
>>> prompt = "make the mountains snowy" |
|
>>> image = pipe(prompt=prompt, image=image).images[0] |
|
``` |
|
|
|
Returns: |
|
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: |
|
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. |
|
When returning a tuple, the first element is a list with the generated images, and the second element is a |
|
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" |
|
(nsfw) content, according to the `safety_checker`. |
|
""" |
|
|
|
|
|
|
|
if generator: |
|
torch_seed = generator.randint(2147483647) |
|
torch_gen = torch.Generator().manual_seed(torch_seed) |
|
else: |
|
generator = np.random |
|
torch_gen = None |
|
|
|
|
|
self.check_inputs(prompt, callback_steps) |
|
|
|
if image is None: |
|
raise ValueError("`image` input cannot be undefined.") |
|
|
|
|
|
if isinstance(prompt, str): |
|
batch_size = 1 |
|
elif isinstance(prompt, list): |
|
batch_size = len(prompt) |
|
else: |
|
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") |
|
|
|
|
|
|
|
|
|
do_classifier_free_guidance = guidance_scale > 1.0 and image_guidance_scale >= 1.0 |
|
|
|
scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas") |
|
|
|
|
|
prompt_embeds = self._encode_prompt( |
|
prompt, |
|
num_images_per_prompt, |
|
do_classifier_free_guidance, |
|
negative_prompt, |
|
) |
|
|
|
|
|
image = preprocess(image) |
|
height, width = image.shape[-2:] |
|
|
|
|
|
self.scheduler.set_timesteps(num_inference_steps) |
|
timesteps = self.scheduler.timesteps |
|
|
|
|
|
latents_dtype = prompt_embeds.dtype |
|
image = image.astype(latents_dtype) |
|
|
|
image_latents = self.vae_encoder(sample=image)[0] |
|
if do_classifier_free_guidance: |
|
uncond_image_latents = np.zeros_like(image_latents) |
|
image_latents = np.concatenate((image_latents, image_latents, uncond_image_latents), axis=0) |
|
|
|
|
|
latents_dtype = prompt_embeds.dtype |
|
latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8) |
|
if latents is None: |
|
latents = generator.randn(*latents_shape).astype(latents_dtype) |
|
elif latents.shape != latents_shape: |
|
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") |
|
latents = latents * self.scheduler.init_noise_sigma.numpy() |
|
|
|
|
|
num_channels_image = image_latents.shape[1] |
|
if 4+ num_channels_image != self.unet_in_channels: |
|
raise ValueError( |
|
f"Incorrect configuration settings! The config of `pipeline.unet`: expects" |
|
f" {self.unet_in_channels} but received `num_channels_latents`: {num_channels_latents} +" |
|
f" `num_channels_image`: {num_channels_image} " |
|
f" = {num_channels_latents+num_channels_image}. Please verify the config of" |
|
" `pipeline.unet` or your `image` input." |
|
) |
|
|
|
timestep_dtype = next( |
|
(input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)" |
|
) |
|
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] |
|
|
|
|
|
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta, torch_gen) |
|
|
|
|
|
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order |
|
with self.progress_bar(total=num_inference_steps) as progress_bar: |
|
for i, t in enumerate(timesteps): |
|
|
|
|
|
|
|
latent_model_input = np.concatenate([latents] * 3) if do_classifier_free_guidance else latents |
|
|
|
scaled_latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t) |
|
scaled_latent_model_input = scaled_latent_model_input.cpu().numpy() |
|
|
|
scaled_latent_model_input = np.concatenate([scaled_latent_model_input, image_latents], axis=1) |
|
|
|
|
|
|
|
noise_pred = self.unet( |
|
sample=scaled_latent_model_input, |
|
timestep=np.array([t], dtype=timestep_dtype), |
|
encoder_hidden_states=prompt_embeds, |
|
)[0] |
|
|
|
|
|
|
|
|
|
|
|
if scheduler_is_in_sigma_space: |
|
step_index = (self.scheduler.timesteps == t).nonzero().item() |
|
sigma = self.scheduler.sigmas[step_index] |
|
noise_pred = latent_model_input - sigma.numpy() * noise_pred |
|
|
|
|
|
if do_classifier_free_guidance: |
|
noise_pred_text, noise_pred_image, noise_pred_uncond = np.split(noise_pred, 3) |
|
noise_pred = ( |
|
noise_pred_uncond |
|
+ guidance_scale * (noise_pred_text - noise_pred_image) |
|
+ image_guidance_scale * (noise_pred_image - noise_pred_uncond) |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if scheduler_is_in_sigma_space: |
|
noise_pred = (noise_pred - latents) / (-sigma) |
|
|
|
|
|
scheduler_output = self.scheduler.step( |
|
noise_pred, t, torch.from_numpy(latents), **extra_step_kwargs |
|
) |
|
latents = scheduler_output.prev_sample.numpy() |
|
|
|
|
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): |
|
progress_bar.update() |
|
if callback is not None and i % callback_steps == 0: |
|
callback(i, t, latents.numpy()) |
|
|
|
|
|
image = self.decode_latents(latents) |
|
|
|
|
|
image, has_nsfw_concept = self.run_safety_checker(image) |
|
|
|
|
|
if output_type == "pil": |
|
image = self.numpy_to_pil(image) |
|
|
|
if not return_dict: |
|
return (image, has_nsfw_concept) |
|
|
|
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) |
|
|
|
def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): |
|
r""" |
|
Encodes the prompt into text encoder hidden states. |
|
|
|
Args: |
|
prompt (`str` or `List[str]`): |
|
prompt to be encoded |
|
num_images_per_prompt (`int`): |
|
number of images that should be generated per prompt |
|
do_classifier_free_guidance (`bool`): |
|
whether to use classifier free guidance or not |
|
negative_prompt (`str` or `List[str]`): |
|
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored |
|
if `guidance_scale` is less than `1`). |
|
""" |
|
negative_prompt_embeds = None |
|
batch_size = len(prompt) if isinstance(prompt, list) else 1 |
|
|
|
|
|
text_inputs = self.tokenizer( |
|
prompt, |
|
padding="max_length", |
|
max_length=self.tokenizer.model_max_length, |
|
truncation=True, |
|
return_tensors="np", |
|
) |
|
text_input_ids = text_inputs.input_ids |
|
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids |
|
|
|
if not np.array_equal(text_input_ids, untruncated_ids): |
|
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) |
|
logger.warning( |
|
"The following part of your input was truncated because CLIP can only handle sequences up to" |
|
f" {self.tokenizer.model_max_length} tokens: {removed_text}" |
|
) |
|
|
|
prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] |
|
prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) |
|
|
|
|
|
if do_classifier_free_guidance: |
|
uncond_tokens: List[str] |
|
if negative_prompt is None: |
|
uncond_tokens = [""] * batch_size |
|
elif type(prompt) is not type(negative_prompt): |
|
raise TypeError( |
|
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" |
|
f" {type(prompt)}." |
|
) |
|
elif isinstance(negative_prompt, str): |
|
uncond_tokens = [negative_prompt] * batch_size |
|
elif batch_size != len(negative_prompt): |
|
raise ValueError( |
|
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" |
|
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" |
|
" the batch size of `prompt`." |
|
) |
|
else: |
|
uncond_tokens = negative_prompt |
|
|
|
max_length = text_input_ids.shape[-1] |
|
uncond_input = self.tokenizer( |
|
uncond_tokens, |
|
padding="max_length", |
|
max_length=max_length, |
|
truncation=True, |
|
return_tensors="np", |
|
) |
|
negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] |
|
negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
prompt_embeds = np.concatenate((prompt_embeds, negative_prompt_embeds, negative_prompt_embeds)) |
|
|
|
return prompt_embeds |
|
|
|
|
|
def run_safety_checker(self, image): |
|
if self.safety_checker is not None: |
|
safety_checker_input = self.feature_extractor( |
|
self.numpy_to_pil(image), return_tensors="np" |
|
).pixel_values.astype(image.dtype) |
|
|
|
images, has_nsfw_concept = [], [] |
|
for i in range(image.shape[0]): |
|
image_i, has_nsfw_concept_i = self.safety_checker( |
|
clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1] |
|
) |
|
images.append(image_i) |
|
has_nsfw_concept.append(has_nsfw_concept_i[0]) |
|
image = np.concatenate(images) |
|
else: |
|
has_nsfw_concept = None |
|
return image, has_nsfw_concept |
|
|
|
|
|
def prepare_extra_step_kwargs(self, generator, eta, torch_gen): |
|
|
|
|
|
|
|
|
|
|
|
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) |
|
extra_step_kwargs = {} |
|
if accepts_eta: |
|
extra_step_kwargs["eta"] = eta |
|
|
|
|
|
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) |
|
if accepts_generator: |
|
extra_step_kwargs["generator"] = torch_gen |
|
return extra_step_kwargs |
|
|
|
|
|
def decode_latents(self, latents): |
|
latents = 1 / 0.18215 * latents |
|
image = np.concatenate( |
|
[self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])] |
|
) |
|
image = np.clip(image / 2 + 0.5, 0, 1) |
|
image = image.transpose((0, 2, 3, 1)) |
|
return image |
|
|
|
def check_inputs(self, prompt, callback_steps): |
|
if not isinstance(prompt, str) and not isinstance(prompt, list): |
|
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") |
|
|
|
if (callback_steps is None) or ( |
|
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) |
|
): |
|
raise ValueError( |
|
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" |
|
f" {type(callback_steps)}." |
|
) |
|
|
|
|