import glob import os import time from typing import List, Optional, Union, Any, Dict, Tuple, Literal from collections import deque import numpy as np import PIL.Image import torch import torch.nn.functional as F from torchvision.models.optical_flow import raft_small from diffusers import LCMScheduler, StableDiffusionPipeline from diffusers.image_processor import VaeImageProcessor from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import ( retrieve_latents, ) from .image_utils import postprocess_image, forward_backward_consistency_check from .models.utils import get_nn_latent from .image_filter import SimilarImageFilter class StreamV2V: def __init__( self, pipe: StableDiffusionPipeline, t_index_list: List[int], torch_dtype: torch.dtype = torch.float16, width: int = 512, height: int = 512, do_add_noise: bool = True, use_denoising_batch: bool = True, frame_buffer_size: int = 1, cfg_type: Literal["none", "full", "self", "initialize"] = "self", ) -> None: self.device = pipe.device self.dtype = torch_dtype self.generator = None self.height = height self.width = width self.latent_height = int(height // pipe.vae_scale_factor) self.latent_width = int(width // pipe.vae_scale_factor) self.frame_bff_size = frame_buffer_size self.denoising_steps_num = len(t_index_list) self.cfg_type = cfg_type if use_denoising_batch: self.batch_size = self.denoising_steps_num * frame_buffer_size if self.cfg_type == "initialize": self.trt_unet_batch_size = ( self.denoising_steps_num + 1 ) * self.frame_bff_size elif self.cfg_type == "full": self.trt_unet_batch_size = ( 2 * self.denoising_steps_num * self.frame_bff_size ) else: self.trt_unet_batch_size = self.denoising_steps_num * frame_buffer_size else: self.trt_unet_batch_size = self.frame_bff_size self.batch_size = frame_buffer_size self.t_list = t_index_list self.do_add_noise = do_add_noise self.use_denoising_batch = use_denoising_batch self.similar_image_filter = False self.similar_filter = SimilarImageFilter() self.prev_image_tensor = None self.prev_x_t_latent = None self.prev_image_result = None self.pipe = pipe self.image_processor = VaeImageProcessor(pipe.vae_scale_factor) self.scheduler = LCMScheduler.from_config(self.pipe.scheduler.config) self.text_encoder = pipe.text_encoder self.unet = pipe.unet self.vae = pipe.vae self.flow_model = raft_small(pretrained=True, progress=False).to(device=pipe.device).eval() self.cached_x_t_latent = deque(maxlen=4) self.inference_time_ema = 0 def load_lcm_lora( self, pretrained_model_name_or_path_or_dict: Union[ str, Dict[str, torch.Tensor] ] = "latent-consistency/lcm-lora-sdv1-5", adapter_name: Optional[Any] = 'lcm', **kwargs, ) -> None: self.pipe.load_lora_weights( pretrained_model_name_or_path_or_dict, adapter_name, **kwargs ) def load_lora( self, pretrained_lora_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name: Optional[Any] = None, **kwargs, ) -> None: self.pipe.load_lora_weights( pretrained_lora_model_name_or_path_or_dict, adapter_name, **kwargs ) def fuse_lora( self, fuse_unet: bool = True, fuse_text_encoder: bool = True, lora_scale: float = 1.0, safe_fusing: bool = False, ) -> None: self.pipe.fuse_lora( fuse_unet=fuse_unet, fuse_text_encoder=fuse_text_encoder, lora_scale=lora_scale, safe_fusing=safe_fusing, ) def enable_similar_image_filter(self, threshold: float = 0.98, max_skip_frame: float = 10) -> None: self.similar_image_filter = True self.similar_filter.set_threshold(threshold) self.similar_filter.set_max_skip_frame(max_skip_frame) def disable_similar_image_filter(self) -> None: self.similar_image_filter = False @torch.no_grad() def prepare( self, prompt: str, negative_prompt: str = "", num_inference_steps: int = 50, guidance_scale: float = 1.2, delta: float = 1.0, generator: Optional[torch.Generator] = torch.Generator(), seed: int = 2, ) -> None: self.generator = generator self.generator.manual_seed(seed) # initialize x_t_latent (it can be any random tensor) if self.denoising_steps_num > 1: self.x_t_latent_buffer = torch.zeros( ( (self.denoising_steps_num - 1) * self.frame_bff_size, 4, self.latent_height, self.latent_width, ), dtype=self.dtype, device=self.device, ) else: self.x_t_latent_buffer = None if self.cfg_type == "none": self.guidance_scale = 1.0 else: self.guidance_scale = guidance_scale self.delta = delta do_classifier_free_guidance = False if self.guidance_scale > 1.0: do_classifier_free_guidance = True encoder_output = self.pipe.encode_prompt( prompt=prompt, device=self.device, num_images_per_prompt=1, do_classifier_free_guidance=True, negative_prompt=negative_prompt, ) self.prompt_embeds = encoder_output[0].repeat(self.batch_size, 1, 1) self.null_prompt_embeds = encoder_output[1] if self.use_denoising_batch and self.cfg_type == "full": uncond_prompt_embeds = encoder_output[1].repeat(self.batch_size, 1, 1) elif self.cfg_type == "initialize": uncond_prompt_embeds = encoder_output[1].repeat(self.frame_bff_size, 1, 1) if self.guidance_scale > 1.0 and ( self.cfg_type == "initialize" or self.cfg_type == "full" ): self.prompt_embeds = torch.cat( [uncond_prompt_embeds, self.prompt_embeds], dim=0 ) self.scheduler.set_timesteps(num_inference_steps, self.device) self.timesteps = self.scheduler.timesteps.to(self.device) # make sub timesteps list based on the indices in the t_list list and the values in the timesteps list self.sub_timesteps = [] for t in self.t_list: self.sub_timesteps.append(self.timesteps[t]) sub_timesteps_tensor = torch.tensor( self.sub_timesteps, dtype=torch.long, device=self.device ) self.sub_timesteps_tensor = torch.repeat_interleave( sub_timesteps_tensor, repeats=self.frame_bff_size if self.use_denoising_batch else 1, dim=0, ) self.init_noise = torch.randn( (self.batch_size, 4, self.latent_height, self.latent_width), generator=generator, ).to(device=self.device, dtype=self.dtype) self.randn_noise = self.init_noise[:1].clone() self.warp_noise = self.init_noise[:1].clone() self.stock_noise = torch.zeros_like(self.init_noise) c_skip_list = [] c_out_list = [] for timestep in self.sub_timesteps: c_skip, c_out = self.scheduler.get_scalings_for_boundary_condition_discrete( timestep ) c_skip_list.append(c_skip) c_out_list.append(c_out) self.c_skip = ( torch.stack(c_skip_list) .view(len(self.t_list), 1, 1, 1) .to(dtype=self.dtype, device=self.device) ) self.c_out = ( torch.stack(c_out_list) .view(len(self.t_list), 1, 1, 1) .to(dtype=self.dtype, device=self.device) ) alpha_prod_t_sqrt_list = [] beta_prod_t_sqrt_list = [] for timestep in self.sub_timesteps: alpha_prod_t_sqrt = self.scheduler.alphas_cumprod[timestep].sqrt() beta_prod_t_sqrt = (1 - self.scheduler.alphas_cumprod[timestep]).sqrt() alpha_prod_t_sqrt_list.append(alpha_prod_t_sqrt) beta_prod_t_sqrt_list.append(beta_prod_t_sqrt) alpha_prod_t_sqrt = ( torch.stack(alpha_prod_t_sqrt_list) .view(len(self.t_list), 1, 1, 1) .to(dtype=self.dtype, device=self.device) ) beta_prod_t_sqrt = ( torch.stack(beta_prod_t_sqrt_list) .view(len(self.t_list), 1, 1, 1) .to(dtype=self.dtype, device=self.device) ) self.alpha_prod_t_sqrt = torch.repeat_interleave( alpha_prod_t_sqrt, repeats=self.frame_bff_size if self.use_denoising_batch else 1, dim=0, ) self.beta_prod_t_sqrt = torch.repeat_interleave( beta_prod_t_sqrt, repeats=self.frame_bff_size if self.use_denoising_batch else 1, dim=0, ) @torch.no_grad() def update_prompt(self, prompt: str) -> None: encoder_output = self.pipe.encode_prompt( prompt=prompt, device=self.device, num_images_per_prompt=1, do_classifier_free_guidance=False, ) self.prompt_embeds = encoder_output[0].repeat(self.batch_size, 1, 1) def add_noise( self, original_samples: torch.Tensor, noise: torch.Tensor, t_index: int, ) -> torch.Tensor: noisy_samples = ( self.alpha_prod_t_sqrt[t_index] * original_samples + self.beta_prod_t_sqrt[t_index] * noise ) return noisy_samples def scheduler_step_batch( self, model_pred_batch: torch.Tensor, x_t_latent_batch: torch.Tensor, idx: Optional[int] = None, ) -> torch.Tensor: # TODO: use t_list to select beta_prod_t_sqrt if idx is None: F_theta = ( x_t_latent_batch - self.beta_prod_t_sqrt * model_pred_batch ) / self.alpha_prod_t_sqrt denoised_batch = self.c_out * F_theta + self.c_skip * x_t_latent_batch else: F_theta = ( x_t_latent_batch - self.beta_prod_t_sqrt[idx] * model_pred_batch ) / self.alpha_prod_t_sqrt[idx] denoised_batch = ( self.c_out[idx] * F_theta + self.c_skip[idx] * x_t_latent_batch ) return denoised_batch def unet_step( self, x_t_latent: torch.Tensor, t_list: Union[torch.Tensor, list[int]], idx: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: if self.guidance_scale > 1.0 and (self.cfg_type == "initialize"): x_t_latent_plus_uc = torch.concat([x_t_latent[0:1], x_t_latent], dim=0) t_list = torch.concat([t_list[0:1], t_list], dim=0) elif self.guidance_scale > 1.0 and (self.cfg_type == "full"): x_t_latent_plus_uc = torch.concat([x_t_latent, x_t_latent], dim=0) t_list = torch.concat([t_list, t_list], dim=0) else: x_t_latent_plus_uc = x_t_latent model_pred = self.unet( x_t_latent_plus_uc, t_list, encoder_hidden_states=self.prompt_embeds, return_dict=False, )[0] if self.guidance_scale > 1.0 and (self.cfg_type == "initialize"): noise_pred_text = model_pred[1:] self.stock_noise = torch.concat( [model_pred[0:1], self.stock_noise[1:]], dim=0 ) # ここコメントアウトでself out cfg elif self.guidance_scale > 1.0 and (self.cfg_type == "full"): noise_pred_uncond, noise_pred_text = model_pred.chunk(2) else: noise_pred_text = model_pred if self.guidance_scale > 1.0 and ( self.cfg_type == "self" or self.cfg_type == "initialize" ): noise_pred_uncond = self.stock_noise * self.delta if self.guidance_scale > 1.0 and self.cfg_type != "none": model_pred = noise_pred_uncond + self.guidance_scale * ( noise_pred_text - noise_pred_uncond ) else: model_pred = noise_pred_text # compute the previous noisy sample x_t -> x_t-1 if self.use_denoising_batch: denoised_batch = self.scheduler_step_batch(model_pred, x_t_latent, idx) if self.cfg_type == "self" or self.cfg_type == "initialize": scaled_noise = self.beta_prod_t_sqrt * self.stock_noise delta_x = self.scheduler_step_batch(model_pred, scaled_noise, idx) alpha_next = torch.concat( [ self.alpha_prod_t_sqrt[1:], torch.ones_like(self.alpha_prod_t_sqrt[0:1]), ], dim=0, ) delta_x = alpha_next * delta_x beta_next = torch.concat( [ self.beta_prod_t_sqrt[1:], torch.ones_like(self.beta_prod_t_sqrt[0:1]), ], dim=0, ) delta_x = delta_x / beta_next init_noise = torch.concat( [self.init_noise[1:], self.init_noise[0:1]], dim=0 ) self.stock_noise = init_noise + delta_x else: # denoised_batch = self.scheduler.step(model_pred, t_list[0], x_t_latent).denoised denoised_batch = self.scheduler_step_batch(model_pred, x_t_latent, idx) return denoised_batch, model_pred def norm_noise(self, noise): # Compute mean and std of blended_noise mean = noise.mean() std = noise.std() # Normalize blended_noise to have mean=0 and std=1 normalized_noise = (noise - mean) / std return normalized_noise def encode_image(self, image_tensors: torch.Tensor) -> torch.Tensor: image_tensors = image_tensors.to( device=self.device, dtype=self.vae.dtype, ) img_latent = retrieve_latents(self.vae.encode(image_tensors), self.generator) img_latent = img_latent * self.vae.config.scaling_factor x_t_latent = self.add_noise(img_latent, self.init_noise[0], 0) return x_t_latent def decode_image(self, x_0_pred_out: torch.Tensor) -> torch.Tensor: output_latent = self.vae.decode( x_0_pred_out / self.vae.config.scaling_factor, return_dict=False )[0] return output_latent def predict_x0_batch(self, x_t_latent: torch.Tensor) -> torch.Tensor: prev_latent_batch = self.x_t_latent_buffer if self.use_denoising_batch: t_list = self.sub_timesteps_tensor if self.denoising_steps_num > 1: x_t_latent = torch.cat((x_t_latent, prev_latent_batch), dim=0) self.stock_noise = torch.cat( (self.init_noise[0:1], self.stock_noise[:-1]), dim=0 ) x_0_pred_batch, model_pred = self.unet_step(x_t_latent, t_list) if self.denoising_steps_num > 1: x_0_pred_out = x_0_pred_batch[-1].unsqueeze(0) if self.do_add_noise: self.x_t_latent_buffer = ( self.alpha_prod_t_sqrt[1:] * x_0_pred_batch[:-1] + self.beta_prod_t_sqrt[1:] * self.init_noise[1:] ) else: self.x_t_latent_buffer = ( self.alpha_prod_t_sqrt[1:] * x_0_pred_batch[:-1] ) else: x_0_pred_out = x_0_pred_batch self.x_t_latent_buffer = None else: self.init_noise = x_t_latent for idx, t in enumerate(self.sub_timesteps_tensor): t = t.view( 1, ).repeat( self.frame_bff_size, ) x_0_pred, model_pred = self.unet_step(x_t_latent, t, idx) if idx < len(self.sub_timesteps_tensor) - 1: if self.do_add_noise: x_t_latent = self.alpha_prod_t_sqrt[ idx + 1 ] * x_0_pred + self.beta_prod_t_sqrt[ idx + 1 ] * torch.randn_like( x_0_pred, device=self.device, dtype=self.dtype ) else: x_t_latent = self.alpha_prod_t_sqrt[idx + 1] * x_0_pred x_0_pred_out = x_0_pred return x_0_pred_out @torch.no_grad() def __call__( self, x: Union[torch.Tensor, PIL.Image.Image, np.ndarray] = None ) -> torch.Tensor: start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() if x is not None: x = self.image_processor.preprocess(x, self.height, self.width).to( device=self.device, dtype=self.dtype ) if self.similar_image_filter: x = self.similar_filter(x) if x is None: time.sleep(self.inference_time_ema) return self.prev_image_result x_t_latent = self.encode_image(x) else: # TODO: check the dimension of x_t_latent x_t_latent = torch.randn((1, 4, self.latent_height, self.latent_width)).to( device=self.device, dtype=self.dtype ) x_0_pred_out = self.predict_x0_batch(x_t_latent) x_output = self.decode_image(x_0_pred_out).detach().clone() self.prev_image_result = x_output end.record() torch.cuda.synchronize() inference_time = start.elapsed_time(end) / 1000 self.inference_time_ema = 0.9 * self.inference_time_ema + 0.1 * inference_time return x_output