diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..556a7df5ea13836d87176ea7d2aa0f8bd875b6e0 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +*.png filter=lfs diff=lfs merge=lfs -text +*.jpg filter=lfs diff=lfs merge=lfs -text +*.jpeg filter=lfs diff=lfs merge=lfs -text diff --git a/README.md b/README.md index ba556240496eb9363142f2a31964f4a1f576a023..d974641ea5b1873beb9440ff29bfdf4d9e9d6a2e 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,17 @@ --- title: StreamingSVD -emoji: 🌍 -colorFrom: pink -colorTo: gray +emoji: 🎥 +colorFrom: yellow +colorTo: green sdk: gradio sdk_version: 4.43.0 +suggested_hardware: a100-large +suggested_storage: large app_file: app.py -pinned: false license: mit ---- - -Check out the configuration reference at https://huggingface.co./docs/hub/spaces-config-reference +tags: + - StreamingSVD + - long-video-generation + - PAIR +short_description: Image-to-Video +disable_embedding: false \ No newline at end of file diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/config.yaml b/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f457a8f9b99269077b792696dd03490d1e320e5e --- /dev/null +++ b/config.yaml @@ -0,0 +1,316 @@ +# pytorch_lightning==2.2.2 +seed_everything: 33 +trainer: + accelerator: auto + strategy: auto + devices: '1' + num_nodes: 1 + precision: 16-mixed + logger: False +model: + class_path: diffusion_trainer.streaming_svd.StreamingSVD + init_args: + vfi: + class_path: modules.params.vfi.VFIParams + init_args: + ckpt_path_local: checkpoint/VFI/ours.pkl + ckpt_path_global: https://drive.google.com/file/d/1XCNoyhA1RX3m8W-XJK8H8inH47l36kxP/view?usp=sharing + i2v_enhance: + class_path: modules.params.i2v_enhance.I2VEnhanceParams + init_args: + ckpt_path_local: checkpoint/i2v_enhance/ + ckpt_path_global: ali-vilab/i2vgen-xl + module_loader: + class_path: modules.loader.module_loader.GenericModuleLoader + init_args: + pipeline_repo: stabilityai/stable-video-diffusion-img2vid-xt + pipeline_obj: streamingt2v_pipeline + set_prediction_type: '' + module_names: + - network_config + - model + - controlnet + - denoiser + - conditioner + - first_stage_model + - sampler + - svd_pipeline + module_config: + controlnet: + class_path: modules.loader.module_loader_config.ModuleLoaderConfig + init_args: + loader_cls_path: models.control.controlnet.ControlNet + cls_func: from_unet + cls_func_fast_dev_run: '' + kwargs_diffusers: null + model_params: + merging_mode: addition + zero_conv_mode: Identity + frame_expansion: none + downsample_controlnet_cond: true + use_image_encoder_normalization: true + use_controlnet_mask: false + condition_encoder: '' + conditioning_embedding_out_channels: + - 32 + - 96 + - 256 + - 512 + kwargs_diff_trainer_params: null + args: [] + dependent_modules: + model: model + dependent_modules_cloned: null + state_dict_path: '' + strict_loading: true + state_dict_filters: [] + network_config: + class_path: models.diffusion.video_model.VideoUNet + init_args: + in_channels: 8 + model_channels: 320 + out_channels: 4 + num_res_blocks: 2 + num_conditional_frames: null + attention_resolutions: + - 4 + - 2 + - 1 + dropout: 0.0 + channel_mult: + - 1 + - 2 + - 4 + - 4 + conv_resample: true + dims: 2 + num_classes: sequential + use_checkpoint: False + num_heads: -1 + num_head_channels: 64 + num_heads_upsample: -1 + use_scale_shift_norm: false + resblock_updown: false + transformer_depth: 1 + transformer_depth_middle: null + context_dim: 1024 + time_downup: false + time_context_dim: null + extra_ff_mix_layer: true + use_spatial_context: true + merge_strategy: learned_with_images + merge_factor: 0.5 + spatial_transformer_attn_type: softmax-xformers + video_kernel_size: + - 3 + - 1 + - 1 + use_linear_in_transformer: true + adm_in_channels: 768 + disable_temporal_crossattention: false + max_ddpm_temb_period: 10000 + merging_mode: attention_cross_attention + controlnet_mode: true + use_apm: false + model: + class_path: modules.loader.module_loader_config.ModuleLoaderConfig + init_args: + loader_cls_path: models.svd.sgm.modules.diffusionmodules.wrappers.OpenAIWrapper + cls_func: '' + cls_func_fast_dev_run: '' + kwargs_diffusers: + compile_model: false + model_params: null + model_params_fast_dev_run: null + kwargs_diff_trainer_params: null + args: [] + dependent_modules: + diffusion_model: network_config + dependent_modules_cloned: null + state_dict_path: '' + strict_loading: true + state_dict_filters: [] + denoiser: + class_path: models.svd.sgm.modules.diffusionmodules.denoiser.Denoiser + init_args: + scaling_config: + target: models.svd.sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise + sampler: + class_path: models.svd.sgm.modules.diffusionmodules.sampling.EulerEDMSampler + init_args: + s_churn: 0.0 + s_tmin: 0.0 + s_tmax: .inf + s_noise: 1.0 + discretization_config: + target: models.diffusion.discretizer.AlignYourSteps + params: + sigma_max: 700.0 + num_steps: 30 + guider_config: + target: models.svd.sgm.modules.diffusionmodules.guiders.LinearPredictionGuider + params: + max_scale: 3.0 + min_scale: 1.5 + num_frames: 25 + verbose: false + device: cuda + conditioner: + class_path: models.svd.sgm.modules.GeneralConditioner + init_args: + emb_models: + - is_trainable: false + input_key: cond_frames_without_noise + target: models.svd.sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder + params: + n_cond_frames: 1 + n_copies: 1 + open_clip_embedding_config: + target: models.svd.sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder + params: + freeze: true + - input_key: fps_id + is_trainable: false + target: models.svd.sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 + - input_key: motion_bucket_id + is_trainable: false + target: models.svd.sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 + - input_key: cond_frames + is_trainable: false + target: models.svd.sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder + params: + disable_encoder_autocast: true + n_cond_frames: 1 + n_copies: 1 + is_ae: true + encoder_config: + target: models.svd.sgm.models.autoencoder.AutoencoderKLModeOnly + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + attn_type: vanilla-xformers + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + - input_key: cond_aug + is_trainable: false + target: models.svd.sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 + first_stage_model: + class_path: models.svd.sgm.AutoencodingEngine + init_args: + encoder_config: + target: models.svd.sgm.modules.diffusionmodules.model.Encoder + params: + attn_type: vanilla + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + decoder_config: + target: models.svd.sgm.modules.autoencoding.temporal_ae.VideoDecoder + params: + attn_type: vanilla + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + video_kernel_size: + - 3 + - 1 + - 1 + loss_config: + target: torch.nn.Identity + regularizer_config: + target: models.svd.sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer + optimizer_config: null + lr_g_factor: 1.0 + trainable_ae_params: null + ae_optimizer_args: null + trainable_disc_params: null + disc_optimizer_args: null + disc_start_iter: 0 + diff_boost_factor: 3.0 + ckpt_engine: null + ckpt_path: null + additional_decode_keys: null + ema_decay: null + monitor: null + input_key: jpg + svd_pipeline: + class_path: modules.loader.module_loader_config.ModuleLoaderConfig + init_args: + loader_cls_path: diffusers.StableVideoDiffusionPipeline + cls_func: from_pretrained + cls_func_fast_dev_run: '' + kwargs_diffusers: + torch_dtype: torch.float16 + variant: fp16 + use_safetensors: true + model_params: null + model_params_fast_dev_run: null + kwargs_diff_trainer_params: null + args: + - stabilityai/stable-video-diffusion-img2vid-xt + dependent_modules: null + dependent_modules_cloned: null + state_dict_path: '' + strict_loading: true + state_dict_filters: [] + root_cls: null + diff_trainer_params: + class_path: modules.params.diffusion_trainer.params_streaming_diff_trainer.DiffusionTrainerParams + init_args: + scale_factor: 0.18215 + streamingsvd_ckpt: + class_path: modules.params.diffusion_trainer.params_streaming_diff_trainer.CheckpointDescriptor + init_args: + ckpt_path_local: checkpoint/StreamingSVD/model.safetensors + ckpt_path_global: PAIR/StreamingSVD/resolve/main/model.safetensors + disable_first_stage_autocast: true + inference_params: + class_path: modules.params.diffusion.inference_params.T2VInferenceParams + init_args: + n_autoregressive_generations: 2 # Number of autoregression for StreamingSVD + num_conditional_frames: 7 # is this used? + anchor_frames: '6' # Take the (Number+1)th frame as CLIP encoding for StreamingSVD + reset_seed_per_generation: true # If true, the seed is reset on every generation diff --git a/dataloader/dataset_factory.py b/dataloader/dataset_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..220e02184395e4df900fef76be375a1655945bdf --- /dev/null +++ b/dataloader/dataset_factory.py @@ -0,0 +1,13 @@ +from pathlib import Path +from torch.utils.data import Dataset + +from dataloader.single_image_dataset import SingleImageDataset + + +class SingleImageDatasetFactory(): + + def __init__(self, file: Path): + self.data_path = file + + def get_dataset(self, max_samples: int = None) -> Dataset: + return SingleImageDataset(file=self.data_path) diff --git a/dataloader/single_image_dataset.py b/dataloader/single_image_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..1a151ae0656d729e174e63c4b194bbe1d270fd17 --- /dev/null +++ b/dataloader/single_image_dataset.py @@ -0,0 +1,16 @@ +import torch +import numpy as np +from torch.utils.data import Dataset + + +class SingleImageDataset(Dataset): + + def __init__(self, file: np.ndarray): + super().__init__() + self.images = [file] + + def __len__(self): + return len(self.images) + + def __getitem__(self, index): + return {"image": self.images[index], "sample_id": torch.tensor(index, dtype=torch.int64)} diff --git a/dataloader/video_data_module.py b/dataloader/video_data_module.py new file mode 100644 index 0000000000000000000000000000000000000000..97b4be25ca2559ef3970c8d0a305dfab900d00c8 --- /dev/null +++ b/dataloader/video_data_module.py @@ -0,0 +1,32 @@ +import pytorch_lightning as pl +import torch +from pytorch_lightning.utilities.types import (EVAL_DATALOADERS) +from dataloader.dataset_factory import SingleImageDatasetFactory + + +class VideoDataModule(pl.LightningDataModule): + + def __init__(self, + workers: int, + predict_dataset_factory: SingleImageDatasetFactory = None, + ) -> None: + super().__init__() + self.num_workers = workers + + self.video_data_module = {} + # TODO read size from loaded unet via unet.sample_sizes + self.predict_dataset_factory = predict_dataset_factory + + def setup(self, stage: str) -> None: + if stage == "predict": + self.video_data_module["predict"] = self.predict_dataset_factory.get_dataset( + ) + + def predict_dataloader(self) -> EVAL_DATALOADERS: + return torch.utils.data.DataLoader(self.video_data_module["predict"], + batch_size=1, + pin_memory=True, + num_workers=self.num_workers, + collate_fn=None, + shuffle=False, + drop_last=False) diff --git a/diffusion_trainer/abstract_trainer.py b/diffusion_trainer/abstract_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..29d0e85f1a33e003bd5a975d38bb260cca203a4d --- /dev/null +++ b/diffusion_trainer/abstract_trainer.py @@ -0,0 +1,108 @@ +import os + +import pytorch_lightning as pl +import torch + +from typing import Any + +from modules.params.diffusion.inference_params import InferenceParams +from modules.loader.module_loader import GenericModuleLoader +from modules.params.diffusion_trainer.params_streaming_diff_trainer import DiffusionTrainerParams + + +class AbstractTrainer(pl.LightningModule): + + def __init__(self, + inference_params: Any, + diff_trainer_params: DiffusionTrainerParams, + module_loader: GenericModuleLoader, + ): + + super().__init__() + + self.inference_params = inference_params + self.diff_trainer_params = diff_trainer_params + self.module_loader = module_loader + + self.on_start_once_called = False + self._setup_methods = [] + + module_loader( + trainer=self, + diff_trainer_params=diff_trainer_params) + + # ------ IMPLEMENTATION HOOKS ------- + + def post_init(self, batch): + ''' + Is called after LightningDataModule and LightningModule is created, but before any training/validation/prediction. + First possible access to the 'trainer' object (e.g. to get 'device'). + ''' + + def generate_output(self, batch, batch_idx, inference_params: InferenceParams): + ''' + Is called during validation to generate for each batch an output. + Return the meta information about produced result (where result were stored). + This is used for the metric evaluation. + ''' + + # ------- HELPER FUNCTIONS ------- + + def _reset_random_generator(self): + ''' + Reset the random generator to the same seed across all workers. The generator is used only for inference. + ''' + if not hasattr(self, "random_generator"): + self.random_generator = torch.Generator(device=self.device) + # set seed according to 'seed_everything' in config + seed = int(os.environ.get("PL_GLOBAL_SEED", 42)) + else: + seed = self.random_generator.initial_seed() + self.random_generator.manual_seed(seed) + + # ----- PREDICT HOOKS ------ + + def on_predict_start(self): + self.on_start() + + def predict_step(self, batch, batch_idx): + self.on_inference_step(batch=batch, batch_idx=batch_idx) + + def on_predict_epoch_start(self): + self.on_inference_epoch_start() + + # ----- CUSTOM HOOKS ----- + + # Global Hooks (Called by Training, Validation and Prediction) + + # abstract method + + def _on_start_once(self): + ''' + Will be called only once by on_start. Thus, it will be called by the first call of train,validation or prediction. + ''' + if self.on_start_once_called: + return + else: + self.on_start_once_called = True + self.post_init() + + def on_start(self): + ''' + Called at the beginning of training, validation and prediction. + ''' + self._on_start_once() + + # Inference Hooks (Called by Validation and Prediction) + + # ----- Inference Hooks (called by 'validation' and 'predict') ------ + + def on_inference_epoch_start(self): + # reset seed at every inference + self._reset_random_generator() + + def on_inference_step(self, batch, batch_idx): + if self.inference_params.reset_seed_per_generation: + self._reset_random_generator() + self.generate_output( + batch=batch, inference_params=self.inference_params, batch_idx=batch_idx) diff --git a/diffusion_trainer/streaming_svd.py b/diffusion_trainer/streaming_svd.py new file mode 100644 index 0000000000000000000000000000000000000000..dcbf3c0b88a7fee72ab8a858d833e715baf17daf --- /dev/null +++ b/diffusion_trainer/streaming_svd.py @@ -0,0 +1,508 @@ +from modules.loader.module_loader import GenericModuleLoader +from modules.params.diffusion_trainer.params_streaming_diff_trainer import DiffusionTrainerParams +import torch +from modules.params.diffusion.inference_params import InferenceParams +from utils import result_processor +from modules.loader.module_loader import GenericModuleLoader +from tqdm import tqdm +from PIL import Image, ImageFilter +from utils.inference_utils import resize_and_crop,get_padding_for_aspect_ratio +import numpy as np +from safetensors.torch import load_file as load_safetensors +import math +from einops import repeat, rearrange +from torchvision.transforms import ToTensor +from models.svd.sgm.modules.autoencoding.temporal_ae import VideoDecoder +import PIL +from modules.params.vfi import VFIParams +from modules.params.i2v_enhance import I2VEnhanceParams +from typing import List,Union +from models.diffusion.wrappers import StreamingWrapper +from diffusion_trainer.abstract_trainer import AbstractTrainer +from utils.loader import download_ckpt +import torchvision.transforms.functional as TF +from diffusers import AutoPipelineForInpainting, DEISMultistepScheduler +from transformers import BlipProcessor, BlipForConditionalGeneration + +class StreamingSVD(AbstractTrainer): + def __init__(self, + module_loader: GenericModuleLoader, + diff_trainer_params: DiffusionTrainerParams, + inference_params: InferenceParams, + vfi: VFIParams, + i2v_enhance: I2VEnhanceParams, + ): + super().__init__(inference_params=inference_params, + diff_trainer_params=diff_trainer_params, + module_loader=module_loader, + ) + + # network config is wrapped by OpenAIWrapper, so we dont need a direct reference anymore + # this corresponds to the config yaml defined at model.module_loader.module_config.model.dependent_modules + del self.network_config + self.diff_trainer_params: DiffusionTrainerParams + self.vfi = vfi + self.i2v_enhance = i2v_enhance + + def on_inference_epoch_start(self): + super().on_inference_epoch_start() + + # for StreamingSVD we use a model wrapper that combines the base SVD model and the control model. + self.inference_model = StreamingWrapper( + diffusion_model=self.model.diffusion_model, + controlnet=self.controlnet, + num_frame_conditioning=self.inference_params.num_conditional_frames + ) + + def post_init(self): + self.svd_pipeline.set_progress_bar_config(disable=True) + if self.device.type != "cpu": + self.svd_pipeline.enable_model_cpu_offload(gpu_id = self.device.index) + + # re-use the open clip already loaded for image conditioner for image_encoder_apm + embedders = self.conditioner.embedders + for embedder in embedders: + if hasattr(embedder,"input_key") and embedder.input_key == "cond_frames_without_noise": + self.image_encoder_apm = embedder.open_clip + self.first_stage_model.to("cpu") + self.conditioner.embedders[3].encoder.to("cpu") + self.conditioner.embedders[0].open_clip.to("cpu") + + pipe = AutoPipelineForInpainting.from_pretrained( + 'Lykon/dreamshaper-8-inpainting', torch_dtype=torch.float16, variant="fp16", safety_checker=None, requires_safety_checker=False) + + pipe.scheduler = DEISMultistepScheduler.from_config(pipe.scheduler.config) + pipe = pipe.to(self.device) + pipe.enable_model_cpu_offload(gpu_id = self.device.index) + self.inpaint_pipe = pipe + + processor = BlipProcessor.from_pretrained( + "Salesforce/blip-image-captioning-large") + + + model = BlipForConditionalGeneration.from_pretrained( + "Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to(self.device) + def blip(x): return processor.decode(model.generate(** processor(x, + return_tensors='pt').to("cuda", torch.float16))[0], skip_special_tokens=True) + self.blip = blip + + # Adapted from https://github.com/Stability-AI/generative-models/blob/main/scripts/sampling/simple_video_sample.py + def get_unique_embedder_keys_from_conditioner(self, conditioner): + return list(set([x.input_key for x in conditioner.embedders])) + + + # Adapted from https://github.com/Stability-AI/generative-models/blob/main/scripts/sampling/simple_video_sample.py + def get_batch_sgm(self, keys, value_dict, N, T, device): + batch = {} + batch_uc = {} + + for key in keys: + if key == "fps_id": + batch[key] = ( + torch.tensor([value_dict["fps_id"]]) + .to(device) + .repeat(int(math.prod(N))) + ) + elif key == "motion_bucket_id": + batch[key] = ( + torch.tensor([value_dict["motion_bucket_id"]]) + .to(device) + .repeat(int(math.prod(N))) + ) + elif key == "cond_aug": + batch[key] = repeat( + torch.tensor([value_dict["cond_aug"]]).to(device), + "1 -> b", + b=math.prod(N), + ) + elif key == "cond_frames": + batch[key] = repeat(value_dict["cond_frames"], + "1 ... -> b ...", b=N[0]) + elif key == "cond_frames_without_noise": + batch[key] = repeat( + value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0] + ) + else: + batch[key] = value_dict[key] + + if T is not None: + batch["num_video_frames"] = T + + for key in batch.keys(): + if key not in batch_uc and isinstance(batch[key], torch.Tensor): + batch_uc[key] = torch.clone(batch[key]) + return batch, batch_uc + + # Adapted from https://github.com/Stability-AI/generative-models/blob/main/sgm/models/diffusion.py + @torch.no_grad() + def decode_first_stage(self, z): + self.first_stage_model.to(self.device) + + z = 1.0 / self.diff_trainer_params.scale_factor * z + #n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0]) + n_samples = min(z.shape[0],8) + #print("SVD decoder started") + import time + start = time.time() + n_rounds = math.ceil(z.shape[0] / n_samples) + all_out = [] + with torch.autocast("cuda", enabled=not self.diff_trainer_params.disable_first_stage_autocast): + for n in range(n_rounds): + if isinstance(self.first_stage_model.decoder, VideoDecoder): + kwargs = {"timesteps": len( + z[n * n_samples: (n + 1) * n_samples])} + else: + kwargs = {} + out = self.first_stage_model.decode( + z[n * n_samples: (n + 1) * n_samples], **kwargs + ) + all_out.append(out) + out = torch.cat(all_out, dim=0) + # print(f"SVD decoder finished after {time.time()-start} seconds.") + self.first_stage_model.to("cpu") + return out + + + # Adapted from https://github.com/Stability-AI/generative-models/blob/main/scripts/sampling/simple_video_sample.py + def _generate_conditional_output(self, svd_input_frame, inference_params: InferenceParams, **params): + C = 4 + F = 8 # spatial compression TODO read from model + + H = svd_input_frame.shape[-2] + W = svd_input_frame.shape[-1] + num_frames = self.sampler.guider.num_frames + + shape = (num_frames, C, H // F, W // F) + batch_size = 1 + + image = svd_input_frame[None,:] + cond_aug = 0.02 + + value_dict = {} + value_dict["motion_bucket_id"] = 127 + value_dict["fps_id"] = 6 + value_dict["cond_aug"] = cond_aug + value_dict["cond_frames_without_noise"] = image + value_dict["cond_frames"] =image + cond_aug * torch.rand_like(image) + + batch, batch_uc = self.get_batch_sgm( + self.get_unique_embedder_keys_from_conditioner( + self.conditioner), + value_dict, + [1, num_frames], + T=num_frames, + device=self.device, + ) + + self.conditioner.embedders[3].encoder.to(self.device) + self.conditioner.embedders[0].open_clip.to(self.device) + c, uc = self.conditioner.get_unconditional_conditioning( + batch, + batch_uc=batch_uc, + force_uc_zero_embeddings=[ + "cond_frames", + "cond_frames_without_noise", + ], + ) + self.conditioner.embedders[3].encoder.to("cpu") + self.conditioner.embedders[0].open_clip.to("cpu") + + + for k in ["crossattn", "concat"]: + uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames) + uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames) + c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames) + c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames) + + randn = torch.randn(shape, device=self.device) + + additional_model_inputs = {} + additional_model_inputs["image_only_indicator"] = torch.zeros(2*batch_size,num_frames).to(self.device) + additional_model_inputs["num_video_frames"] = batch["num_video_frames"] + + # StreamingSVD inputs + additional_model_inputs["batch_size"] = 2*batch_size + additional_model_inputs["num_conditional_frames"] = self.inference_params.num_conditional_frames + additional_model_inputs["ctrl_frames"] = params["ctrl_frames"] + + self.inference_model.diffusion_model = self.inference_model.diffusion_model.to( + self.device) + self.inference_model.controlnet = self.inference_model.controlnet.to( + self.device) + + c["vector"] = c["vector"].to(randn.dtype) + uc["vector"] = uc["vector"].to(randn.dtype) + def denoiser(input, sigma, c): + return self.denoiser(self.inference_model,input,sigma,c, **additional_model_inputs) + samples_z = self.sampler(denoiser,randn,cond=c,uc=uc) + + self.inference_model.diffusion_model = self.inference_model.diffusion_model.to( "cpu") + self.inference_model.controlnet = self.inference_model.controlnet.to("cpu") + samples_x = self.decode_first_stage(samples_z) + + samples = torch.clamp(samples_x,min=-1.0,max=1.0) + return samples + + + def extract_anchor_frames(self, video, input_range,inference_params: InferenceParams): + """ + Extracts anchor frames from the input video based on the provided inference parameters. + + Parameters: + - video: torch.Tensor + The input video tensor. + - input_range: list + The pixel value range of input video. + - inference_params: InferenceParams + An object containing inference parameters. + - anchor_frames: str + Specifies how the anchor frames are encoded. It can be either a single number specifying which frame is used as the anchor frame, + or a range in the format "a:b" indicating that frames from index a up to index b (inclusive) are used as anchor frames. + + Returns: + - torch.Tensor + The extracted anchor frames from the input video. + """ + video = result_processor.convert_range(video=video.clone(),input_range=input_range,output_range=[-1,1]) + + if video.shape[1] == 3 and video.shape[0]>3: + video = rearrange(video,"F C W H -> 1 F C W H") + elif video.shape[0]>3 and video.shape[-1] == 3: + video = rearrange(video,"F W H C -> 1 F C W H") + else: + raise NotImplementedError(f"Unexpected video input format: {video.shape}") + + if ":" in inference_params.anchor_frames: + anchor_frames = inference_params.anchor_frames.split(":") + anchor_frames = [int(anchor_frame) for anchor_frame in anchor_frames] + assert len(anchor_frames) == 2,"Anchor frames encoding wrong." + anchor = video[:,anchor_frames[0]:anchor_frames[1]] + else: + anchor_frame = int(inference_params.anchor_frames) + anchor = video[:, anchor_frame].unsqueeze(0) + + return anchor + + def extract_ctrl_frames(self,video: torch.FloatType, input_range: List[int], inference_params: InferenceParams): + """ + Extracts control frames from the input video. + + Parameters: + - video: torch.Tensor + The input video tensor. + - input_range: list + The pixel value range of input video. + - inference_params: InferenceParams + An object containing inference parameters. + + Returns: + - torch.Tensor + The extracted control image encoding frames from the input video. + """ + video = result_processor.convert_range(video=video.clone(), input_range=input_range, output_range=[-1, 1]) + if video.shape[1] == 3 and video.shape[0] > 3: + video = rearrange(video, "F C W H -> 1 F C W H") + elif video.shape[0] > 3 and video.shape[-1] == 3: + video = rearrange(video, "F W H C -> 1 F C W H") + else: + raise NotImplementedError( + f"Unexpected video input format: {video.shape}") + + # return the last num_conditional_frames frames + video = video[:, -inference_params.num_conditional_frames:] + return video + + + def _autoregressive_generation(self,initial_generation: Union[torch.FloatType,List[torch.FloatType]], inference_params:InferenceParams): + """ + Perform autoregressive generation of video chunks based on the initial generation and inference parameters. + + Parameters: + - initial_generation: torch.Tensor or list of torch.Tensor + The initial generation or list of initial generation video chunks. + - inference_params: InferenceParams + An object containing inference parameters. + + Returns: + - torch.Tensor + The generated video resulting from autoregressive generation. + """ + + # input is [-1,1] float + result_chunks = initial_generation + if not isinstance(result_chunks,list): + result_chunks = [result_chunks] + + # make sure + if (result_chunks[0].shape[1] >3) and (result_chunks[0].shape[-1] == 3): + result_chunks = [rearrange(result_chunks[0],"F W H C -> F C W H")] + + # generating chunk by conditioning on the previous chunks + for _ in tqdm(list(range(inference_params.n_autoregressive_generations)),desc="StreamingSVD"): + + # extract anchor frames based on the entire, so far generated, video + # note that we do note use anchor frame in StreamingSVD (apart from the anchor frame already used by SVD). + anchor_frames = self.extract_anchor_frames( + video = torch.cat(result_chunks), + inference_params=inference_params, + input_range=[-1, 1], + ) + + # extract control frames based on the last generated chunk + ctrl_frames = self.extract_ctrl_frames( + video = result_chunks[-1], + input_range=[-1, 1], + inference_params=inference_params, + ) + + # select the anchor frame for svd + svd_input_frame = result_chunks[0][int(inference_params.anchor_frames)] + + # generate the next chunk + # result is [F, C, H, W], range is [-1,1] float. + result = self._generate_conditional_output( + svd_input_frame = svd_input_frame, + inference_params=inference_params, + anchor_frames=anchor_frames, + ctrl_frames=ctrl_frames, + ) + + # from each generation, we keep all frames except for the first frames + result = result[inference_params.num_conditional_frames:] + result_chunks.append(result) + torch.cuda.empty_cache() + + # concat all chunks to one long video + result_chunks = [result_processor.convert_range(chunk,output_range=[0,255],input_range=[-1,1]) for chunk in result_chunks] + result = result_processor.concat_chunks(result_chunks) + torch.cuda.empty_cache() + return result + + def ensure_image_ratio(self,source_image: PIL,target_aspect_ratio = 16/9): + + if source_image.width / source_image.height == target_aspect_ratio: + return source_image, None + + image = source_image.copy().convert("RGBA") + mask = image.split()[-1] + image = image.convert("RGB") + padding = get_padding_for_aspect_ratio(image) + + + mask_padded = TF.pad(mask, padding) + mask_padded_size = mask_padded.size + mask_padded_resized = TF.resize(mask_padded, (512, 512), + interpolation=TF.InterpolationMode.NEAREST) + mask_padded_resized = TF.invert(mask_padded_resized) + + # image + padded_input_image = TF.pad(image, padding, padding_mode="reflect") + resized_image = TF.resize(padded_input_image, (512, 512)) + + image_tensor = (self.inpaint_pipe.image_processor.preprocess( + resized_image).cuda().half()) + latent_tensor = self.inpaint_pipe._encode_vae_image(image_tensor, None) + self.inpaint_pipe.scheduler.set_timesteps(999) + noisy_latent_tensor = self.inpaint_pipe.scheduler.add_noise( + latent_tensor, + torch.randn_like(latent_tensor), + self.inpaint_pipe.scheduler.timesteps[:1], + ) + + prompt = self.blip(source_image) + if prompt.startswith("there is "): + prompt = prompt[len("there is "):] + + output_image_normalized_size = self.inpaint_pipe( + prompt=prompt, + image=resized_image, + mask_image=mask_padded_resized, + latents=noisy_latent_tensor, + ).images[0] + + output_image_extended_size = TF.resize( + output_image_normalized_size, mask_padded_size[::-1]) + + blured_outpainting_mask = TF.invert(mask_padded).filter( + ImageFilter.GaussianBlur(radius=5)) + + final_image = Image.composite( + output_image_extended_size, padded_input_image, blured_outpainting_mask) + return final_image, TF.invert(mask_padded) + + + def image_to_video(self, batch, inference_params: InferenceParams, batch_idx): + + """ + Performs image to video based on the input batch and inference parameters. + It runs SVD-XT one to generate the first chunk, then auto-regressively applies StreamingSVD. + + Parameters: + - batch: dict + The input batch containing the start image for generating the video. + - inference_params: InferenceParams + An object containing inference parameters. + - batch_idx: int + The index of the batch. + + Returns: + - torch.Tensor + The generated video based on the image image. + """ + batch_key = "image" + assert batch_key == "image", f"Generating video from {batch_key} not implemented." + input_image = PIL.Image.fromarray(batch[batch_key][0].cpu().numpy()) + # TODO remove conversion forth and back + + outpainted_image, _ = self.ensure_image_ratio(input_image) + + #image = Image.fromarray(np.uint8(image)) + ''' + if image.width/image.height != 16/9: + print(f"Warning! For best results, we assume the aspect ratio of the input image to be 16:9. Found ratio {image.width}:{image.height}.") + ''' + scaled_outpainted_image, expanded_size = resize_and_crop(outpainted_image) + assert scaled_outpainted_image.width == 1024 and scaled_outpainted_image.height == 576, f"Wrong shape for file {batch[batch_key]} with shape {scaled_outpainted_image.width}:{scaled_outpainted_image.height}." + + # Generating first chunk + with torch.autocast(device_type="cuda",enabled=False): + video_chunks = self.svd_pipeline( + scaled_outpainted_image, decode_chunk_size=8).frames[0] + + video_chunks = torch.stack([ToTensor()(frame) for frame in video_chunks]) + video_chunks = video_chunks * 2.0 - 1 # [-1,1], float + + video_chunks = video_chunks.to(self.device) + + video = self._autoregressive_generation( + initial_generation=video_chunks, + inference_params=inference_params) + + return video, scaled_outpainted_image, expanded_size + + + def generate_output(self, batch, batch_idx,inference_params: InferenceParams): + """ + Generate output video based on the input batch and inference parameters. + + Parameters: + - batch: dict + The input batch containing data for generating the output video. + - batch_idx: int + The index of the batch. + - inference_params: InferenceParams + An object containing inference parameters. + + Returns: + - torch.Tensor + The generated video. Note the result is also accessible via self.trainer.generated_video + """ + + sample_id = batch["sample_id"].item() + video, scaled_outpainted_image, expanded_size = self.image_to_video( + batch, inference_params=inference_params, batch_idx=sample_id) + + self.trainer.generated_video = video.numpy() + self.trainer.expanded_size = expanded_size + self.trainer.scaled_outpainted_image = scaled_outpainted_image + return video diff --git a/gradio_demo.py b/gradio_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..79cd7a05c6b3b6bc88b4a1cefe49536cc5677e2e --- /dev/null +++ b/gradio_demo.py @@ -0,0 +1,214 @@ +import os +import gradio as gr +from utils.gradio_utils import * +import argparse + +GRADIO_CACHE = "" + +parser = argparse.ArgumentParser() +parser.add_argument('--public_access', action='store_true') +args = parser.parse_args() + +streaming_svd = StreamingSVD(load_argv=False) +on_huggingspace = os.environ.get("SPACE_AUTHOR_NAME") == "PAIR" + +examples = [ + ["Experience the dance of jellyfish: float through mesmerizing swarms of jellyfish, pulsating with otherworldly grace and beauty.", + "200 - frames (recommended)", 33, None, None], + ["Dive into the depths of the ocean: explore vibrant coral reefs, mysterious underwater caves, and the mesmerizing creatures that call the sea home.", + "200 - frames (recommended)", 33, None, None], + ["A cute cat.", + "200 - frames (recommended)", 33, None, None], + ["", + "200 - frames (recommended)", 33, "__assets__/gradio_cached_examples/test1.jpg", None], + ["", + "200 - frames (recommended)", 33, "__assets__/gradio_cached_examples/test2.jpg", None], + ["", + "200 - frames (recommended)", 33, "__assets__/gradio_cached_examples/test3.png", None], + ["", + "200 - frames (recommended)", 33, "__assets__/gradio_cached_examples/test4.png", None], + ["", + "200 - frames (recommended)", 33, "__assets__/gradio_cached_examples/test5.jpg", None], + ["", + "200 - frames (recommended)", 33, "__assets__/gradio_cached_examples/test6.png", None], + ["", + "200 - frames (recommended)", 33, "__assets__/gradio_cached_examples/test7.jpg", None], + ["", + "200 - frames (recommended)", 33, "__assets__/gradio_cached_examples/test8.jpg", None], + ["", + "200 - frames (recommended)", 33, "__assets__/gradio_cached_examples/test9.jpg", None], + ["", + "200 - frames (recommended)", 33, "__assets__/gradio_cached_examples/test10.jpg", None], + ["", + "200 - frames (recommended)", 33, "__assets__/gradio_cached_examples/test11.jpg", None], + ] + +def generate(prompt, num_frames, seed, image: np.ndarray): + if num_frames == [] or num_frames is None: + num_frames = 50 + else: + num_frames = int(num_frames.split(" ")[0]) + if num_frames > 200: # and on_huggingspace: + num_frames = 200 + + if image is None: + image = text_to_image_gradio( + prompt=prompt, streaming_svd=streaming_svd, seed=seed) + + video_file_stage_one = image_to_video_vfi_gradio( + img=image, num_frames=num_frames, streaming_svd=streaming_svd, seed=seed, gradio_cache=GRADIO_CACHE) + + expanded_size, orig_size, scaled_outpainted_image = retrieve_intermediate_data(video_file_stage_one) + + video_file_stage_two = enhance_video_vfi_gradio( + img=scaled_outpainted_image, video=video_file_stage_one.replace("__cropped__", "__expanded__"), num_frames=24, streaming_svd=streaming_svd, seed=seed, expanded_size=expanded_size, orig_size=orig_size, gradio_cache=GRADIO_CACHE) + + return image, video_file_stage_one, video_file_stage_two + + +def enhance(prompt, num_frames, seed, image: np.ndarray, video:str): + if num_frames == [] or num_frames is None: + num_frames = 50 + else: + num_frames = int(num_frames.split(" ")[0]) + if num_frames > 200: # and on_huggingspace: + num_frames = 200 + + # User directly applied Long Video Generation (without preview) with Flux. + if image is None: + image = text_to_image_gradio( + prompt=prompt, streaming_svd=streaming_svd, seed=seed) + + # User directly applied Long Video Generation (without preview) with or without Flux. + if video is None: + video = image_to_video_gradio( + img=image, num_frames=(num_frames+1) // 2, streaming_svd=streaming_svd, seed=seed, gradio_cache=GRADIO_CACHE) + expanded_size, orig_size, scaled_outpainted_image = retrieve_intermediate_data(video) + + # Here the video is path and image is numpy array + video_file_stage_two = enhance_video_vfi_gradio( + img=scaled_outpainted_image, video=video.replace("__cropped__", "__expanded__"), num_frames=num_frames, streaming_svd=streaming_svd, seed=seed, expanded_size=expanded_size, orig_size=orig_size, gradio_cache=GRADIO_CACHE) + + return image, video_file_stage_two + + +with gr.Blocks() as demo: + GRADIO_CACHE = demo.GRADIO_CACHE + gr.HTML(""" +
+

+ StreamingSVD +

+

+ A StreamingT2V method for high-quality long video generation +

+

+ Roberto Henschel1*, Levon Khachatryan1*, Daniil Hayrapetyan1*, Hayk Poghosyan1, Vahram Tadevosyan1, Zhangyang Wang1,2, Shant Navasardyan1, Humphrey Shi1,3 +

+

+ 1Picsart AI Resarch (PAIR), 2UT Austin, 3SHI Labs @ Georgia Tech, Oregon & UIUC +

+

+ *Equal Contribution +

+

+ [arXiv] + [GitHub] +

+

+ StreamingSVD is an advanced autoregressive technique for text-to-video and image-to-video generation, + generating long hiqh-quality videos with rich motion dynamics, turning SVD into a long video generator. + Our method ensures temporal consistency throughout the video, aligns closely to the input text/image, + and maintains high frame-level image quality. Our demonstrations include successful examples of videos + up to 200 frames, spanning 8 seconds, and can be extended for even longer durations. +

+
+ """) + + if on_huggingspace: + gr.HTML(""" +

For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. +
+ + Duplicate Space +

""") + + with gr.Row(): + with gr.Column(scale=1): + with gr.Row(): + with gr.Column(): + with gr.Row(): + num_frames = gr.Dropdown(["50 - frames (recommended)", "80 - frames (recommended)", "140 - frames (recommended)", "200 - frames (recommended)", "500 - frames", "1000 - frames", "10000 - frames"], + label="Number of Video Frames", info="For >200 frames use local workstation!", value="50 - frames (recommended)") + with gr.Row(): + prompt_stage1 = gr.Textbox(label='Text-to-Video (Enter text prompt here)', + interactive=True, max_lines=1) + with gr.Row(): + image_stage1 = gr.Image(label='Image-to-Video (Upload Image here, text prompt will be ignored for I2V if entered)', + show_label=True, show_download_button=True, interactive=True, height=250) + with gr.Column(): + video_stage1 = gr.Video(label='Long Video Preview', show_label=True, + interactive=False, show_download_button=True, height=203) + with gr.Row(): + run_button_stage1 = gr.Button("Long Video Generation (faster preview)") + with gr.Row(): + with gr.Column(): + with gr.Accordion('Advanced options', open=False): + seed = gr.Slider(label='Seed', minimum=0, + maximum=65536, value=33, step=1,) + + with gr.Column(scale=3): + with gr.Row(): + video_stage2 = gr.Video(label='High-Quality Long Video (Preview or Full)', show_label=True, + interactive=False, show_download_button=True, height=700) + with gr.Row(): + run_button_stage2 = gr.Button("Long Video Generation (full high-quality)") + + inputs_t2v = [prompt_stage1, num_frames, + seed, image_stage1] + inputs_v2v = [prompt_stage1, num_frames, seed, + image_stage1, video_stage1] + + run_button_stage1.click(fn=generate, inputs=inputs_t2v, + outputs=[image_stage1, video_stage1, video_stage2]) + run_button_stage2.click(fn=enhance, inputs=inputs_v2v, + outputs=[image_stage1, video_stage2]) + + + gr.Examples(examples=examples, + inputs=inputs_v2v, + outputs=[image_stage1, video_stage2], + fn=enhance, + cache_examples=True, + run_on_click=False, + ) + + + ''' + ''' + gr.HTML(""" +
+

+ Version: v1.0 +

+

+ Caution: + We would like the raise the awareness of users of this demo of its potential issues and concerns. + Like previous large foundation models, StreamingSVD could be problematic in some cases, partially we use pretrained ModelScope, therefore StreamingSVD can Inherit Its Imperfections. + So far, we keep all features available for research testing both to show the great potential of the StreamingSVD framework and to collect important feedback to improve the model in the future. + We welcome researchers and users to report issues with the HuggingFace community discussion feature or email the authors. +

+

+ Biases and content acknowledgement: + Beware that StreamingSVD may output content that reinforces or exacerbates societal biases, as well as realistic faces, pornography, and violence. + StreamingSVD in this demo is meant only for research purposes. +

+
+ """) + + +if on_huggingspace: + demo.queue(max_size=20) + demo.launch(debug=True) +else: + demo.queue(api_open=False).launch(share=args.public_access) diff --git a/i2v_enhance/i2v_enhance_interface.py b/i2v_enhance/i2v_enhance_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..74b76c581ab2a7b515be3bb262f1cee52b1e5955 --- /dev/null +++ b/i2v_enhance/i2v_enhance_interface.py @@ -0,0 +1,128 @@ +import torch +from i2v_enhance.pipeline_i2vgen_xl import I2VGenXLPipeline +from tqdm import tqdm +from PIL import Image +import numpy as np +from einops import rearrange +import i2v_enhance.thirdparty.VFI.config as cfg +from i2v_enhance.thirdparty.VFI.Trainer import Model as VFI +from pathlib import Path +from modules.params.vfi import VFIParams +from modules.params.i2v_enhance import I2VEnhanceParams +from utils.loader import download_ckpt + + +def vfi_init(ckpt_cfg: VFIParams, device_id=0): + cfg.MODEL_CONFIG['MODEL_ARCH'] = cfg.init_model_config(F=32, depth=[ + 2, 2, 2, 4, 4]) + vfi = VFI(-1) + + ckpt_file = Path(download_ckpt( + local_path=ckpt_cfg.ckpt_path_local, global_path=ckpt_cfg.ckpt_path_global)) + + vfi.load_model(ckpt_file.as_posix()) + vfi.eval() + vfi.device() + assert device_id == 0, "VFI on rank!=0 not implemented yet." + return vfi + + +def vfi_process(video, vfi, video_len): + video = video[:(video_len//2+1)] + + video = [i[:, :, :3]/255. for i in video] + video = [i[:, :, ::-1] for i in video] + video = np.stack(video, axis=0) + video = rearrange(torch.from_numpy(video), + 'b h w c -> b c h w').to("cuda", torch.float32) + + frames = [] + for i in tqdm(range(video.shape[0]-1), desc="VFI"): + I0_ = video[i:i+1, ...] + I2_ = video[i+1:i+2, ...] + frames.append((I0_[0].detach().cpu().numpy().transpose( + 1, 2, 0) * 255.0).astype(np.uint8)[:, :, ::-1]) + + mid = (vfi.inference(I0_, I2_, TTA=True, fast_TTA=True)[ + 0].detach().cpu().numpy().transpose(1, 2, 0) * 255.0).astype(np.uint8) + frames.append(mid[:, :, ::-1]) + + frames.append((video[-1].detach().cpu().numpy().transpose(1, + 2, 0) * 255.0).astype(np.uint8)[:, :, ::-1]) + if video_len % 2 == 0: + frames.append((video[-1].detach().cpu().numpy().transpose(1, + 2, 0) * 255.0).astype(np.uint8)[:, :, ::-1]) + + del vfi + del video + torch.cuda.empty_cache() + + video = [Image.fromarray(frame).resize((1280, 720)) for frame in frames] + del frames + return video + + +def i2v_enhance_init(i2vgen_cfg: I2VEnhanceParams): + generator = torch.manual_seed(8888) + try: + pipeline = I2VGenXLPipeline.from_pretrained( + i2vgen_cfg.ckpt_path_local, torch_dtype=torch.float16, variant="fp16") + except Exception as e: + pipeline = I2VGenXLPipeline.from_pretrained( + i2vgen_cfg.ckpt_path_global, torch_dtype=torch.float16, variant="fp16") + pipeline.save_pretrained(i2vgen_cfg.ckpt_path_local) + pipeline.enable_model_cpu_offload() + return pipeline, generator + + +def i2v_enhance_process(image, video, pipeline, generator, overlap_size, strength, chunk_size=38, use_randomized_blending=False): + prompt = "High Quality, HQ, detailed." + negative_prompt = "Distorted, blurry, discontinuous, Ugly, blurry, low resolution, motionless, static, disfigured, disconnected limbs, Ugly faces, incomplete arms" + + if use_randomized_blending: + # We first need to enhance key-frames (the 1st frame of each chunk) + video_chunks = [video[i:i+chunk_size] for i in range(0, len( + video), chunk_size-overlap_size) if len(video[i:i+chunk_size]) == chunk_size] + video_short = [chunk[0] for chunk in video_chunks] + + # If randomized blending then we must have a list of starting images (1 for each chunk) + image = pipeline( + prompt=prompt, + height=720, + width=1280, + image=image, + video=video_short, + strength=strength, + overlap_size=0, + chunk_size=len(video_short), + num_frames=len(video_short), + num_inference_steps=30, + decode_chunk_size=1, + negative_prompt=negative_prompt, + guidance_scale=9.0, + generator=generator, + ).frames[0] + + # Remove the last few frames (< chunk_size) of the video that do not fit into one chunk. + max_idx = (chunk_size - overlap_size) * \ + (len(video_chunks) - 1) + chunk_size + video = video[:max_idx] + + frames = pipeline( + prompt=prompt, + height=720, + width=1280, + image=image, + video=video, + strength=strength, + overlap_size=overlap_size, + chunk_size=chunk_size, + num_frames=chunk_size, + num_inference_steps=30, + decode_chunk_size=1, + negative_prompt=negative_prompt, + guidance_scale=9.0, + generator=generator, + ).frames[0] + + return frames diff --git a/i2v_enhance/pipeline_i2vgen_xl.py b/i2v_enhance/pipeline_i2vgen_xl.py new file mode 100644 index 0000000000000000000000000000000000000000..dec4827a60b8e5d145909053ae62ec7e26f25d6d --- /dev/null +++ b/i2v_enhance/pipeline_i2vgen_xl.py @@ -0,0 +1,988 @@ +# Copyright 2024 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL +import torch +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from diffusers.image_processor import PipelineImageInput, VaeImageProcessor +from diffusers.models import AutoencoderKL +from diffusers.models.unets.unet_i2vgen_xl import I2VGenXLUNet +from diffusers.schedulers import DDIMScheduler +from diffusers.utils import ( + BaseOutput, + logging, + replace_example_docstring, +) +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin +import random + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import I2VGenXLPipeline + >>> from diffusers.utils import export_to_gif, load_image + + >>> pipeline = I2VGenXLPipeline.from_pretrained( + ... "ali-vilab/i2vgen-xl", torch_dtype=torch.float16, variant="fp16" + ... ) + >>> pipeline.enable_model_cpu_offload() + + >>> image_url = ( + ... "https://huggingface.co./datasets/diffusers/docs-images/resolve/main/i2vgen_xl_images/img_0009.png" + ... ) + >>> image = load_image(image_url).convert("RGB") + + >>> prompt = "Papers were floating in the air on a table in the library" + >>> negative_prompt = "Distorted, discontinuous, Ugly, blurry, low resolution, motionless, static, disfigured, disconnected limbs, Ugly faces, incomplete arms" + >>> generator = torch.manual_seed(8888) + + >>> frames = pipeline( + ... prompt=prompt, + ... image=image, + ... num_inference_steps=50, + ... negative_prompt=negative_prompt, + ... guidance_scale=9.0, + ... generator=generator, + ... ).frames[0] + >>> video_path = export_to_gif(frames, "i2v.gif") + ``` +""" + + +@dataclass +class I2VGenXLPipelineOutput(BaseOutput): + r""" + Output class for image-to-video pipeline. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised + PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)` + """ + + frames: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]]] + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents + + +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError( + "Could not access latents of provided encoder_output") + + +class I2VGenXLPipeline( + DiffusionPipeline, + StableDiffusionMixin, +): + r""" + Pipeline for image-to-video generation as proposed in [I2VGenXL](https://i2vgen-xl.github.io/). + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co./openai/clip-vit-large-patch14)). + tokenizer (`CLIPTokenizer`): + A [`~transformers.CLIPTokenizer`] to tokenize text. + unet ([`I2VGenXLUNet`]): + A [`I2VGenXLUNet`] to denoise the encoded video latents. + scheduler ([`DDIMScheduler`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + image_encoder: CLIPVisionModelWithProjection, + feature_extractor: CLIPImageProcessor, + unet: I2VGenXLUNet, + scheduler: DDIMScheduler, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + unet=unet, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** ( + len(self.vae.config.block_out_channels) - 1) + # `do_resize=False` as we do custom resizing. + self.video_processor = VideoProcessor( + vae_scale_factor=self.vae_scale_factor, do_resize=False) + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + def encode_prompt( + self, + prompt, + device, + num_videos_per_prompt, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_videos_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer( + prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1: -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm( + prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to( + dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view( + bs_embed * num_videos_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if self.do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + # Apply clip_skip to negative prompt embeds + if clip_skip is None: + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + else: + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + negative_prompt_embeds = negative_prompt_embeds[-1][-( + clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + negative_prompt_embeds = self.text_encoder.text_model.final_layer_norm( + negative_prompt_embeds) + + if self.do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to( + dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat( + 1, num_videos_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds + + def _encode_image(self, image, device, num_videos_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.video_processor.pil_to_numpy(image) + image = self.video_processor.numpy_to_pt(image) + + # Normalize the image with CLIP training stats. + image = self.feature_extractor( + images=image, + do_normalize=True, + do_center_crop=False, + do_resize=False, + do_rescale=False, + return_tensors="pt", + ).pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeddings = self.image_encoder(image).image_embeds + image_embeddings = image_embeddings.unsqueeze(1) + + # duplicate image embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = image_embeddings.shape + image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1) + image_embeddings = image_embeddings.view( + bs_embed * num_videos_per_prompt, seq_len, -1) + + if self.do_classifier_free_guidance: + negative_image_embeddings = torch.zeros_like(image_embeddings) + image_embeddings = torch.cat( + [negative_image_embeddings, image_embeddings]) + + return image_embeddings + + def decode_latents(self, latents, decode_chunk_size=None): + latents = 1 / self.vae.config.scaling_factor * latents + + batch_size, channels, num_frames, height, width = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape( + batch_size * num_frames, channels, height, width) + + if decode_chunk_size is not None: + frames = [] + for i in range(0, latents.shape[0], decode_chunk_size): + frame = self.vae.decode( + latents[i: i + decode_chunk_size]).sample + frames.append(frame) + image = torch.cat(frames, dim=0) + else: + image = self.vae.decode(latents).sample + + decode_shape = (batch_size, num_frames, -1) + image.shape[2:] + video = image[None, :].reshape(decode_shape).permute(0, 2, 1, 3, 4) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + video = video.float() + return video + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature( + self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set( + inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError( + f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError( + f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if ( + not isinstance(image, torch.Tensor) + and not isinstance(image, PIL.Image.Image) + and not isinstance(image, list) + ): + raise ValueError( + "`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" + f" {type(image)}" + ) + + def prepare_image_latents( + self, + image, + device, + num_frames, + num_videos_per_prompt, + ): + image = image.to(device=device) + image_latents = self.vae.encode(image).latent_dist.sample() + image_latents = image_latents * self.vae.config.scaling_factor + + # Add frames dimension to image latents + image_latents = image_latents.unsqueeze(2) + + # Append a position mask for each subsequent frame + # after the intial image latent frame + frame_position_mask = [] + for frame_idx in range(num_frames - 1): + scale = (frame_idx + 1) / (num_frames - 1) + frame_position_mask.append( + torch.ones_like(image_latents[:, :, :1]) * scale) + if frame_position_mask: + frame_position_mask = torch.cat(frame_position_mask, dim=2) + image_latents = torch.cat( + [image_latents, frame_position_mask], dim=2) + + # duplicate image_latents for each generation per prompt, using mps friendly method + image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1, 1) + + if self.do_classifier_free_guidance: + image_latents = torch.cat([image_latents] * 2) + + return image_latents + + # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + shape = ( + batch_size, + num_channels_latents, + num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor( + shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min( + int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order:] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + # Similar to image, we need to prepare the latents for the video. + def prepare_video_latents( + self, video, timestep, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + video = video.to(device=device, dtype=dtype) + is_long = video.shape[2] > 16 + + # change from (b, c, f, h, w) -> (b * f, c, w, h) + bsz, channel, frames, width, height = video.shape + video = video.permute(0, 2, 1, 3, 4).reshape( + bsz * frames, channel, width, height) + + if video.shape[1] == 4: + init_latents = video + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + elif isinstance(generator, list): + init_latents = [ + retrieve_latents(self.vae.encode( + video[i: i + 1]), generator=generator[i]) + for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + if not is_long: + # 1 step encoding + init_latents = retrieve_latents( + self.vae.encode(video), generator=generator) + else: + # chunk by chunk encoding. for low-memory consumption. + video_list = torch.chunk( + video, video.shape[0] // 16, dim=0) + with torch.no_grad(): + init_latents = [] + for video_chunk in video_list: + video_chunk = retrieve_latents( + self.vae.encode(video_chunk), generator=generator) + init_latents.append(video_chunk) + init_latents = torch.cat(init_latents, dim=0) + # torch.cuda.empty_cache() + + init_latents = self.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `video` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, + device=device, dtype=dtype) + + latents = self.scheduler.add_noise(init_latents, noise, timestep) + latents = latents[None, :].reshape( + (bsz, frames, latents.shape[1]) + latents.shape[2:]).permute(0, 2, 1, 3, 4) + + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + # Now image can be either a single image or a list of images (when randomized blending is enalbled). + image: Union[List[PipelineImageInput], PipelineImageInput] = None, + video: Union[List[np.ndarray], torch.Tensor] = None, + strength: float = 0.97, + overlap_size: int = 0, + chunk_size: int = 38, + height: Optional[int] = 720, + width: Optional[int] = 1280, + target_fps: Optional[int] = 38, + num_frames: int = 38, + num_inference_steps: int = 50, + guidance_scale: float = 9.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + eta: float = 0.0, + num_videos_per_prompt: Optional[int] = 1, + decode_chunk_size: Optional[int] = 1, + generator: Optional[Union[torch.Generator, + List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + clip_skip: Optional[int] = 1, + ): + r""" + The call function to the pipeline for image-to-video generation with [`I2VGenXLPipeline`]. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.Tensor`): + Image or images to guide image generation. If you provide a tensor, it needs to be compatible with + [`CLIPImageProcessor`](https://huggingface.co./lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json). + video (`List[np.ndarray]` or `torch.Tensor`): + Video to guide video enhancement. + strength (`float`, *optional*, defaults to 0.97): + Indicates extent to transform the reference `video`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + overlap_size (`int`, *optional*, defaults to 0): + This parameter is used in randomized blending, when it is enabled. + It defines the size of overlap between neighbouring chunks. + chunk_size (`int`, *optional*, defaults to 38): + This parameter is used in randomized blending, when it is enabled. + It defines the number of frames we will enhance during each chunk of randomized blending. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + target_fps (`int`, *optional*): + Frames per second. The rate at which the generated images shall be exported to a video after + generation. This is also used as a "micro-condition" while generation. + num_frames (`int`, *optional*): + The number of video frames to generate. + num_inference_steps (`int`, *optional*): + The number of denoising steps. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + eta (`float`, *optional*): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + num_videos_per_prompt (`int`, *optional*): + The number of images to generate per prompt. + decode_chunk_size (`int`, *optional*): + The number of frames to decode at a time. The higher the chunk size, the higher the temporal + consistency between frames, but also the higher the memory consumption. By default, the decoder will + decode all frames at once for maximal quality. Reduce `decode_chunk_size` to reduce memory usage. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + + Examples: + + Returns: + [`pipelines.i2vgen_xl.pipeline_i2vgen_xl.I2VGenXLPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`pipelines.i2vgen_xl.pipeline_i2vgen_xl.I2VGenXLPipelineOutput`] is + returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, image, height, width, + negative_prompt, prompt_embeds, negative_prompt_embeds) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + self._guidance_scale = guidance_scale + + # 3.1 Encode input text prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_videos_per_prompt, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + clip_skip=clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 3.2 Encode image prompt + # 3.2.1 Image encodings. + # https://github.com/ali-vilab/i2vgen-xl/blob/2539c9262ff8a2a22fa9daecbfd13f0a2dbc32d0/tools/inferences/inference_i2vgen_entrance.py#L114 + # As now we can have a list of images (when randomized blending), we encode each image separately as before. + image_embeddings_list = [] + for img in image: + cropped_image = _center_crop_wide(img, (width, width)) + cropped_image = _resize_bilinear( + cropped_image, (self.feature_extractor.crop_size["width"], + self.feature_extractor.crop_size["height"]) + ) + image_embeddings = self._encode_image( + cropped_image, device, num_videos_per_prompt) + image_embeddings_list.append(image_embeddings) + + # 3.2.2 Image latents. + # As now we can have a list of images (when randomized blending), we encode each image separately as before. + image_latents_list = [] + for img in image: + resized_image = _center_crop_wide(img, (width, height)) + img = self.video_processor.preprocess(resized_image).to( + device=device, dtype=image_embeddings_list[0].dtype) + image_latents = self.prepare_image_latents( + img, + device=device, + num_frames=num_frames, + num_videos_per_prompt=num_videos_per_prompt, + ) + image_latents_list.append(image_latents) + + # 3.3 Prepare additional conditions for the UNet. + if self.do_classifier_free_guidance: + fps_tensor = torch.tensor([target_fps, target_fps]).to(device) + else: + fps_tensor = torch.tensor([target_fps]).to(device) + fps_tensor = fps_tensor.repeat( + batch_size * num_videos_per_prompt, 1).ravel() + + # 3.4 Preprocess video, similar to images. + video = self.video_processor.preprocess_video(video).to( + device=device, dtype=image_embeddings_list[0].dtype) + num_images_per_prompt = 1 + + # 4. Prepare timesteps. This will be used for modified SDEdit approach. + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat( + batch_size * num_images_per_prompt) + + # 5. Prepare latent variables. Now we get latents for input video. + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_video_latents( + video, + latent_timestep, + batch_size * num_videos_per_prompt, + num_channels_latents, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - \ + num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + latents_denoised = torch.empty_like(latents) + + CHUNK_START = 0 + # Each chunk must have a corresponding 1st frame + for idx in range(len(image_latents_list)): + latents_chunk = latents[:, :, + CHUNK_START:CHUNK_START + chunk_size] + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat( + [latents_chunk] * 2) if self.do_classifier_free_guidance else latents_chunk + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + fps=fps_tensor, + image_latents=image_latents_list[idx], + image_embeddings=image_embeddings_list[idx], + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk( + 2) + noise_pred = noise_pred_uncond + guidance_scale * \ + (noise_pred_text - noise_pred_uncond) + + # reshape latents_chunk + batch_size, channel, frames, width, height = latents_chunk.shape + latents_chunk = latents_chunk.permute(0, 2, 1, 3, 4).reshape( + batch_size * frames, channel, width, height) + noise_pred = noise_pred.permute(0, 2, 1, 3, 4).reshape( + batch_size * frames, channel, width, height) + + # compute the previous noisy sample x_t -> x_t-1 + latents_chunk = self.scheduler.step( + noise_pred, t, latents_chunk, **extra_step_kwargs).prev_sample + + # reshape latents back + latents_chunk = latents_chunk[None, :].reshape( + batch_size, frames, channel, width, height).permute(0, 2, 1, 3, 4) + + # Make sure random_offset is set correctly. + if CHUNK_START == 0: + random_offset = 0 + else: + if overlap_size != 0: + random_offset = random.randint(0, overlap_size - 1) + else: + random_offset = 0 + + # Apply Randomized Blending. + latents_denoised[:, :, CHUNK_START + random_offset:CHUNK_START + + chunk_size] = latents_chunk[:, :, random_offset:] + CHUNK_START += chunk_size - overlap_size + + latents = latents_denoised + + if CHUNK_START + overlap_size > latents_denoised.shape[2]: + raise NotImplementedError(f"Video of size={latents_denoised.shape[2]} is not dividable into chunks " + f"with size={chunk_size} and overlap={overlap_size}") + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + # 8. Post processing + if output_type == "latent": + video = latents + else: + video_tensor = self.decode_latents( + latents, decode_chunk_size=decode_chunk_size) + video = self.video_processor.postprocess_video( + video=video_tensor, output_type=output_type) + + # 9. Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return I2VGenXLPipelineOutput(frames=video) + + +# The following utilities are taken and adapted from +# https://github.com/ali-vilab/i2vgen-xl/blob/main/utils/transforms.py. + + +def _convert_pt_to_pil(image: Union[torch.Tensor, List[torch.Tensor]]): + if isinstance(image, list) and isinstance(image[0], torch.Tensor): + image = torch.cat(image, 0) + + if isinstance(image, torch.Tensor): + if image.ndim == 3: + image = image.unsqueeze(0) + + image_numpy = VaeImageProcessor.pt_to_numpy(image) + image_pil = VaeImageProcessor.numpy_to_pil(image_numpy) + image = image_pil + + return image + + +def _resize_bilinear( + image: Union[torch.Tensor, List[torch.Tensor], PIL.Image.Image, List[PIL.Image.Image]], resolution: Tuple[int, int] +): + # First convert the images to PIL in case they are float tensors (only relevant for tests now). + image = _convert_pt_to_pil(image) + + if isinstance(image, list): + image = [u.resize(resolution, PIL.Image.BILINEAR) for u in image] + else: + image = image.resize(resolution, PIL.Image.BILINEAR) + return image + + +def _center_crop_wide( + image: Union[torch.Tensor, List[torch.Tensor], PIL.Image.Image, List[PIL.Image.Image]], resolution: Tuple[int, int] +): + # First convert the images to PIL in case they are float tensors (only relevant for tests now). + image = _convert_pt_to_pil(image) + + if isinstance(image, list): + scale = min(image[0].size[0] / resolution[0], + image[0].size[1] / resolution[1]) + image = [u.resize((round(u.width // scale), round(u.height // + scale)), resample=PIL.Image.BOX) for u in image] + + # center crop + x1 = (image[0].width - resolution[0]) // 2 + y1 = (image[0].height - resolution[1]) // 2 + image = [u.crop((x1, y1, x1 + resolution[0], y1 + resolution[1])) + for u in image] + return image + else: + scale = min(image.size[0] / resolution[0], + image.size[1] / resolution[1]) + image = image.resize((round(image.width // scale), + round(image.height // scale)), resample=PIL.Image.BOX) + x1 = (image.width - resolution[0]) // 2 + y1 = (image.height - resolution[1]) // 2 + image = image.crop((x1, y1, x1 + resolution[0], y1 + resolution[1])) + return image diff --git a/i2v_enhance/thirdparty/VFI/Trainer.py b/i2v_enhance/thirdparty/VFI/Trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..8bb45af555724200e31ae74581da6ac57a1fd616 --- /dev/null +++ b/i2v_enhance/thirdparty/VFI/Trainer.py @@ -0,0 +1,168 @@ +# Adapted from https://github.com/MCG-NJU/EMA-VFI/blob/main/Trainer.py +import torch +import torch.nn.functional as F +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import AdamW +from i2v_enhance.thirdparty.VFI.model.loss import * +from i2v_enhance.thirdparty.VFI.config import * + + +class Model: + def __init__(self, local_rank): + backbonetype, multiscaletype = MODEL_CONFIG['MODEL_TYPE'] + backbonecfg, multiscalecfg = MODEL_CONFIG['MODEL_ARCH'] + self.net = multiscaletype(backbonetype(**backbonecfg), **multiscalecfg) + self.name = MODEL_CONFIG['LOGNAME'] + self.device() + + # train + self.optimG = AdamW(self.net.parameters(), lr=2e-4, weight_decay=1e-4) + self.lap = LapLoss() + if local_rank != -1: + self.net = DDP(self.net, device_ids=[local_rank], output_device=local_rank) + + def train(self): + self.net.train() + + def eval(self): + self.net.eval() + + def device(self): + self.net.to(torch.device("cuda")) + + def unload(self): + self.net.to(torch.device("cpu")) + + def load_model(self, name=None, rank=0): + def convert(param): + return { + k.replace("module.", ""): v + for k, v in param.items() + if "module." in k and 'attn_mask' not in k and 'HW' not in k + } + if rank <= 0 : + if name is None: + name = self.name + # self.net.load_state_dict(convert(torch.load(f'ckpt/{name}.pkl'))) + self.net.load_state_dict(convert(torch.load(f'{name}'))) + + def save_model(self, rank=0): + if rank == 0: + torch.save(self.net.state_dict(),f'ckpt/{self.name}.pkl') + + @torch.no_grad() + def hr_inference(self, img0, img1, TTA = False, down_scale = 1.0, timestep = 0.5, fast_TTA = False): + ''' + Infer with down_scale flow + Noting: return BxCxHxW + ''' + def infer(imgs): + img0, img1 = imgs[:, :3], imgs[:, 3:6] + imgs_down = F.interpolate(imgs, scale_factor=down_scale, mode="bilinear", align_corners=False) + + flow, mask = self.net.calculate_flow(imgs_down, timestep) + + flow = F.interpolate(flow, scale_factor = 1/down_scale, mode="bilinear", align_corners=False) * (1/down_scale) + mask = F.interpolate(mask, scale_factor = 1/down_scale, mode="bilinear", align_corners=False) + + af, _ = self.net.feature_bone(img0, img1) + pred = self.net.coraseWarp_and_Refine(imgs, af, flow, mask) + return pred + + imgs = torch.cat((img0, img1), 1) + if fast_TTA: + imgs_ = imgs.flip(2).flip(3) + input = torch.cat((imgs, imgs_), 0) + preds = infer(input) + return (preds[0] + preds[1].flip(1).flip(2)).unsqueeze(0) / 2. + + if TTA == False: + return infer(imgs) + else: + return (infer(imgs) + infer(imgs.flip(2).flip(3)).flip(2).flip(3)) / 2 + + @torch.no_grad() + def inference(self, img0, img1, TTA = False, timestep = 0.5, fast_TTA = False): + imgs = torch.cat((img0, img1), 1) + ''' + Noting: return BxCxHxW + ''' + if fast_TTA: + imgs_ = imgs.flip(2).flip(3) + input = torch.cat((imgs, imgs_), 0) + _, _, _, preds = self.net(input, timestep=timestep) + return (preds[0] + preds[1].flip(1).flip(2)).unsqueeze(0) / 2. + + _, _, _, pred = self.net(imgs, timestep=timestep) + if TTA == False: + return pred + else: + _, _, _, pred2 = self.net(imgs.flip(2).flip(3), timestep=timestep) + return (pred + pred2.flip(2).flip(3)) / 2 + + @torch.no_grad() + def multi_inference(self, img0, img1, TTA = False, down_scale = 1.0, time_list=[], fast_TTA = False): + ''' + Run backbone once, get multi frames at different timesteps + Noting: return a list of [CxHxW] + ''' + assert len(time_list) > 0, 'Time_list should not be empty!' + def infer(imgs): + img0, img1 = imgs[:, :3], imgs[:, 3:6] + af, mf = self.net.feature_bone(img0, img1) + imgs_down = None + if down_scale != 1.0: + imgs_down = F.interpolate(imgs, scale_factor=down_scale, mode="bilinear", align_corners=False) + afd, mfd = self.net.feature_bone(imgs_down[:, :3], imgs_down[:, 3:6]) + + pred_list = [] + for timestep in time_list: + if imgs_down is None: + flow, mask = self.net.calculate_flow(imgs, timestep, af, mf) + else: + flow, mask = self.net.calculate_flow(imgs_down, timestep, afd, mfd) + flow = F.interpolate(flow, scale_factor = 1/down_scale, mode="bilinear", align_corners=False) * (1/down_scale) + mask = F.interpolate(mask, scale_factor = 1/down_scale, mode="bilinear", align_corners=False) + + pred = self.net.coraseWarp_and_Refine(imgs, af, flow, mask) + pred_list.append(pred) + + return pred_list + + imgs = torch.cat((img0, img1), 1) + if fast_TTA: + imgs_ = imgs.flip(2).flip(3) + input = torch.cat((imgs, imgs_), 0) + preds_lst = infer(input) + return [(preds_lst[i][0] + preds_lst[i][1].flip(1).flip(2))/2 for i in range(len(time_list))] + + preds = infer(imgs) + if TTA is False: + return [preds[i][0] for i in range(len(time_list))] + else: + flip_pred = infer(imgs.flip(2).flip(3)) + return [(preds[i][0] + flip_pred[i][0].flip(1).flip(2))/2 for i in range(len(time_list))] + + def update(self, imgs, gt, learning_rate=0, training=True): + for param_group in self.optimG.param_groups: + param_group['lr'] = learning_rate + if training: + self.train() + else: + self.eval() + + if training: + flow, mask, merged, pred = self.net(imgs) + loss_l1 = (self.lap(pred, gt)).mean() + + for merge in merged: + loss_l1 += (self.lap(merge, gt)).mean() * 0.5 + + self.optimG.zero_grad() + loss_l1.backward() + self.optimG.step() + return pred, loss_l1 + else: + with torch.no_grad(): + flow, mask, merged, pred = self.net(imgs) + return pred, 0 diff --git a/i2v_enhance/thirdparty/VFI/ckpt/Put ours.pkl files here.txt b/i2v_enhance/thirdparty/VFI/ckpt/Put ours.pkl files here.txt new file mode 100644 index 0000000000000000000000000000000000000000..1367c7a25a5cf19a8e1665567ca4f5ec41aa7f65 --- /dev/null +++ b/i2v_enhance/thirdparty/VFI/ckpt/Put ours.pkl files here.txt @@ -0,0 +1 @@ +here is the link to the all EMA-VFI models:https://drive.google.com/drive/folders/16jUa3HkQ85Z5lb5gce1yoaWkP-rdCd0o \ No newline at end of file diff --git a/i2v_enhance/thirdparty/VFI/ckpt/__init__.py b/i2v_enhance/thirdparty/VFI/ckpt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/i2v_enhance/thirdparty/VFI/config.py b/i2v_enhance/thirdparty/VFI/config.py new file mode 100644 index 0000000000000000000000000000000000000000..7ff22763fd62319f2c287a12fc8840463715cc3c --- /dev/null +++ b/i2v_enhance/thirdparty/VFI/config.py @@ -0,0 +1,49 @@ +# Adapted from https://github.com/MCG-NJU/EMA-VFI/blob/main/config.py +from functools import partial +import torch.nn as nn + +from i2v_enhance.thirdparty.VFI.model import feature_extractor +from i2v_enhance.thirdparty.VFI.model import flow_estimation + +'''==========Model config==========''' +def init_model_config(F=32, W=7, depth=[2, 2, 2, 4, 4]): + '''This function should not be modified''' + return { + 'embed_dims':[F, 2*F, 4*F, 8*F, 16*F], + 'motion_dims':[0, 0, 0, 8*F//depth[-2], 16*F//depth[-1]], + 'num_heads':[8*F//32, 16*F//32], + 'mlp_ratios':[4, 4], + 'qkv_bias':True, + 'norm_layer':partial(nn.LayerNorm, eps=1e-6), + 'depths':depth, + 'window_sizes':[W, W] + }, { + 'embed_dims':[F, 2*F, 4*F, 8*F, 16*F], + 'motion_dims':[0, 0, 0, 8*F//depth[-2], 16*F//depth[-1]], + 'depths':depth, + 'num_heads':[8*F//32, 16*F//32], + 'window_sizes':[W, W], + 'scales':[4, 8, 16], + 'hidden_dims':[4*F, 4*F], + 'c':F + } + +MODEL_CONFIG = { + 'LOGNAME': 'ours', + 'MODEL_TYPE': (feature_extractor, flow_estimation), + 'MODEL_ARCH': init_model_config( + F = 32, + W = 7, + depth = [2, 2, 2, 4, 4] + ) +} + +# MODEL_CONFIG = { +# 'LOGNAME': 'ours_small', +# 'MODEL_TYPE': (feature_extractor, flow_estimation), +# 'MODEL_ARCH': init_model_config( +# F = 16, +# W = 7, +# depth = [2, 2, 2, 2, 2] +# ) +# } \ No newline at end of file diff --git a/i2v_enhance/thirdparty/VFI/dataset.py b/i2v_enhance/thirdparty/VFI/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..dcd5583151b581f3de4db2500cbef82f75271fa2 --- /dev/null +++ b/i2v_enhance/thirdparty/VFI/dataset.py @@ -0,0 +1,93 @@ +# Adapted from https://github.com/MCG-NJU/EMA-VFI/blob/main/dataset.py +import cv2 +import os +import torch +import numpy as np +import random +from torch.utils.data import Dataset +from config import * + +cv2.setNumThreads(1) +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +class VimeoDataset(Dataset): + def __init__(self, dataset_name, path, batch_size=32, model="RIFE"): + self.batch_size = batch_size + self.dataset_name = dataset_name + self.model = model + self.h = 256 + self.w = 448 + self.data_root = path + self.image_root = os.path.join(self.data_root, 'sequences') + train_fn = os.path.join(self.data_root, 'tri_trainlist.txt') + test_fn = os.path.join(self.data_root, 'tri_testlist.txt') + with open(train_fn, 'r') as f: + self.trainlist = f.read().splitlines() + with open(test_fn, 'r') as f: + self.testlist = f.read().splitlines() + self.load_data() + + def __len__(self): + return len(self.meta_data) + + def load_data(self): + if self.dataset_name != 'test': + self.meta_data = self.trainlist + else: + self.meta_data = self.testlist + + def aug(self, img0, gt, img1, h, w): + ih, iw, _ = img0.shape + x = np.random.randint(0, ih - h + 1) + y = np.random.randint(0, iw - w + 1) + img0 = img0[x:x+h, y:y+w, :] + img1 = img1[x:x+h, y:y+w, :] + gt = gt[x:x+h, y:y+w, :] + return img0, gt, img1 + + def getimg(self, index): + imgpath = os.path.join(self.image_root, self.meta_data[index]) + imgpaths = [imgpath + '/im1.png', imgpath + '/im2.png', imgpath + '/im3.png'] + + img0 = cv2.imread(imgpaths[0]) + gt = cv2.imread(imgpaths[1]) + img1 = cv2.imread(imgpaths[2]) + return img0, gt, img1 + + def __getitem__(self, index): + img0, gt, img1 = self.getimg(index) + + if 'train' in self.dataset_name: + img0, gt, img1 = self.aug(img0, gt, img1, 256, 256) + if random.uniform(0, 1) < 0.5: + img0 = img0[:, :, ::-1] + img1 = img1[:, :, ::-1] + gt = gt[:, :, ::-1] + if random.uniform(0, 1) < 0.5: + img1, img0 = img0, img1 + if random.uniform(0, 1) < 0.5: + img0 = img0[::-1] + img1 = img1[::-1] + gt = gt[::-1] + if random.uniform(0, 1) < 0.5: + img0 = img0[:, ::-1] + img1 = img1[:, ::-1] + gt = gt[:, ::-1] + + p = random.uniform(0, 1) + if p < 0.25: + img0 = cv2.rotate(img0, cv2.ROTATE_90_CLOCKWISE) + gt = cv2.rotate(gt, cv2.ROTATE_90_CLOCKWISE) + img1 = cv2.rotate(img1, cv2.ROTATE_90_CLOCKWISE) + elif p < 0.5: + img0 = cv2.rotate(img0, cv2.ROTATE_180) + gt = cv2.rotate(gt, cv2.ROTATE_180) + img1 = cv2.rotate(img1, cv2.ROTATE_180) + elif p < 0.75: + img0 = cv2.rotate(img0, cv2.ROTATE_90_COUNTERCLOCKWISE) + gt = cv2.rotate(gt, cv2.ROTATE_90_COUNTERCLOCKWISE) + img1 = cv2.rotate(img1, cv2.ROTATE_90_COUNTERCLOCKWISE) + + img0 = torch.from_numpy(img0.copy()).permute(2, 0, 1) + img1 = torch.from_numpy(img1.copy()).permute(2, 0, 1) + gt = torch.from_numpy(gt.copy()).permute(2, 0, 1) + return torch.cat((img0, img1, gt), 0) diff --git a/i2v_enhance/thirdparty/VFI/model/__init__.py b/i2v_enhance/thirdparty/VFI/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f92d852e46d9f2f94468a7acf39fd001a3ef825a --- /dev/null +++ b/i2v_enhance/thirdparty/VFI/model/__init__.py @@ -0,0 +1,5 @@ +from .feature_extractor import feature_extractor +from .flow_estimation import MultiScaleFlow as flow_estimation + + +__all__ = ['feature_extractor', 'flow_estimation'] \ No newline at end of file diff --git a/i2v_enhance/thirdparty/VFI/model/feature_extractor.py b/i2v_enhance/thirdparty/VFI/model/feature_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..9a6b8f8147cd30a491305f8adfb67646ef5770c0 --- /dev/null +++ b/i2v_enhance/thirdparty/VFI/model/feature_extractor.py @@ -0,0 +1,516 @@ +# Adapted from https://github.com/MCG-NJU/EMA-VFI/blob/main/model/feature_extractor.py +import torch +import torch.nn as nn +import math +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + +def window_partition(x, window_size): + B, H, W, C = x.shape + x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C) + windows = ( + x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0]*window_size[1], C) + ) + return windows + + +def window_reverse(windows, window_size, H, W): + nwB, N, C = windows.shape + windows = windows.view(-1, window_size[0], window_size[1], C) + B = int(nwB / (H * W / window_size[0] / window_size[1])) + x = windows.view( + B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1 + ) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +def pad_if_needed(x, size, window_size): + n, h, w, c = size + pad_h = math.ceil(h / window_size[0]) * window_size[0] - h + pad_w = math.ceil(w / window_size[1]) * window_size[1] - w + if pad_h > 0 or pad_w > 0: # center-pad the feature on H and W axes + img_mask = torch.zeros((1, h+pad_h, w+pad_w, 1)) # 1 H W 1 + h_slices = ( + slice(0, pad_h//2), + slice(pad_h//2, h+pad_h//2), + slice(h+pad_h//2, None), + ) + w_slices = ( + slice(0, pad_w//2), + slice(pad_w//2, w+pad_w//2), + slice(w+pad_w//2, None), + ) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition( + img_mask, window_size + ) # nW, window_size*window_size, 1 + mask_windows = mask_windows.squeeze(-1) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill( + attn_mask != 0, float(-100.0) + ).masked_fill(attn_mask == 0, float(0.0)) + return nn.functional.pad( + x, + (0, 0, pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2), + ), attn_mask + return x, None + + +def depad_if_needed(x, size, window_size): + n, h, w, c = size + pad_h = math.ceil(h / window_size[0]) * window_size[0] - h + pad_w = math.ceil(w / window_size[1]) * window_size[1] - w + if pad_h > 0 or pad_w > 0: # remove the center-padding on feature + return x[:, pad_h // 2 : pad_h // 2 + h, pad_w // 2 : pad_w // 2 + w, :].contiguous() + return x + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.dwconv = DWConv(hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + self.relu = nn.ReLU(inplace=True) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + x = self.fc1(x) + x = self.dwconv(x, H, W) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class InterFrameAttention(nn.Module): + def __init__(self, dim, motion_dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.motion_dim = motion_dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) + self.cor_embed = nn.Linear(2, motion_dim, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.motion_proj = nn.Linear(motion_dim, motion_dim) + self.proj_drop = nn.Dropout(proj_drop) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x1, x2, cor, H, W, mask=None): + B, N, C = x1.shape + B, N, C_c = cor.shape + q = self.q(x1).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + kv = self.kv(x2).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + cor_embed_ = self.cor_embed(cor) + cor_embed = cor_embed_.reshape(B, N, self.num_heads, self.motion_dim // self.num_heads).permute(0, 2, 1, 3) + k, v = kv[0], kv[1] + attn = (q @ k.transpose(-2, -1)) * self.scale + + if mask is not None: + nW = mask.shape[0] # mask: nW, N, N + attn = attn.view(B // nW, nW, self.num_heads, N, N) + mask.unsqueeze( + 1 + ).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = attn.softmax(dim=-1) + else: + attn = attn.softmax(dim=-1) + + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + c_reverse = (attn @ cor_embed).transpose(1, 2).reshape(B, N, -1) + motion = self.motion_proj(c_reverse-cor_embed_) + x = self.proj(x) + x = self.proj_drop(x) + return x, motion + + +class MotionFormerBlock(nn.Module): + def __init__(self, dim, motion_dim, num_heads, window_size=0, shift_size=0, mlp_ratio=4., bidirectional=True, qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,): + super().__init__() + self.window_size = window_size + if not isinstance(self.window_size, (tuple, list)): + self.window_size = to_2tuple(window_size) + self.shift_size = shift_size + if not isinstance(self.shift_size, (tuple, list)): + self.shift_size = to_2tuple(shift_size) + self.bidirectional = bidirectional + self.norm1 = norm_layer(dim) + self.attn = InterFrameAttention( + dim, + motion_dim, + num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, cor, H, W, B): + x = x.view(2*B, H, W, -1) + x_pad, mask = pad_if_needed(x, x.size(), self.window_size) + cor_pad, _ = pad_if_needed(cor, cor.size(), self.window_size) + + if self.shift_size[0] or self.shift_size[1]: + _, H_p, W_p, C = x_pad.shape + x_pad = torch.roll(x_pad, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2)) + cor_pad = torch.roll(cor_pad, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2)) + + if hasattr(self, 'HW') and self.HW.item() == H_p * W_p: + shift_mask = self.attn_mask + else: + shift_mask = torch.zeros((1, H_p, W_p, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size[0]), + slice(-self.window_size[0], -self.shift_size[0]), + slice(-self.shift_size[0], None)) + w_slices = (slice(0, -self.window_size[1]), + slice(-self.window_size[1], -self.shift_size[1]), + slice(-self.shift_size[1], None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + shift_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(shift_mask, self.window_size).squeeze(-1) + shift_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + shift_mask = shift_mask.masked_fill(shift_mask != 0, + float(-100.0)).masked_fill(shift_mask == 0, + float(0.0)) + + if mask is not None: + shift_mask = shift_mask.masked_fill(mask != 0, + float(-100.0)) + self.register_buffer("attn_mask", shift_mask) + self.register_buffer("HW", torch.Tensor([H_p*W_p])) + else: + shift_mask = mask + + if shift_mask is not None: + shift_mask = shift_mask.to(x_pad.device) + + + _, Hw, Ww, C = x_pad.shape + x_win = window_partition(x_pad, self.window_size) + cor_win = window_partition(cor_pad, self.window_size) + + nwB = x_win.shape[0] + x_norm = self.norm1(x_win) + + x_reverse = torch.cat([x_norm[nwB//2:], x_norm[:nwB//2]]) + x_appearence, x_motion = self.attn(x_norm, x_reverse, cor_win, H, W, shift_mask) + x_norm = x_norm + self.drop_path(x_appearence) + + x_back = x_norm + x_back_win = window_reverse(x_back, self.window_size, Hw, Ww) + x_motion = window_reverse(x_motion, self.window_size, Hw, Ww) + + if self.shift_size[0] or self.shift_size[1]: + x_back_win = torch.roll(x_back_win, shifts=(self.shift_size[0], self.shift_size[1]), dims=(1, 2)) + x_motion = torch.roll(x_motion, shifts=(self.shift_size[0], self.shift_size[1]), dims=(1, 2)) + + x = depad_if_needed(x_back_win, x.size(), self.window_size).view(2*B, H * W, -1) + x_motion = depad_if_needed(x_motion, cor.size(), self.window_size).view(2*B, H * W, -1) + + x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) + return x, x_motion + + +class ConvBlock(nn.Module): + def __init__(self, in_dim, out_dim, depths=2,act_layer=nn.PReLU): + super().__init__() + layers = [] + for i in range(depths): + if i == 0: + layers.append(nn.Conv2d(in_dim, out_dim, 3,1,1)) + else: + layers.append(nn.Conv2d(out_dim, out_dim, 3,1,1)) + layers.extend([ + act_layer(out_dim), + ]) + self.conv = nn.Sequential(*layers) + + def _init_weights(self, m): + if isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x): + x = self.conv(x) + return x + + +class OverlapPatchEmbed(nn.Module): + def __init__(self, patch_size=7, stride=4, in_chans=3, embed_dim=768): + super().__init__() + patch_size = to_2tuple(patch_size) + + self.patch_size = patch_size + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, + padding=(patch_size[0] // 2, patch_size[1] // 2)) + self.norm = nn.LayerNorm(embed_dim) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x): + x = self.proj(x) + _, _, H, W = x.shape + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + + return x, H, W + + +class CrossScalePatchEmbed(nn.Module): + def __init__(self, in_dims=[16,32,64], embed_dim=768): + super().__init__() + base_dim = in_dims[0] + + layers = [] + for i in range(len(in_dims)): + for j in range(2 ** i): + layers.append(nn.Conv2d(in_dims[-1-i], base_dim, 3, 2**(i+1), 1+j, 1+j)) + self.layers = nn.ModuleList(layers) + self.proj = nn.Conv2d(base_dim * len(layers), embed_dim, 1, 1) + self.norm = nn.LayerNorm(embed_dim) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, xs): + ys = [] + k = 0 + for i in range(len(xs)): + for _ in range(2 ** i): + ys.append(self.layers[k](xs[-1-i])) + k += 1 + x = self.proj(torch.cat(ys,1)) + _, _, H, W = x.shape + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + + return x, H, W + + +class MotionFormer(nn.Module): + def __init__(self, in_chans=3, embed_dims=[32, 64, 128, 256, 512], motion_dims=64, num_heads=[8, 16], + mlp_ratios=[4, 4], qkv_bias=True, qk_scale=None, drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, + depths=[2, 2, 2, 6, 2], window_sizes=[11, 11],**kwarg): + super().__init__() + self.depths = depths + self.num_stages = len(embed_dims) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + cur = 0 + + self.conv_stages = self.num_stages - len(num_heads) + + for i in range(self.num_stages): + if i == 0: + block = ConvBlock(in_chans,embed_dims[i],depths[i]) + else: + if i < self.conv_stages: + patch_embed = nn.Sequential( + nn.Conv2d(embed_dims[i-1], embed_dims[i], 3,2,1), + nn.PReLU(embed_dims[i]) + ) + block = ConvBlock(embed_dims[i],embed_dims[i],depths[i]) + else: + if i == self.conv_stages: + patch_embed = CrossScalePatchEmbed(embed_dims[:i], + embed_dim=embed_dims[i]) + else: + patch_embed = OverlapPatchEmbed(patch_size=3, + stride=2, + in_chans=embed_dims[i - 1], + embed_dim=embed_dims[i]) + + block = nn.ModuleList([MotionFormerBlock( + dim=embed_dims[i], motion_dim=motion_dims[i], num_heads=num_heads[i-self.conv_stages], window_size=window_sizes[i-self.conv_stages], + shift_size= 0 if (j % 2) == 0 else window_sizes[i-self.conv_stages] // 2, + mlp_ratio=mlp_ratios[i-self.conv_stages], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + j], norm_layer=norm_layer) + for j in range(depths[i])]) + + norm = norm_layer(embed_dims[i]) + setattr(self, f"norm{i + 1}", norm) + setattr(self, f"patch_embed{i + 1}", patch_embed) + cur += depths[i] + + setattr(self, f"block{i + 1}", block) + + self.cor = {} + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def get_cor(self, shape, device): + k = (str(shape), str(device)) + if k not in self.cor: + tenHorizontal = torch.linspace(-1.0, 1.0, shape[2], device=device).view( + 1, 1, 1, shape[2]).expand(shape[0], -1, shape[1], -1).permute(0, 2, 3, 1) + tenVertical = torch.linspace(-1.0, 1.0, shape[1], device=device).view( + 1, 1, shape[1], 1).expand(shape[0], -1, -1, shape[2]).permute(0, 2, 3, 1) + self.cor[k] = torch.cat([tenHorizontal, tenVertical], -1).to(device) + return self.cor[k] + + def forward(self, x1, x2): + B = x1.shape[0] + x = torch.cat([x1, x2], 0) + motion_features = [] + appearence_features = [] + xs = [] + for i in range(self.num_stages): + motion_features.append([]) + patch_embed = getattr(self, f"patch_embed{i + 1}",None) + block = getattr(self, f"block{i + 1}",None) + norm = getattr(self, f"norm{i + 1}",None) + if i < self.conv_stages: + if i > 0: + x = patch_embed(x) + x = block(x) + xs.append(x) + else: + if i == self.conv_stages: + x, H, W = patch_embed(xs) + else: + x, H, W = patch_embed(x) + cor = self.get_cor((x.shape[0], H, W), x.device) + for blk in block: + x, x_motion = blk(x, cor, H, W, B) + motion_features[i].append(x_motion.reshape(2*B, H, W, -1).permute(0, 3, 1, 2).contiguous()) + x = norm(x) + x = x.reshape(2*B, H, W, -1).permute(0, 3, 1, 2).contiguous() + motion_features[i] = torch.cat(motion_features[i], 1) + appearence_features.append(x) + return appearence_features, motion_features + + +class DWConv(nn.Module): + def __init__(self, dim): + super(DWConv, self).__init__() + self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) + + def forward(self, x, H, W): + B, N, C = x.shape + x = x.transpose(1, 2).reshape(B, C, H, W) + x = self.dwconv(x) + x = x.reshape(B, C, -1).transpose(1, 2) + + return x + + +def feature_extractor(**kargs): + model = MotionFormer(**kargs) + return model \ No newline at end of file diff --git a/i2v_enhance/thirdparty/VFI/model/flow_estimation.py b/i2v_enhance/thirdparty/VFI/model/flow_estimation.py new file mode 100644 index 0000000000000000000000000000000000000000..2bbaf7f5a1fad58027861d41e598862702680fda --- /dev/null +++ b/i2v_enhance/thirdparty/VFI/model/flow_estimation.py @@ -0,0 +1,141 @@ +# Adapted from https://github.com/MCG-NJU/EMA-VFI/blob/main/model/flow_estimation +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .warplayer import warp +from .refine import * + +def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, bias=True), + nn.PReLU(out_planes) + ) + + +class Head(nn.Module): + def __init__(self, in_planes, scale, c, in_else=17): + super(Head, self).__init__() + self.upsample = nn.Sequential(nn.PixelShuffle(2), nn.PixelShuffle(2)) + self.scale = scale + self.conv = nn.Sequential( + conv(in_planes*2 // (4*4) + in_else, c), + conv(c, c), + conv(c, 5), + ) + + def forward(self, motion_feature, x, flow): # /16 /8 /4 + motion_feature = self.upsample(motion_feature) #/4 /2 /1 + if self.scale != 4: + x = F.interpolate(x, scale_factor = 4. / self.scale, mode="bilinear", align_corners=False) + if flow != None: + if self.scale != 4: + flow = F.interpolate(flow, scale_factor = 4. / self.scale, mode="bilinear", align_corners=False) * 4. / self.scale + x = torch.cat((x, flow), 1) + x = self.conv(torch.cat([motion_feature, x], 1)) + if self.scale != 4: + x = F.interpolate(x, scale_factor = self.scale // 4, mode="bilinear", align_corners=False) + flow = x[:, :4] * (self.scale // 4) + else: + flow = x[:, :4] + mask = x[:, 4:5] + return flow, mask + + +class MultiScaleFlow(nn.Module): + def __init__(self, backbone, **kargs): + super(MultiScaleFlow, self).__init__() + self.flow_num_stage = len(kargs['hidden_dims']) + self.feature_bone = backbone + self.block = nn.ModuleList([Head( kargs['motion_dims'][-1-i] * kargs['depths'][-1-i] + kargs['embed_dims'][-1-i], + kargs['scales'][-1-i], + kargs['hidden_dims'][-1-i], + 6 if i==0 else 17) + for i in range(self.flow_num_stage)]) + self.unet = Unet(kargs['c'] * 2) + + def warp_features(self, xs, flow): + y0 = [] + y1 = [] + B = xs[0].size(0) // 2 + for x in xs: + y0.append(warp(x[:B], flow[:, 0:2])) + y1.append(warp(x[B:], flow[:, 2:4])) + flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 0.5 + return y0, y1 + + def calculate_flow(self, imgs, timestep, af=None, mf=None): + img0, img1 = imgs[:, :3], imgs[:, 3:6] + B = img0.size(0) + flow, mask = None, None + # appearence_features & motion_features + if (af is None) or (mf is None): + af, mf = self.feature_bone(img0, img1) + for i in range(self.flow_num_stage): + t = torch.full(mf[-1-i][:B].shape, timestep, dtype=torch.float).cuda() + if flow != None: + warped_img0 = warp(img0, flow[:, :2]) + warped_img1 = warp(img1, flow[:, 2:4]) + flow_, mask_ = self.block[i]( + torch.cat([t*mf[-1-i][:B],(1-t)*mf[-1-i][B:],af[-1-i][:B],af[-1-i][B:]],1), + torch.cat((img0, img1, warped_img0, warped_img1, mask), 1), + flow + ) + flow = flow + flow_ + mask = mask + mask_ + else: + flow, mask = self.block[i]( + torch.cat([t*mf[-1-i][:B],(1-t)*mf[-1-i][B:],af[-1-i][:B],af[-1-i][B:]],1), + torch.cat((img0, img1), 1), + None + ) + + return flow, mask + + def coraseWarp_and_Refine(self, imgs, af, flow, mask): + img0, img1 = imgs[:, :3], imgs[:, 3:6] + warped_img0 = warp(img0, flow[:, :2]) + warped_img1 = warp(img1, flow[:, 2:4]) + c0, c1 = self.warp_features(af, flow) + tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1) + res = tmp[:, :3] * 2 - 1 + mask_ = torch.sigmoid(mask) + merged = warped_img0 * mask_ + warped_img1 * (1 - mask_) + pred = torch.clamp(merged + res, 0, 1) + return pred + + + # Actually consist of 'calculate_flow' and 'coraseWarp_and_Refine' + def forward(self, x, timestep=0.5): + img0, img1 = x[:, :3], x[:, 3:6] + B = x.size(0) + flow_list = [] + merged = [] + mask_list = [] + warped_img0 = img0 + warped_img1 = img1 + flow = None + # appearence_features & motion_features + af, mf = self.feature_bone(img0, img1) + for i in range(self.flow_num_stage): + t = torch.full(mf[-1-i][:B].shape, timestep, dtype=torch.float).cuda() + if flow != None: + flow_d, mask_d = self.block[i]( torch.cat([t*mf[-1-i][:B], (1-timestep)*mf[-1-i][B:],af[-1-i][:B],af[-1-i][B:]],1), + torch.cat((img0, img1, warped_img0, warped_img1, mask), 1), flow) + flow = flow + flow_d + mask = mask + mask_d + else: + flow, mask = self.block[i]( torch.cat([t*mf[-1-i][:B], (1-t)*mf[-1-i][B:],af[-1-i][:B],af[-1-i][B:]],1), + torch.cat((img0, img1), 1), None) + mask_list.append(torch.sigmoid(mask)) + flow_list.append(flow) + warped_img0 = warp(img0, flow[:, :2]) + warped_img1 = warp(img1, flow[:, 2:4]) + merged.append(warped_img0 * mask_list[i] + warped_img1 * (1 - mask_list[i])) + + c0, c1 = self.warp_features(af, flow) + tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1) + res = tmp[:, :3] * 2 - 1 + pred = torch.clamp(merged[-1] + res, 0, 1) + return flow_list, mask_list, merged, pred \ No newline at end of file diff --git a/i2v_enhance/thirdparty/VFI/model/loss.py b/i2v_enhance/thirdparty/VFI/model/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..30e07bfa5b11e0c627232a25b0155929b3362270 --- /dev/null +++ b/i2v_enhance/thirdparty/VFI/model/loss.py @@ -0,0 +1,95 @@ +# Adapted from https://github.com/MCG-NJU/EMA-VFI/blob/main/model/loss.py +import torch +import torch.nn as nn +import numpy as np +import torch.nn.functional as F + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +def gauss_kernel(channels=3): + kernel = torch.tensor([[1., 4., 6., 4., 1], + [4., 16., 24., 16., 4.], + [6., 24., 36., 24., 6.], + [4., 16., 24., 16., 4.], + [1., 4., 6., 4., 1.]]) + kernel /= 256. + kernel = kernel.repeat(channels, 1, 1, 1) + kernel = kernel.to(device) + return kernel + +def downsample(x): + return x[:, :, ::2, ::2] + +def upsample(x): + cc = torch.cat([x, torch.zeros(x.shape[0], x.shape[1], x.shape[2], x.shape[3]).to(device)], dim=3) + cc = cc.view(x.shape[0], x.shape[1], x.shape[2]*2, x.shape[3]) + cc = cc.permute(0,1,3,2) + cc = torch.cat([cc, torch.zeros(x.shape[0], x.shape[1], x.shape[3], x.shape[2]*2).to(device)], dim=3) + cc = cc.view(x.shape[0], x.shape[1], x.shape[3]*2, x.shape[2]*2) + x_up = cc.permute(0,1,3,2) + return conv_gauss(x_up, 4*gauss_kernel(channels=x.shape[1])) + +def conv_gauss(img, kernel): + img = torch.nn.functional.pad(img, (2, 2, 2, 2), mode='reflect') + out = torch.nn.functional.conv2d(img, kernel, groups=img.shape[1]) + return out + +def laplacian_pyramid(img, kernel, max_levels=3): + current = img + pyr = [] + for level in range(max_levels): + filtered = conv_gauss(current, kernel) + down = downsample(filtered) + up = upsample(down) + diff = current-up + pyr.append(diff) + current = down + return pyr + +class LapLoss(torch.nn.Module): + def __init__(self, max_levels=5, channels=3): + super(LapLoss, self).__init__() + self.max_levels = max_levels + self.gauss_kernel = gauss_kernel(channels=channels) + + def forward(self, input, target): + pyr_input = laplacian_pyramid(img=input, kernel=self.gauss_kernel, max_levels=self.max_levels) + pyr_target = laplacian_pyramid(img=target, kernel=self.gauss_kernel, max_levels=self.max_levels) + return sum(torch.nn.functional.l1_loss(a, b) for a, b in zip(pyr_input, pyr_target)) + +class Ternary(nn.Module): + def __init__(self, device): + super(Ternary, self).__init__() + patch_size = 7 + out_channels = patch_size * patch_size + self.w = np.eye(out_channels).reshape( + (patch_size, patch_size, 1, out_channels)) + self.w = np.transpose(self.w, (3, 2, 0, 1)) + self.w = torch.tensor(self.w).float().to(device) + + def transform(self, img): + patches = F.conv2d(img, self.w, padding=3, bias=None) + transf = patches - img + transf_norm = transf / torch.sqrt(0.81 + transf**2) + return transf_norm + + def rgb2gray(self, rgb): + r, g, b = rgb[:, 0:1, :, :], rgb[:, 1:2, :, :], rgb[:, 2:3, :, :] + gray = 0.2989 * r + 0.5870 * g + 0.1140 * b + return gray + + def hamming(self, t1, t2): + dist = (t1 - t2) ** 2 + dist_norm = torch.mean(dist / (0.1 + dist), 1, True) + return dist_norm + + def valid_mask(self, t, padding): + n, _, h, w = t.size() + inner = torch.ones(n, 1, h - 2 * padding, w - 2 * padding).type_as(t) + mask = F.pad(inner, [padding] * 4) + return mask + + def forward(self, img0, img1): + img0 = self.transform(self.rgb2gray(img0)) + img1 = self.transform(self.rgb2gray(img1)) + return self.hamming(img0, img1) * self.valid_mask(img0, 1) diff --git a/i2v_enhance/thirdparty/VFI/model/refine.py b/i2v_enhance/thirdparty/VFI/model/refine.py new file mode 100644 index 0000000000000000000000000000000000000000..246f77276d99e1f4fe08503c140589e1fe271136 --- /dev/null +++ b/i2v_enhance/thirdparty/VFI/model/refine.py @@ -0,0 +1,71 @@ +import torch +import torch.nn as nn +import math +from timm.models.layers import trunc_normal_ + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, bias=True), + nn.PReLU(out_planes) + ) + +def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): + return nn.Sequential( + torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1, bias=True), + nn.PReLU(out_planes) + ) + +class Conv2(nn.Module): + def __init__(self, in_planes, out_planes, stride=2): + super(Conv2, self).__init__() + self.conv1 = conv(in_planes, out_planes, 3, stride, 1) + self.conv2 = conv(out_planes, out_planes, 3, 1, 1) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + return x + +class Unet(nn.Module): + def __init__(self, c, out=3): + super(Unet, self).__init__() + self.down0 = Conv2(17+c, 2*c) + self.down1 = Conv2(4*c, 4*c) + self.down2 = Conv2(8*c, 8*c) + self.down3 = Conv2(16*c, 16*c) + self.up0 = deconv(32*c, 8*c) + self.up1 = deconv(16*c, 4*c) + self.up2 = deconv(8*c, 2*c) + self.up3 = deconv(4*c, c) + self.conv = nn.Conv2d(c, out, 3, 1, 1) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, img0, img1, warped_img0, warped_img1, mask, flow, c0, c1): + s0 = self.down0(torch.cat((img0, img1, warped_img0, warped_img1, mask, flow,c0[0], c1[0]), 1)) + s1 = self.down1(torch.cat((s0, c0[1], c1[1]), 1)) + s2 = self.down2(torch.cat((s1, c0[2], c1[2]), 1)) + s3 = self.down3(torch.cat((s2, c0[3], c1[3]), 1)) + x = self.up0(torch.cat((s3, c0[4], c1[4]), 1)) + x = self.up1(torch.cat((x, s2), 1)) + x = self.up2(torch.cat((x, s1), 1)) + x = self.up3(torch.cat((x, s0), 1)) + x = self.conv(x) + return torch.sigmoid(x) diff --git a/i2v_enhance/thirdparty/VFI/model/warplayer.py b/i2v_enhance/thirdparty/VFI/model/warplayer.py new file mode 100644 index 0000000000000000000000000000000000000000..86418218d5b256ad93f844eced08d3dac54799ff --- /dev/null +++ b/i2v_enhance/thirdparty/VFI/model/warplayer.py @@ -0,0 +1,21 @@ +# Adapted from https://github.com/MCG-NJU/EMA-VFI/blob/main/model/warplayer.py +import torch + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +backwarp_tenGrid = {} + +def warp(tenInput, tenFlow): + k = (str(tenFlow.device), str(tenFlow.size())) + if k not in backwarp_tenGrid: + tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device).view( + 1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1) + tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device).view( + 1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3]) + backwarp_tenGrid[k] = torch.cat( + [tenHorizontal, tenVertical], 1).to(device) + + tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), + tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1) + + g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1) + return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True) diff --git a/i2v_enhance/thirdparty/VFI/train.py b/i2v_enhance/thirdparty/VFI/train.py new file mode 100644 index 0000000000000000000000000000000000000000..deb2a1d3550422d7436e1d72fbfa25910b9bf07c --- /dev/null +++ b/i2v_enhance/thirdparty/VFI/train.py @@ -0,0 +1,105 @@ +# Adapted from https://github.com/MCG-NJU/EMA-VFI/blob/main/train.py +import os +import cv2 +import math +import time +import torch +import torch.distributed as dist +import numpy as np +import random +import argparse + +from Trainer import Model +from dataset import VimeoDataset +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +from torch.utils.data.distributed import DistributedSampler +from config import * + +device = torch.device("cuda") +exp = os.path.abspath('.').split('/')[-1] + +def get_learning_rate(step): + if step < 2000: + mul = step / 2000 + return 2e-4 * mul + else: + mul = np.cos((step - 2000) / (300 * args.step_per_epoch - 2000) * math.pi) * 0.5 + 0.5 + return (2e-4 - 2e-5) * mul + 2e-5 + +def train(model, local_rank, batch_size, data_path): + if local_rank == 0: + writer = SummaryWriter('log/train_EMAVFI') + step = 0 + nr_eval = 0 + best = 0 + dataset = VimeoDataset('train', data_path) + sampler = DistributedSampler(dataset) + train_data = DataLoader(dataset, batch_size=batch_size, num_workers=8, pin_memory=True, drop_last=True, sampler=sampler) + args.step_per_epoch = train_data.__len__() + dataset_val = VimeoDataset('test', data_path) + val_data = DataLoader(dataset_val, batch_size=batch_size, pin_memory=True, num_workers=8) + print('training...') + time_stamp = time.time() + for epoch in range(300): + sampler.set_epoch(epoch) + for i, imgs in enumerate(train_data): + data_time_interval = time.time() - time_stamp + time_stamp = time.time() + imgs = imgs.to(device, non_blocking=True) / 255. + imgs, gt = imgs[:, 0:6], imgs[:, 6:] + learning_rate = get_learning_rate(step) + _, loss = model.update(imgs, gt, learning_rate, training=True) + train_time_interval = time.time() - time_stamp + time_stamp = time.time() + if step % 200 == 1 and local_rank == 0: + writer.add_scalar('learning_rate', learning_rate, step) + writer.add_scalar('loss', loss, step) + if local_rank == 0: + print('epoch:{} {}/{} time:{:.2f}+{:.2f} loss:{:.4e}'.format(epoch, i, args.step_per_epoch, data_time_interval, train_time_interval, loss)) + step += 1 + nr_eval += 1 + if nr_eval % 3 == 0: + evaluate(model, val_data, nr_eval, local_rank) + model.save_model(local_rank) + + dist.barrier() + +def evaluate(model, val_data, nr_eval, local_rank): + if local_rank == 0: + writer_val = SummaryWriter('log/validate_EMAVFI') + + psnr = [] + for _, imgs in enumerate(val_data): + imgs = imgs.to(device, non_blocking=True) / 255. + imgs, gt = imgs[:, 0:6], imgs[:, 6:] + with torch.no_grad(): + pred, _ = model.update(imgs, gt, training=False) + for j in range(gt.shape[0]): + psnr.append(-10 * math.log10(((gt[j] - pred[j]) * (gt[j] - pred[j])).mean().cpu().item())) + + psnr = np.array(psnr).mean() + if local_rank == 0: + print(str(nr_eval), psnr) + writer_val.add_scalar('psnr', psnr, nr_eval) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--local_rank', default=0, type=int, help='local rank') + parser.add_argument('--world_size', default=4, type=int, help='world size') + parser.add_argument('--batch_size', default=8, type=int, help='batch size') + parser.add_argument('--data_path', type=str, help='data path of vimeo90k') + args = parser.parse_args() + torch.distributed.init_process_group(backend="nccl", world_size=args.world_size) + torch.cuda.set_device(args.local_rank) + if args.local_rank == 0 and not os.path.exists('log'): + os.mkdir('log') + seed = 1234 + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.benchmark = True + model = Model(args.local_rank) + train(model, args.local_rank, args.batch_size, args.data_path) + diff --git a/lib/__init__.py b/lib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/farancia/__init__.py b/lib/farancia/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c6abd0b5504787a34830c7aa4dae2fd9f149ea95 --- /dev/null +++ b/lib/farancia/__init__.py @@ -0,0 +1,4 @@ +from .libimage import IImage + +from os.path import dirname, pardir, realpath +import os diff --git a/lib/farancia/animation.py b/lib/farancia/animation.py new file mode 100644 index 0000000000000000000000000000000000000000..78f40d41c562c0ed1218ee055e0e1583cff053f7 --- /dev/null +++ b/lib/farancia/animation.py @@ -0,0 +1,43 @@ +import matplotlib.pyplot as plt +from matplotlib import animation + + +class Animation: + JS = 0 + HTML = 1 + ANIMATION_MODE = HTML + + def __init__(self, frames, fps=30): + """_summary_ + + Args: + frames (np.ndarray): _description_ + """ + self.frames = frames + self.fps = fps + self.anim_obj = None + self.anim_str = None + + def render(self): + size = (self.frames.shape[2], self.frames.shape[1]) + self.fig = plt.figure(figsize=size, dpi=1) + plt.axis('off') + img = plt.imshow(self.frames[0], cmap='gray', vmin=0, vmax=255) + self.fig.subplots_adjust(0, 0, 1, 1) + self.anim_obj = animation.FuncAnimation( + self.fig, + lambda i: img.set_data(self.frames[i, :, :, :]), + frames=self.frames.shape[0], + interval=1000 / self.fps + ) + plt.close() + if Animation.ANIMATION_MODE == Animation.HTML: + self.anim_str = self.anim_obj.to_html5_video() + elif Animation.ANIMATION_MODE == Animation.JS: + self.anim_str = self.anim_obj.to_jshtml() + return self.anim_obj + + def _repr_html_(self): + if self.anim_obj is None: + self.render() + return self.anim_str diff --git a/lib/farancia/config.py b/lib/farancia/config.py new file mode 100644 index 0000000000000000000000000000000000000000..f94efdb9d2c533c8f433f02beee612ee29fd519c --- /dev/null +++ b/lib/farancia/config.py @@ -0,0 +1 @@ +IMG_THUMBSIZE = None \ No newline at end of file diff --git a/lib/farancia/libimage/__init__.py b/lib/farancia/libimage/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9f4c4d47987cf8d1481f672f23e848311adcfb0d --- /dev/null +++ b/lib/farancia/libimage/__init__.py @@ -0,0 +1,45 @@ +from .iimage import IImage + +import math +import numpy as np +import warnings + +# ========= STATIC FUNCTIONS ============= +def find_max_h(images): + return max([x.size[1] for x in images]) +def find_max_w(images): + return max([x.size[0] for x in images]) +def find_max_size(images): + return find_max_w(images), find_max_h(images) + + +def stack(images, axis = 0): + return IImage(np.concatenate([x.data for x in images], axis)) +def tstack(images): + w,h = find_max_size(images) + images = [x.pad2wh(w,h) for x in images] + return IImage(np.concatenate([x.data for x in images], 0)) +def hstack(images): + h = find_max_h(images) + images = [x.pad2wh(h = h) for x in images] + return IImage(np.concatenate([x.data for x in images], 2)) +def vstack(images): + w = find_max_w(images) + images = [x.pad2wh(w = w) for x in images] + return IImage(np.concatenate([x.data for x in images], 1)) + +def grid(images, nrows = None, ncols = None): + combined = stack(images) + if nrows is not None: + ncols = math.ceil(combined.data.shape[0] / nrows) + elif ncols is not None: + nrows = math.ceil(combined.data.shape[0] / ncols) + else: + warnings.warn("No dimensions specified, creating a grid with 5 columns (default)") + ncols = 5 + nrows = math.ceil(combined.data.shape[0] / ncols) + + pad = nrows * ncols - combined.data.shape[0] + data = np.pad(combined.data, ((0,pad),(0,0),(0,0),(0,0))) + rows = [np.concatenate(x,1,dtype=np.uint8) for x in np.array_split(data, nrows)] + return IImage(np.concatenate(rows, 0, dtype = np.uint8)[None]) \ No newline at end of file diff --git a/lib/farancia/libimage/iimage.py b/lib/farancia/libimage/iimage.py new file mode 100644 index 0000000000000000000000000000000000000000..e9b1561012cf00775525ecfe296b0b73b371b2fd --- /dev/null +++ b/lib/farancia/libimage/iimage.py @@ -0,0 +1,511 @@ +import io +import math +import os +import PIL.Image +import numpy as np +import imageio.v3 as iio +import warnings +from torchvision.utils import flow_to_image + +import torch +import torchvision.transforms.functional as TF +from scipy.ndimage import binary_dilation, binary_erosion +import cv2 + +from ..animation import Animation +from .. import config +from .. import libimage +import re + + +def torch2np(x, vmin=-1, vmax=1): + if x.ndim != 4: + # raise Exception("Please only use (B,C,H,W) torch tensors!") + warnings.warn( + "Warning! Shape of the image was not provided in (B,C,H,W) format, the shape was inferred automatically!") + if x.ndim == 3: + x = x[None] + if x.ndim == 2: + x = x[None, None] + assert x.shape[1] == 3 or x.shape[1] == 1 + x = x.detach().cpu().float() + if x.dtype == torch.uint8: + return x.numpy().astype(np.uint8) + elif vmin is not None and vmax is not None: + x = (255 * (x.clip(vmin, vmax) - vmin) / (vmax - vmin)) + x = x.permute(0, 2, 3, 1).to(torch.uint8) + return x.numpy() + else: + raise NotImplementedError() + + +class IImage: + ''' + Generic media storage. Can store both images and videos. + Stores data as a numpy array by default. + Can be viewed in a jupyter notebook. + ''' + @staticmethod + def open(path): + + iio_obj = iio.imopen(path, 'r') + data = iio_obj.read() + try: + # .properties() does not work for images but for gif files + if not iio_obj.properties().is_batch: + data = data[None] + except AttributeError as e: + # this one works for gif files + if not "duration" in iio_obj.metadata(): + data = data[None] + if data.ndim == 3: + data = data[..., None] + image = IImage(data) + image.link = os.path.abspath(path) + return image + + @staticmethod + def flow_field(flow): + flow_images = flow_to_image(flow) + return IImage(flow_images, vmin=0, vmax=255) + + @staticmethod + def normalized(x, dims=[-1, -2]): + x = (x - x.amin(dims, True)) / \ + (x.amax(dims, True) - x.amin(dims, True)) + return IImage(x, 0) + + def numpy(self): return self.data + + def torch(self, vmin=-1, vmax=1): + if self.data.ndim == 3: + data = self.data.transpose(2, 0, 1) / 255. + else: + data = self.data.transpose(0, 3, 1, 2) / 255. + return vmin + torch.from_numpy(data).float().to(self.device) * (vmax - vmin) + + def cuda(self): + self.device = 'cuda' + return self + + def cpu(self): + self.device = 'cpu' + return self + + def pil(self): + ans = [] + for x in self.data: + if x.shape[-1] == 1: + x = x[..., 0] + + ans.append(PIL.Image.fromarray(x)) + if len(ans) == 1: + return ans[0] + return ans + + def is_iimage(self): + return True + + @property + def shape(self): return self.data.shape + @property + def size(self): return (self.data.shape[-2], self.data.shape[-3]) + + def setFps(self, fps): + self.fps = fps + self.generate_display() + return self + + def __init__(self, x, vmin=-1, vmax=1, fps=None): + + if isinstance(x, PIL.Image.Image): + self.data = np.array(x) + if self.data.ndim == 2: + self.data = self.data[..., None] # (H,W,C) + self.data = self.data[None] # (B,H,W,C) + elif isinstance(x, IImage): + self.data = x.data.copy() # Simple Copy + elif isinstance(x, np.ndarray): + self.data = x.copy().astype(np.uint8) + if self.data.ndim == 2: + self.data = self.data[None, ..., None] + if self.data.ndim == 3: + warnings.warn( + "Inferred dimensions for a 3D array as (H,W,C), but could've been (B,H,W)") + self.data = self.data[None] + elif isinstance(x, torch.Tensor): + assert x.min() >= vmin and x.max( + ) <= vmax, f"input data was [{x.min()},{x.max()}], but expected [{vmin},{vmax}]" + self.data = torch2np(x, vmin, vmax) + self.display_str = None + self.device = 'cpu' + self.fps = fps if fps is not None else ( + 1 if len(self.data) < 10 else 30) + self.link = None + + def generate_display(self): + if config.IMG_THUMBSIZE is not None: + if self.size[1] < self.size[0]: + thumb = self.resize( + (self.size[1]*config.IMG_THUMBSIZE//self.size[0], config.IMG_THUMBSIZE)) + else: + thumb = self.resize( + (config.IMG_THUMBSIZE, self.size[0]*config.IMG_THUMBSIZE//self.size[1])) + else: + thumb = self + if self.is_video(): + self.anim = Animation(thumb.data, fps=self.fps) + self.anim.render() + self.display_str = self.anim.anim_str + else: + b = io.BytesIO() + data = thumb.data[0] + if data.shape[-1] == 1: + data = data[..., 0] + PIL.Image.fromarray(data).save(b, "PNG") + self.display_str = b.getvalue() + return self.display_str + + def resize(self, size, *args, **kwargs): + if size is None: + return self + use_small_edge_when_int = kwargs.pop('use_small_edge_when_int', False) + + # Backward compatibility + resample = kwargs.pop('filter', PIL.Image.BICUBIC) + resample = kwargs.pop('resample', resample) + + if isinstance(size, int): + if use_small_edge_when_int: + h, w = self.data.shape[1:3] + aspect_ratio = h / w + size = (max(size, int(size * aspect_ratio)), + max(size, int(size / aspect_ratio))) + else: + h, w = self.data.shape[1:3] + aspect_ratio = h / w + size = (min(size, int(size * aspect_ratio)), + min(size, int(size / aspect_ratio))) + + if self.size == size[::-1]: + return self + return libimage.stack([IImage(x.pil().resize(size[::-1], *args, resample=resample, **kwargs)) for x in self]) + # return IImage(TF.resize(self.cpu().torch(0), size, *args, **kwargs), 0) + + def pad(self, padding, *args, **kwargs): + return IImage(TF.pad(self.torch(0), padding=padding, *args, **kwargs), 0) + + def padx(self, multiplier, *args, **kwargs): + size = np.array(self.size) + padding = np.concatenate( + [[0, 0], np.ceil(size / multiplier).astype(int) * multiplier - size]) + return self.pad(list(padding), *args, **kwargs) + + def pad2wh(self, w=0, h=0, **kwargs): + cw, ch = self.size + return self.pad([0, 0, max(0, w - cw), max(0, h-ch)], **kwargs) + + def pad2square(self, *args, **kwargs): + if self.size[0] > self.size[1]: + dx = self.size[0] - self.size[1] + return self.pad([0, dx//2, 0, dx-dx//2], *args, **kwargs) + elif self.size[0] < self.size[1]: + dx = self.size[1] - self.size[0] + return self.pad([dx//2, 0, dx-dx//2, 0], *args, **kwargs) + return self + + def crop2square(self, *args, **kwargs): + if self.size[0] > self.size[1]: + dx = self.size[0] - self.size[1] + return self.crop([dx//2, 0, self.size[1], self.size[1]], *args, **kwargs) + elif self.size[0] < self.size[1]: + dx = self.size[1] - self.size[0] + return self.crop([0, dx//2, self.size[0], self.size[0]], *args, **kwargs) + return self + + def alpha(self): + return IImage(self.data[..., -1, None], fps=self.fps) + + def rgb(self): + return IImage(self.pil().convert('RGB'), fps=self.fps) + + def png(self): + return IImage(np.concatenate([self.data, 255 * np.ones_like(self.data)[..., :1]], -1)) + + def grid(self, nrows=None, ncols=None): + if nrows is not None: + ncols = math.ceil(self.data.shape[0] / nrows) + elif ncols is not None: + nrows = math.ceil(self.data.shape[0] / ncols) + else: + warnings.warn( + "No dimensions specified, creating a grid with 5 columns (default)") + ncols = 5 + nrows = math.ceil(self.data.shape[0] / ncols) + + pad = nrows * ncols - self.data.shape[0] + data = np.pad(self.data, ((0, pad), (0, 0), (0, 0), (0, 0))) + rows = [np.concatenate(x, 1, dtype=np.uint8) + for x in np.array_split(data, nrows)] + return IImage(np.concatenate(rows, 0, dtype=np.uint8)[None]) + + def hstack(self): + return IImage(np.concatenate(self.data, 1, dtype=np.uint8)[None]) + + def vstack(self): + return IImage(np.concatenate(self.data, 0, dtype=np.uint8)[None]) + + def vsplit(self, number_of_splits): + return IImage(np.concatenate(np.split(self.data, number_of_splits, 1))) + + def hsplit(self, number_of_splits): + return IImage(np.concatenate(np.split(self.data, number_of_splits, 2))) + + def heatmap(self, resize=None, cmap=cv2.COLORMAP_JET): + data = np.stack([cv2.cvtColor(cv2.applyColorMap( + x, cmap), cv2.COLOR_BGR2RGB) for x in self.data]) + return IImage(data).resize(resize, use_small_edge_when_int=True) + + def display(self): + try: + display(self) + except: + print("No display") + return self + + def dilate(self, iterations=1, *args, **kwargs): + if iterations == 0: + return IImage(self.data) + return IImage((binary_dilation(self.data, iterations=iterations, *args, *kwargs)*255.).astype(np.uint8)) + + def erode(self, iterations=1, *args, **kwargs): + return IImage((binary_erosion(self.data, iterations=iterations, *args, *kwargs)*255.).astype(np.uint8)) + + def hull(self): + convex_hulls = [] + for frame in self.data: + contours, hierarchy = cv2.findContours( + frame, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) + contours = [x.astype(np.int32) for x in contours] + mask_contours = [cv2.convexHull(np.concatenate(contours))] + canvas = np.zeros(self.data[0].shape, np.uint8) + convex_hull = cv2.drawContours( + canvas, mask_contours, -1, (255, 0, 0), -1) + convex_hulls.append(convex_hull) + return IImage(np.array(convex_hulls)) + + def is_video(self): + return self.data.shape[0] > 1 + + def __getitem__(self, idx): + return IImage(self.data[None, idx], fps=self.fps) + # if self.is_video(): return IImage(self.data[idx], fps = self.fps) + # return self + + def _repr_png_(self): + if self.is_video(): + return None + if self.display_str is None: + self.generate_display() + return self.display_str + + def _repr_html_(self): + if not self.is_video(): + return None + if self.display_str is None: + self.generate_display() + return self.display_str + + def save(self, path): + _, ext = os.path.splitext(path) + if self.is_video(): + # if ext in ['.jpg', '.png']: + if self.display_str is None: + self.generate_display() + if ext == ".apng": + self.anim.anim_obj.save(path, writer="pillow") + else: + self.anim.anim_obj.save(path) + else: + data = self.data if self.data.ndim == 3 else self.data[0] + if data.shape[-1] == 1: + data = data[:, :, 0] + PIL.Image.fromarray(data).save(path) + return self + + def to_html(self, width='auto', root_path='/'): + if self.display_str is None: + self.generate_display() + # print (self.display_str) + html_tag = bytes2html(self.display_str, width=width) + if self.link is not None: + link = os.path.relpath(self.link, root_path) + return f'{html_tag}' + return html_tag + + def write(self, text, center=(0, 25), font_scale=0.8, color=(255, 255, 255), thickness=2): + if not isinstance(text, list): + text = [text for _ in self.data] + data = np.stack([cv2.putText(x.copy(), t, center, cv2.FONT_HERSHEY_COMPLEX, + font_scale, color, thickness) for x, t in zip(self.data, text)]) + return IImage(data) + + def append_text(self, text, padding, font_scale=0.8, color=(255, 255, 255), thickness=2, scale_factor=0.9, center=(0, 0), fill=0): + + assert np.count_nonzero(padding) == 1 + axis_padding = np.nonzero(padding)[0][0] + scale_padding = padding[axis_padding] + + y_0 = 0 + x_0 = 0 + if axis_padding == 0: + width = scale_padding + y_max = self.shape[1] + elif axis_padding == 1: + width = self.shape[2] + y_max = scale_padding + elif axis_padding == 2: + x_0 = self.shape[2] + width = scale_padding + y_max = self.shape[1] + elif axis_padding == 3: + width = self.shape[2] + y_0 = self.shape[1] + y_max = self.shape[1]+scale_padding + + width -= center[0] + x_0 += center[0] + y_0 += center[1] + + self = self.pad(padding, fill=fill) + + def wrap_text(text, width, _font_scale): + allowed_seperator = ' |-|_|/|\n' + words = re.split(allowed_seperator, text) + # words = text.split() + lines = [] + current_line = words[0] + sep_list = [] + start_idx = 0 + for start_word in words[:-1]: + pos = text.find(start_word, start_idx) + pos += len(start_word) + sep_list.append(text[pos]) + start_idx = pos+1 + + for word, separator in zip(words[1:], sep_list): + if cv2.getTextSize(current_line + separator + word, cv2.FONT_HERSHEY_COMPLEX, _font_scale, thickness)[0][0] <= width: + current_line += separator + word + else: + if cv2.getTextSize(current_line, cv2.FONT_HERSHEY_COMPLEX, _font_scale, thickness)[0][0] <= width: + lines.append(current_line) + current_line = word + else: + return [] + + if cv2.getTextSize(current_line, cv2.FONT_HERSHEY_COMPLEX, _font_scale, thickness)[0][0] <= width: + lines.append(current_line) + else: + return [] + return lines + + def wrap_text_and_scale(text, width, _font_scale, y_0, y_max): + height = y_max+1 + while height > y_max: + text_lines = wrap_text(text, width, _font_scale) + if len(text) > 0 and len(text_lines) == 0: + + height = y_max+1 + else: + line_height = cv2.getTextSize( + text_lines[0], cv2.FONT_HERSHEY_COMPLEX, _font_scale, thickness)[0][1] + height = line_height * len(text_lines) + y_0 + + # scale font if out of frame + if height > y_max: + _font_scale = _font_scale * scale_factor + + return text_lines, line_height, _font_scale + + result = [] + if not isinstance(text, list): + text = [text for _ in self.data] + else: + assert len(text) == len(self.data) + + for x, t in zip(self.data, text): + x = x.copy() + text_lines, line_height, _font_scale = wrap_text_and_scale( + t, width, font_scale, y_0, y_max) + y = line_height + for line in text_lines: + x = cv2.putText( + x, line, (x_0, y_0+y), cv2.FONT_HERSHEY_COMPLEX, _font_scale, color, thickness) + y += line_height + result.append(x) + data = np.stack(result) + + return IImage(data) + + # ========== OPERATORS ============= + + def __or__(self, other): + # TODO: fix for variable sizes + return IImage(np.concatenate([self.data, other.data], 2)) + + def __truediv__(self, other): + # TODO: fix for variable sizes + return IImage(np.concatenate([self.data, other.data], 1)) + + def __and__(self, other): + return IImage(np.concatenate([self.data, other.data], 0)) + + def __add__(self, other): + return IImage(0.5 * self.data + 0.5 * other.data) + + def __mul__(self, other): + if isinstance(other, IImage): + return IImage(self.data / 255. * other.data) + return IImage(self.data * other / 255.) + + def __xor__(self, other): + return IImage(0.5 * self.data + 0.5 * other.data + 0.5 * self.data * (other.data.sum(-1, keepdims=True) == 0)) + + def __invert__(self): + return IImage(255 - self.data) + __rmul__ = __mul__ + + def bbox(self): + return [cv2.boundingRect(x) for x in self.data] + + def fill_bbox(self, bbox_list, fill=255): + data = self.data.copy() + for bbox in bbox_list: + x, y, w, h = bbox + data[:, y:y+h, x:x+w, :] = fill + return IImage(data) + + def crop(self, bbox): + assert len(bbox) in [2, 4] + if len(bbox) == 2: + x, y = 0, 0 + w, h = bbox + elif len(bbox) == 4: + x, y, w, h = bbox + return IImage(self.data[:, y:y+h, x:x+w, :]) + + # def alpha(self): + # return BetterImage(self.img.split()[-1]) + # def resize(self, size, *args, **kwargs): + # if size is None: return self + # return BetterImage(TF.resize(self.img, size, *args, **kwargs)) + # def pad(self, *args): + # return BetterImage(TF.pad(self.img, *args)) + # def padx(self, mult): + # size = np.array(self.img.size) + # padding = np.concatenate([[0,0],np.ceil(size / mult).astype(int) * mult - size]) + # return self.pad(list(padding)) + # def crop(self, *args): + # return BetterImage(self.img.crop(*args)) + # def torch(self, min = -1., max = 1.): + # return (max - min) * TF.to_tensor(self.img)[None] + min diff --git a/lib/farancia/libimage/utils.py b/lib/farancia/libimage/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..07b4424c2d38ad38aee679753c1a88f8d255dee6 --- /dev/null +++ b/lib/farancia/libimage/utils.py @@ -0,0 +1,8 @@ +from IPython.display import Image as IpyImage + +def bytes2html(data, width='auto'): + img_obj = IpyImage(data=data, format='JPG') + for bundle in img_obj._repr_mimebundle_(): + for mimetype, b64value in bundle.items(): + if mimetype.startswith('image/'): + return f'' \ No newline at end of file diff --git a/models/cam/conditioning.py b/models/cam/conditioning.py new file mode 100644 index 0000000000000000000000000000000000000000..1889fe3837e68df74b8b9e21cb8d54a413f27d36 --- /dev/null +++ b/models/cam/conditioning.py @@ -0,0 +1,150 @@ +import torch +import torch.nn as nn +from einops import rearrange +from diffusers.models.attention_processor import Attention + + +class CrossAttention(nn.Module): + """ + CrossAttention module implements per-pixel temporal attention to fuse the conditional attention module with the base module. + + Args: + input_channels (int): Number of input channels. + attention_head_dim (int): Dimension of attention head. + norm_num_groups (int): Number of groups for GroupNorm normalization (default is 32). + + Attributes: + attention (Attention): Attention module for computing attention scores. + norm (torch.nn.GroupNorm): Group normalization layer. + proj_in (nn.Linear): Linear layer for projecting input data. + proj_out (nn.Linear): Linear layer for projecting output data. + dropout (nn.Dropout): Dropout layer for regularization. + + Methods: + forward(hidden_state, encoder_hidden_states, num_frames, num_conditional_frames): + Forward pass of the CrossAttention module. + + """ + + def __init__(self, input_channels, attention_head_dim, norm_num_groups=32): + super().__init__() + self.attention = Attention( + query_dim=input_channels, cross_attention_dim=input_channels, heads=input_channels//attention_head_dim, dim_head=attention_head_dim, bias=False, upcast_attention=False) + self.norm = torch.nn.GroupNorm( + num_groups=norm_num_groups, num_channels=input_channels, eps=1e-6, affine=True) + self.proj_in = nn.Linear(input_channels, input_channels) + self.proj_out = nn.Linear(input_channels, input_channels) + self.dropout = nn.Dropout(p=0.25) + + def forward(self, hidden_state, encoder_hidden_states, num_frames, num_conditional_frames): + """ + The input hidden state is normalized, then projected using a linear layer. + Multi-head cross attention is computed between the hidden state (latent of noisy video) and encoder hidden states (CLIP image encoder). + The output is projected using a linear layer. + We apply dropout to the newly generated frames (without the control frames). + + Args: + hidden_state (torch.Tensor): Input hidden state tensor. + encoder_hidden_states (torch.Tensor): Encoder hidden states tensor. + num_frames (int): Number of frames. + num_conditional_frames (int): Number of conditional frames. + + Returns: + output (torch.Tensor): Output tensor after processing with attention mechanism. + + """ + h, w = hidden_state.shape[2], hidden_state.shape[3] + hidden_state_norm = rearrange( + hidden_state, "(B F) C H W -> B C F H W", F=num_frames) + hidden_state_norm = self.norm(hidden_state_norm) + hidden_state_norm = rearrange( + hidden_state_norm, "B C F H W -> (B H W) F C") + + hidden_state_norm = self.proj_in(hidden_state_norm) + + attn = self.attention(hidden_state_norm, + encoder_hidden_states=encoder_hidden_states, + attention_mask=None, + ) + # proj_out + + residual = self.proj_out(attn) # (B H W) F C + hidden_state = rearrange( + hidden_state, "(B F) ... -> B F ...", F=num_frames) + hidden_state = torch.cat([hidden_state[:, :num_conditional_frames], self.dropout( + hidden_state[:, num_conditional_frames:])], dim=1) + hidden_state = rearrange(hidden_state, "B F ... -> (B F) ... ") + + residual = rearrange( + residual, "(B H W) F C -> (B F) C H W", H=h, W=w) + output = hidden_state + residual + return output + + +class ConditionalModel(nn.Module): + """ + ConditionalModel module performs the fusion of the conditional attention module to be base model. + + Args: + input_channels (int): Number of input channels. + conditional_model (str): Type of conditional model to use. Currently only "cross_attention" is implemented. + attention_head_dim (int): Dimension of attention head (default is 64). + + Attributes: + temporal_transformer (CrossAttention): CrossAttention module for temporal transformation. + conditional_model (str): Type of conditional model used. + + Methods: + forward(sample, conditioning, num_frames=None, num_conditional_frames=None): + Forward pass of the ConditionalModel module. + + """ + + def __init__(self, input_channels, conditional_model: str, attention_head_dim=64): + super().__init__() + + if conditional_model == "cross_attention": + self.temporal_transformer = CrossAttention( + input_channels=input_channels, attention_head_dim=attention_head_dim) + else: + raise NotImplementedError( + f"mode {conditional_model} not implemented") + + nn.init.zeros_(self.temporal_transformer.proj_out.weight) + nn.init.zeros_(self.temporal_transformer.proj_out.bias) + self.conditional_model = conditional_model + + def forward(self, sample, conditioning, num_frames=None, num_conditional_frames=None): + """ + Forward pass of the ConditionalModel module. + + Args: + sample (torch.Tensor): Input sample tensor. + conditioning (torch.Tensor): Conditioning tensor containing the enconding of the conditional frames. + num_frames (int): Number of frames in the sample. + num_conditional_frames (int): Number of conditional frames. + + Returns: + sample (torch.Tensor): Transformed sample tensor. + + """ + sample = rearrange(sample, "(B F) ... -> B F ...", F=num_frames) + batch_size = sample.shape[0] + conditioning = rearrange( + conditioning, "(B F) ... -> B F ...", B=batch_size) + + assert conditioning.ndim == 5 + assert sample.ndim == 5 + + conditioning = rearrange(conditioning, "B F C H W -> (B H W) F C") + + sample = rearrange(sample, "B F C H W -> (B F) C H W") + + sample = self.temporal_transformer( + sample, encoder_hidden_states=conditioning, num_frames=num_frames, num_conditional_frames=num_conditional_frames) + + return sample + + +if __name__ == "__main__": + model = CrossAttention(input_channels=320, attention_head_dim=32) diff --git a/models/control/controlnet.py b/models/control/controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..10fa5a8430773b83a77de8ea3edc6295ecef62b9 --- /dev/null +++ b/models/control/controlnet.py @@ -0,0 +1,581 @@ +import torch +import torch.nn as nn +from typing import List, Optional, Union +from models.svd.sgm.util import default +from models.svd.sgm.modules.video_attention import SpatialVideoTransformer +from models.svd.sgm.modules.diffusionmodules.openaimodel import * +from models.diffusion.video_model import VideoResBlock, VideoUNet +from einops import repeat, rearrange +from models.svd.sgm.modules.diffusionmodules.wrappers import OpenAIWrapper + + +class Merger(nn.Module): + """ + Merges the controlnet latents with the conditioning embedding (encoding of control frames). + + """ + + def __init__(self, merge_mode: str = "addition", input_channels=0, frame_expansion="last_frame") -> None: + super().__init__() + self.merge_mode = merge_mode + self.frame_expansion = frame_expansion + + def forward(self, x, condition_signal, num_video_frames, num_video_frames_conditional): + x = rearrange(x, "(B F) C H W -> B F C H W", F=num_video_frames) + + condition_signal = rearrange( + condition_signal, "(B F) C H W -> B F C H W", B=x.shape[0]) + + if x.shape[1] - condition_signal.shape[1] > 0: + if self.frame_expansion == "last_frame": + fillup_latent = repeat( + condition_signal[:, -1], "B C H W -> B F C H W", F=x.shape[1] - condition_signal.shape[1]) + elif self.frame_expansion == "zero": + fillup_latent = torch.zeros( + (x.shape[0], num_video_frames-num_video_frames_conditional, *x.shape[2:]), device=x.device, dtype=x.dtype) + + if self.frame_expansion != "none": + condition_signal = torch.cat( + [condition_signal, fillup_latent], dim=1) + + if self.merge_mode == "addition": + out = x + condition_signal + else: + raise NotImplementedError( + f"Merging mode {self.merge_mode} not implemented.") + + out = rearrange(out, "B F C H W -> (B F) C H W") + return out + + +class ControlNetConditioningEmbedding(nn.Module): + """ + Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN + [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized + training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the + convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides + (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full + model) to encode image-space conditions ... into feature maps ..." + """ + + def __init__( + self, + conditioning_embedding_channels: int, + conditioning_channels: int = 3, + block_out_channels: Tuple[int] = (16, 32, 96, 256), + downsample: bool = True, + final_3d_conv: bool = False, + zero_init: bool = True, + use_controlnet_mask: bool = False, + use_normalization: bool = False, + ): + super().__init__() + + self.final_3d_conv = final_3d_conv + self.conv_in = nn.Conv2d( + conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) + if final_3d_conv: + print("USING 3D CONV in ControlNET") + + self.blocks = nn.ModuleList([]) + if use_normalization: + self.norms = nn.ModuleList([]) + self.use_normalization = use_normalization + + stride = 2 if downsample else 1 + + for i in range(len(block_out_channels) - 1): + channel_in = block_out_channels[i] + channel_out = block_out_channels[i + 1] + self.blocks.append( + nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) + if use_normalization: + self.norms.append(nn.LayerNorm((channel_in))) + self.blocks.append( + nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=stride)) + if use_normalization: + self.norms.append(nn.LayerNorm((channel_out))) + + self.conv_out = zero_module( + nn.Conv2d( + block_out_channels[-1]+int(use_controlnet_mask), conditioning_embedding_channels, kernel_size=3, padding=1), reset=zero_init + ) + + def forward(self, conditioning): + embedding = self.conv_in(conditioning) + embedding = F.silu(embedding) + + if self.use_normalization: + for block, norm in zip(self.blocks, self.norms): + embedding = block(embedding) + embedding = rearrange(embedding, " ... C W H -> ... W H C") + embedding = norm(embedding) + embedding = rearrange(embedding, "... W H C -> ... C W H") + embedding = F.silu(embedding) + else: + for block in self.blocks: + embedding = block(embedding) + embedding = F.silu(embedding) + + embedding = self.conv_out(embedding) + return embedding + + +class ControlNet(nn.Module): + + def __init__( + self, + in_channels: int, + model_channels: int, + out_channels: int, + num_res_blocks: int, + attention_resolutions: Union[List[int], int], + dropout: float = 0.0, + channel_mult: List[int] = (1, 2, 4, 8), + conv_resample: bool = True, + dims: int = 2, + num_classes: Optional[Union[int, str]] = None, + use_checkpoint: bool = False, + num_heads: int = -1, + num_head_channels: int = -1, + num_heads_upsample: int = -1, + use_scale_shift_norm: bool = False, + resblock_updown: bool = False, + transformer_depth: Union[List[int], int] = 1, + transformer_depth_middle: Optional[int] = None, + context_dim: Optional[int] = None, + time_downup: bool = False, + time_context_dim: Optional[int] = None, + extra_ff_mix_layer: bool = False, + use_spatial_context: bool = False, + merge_strategy: str = "fixed", + merge_factor: float = 0.5, + spatial_transformer_attn_type: str = "softmax", + video_kernel_size: Union[int, List[int]] = 3, + use_linear_in_transformer: bool = False, + adm_in_channels: Optional[int] = None, + disable_temporal_crossattention: bool = False, + max_ddpm_temb_period: int = 10000, + conditioning_embedding_out_channels: Optional[Tuple[int]] = ( + 16, 32, 96, 256), + condition_encoder: str = "", + use_controlnet_mask: bool = False, + downsample_controlnet_cond: bool = True, + use_image_encoder_normalization: bool = False, + zero_conv_mode: str = "Identity", + frame_expansion: str = "none", + merging_mode: str = "addition", + ): + super().__init__() + assert zero_conv_mode == "Identity", "Zero convolution not implemented" + + assert context_dim is not None + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert num_head_channels != -1 + + if num_head_channels == -1: + assert num_heads != -1 + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + if isinstance(transformer_depth, int): + transformer_depth = len(channel_mult) * [transformer_depth] + transformer_depth_middle = default( + transformer_depth_middle, transformer_depth[-1] + ) + + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.dims = dims + self.use_scale_shift_norm = use_scale_shift_norm + self.resblock_updown = resblock_updown + self.transformer_depth = transformer_depth + self.transformer_depth_middle = transformer_depth_middle + self.context_dim = context_dim + self.time_downup = time_downup + self.time_context_dim = time_context_dim + self.extra_ff_mix_layer = extra_ff_mix_layer + self.use_spatial_context = use_spatial_context + self.merge_strategy = merge_strategy + self.merge_factor = merge_factor + self.spatial_transformer_attn_type = spatial_transformer_attn_type + self.video_kernel_size = video_kernel_size + self.use_linear_in_transformer = use_linear_in_transformer + self.adm_in_channels = adm_in_channels + self.disable_temporal_crossattention = disable_temporal_crossattention + self.max_ddpm_temb_period = max_ddpm_temb_period + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + if isinstance(self.num_classes, int): + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + elif self.num_classes == "continuous": + print("setting up linear c_adm embedding layer") + self.label_emb = nn.Linear(1, time_embed_dim) + elif self.num_classes == "timestep": + self.label_emb = nn.Sequential( + Timestep(model_channels), + nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ), + ) + + elif self.num_classes == "sequential": + assert adm_in_channels is not None + self.label_emb = nn.Sequential( + nn.Sequential( + linear(adm_in_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + ) + else: + raise ValueError() + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + + def get_attention_layer( + ch, + num_heads, + dim_head, + depth=1, + context_dim=None, + use_checkpoint=False, + disabled_sa=False, + ): + return SpatialVideoTransformer( + ch, + num_heads, + dim_head, + depth=depth, + context_dim=context_dim, + time_context_dim=time_context_dim, + dropout=dropout, + ff_in=extra_ff_mix_layer, + use_spatial_context=use_spatial_context, + merge_strategy=merge_strategy, + merge_factor=merge_factor, + checkpoint=use_checkpoint, + use_linear=use_linear_in_transformer, + attn_mode=spatial_transformer_attn_type, + disable_self_attn=disabled_sa, + disable_temporal_crossattention=disable_temporal_crossattention, + max_time_embed_period=max_ddpm_temb_period, + ) + + def get_resblock( + merge_factor, + merge_strategy, + video_kernel_size, + ch, + time_embed_dim, + dropout, + out_ch, + dims, + use_checkpoint, + use_scale_shift_norm, + down=False, + up=False, + ): + return VideoResBlock( + merge_factor=merge_factor, + merge_strategy=merge_strategy, + video_kernel_size=video_kernel_size, + channels=ch, + emb_channels=time_embed_dim, + dropout=dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=down, + up=up, + ) + + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + get_resblock( + merge_factor=merge_factor, + merge_strategy=merge_strategy, + video_kernel_size=video_kernel_size, + ch=ch, + time_embed_dim=time_embed_dim, + dropout=dropout, + out_ch=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + + layers.append( + get_attention_layer( + ch, + num_heads, + dim_head, + depth=transformer_depth[level], + context_dim=context_dim, + use_checkpoint=use_checkpoint, + disabled_sa=False, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + ds *= 2 + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + get_resblock( + merge_factor=merge_factor, + merge_strategy=merge_strategy, + video_kernel_size=video_kernel_size, + ch=ch, + time_embed_dim=time_embed_dim, + dropout=dropout, + out_ch=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, + conv_resample, + dims=dims, + out_channels=out_ch, + third_down=time_downup, + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + + self.middle_block = TimestepEmbedSequential( + get_resblock( + merge_factor=merge_factor, + merge_strategy=merge_strategy, + video_kernel_size=video_kernel_size, + ch=ch, + time_embed_dim=time_embed_dim, + out_ch=None, + dropout=dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + get_attention_layer( + ch, + num_heads, + dim_head, + depth=transformer_depth_middle, + context_dim=context_dim, + use_checkpoint=use_checkpoint, + ), + get_resblock( + merge_factor=merge_factor, + merge_strategy=merge_strategy, + video_kernel_size=video_kernel_size, + ch=ch, + out_ch=None, + time_embed_dim=time_embed_dim, + dropout=dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.merger = Merger( + merge_mode=merging_mode, input_channels=model_channels, frame_expansion=frame_expansion) + + conditioning_channels = 3 if downsample_controlnet_cond else 4 + block_out_channels = (320, 640, 1280, 1280) + + self.controlnet_cond_embedding = ControlNetConditioningEmbedding( + conditioning_embedding_channels=block_out_channels[0], + conditioning_channels=conditioning_channels, + block_out_channels=conditioning_embedding_out_channels, + downsample=downsample_controlnet_cond, + final_3d_conv=condition_encoder.endswith("3DConv"), + use_controlnet_mask=use_controlnet_mask, + use_normalization=use_image_encoder_normalization, + ) + + def forward( + self, + x: th.Tensor, + timesteps: th.Tensor, + controlnet_cond: th.Tensor, + context: Optional[th.Tensor] = None, + y: Optional[th.Tensor] = None, + time_context: Optional[th.Tensor] = None, + num_video_frames: Optional[int] = None, + num_video_frames_conditional: Optional[int] = None, + image_only_indicator: Optional[th.Tensor] = None, + ): + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional -> no, relax this TODO" + hs = [] + t_emb = timestep_embedding( + timesteps, self.model_channels, repeat_only=False).to(x.dtype) + + emb = self.time_embed(t_emb) + + # TODO restrict y to [:self.num_frames] (conditonal frames) + + if self.num_classes is not None: + assert y.shape[0] == x.shape[0] + emb = emb + self.label_emb(y) + + controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) + + h = x + for idx, module in enumerate(self.input_blocks): + h = module( + h, + emb, + context=context, + image_only_indicator=image_only_indicator, + time_context=time_context, + num_video_frames=num_video_frames, + ) + if idx == 0: + h = self.merger(h, controlnet_cond, num_video_frames=num_video_frames, + num_video_frames_conditional=num_video_frames_conditional) + + hs.append(h) + h = self.middle_block( + h, + emb, + context=context, + image_only_indicator=image_only_indicator, + time_context=time_context, + num_video_frames=num_video_frames, + ) + + # 5. Control net blocks + + down_block_res_samples = hs + + mid_block_res_sample = h + + return (down_block_res_samples, mid_block_res_sample) + + @classmethod + def from_unet(cls, + model: OpenAIWrapper, + merging_mode: str = "addition", + zero_conv_mode: str = "Identity", + frame_expansion: str = "none", + downsample_controlnet_cond: bool = True, + use_image_encoder_normalization: bool = False, + use_controlnet_mask: bool = False, + condition_encoder: str = "", + conditioning_embedding_out_channels: List[int] = None, + + ): + + unet: VideoUNet = model.diffusion_model + + controlnet = cls(in_channels=unet.in_channels, + model_channels=unet.model_channels, + out_channels=unet.out_channels, + num_res_blocks=unet.num_res_blocks, + attention_resolutions=unet.attention_resolutions, + dropout=unet.dropout, + channel_mult=unet.channel_mult, + conv_resample=unet.conv_resample, + dims=unet.dims, + num_classes=unet.num_classes, + use_checkpoint=unet.use_checkpoint, + num_heads=unet.num_heads, + num_head_channels=unet.num_head_channels, + num_heads_upsample=unet.num_heads_upsample, + use_scale_shift_norm=unet.use_scale_shift_norm, + resblock_updown=unet.resblock_updown, + transformer_depth=unet.transformer_depth, + transformer_depth_middle=unet.transformer_depth_middle, + context_dim=unet.context_dim, + time_downup=unet.time_downup, + time_context_dim=unet.time_context_dim, + extra_ff_mix_layer=unet.extra_ff_mix_layer, + use_spatial_context=unet.use_spatial_context, + merge_strategy=unet.merge_strategy, + merge_factor=unet.merge_factor, + spatial_transformer_attn_type=unet.spatial_transformer_attn_type, + video_kernel_size=unet.video_kernel_size, + use_linear_in_transformer=unet.use_linear_in_transformer, + adm_in_channels=unet.adm_in_channels, + disable_temporal_crossattention=unet.disable_temporal_crossattention, + max_ddpm_temb_period=unet.max_ddpm_temb_period, # up to here unet params + merging_mode=merging_mode, + zero_conv_mode=zero_conv_mode, + frame_expansion=frame_expansion, + downsample_controlnet_cond=downsample_controlnet_cond, + use_image_encoder_normalization=use_image_encoder_normalization, + use_controlnet_mask=use_controlnet_mask, + condition_encoder=condition_encoder, + conditioning_embedding_out_channels=conditioning_embedding_out_channels, + ) + controlnet: ControlNet + + return controlnet + + +def zero_module(module, reset=True): + if reset: + for p in module.parameters(): + nn.init.zeros_(p) + return module diff --git a/models/diffusion/discretizer.py b/models/diffusion/discretizer.py new file mode 100644 index 0000000000000000000000000000000000000000..5fa009aac381d49dce72335385f386fabe85da94 --- /dev/null +++ b/models/diffusion/discretizer.py @@ -0,0 +1,33 @@ +import numpy as np +import torch + +from models.svd.sgm.modules.diffusionmodules.discretizer import Discretization + + +# Implementation of https://arxiv.org/abs/2404.14507 +class AlignYourSteps(Discretization): + + def __init__(self, sigma_min=0.002, sigma_max=80.0, rho=7.0): + self.sigma_min = sigma_min + self.sigma_max = sigma_max + self.rho = rho + + def loglinear_interp(self, t_steps, num_steps): + """ + Performs log-linear interpolation of a given array of decreasing numbers. + """ + xs = np.linspace(0, 1, len(t_steps)) + ys = np.log(t_steps[::-1]) + + new_xs = np.linspace(0, 1, num_steps) + new_ys = np.interp(new_xs, xs, ys) + + interped_ys = np.exp(new_ys)[::-1].copy() + return interped_ys + + def get_sigmas(self, n, device="cpu"): + sampling_schedule = [700.00, 54.5, 15.886, 7.977, + 4.248, 1.789, 0.981, 0.403, 0.173, 0.034, 0.002] + sigmas = torch.from_numpy(self.loglinear_interp( + sampling_schedule, n)).to(device) + return sigmas diff --git a/models/diffusion/video_model.py b/models/diffusion/video_model.py new file mode 100644 index 0000000000000000000000000000000000000000..550cf63f8bc7c3116fa2a67371328543d24e7a83 --- /dev/null +++ b/models/diffusion/video_model.py @@ -0,0 +1,574 @@ +# Adapted from https://github.com/Stability-AI/generative-models/blob/main/sgm/modules/diffusionmodules/video_model.py +from functools import partial +from typing import List, Optional, Union + +from einops import rearrange + +from models.svd.sgm.modules.diffusionmodules.openaimodel import * +from models.svd.sgm.modules.video_attention import SpatialVideoTransformer +from models.svd.sgm.util import default +from models.svd.sgm.modules.diffusionmodules.util import AlphaBlender +from functools import partial +from models.cam.conditioning import ConditionalModel + + +class VideoResBlock(ResBlock): + def __init__( + self, + channels: int, + emb_channels: int, + dropout: float, + video_kernel_size: Union[int, List[int]] = 3, + merge_strategy: str = "fixed", + merge_factor: float = 0.5, + out_channels: Optional[int] = None, + use_conv: bool = False, + use_scale_shift_norm: bool = False, + dims: int = 2, + use_checkpoint: bool = False, + up: bool = False, + down: bool = False, + ): + super().__init__( + channels, + emb_channels, + dropout, + out_channels=out_channels, + use_conv=use_conv, + use_scale_shift_norm=use_scale_shift_norm, + dims=dims, + use_checkpoint=use_checkpoint, + up=up, + down=down, + ) + + self.time_stack = ResBlock( + default(out_channels, channels), + emb_channels, + dropout=dropout, + dims=3, + out_channels=default(out_channels, channels), + use_scale_shift_norm=False, + use_conv=False, + up=False, + down=False, + kernel_size=video_kernel_size, + use_checkpoint=use_checkpoint, + exchange_temb_dims=True, + ) + self.time_mixer = AlphaBlender( + alpha=merge_factor, + merge_strategy=merge_strategy, + rearrange_pattern="b t -> b 1 t 1 1", + ) + + def forward( + self, + x: th.Tensor, + emb: th.Tensor, + num_video_frames: int, + image_only_indicator: Optional[th.Tensor] = None, + ) -> th.Tensor: + x = super().forward(x, emb) + + x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames) + x = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames) + + x = self.time_stack( + x, rearrange(emb, "(b t) ... -> b t ...", t=num_video_frames) + ) + x = self.time_mixer( + x_spatial=x_mix, x_temporal=x, image_only_indicator=image_only_indicator + ) + x = rearrange(x, "b c t h w -> (b t) c h w") + return x + + +class VideoUNet(nn.Module): + ''' + Adapted from the vanilla SVD model. We add "cross_attention_merger_input_blocks" and "cross_attention_merger_mid_block" to incorporate the CAM control features. + + ''' + + def __init__( + self, + in_channels: int, + model_channels: int, + out_channels: int, + num_res_blocks: int, + num_conditional_frames: int, + attention_resolutions: Union[List[int], int], + dropout: float = 0.0, + channel_mult: List[int] = (1, 2, 4, 8), + conv_resample: bool = True, + dims: int = 2, + num_classes: Optional[Union[int, str]] = None, + use_checkpoint: bool = False, + num_heads: int = -1, + num_head_channels: int = -1, + num_heads_upsample: int = -1, + use_scale_shift_norm: bool = False, + resblock_updown: bool = False, + transformer_depth: Union[List[int], int] = 1, + transformer_depth_middle: Optional[int] = None, + context_dim: Optional[int] = None, + time_downup: bool = False, + time_context_dim: Optional[int] = None, + extra_ff_mix_layer: bool = False, + use_spatial_context: bool = False, + merge_strategy: str = "fixed", + merge_factor: float = 0.5, + spatial_transformer_attn_type: str = "softmax", + video_kernel_size: Union[int, List[int]] = 3, + use_linear_in_transformer: bool = False, + adm_in_channels: Optional[int] = None, + disable_temporal_crossattention: bool = False, + max_ddpm_temb_period: int = 10000, + merging_mode: str = "addition", + controlnet_mode: bool = False, + use_apm: bool = False, + ): + super().__init__() + assert context_dim is not None + self.controlnet_mode = controlnet_mode + if controlnet_mode: + assert merging_mode.startswith( + "attention"), "other merging modes not implemented" + AttentionCondModel = partial( + ConditionalModel, conditional_model=merging_mode.split("attention_")[1]) + self.cross_attention_merger_input_blocks = nn.ModuleList([]) + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert num_head_channels != -1 + + if num_head_channels == -1: + assert num_heads != -1 + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + if isinstance(transformer_depth, int): + transformer_depth = len(channel_mult) * [transformer_depth] + transformer_depth_middle = default( + transformer_depth_middle, transformer_depth[-1] + ) + + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.dims = dims + self.use_scale_shift_norm = use_scale_shift_norm + self.resblock_updown = resblock_updown + self.transformer_depth = transformer_depth + self.transformer_depth_middle = transformer_depth_middle + self.context_dim = context_dim + self.time_downup = time_downup + self.time_context_dim = time_context_dim + self.extra_ff_mix_layer = extra_ff_mix_layer + self.use_spatial_context = use_spatial_context + self.merge_strategy = merge_strategy + self.merge_factor = merge_factor + self.spatial_transformer_attn_type = spatial_transformer_attn_type + self.video_kernel_size = video_kernel_size + self.use_linear_in_transformer = use_linear_in_transformer + self.adm_in_channels = adm_in_channels + self.disable_temporal_crossattention = disable_temporal_crossattention + self.max_ddpm_temb_period = max_ddpm_temb_period + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + if isinstance(self.num_classes, int): + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + elif self.num_classes == "continuous": + print("setting up linear c_adm embedding layer") + self.label_emb = nn.Linear(1, time_embed_dim) + elif self.num_classes == "timestep": + self.label_emb = nn.Sequential( + Timestep(model_channels), + nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ), + ) + + elif self.num_classes == "sequential": + assert adm_in_channels is not None + self.label_emb = nn.Sequential( + nn.Sequential( + linear(adm_in_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + ) + else: + raise ValueError() + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + if controlnet_mode and merging_mode.startswith("attention"): + self.cross_attention_merger_input_blocks.append( + AttentionCondModel(input_channels=ch)) + + def get_attention_layer( + ch, + num_heads, + dim_head, + depth=1, + context_dim=None, + use_checkpoint=False, + disabled_sa=False, + use_apm: bool = False, + ): + return SpatialVideoTransformer( + ch, + num_heads, + dim_head, + depth=depth, + context_dim=context_dim, + time_context_dim=time_context_dim, + dropout=dropout, + ff_in=extra_ff_mix_layer, + use_spatial_context=use_spatial_context, + merge_strategy=merge_strategy, + merge_factor=merge_factor, + checkpoint=use_checkpoint, + use_linear=use_linear_in_transformer, + attn_mode=spatial_transformer_attn_type, + disable_self_attn=disabled_sa, + disable_temporal_crossattention=disable_temporal_crossattention, + max_time_embed_period=max_ddpm_temb_period, + use_apm=use_apm, + ) + + def get_resblock( + merge_factor, + merge_strategy, + video_kernel_size, + ch, + time_embed_dim, + dropout, + out_ch, + dims, + use_checkpoint, + use_scale_shift_norm, + down=False, + up=False, + ): + return VideoResBlock( + merge_factor=merge_factor, + merge_strategy=merge_strategy, + video_kernel_size=video_kernel_size, + channels=ch, + emb_channels=time_embed_dim, + dropout=dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=down, + up=up, + ) + + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + get_resblock( + merge_factor=merge_factor, + merge_strategy=merge_strategy, + video_kernel_size=video_kernel_size, + ch=ch, + time_embed_dim=time_embed_dim, + dropout=dropout, + out_ch=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + + layers.append( + get_attention_layer( + ch, + num_heads, + dim_head, + depth=transformer_depth[level], + context_dim=context_dim, + use_checkpoint=use_checkpoint, + disabled_sa=False, + use_apm=use_apm, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + if controlnet_mode and merging_mode.startswith("attention"): + self.cross_attention_merger_input_blocks.append( + AttentionCondModel(input_channels=ch)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + ds *= 2 + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + get_resblock( + merge_factor=merge_factor, + merge_strategy=merge_strategy, + video_kernel_size=video_kernel_size, + ch=ch, + time_embed_dim=time_embed_dim, + dropout=dropout, + out_ch=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, + conv_resample, + dims=dims, + out_channels=out_ch, + third_down=time_downup, + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + + if controlnet_mode and merging_mode.startswith("attention"): + self.cross_attention_merger_input_blocks.append( + AttentionCondModel(input_channels=ch)) + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + + self.middle_block = TimestepEmbedSequential( + get_resblock( + merge_factor=merge_factor, + merge_strategy=merge_strategy, + video_kernel_size=video_kernel_size, + ch=ch, + time_embed_dim=time_embed_dim, + out_ch=None, + dropout=dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + get_attention_layer( + ch, + num_heads, + dim_head, + depth=transformer_depth_middle, + context_dim=context_dim, + use_checkpoint=use_checkpoint, + use_apm=use_apm, + ), + get_resblock( + merge_factor=merge_factor, + merge_strategy=merge_strategy, + video_kernel_size=video_kernel_size, + ch=ch, + out_ch=None, + time_embed_dim=time_embed_dim, + dropout=dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + if controlnet_mode and merging_mode.startswith("attention"): + self.cross_attention_merger_mid_block = AttentionCondModel( + input_channels=ch) + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(num_res_blocks + 1): + ich = input_block_chans.pop() + layers = [ + get_resblock( + merge_factor=merge_factor, + merge_strategy=merge_strategy, + video_kernel_size=video_kernel_size, + ch=ch + ich, + time_embed_dim=time_embed_dim, + dropout=dropout, + out_ch=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + + layers.append( + get_attention_layer( + ch, + num_heads, + dim_head, + depth=transformer_depth[level], + context_dim=context_dim, + use_checkpoint=use_checkpoint, + disabled_sa=False, + use_apm=use_apm, + ) + ) + if level and i == num_res_blocks: + out_ch = ch + ds //= 2 + layers.append( + get_resblock( + merge_factor=merge_factor, + merge_strategy=merge_strategy, + video_kernel_size=video_kernel_size, + ch=ch, + time_embed_dim=time_embed_dim, + dropout=dropout, + out_ch=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample( + ch, + conv_resample, + dims=dims, + out_channels=out_ch, + third_up=time_downup, + ) + ) + + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, + out_channels, 3, padding=1)), + ) + + def forward( + self, + # [28,8,72,128], i.e. (B F) (2 C) H W = concat([z_t,]) + x: th.Tensor, + timesteps: th.Tensor, # [28], i.e. (B F) + # [28, 1, 1024], i.e. (B F) 1 T, for cross attention from clip image encoder, + context: Optional[th.Tensor] = None, + # [28, 768], i.e. (B F) T ? concat([,,] + y: Optional[th.Tensor] = None, + time_context: Optional[th.Tensor] = None, # NONE + num_video_frames: Optional[int] = None, # 14 + num_conditional_frames: Optional[int] = None, # 8 + # zeros, [2,14], i.e. [B, F] + image_only_indicator: Optional[th.Tensor] = None, + hs_control_input: Optional[th.Tensor] = None, # cam features + hs_control_mid: Optional[th.Tensor] = None, # cam features + ): + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional -> no, relax this TODO" + hs = [] + t_emb = timestep_embedding( + timesteps, self.model_channels, repeat_only=False).to(x.dtype) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y.shape[0] == x.shape[0] + emb = emb + self.label_emb(y) + + h = x + for module in self.input_blocks: + h = module( + h, + emb, + context=context, + image_only_indicator=image_only_indicator, + time_context=time_context, + num_video_frames=num_video_frames, + ) + hs.append(h) + + # fusion of cam features with base features + if hs_control_input is not None: + new_hs = [] + + assert len(hs) == len(hs_control_input) and len( + hs) == len(self.cross_attention_merger_input_blocks) + for h_no_ctrl, h_ctrl, merger in zip(hs, hs_control_input, self.cross_attention_merger_input_blocks): + merged_h = merger(h_no_ctrl, h_ctrl, num_frames=num_video_frames, + num_conditional_frames=num_conditional_frames) + new_hs.append(merged_h) + hs = new_hs + + h = self.middle_block( + h, + emb, + context=context, + image_only_indicator=image_only_indicator, + time_context=time_context, + num_video_frames=num_video_frames, + ) + + # fusion of cam features with base features + if hs_control_mid is not None: + h = self.cross_attention_merger_mid_block( + h, hs_control_mid, num_frames=num_video_frames, num_conditional_frames=num_conditional_frames) + + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module( + h, + emb, + context=context, + image_only_indicator=image_only_indicator, + time_context=time_context, + num_video_frames=num_video_frames, + ) + h = h.type(x.dtype) + return self.out(h) diff --git a/models/diffusion/wrappers.py b/models/diffusion/wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..38647ccb30cc9d1e7baa7b8df8a58b81fafb9c40 --- /dev/null +++ b/models/diffusion/wrappers.py @@ -0,0 +1,78 @@ + +import torch +from models.svd.sgm.modules.diffusionmodules.wrappers import OpenAIWrapper +from einops import rearrange, repeat + + +class StreamingWrapper(OpenAIWrapper): + """ + Modelwrapper for StreamingSVD, which holds the CAM model and the base model + + """ + + def __init__(self, diffusion_model, controlnet, num_frame_conditioning: int, compile_model: bool = False, pipeline_offloading: bool = False): + super().__init__(diffusion_model=diffusion_model, + compile_model=compile_model) + self.controlnet = controlnet + self.num_frame_conditioning = num_frame_conditioning + self.pipeline_offloading = pipeline_offloading + if pipeline_offloading: + raise NotImplementedError( + "Pipeline offloading for StreamingI2V not implemented yet.") + + def forward(self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs): + + batch_size = kwargs.pop("batch_size") + + # We apply the controlnet model only to the control frames. + def reduce_to_cond_frames(input): + input = rearrange(input, "(B F) ... -> B F ...", B=batch_size) + input = input[:, :self.num_frame_conditioning] + return rearrange(input, "B F ... -> (B F) ...") + + x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1) + x_ctrl = reduce_to_cond_frames(x) + t_ctrl = reduce_to_cond_frames(t) + + context = c.get("crossattn", None) + # controlnet is not using APM so we remove potentially additional tokens + context_ctrl = context[:, :1] + context_ctrl = reduce_to_cond_frames(context_ctrl) + y = c.get("vector", None) + y_ctrl = reduce_to_cond_frames(y) + num_video_frames = kwargs.pop("num_video_frames") + image_only_indicator = kwargs.pop("image_only_indicator") + ctrl_img_enc_frames = repeat( + kwargs['ctrl_frames'], "B ... -> (2 B) ... ") + controlnet_cond = rearrange( + ctrl_img_enc_frames, "B F ... -> (B F) ...") + + if self.diffusion_model.controlnet_mode: + hs_control_input, hs_control_mid = self.controlnet(x=x_ctrl, # video latent + timesteps=t_ctrl, # timestep + context=context_ctrl, # clip image conditioning + y=y_ctrl, # conditionigs, e.g. fps + controlnet_cond=controlnet_cond, # control frames + num_video_frames=self.num_frame_conditioning, + num_video_frames_conditional=self.num_frame_conditioning, + image_only_indicator=image_only_indicator[:, + :self.num_frame_conditioning] + ) + else: + hs_control_input = None + hs_control_mid = None + kwargs["hs_control_input"] = hs_control_input + kwargs["hs_control_mid"] = hs_control_mid + + out = self.diffusion_model( + x=x, + timesteps=t, + context=context, # must be (B F) T C + y=y, # must be (B F) 768 + num_video_frames=num_video_frames, + num_conditional_frames=self.num_frame_conditioning, + image_only_indicator=image_only_indicator, + hs_control_input=hs_control_input, + hs_control_mid=hs_control_mid, + ) + return out diff --git a/models/svd/sgm/__init__.py b/models/svd/sgm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a10f2a168623df34cf834a727762869ce307f405 --- /dev/null +++ b/models/svd/sgm/__init__.py @@ -0,0 +1,4 @@ +from models.svd.sgm.models import AutoencodingEngine, DiffusionEngine +from models.svd.sgm.util import get_configs_path, instantiate_from_config + +__version__ = "0.1.0" diff --git a/models/svd/sgm/data/__init__.py b/models/svd/sgm/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7664a25c655c376bd1a7b0ccbaca7b983a2bf9ad --- /dev/null +++ b/models/svd/sgm/data/__init__.py @@ -0,0 +1 @@ +from .dataset import StableDataModuleFromConfig diff --git a/models/svd/sgm/data/cifar10.py b/models/svd/sgm/data/cifar10.py new file mode 100644 index 0000000000000000000000000000000000000000..6083646f136bad308a0485843b89234cf7a9d6cd --- /dev/null +++ b/models/svd/sgm/data/cifar10.py @@ -0,0 +1,67 @@ +import pytorch_lightning as pl +import torchvision +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms + + +class CIFAR10DataDictWrapper(Dataset): + def __init__(self, dset): + super().__init__() + self.dset = dset + + def __getitem__(self, i): + x, y = self.dset[i] + return {"jpg": x, "cls": y} + + def __len__(self): + return len(self.dset) + + +class CIFAR10Loader(pl.LightningDataModule): + def __init__(self, batch_size, num_workers=0, shuffle=True): + super().__init__() + + transform = transforms.Compose( + [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)] + ) + + self.batch_size = batch_size + self.num_workers = num_workers + self.shuffle = shuffle + self.train_dataset = CIFAR10DataDictWrapper( + torchvision.datasets.CIFAR10( + root=".data/", train=True, download=True, transform=transform + ) + ) + self.test_dataset = CIFAR10DataDictWrapper( + torchvision.datasets.CIFAR10( + root=".data/", train=False, download=True, transform=transform + ) + ) + + def prepare_data(self): + pass + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + shuffle=self.shuffle, + num_workers=self.num_workers, + ) + + def test_dataloader(self): + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + shuffle=self.shuffle, + num_workers=self.num_workers, + ) + + def val_dataloader(self): + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + shuffle=self.shuffle, + num_workers=self.num_workers, + ) diff --git a/models/svd/sgm/data/dataset.py b/models/svd/sgm/data/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..b726149996591c6c3db69230e1bb68c07d2faa12 --- /dev/null +++ b/models/svd/sgm/data/dataset.py @@ -0,0 +1,80 @@ +from typing import Optional + +import torchdata.datapipes.iter +import webdataset as wds +from omegaconf import DictConfig +from pytorch_lightning import LightningDataModule + +try: + from sdata import create_dataset, create_dummy_dataset, create_loader +except ImportError as e: + print("#" * 100) + print("Datasets not yet available") + print("to enable, we need to add stable-datasets as a submodule") + print("please use ``git submodule update --init --recursive``") + print("and do ``pip install -e stable-datasets/`` from the root of this repo") + print("#" * 100) + exit(1) + + +class StableDataModuleFromConfig(LightningDataModule): + def __init__( + self, + train: DictConfig, + validation: Optional[DictConfig] = None, + test: Optional[DictConfig] = None, + skip_val_loader: bool = False, + dummy: bool = False, + ): + super().__init__() + self.train_config = train + assert ( + "datapipeline" in self.train_config and "loader" in self.train_config + ), "train config requires the fields `datapipeline` and `loader`" + + self.val_config = validation + if not skip_val_loader: + if self.val_config is not None: + assert ( + "datapipeline" in self.val_config and "loader" in self.val_config + ), "validation config requires the fields `datapipeline` and `loader`" + else: + print( + "Warning: No Validation datapipeline defined, using that one from training" + ) + self.val_config = train + + self.test_config = test + if self.test_config is not None: + assert ( + "datapipeline" in self.test_config and "loader" in self.test_config + ), "test config requires the fields `datapipeline` and `loader`" + + self.dummy = dummy + if self.dummy: + print("#" * 100) + print("USING DUMMY DATASET: HOPE YOU'RE DEBUGGING ;)") + print("#" * 100) + + def setup(self, stage: str) -> None: + print("Preparing datasets") + if self.dummy: + data_fn = create_dummy_dataset + else: + data_fn = create_dataset + + self.train_datapipeline = data_fn(**self.train_config.datapipeline) + if self.val_config: + self.val_datapipeline = data_fn(**self.val_config.datapipeline) + if self.test_config: + self.test_datapipeline = data_fn(**self.test_config.datapipeline) + + def train_dataloader(self) -> torchdata.datapipes.iter.IterDataPipe: + loader = create_loader(self.train_datapipeline, **self.train_config.loader) + return loader + + def val_dataloader(self) -> wds.DataPipeline: + return create_loader(self.val_datapipeline, **self.val_config.loader) + + def test_dataloader(self) -> wds.DataPipeline: + return create_loader(self.test_datapipeline, **self.test_config.loader) diff --git a/models/svd/sgm/data/mnist.py b/models/svd/sgm/data/mnist.py new file mode 100644 index 0000000000000000000000000000000000000000..dea4d7e670666bec80ecb22aa89603345e173d09 --- /dev/null +++ b/models/svd/sgm/data/mnist.py @@ -0,0 +1,85 @@ +import pytorch_lightning as pl +import torchvision +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms + + +class MNISTDataDictWrapper(Dataset): + def __init__(self, dset): + super().__init__() + self.dset = dset + + def __getitem__(self, i): + x, y = self.dset[i] + return {"jpg": x, "cls": y} + + def __len__(self): + return len(self.dset) + + +class MNISTLoader(pl.LightningDataModule): + def __init__(self, batch_size, num_workers=0, prefetch_factor=2, shuffle=True): + super().__init__() + + transform = transforms.Compose( + [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)] + ) + + self.batch_size = batch_size + self.num_workers = num_workers + self.prefetch_factor = prefetch_factor if num_workers > 0 else 0 + self.shuffle = shuffle + self.train_dataset = MNISTDataDictWrapper( + torchvision.datasets.MNIST( + root=".data/", train=True, download=True, transform=transform + ) + ) + self.test_dataset = MNISTDataDictWrapper( + torchvision.datasets.MNIST( + root=".data/", train=False, download=True, transform=transform + ) + ) + + def prepare_data(self): + pass + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + shuffle=self.shuffle, + num_workers=self.num_workers, + prefetch_factor=self.prefetch_factor, + ) + + def test_dataloader(self): + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + shuffle=self.shuffle, + num_workers=self.num_workers, + prefetch_factor=self.prefetch_factor, + ) + + def val_dataloader(self): + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + shuffle=self.shuffle, + num_workers=self.num_workers, + prefetch_factor=self.prefetch_factor, + ) + + +if __name__ == "__main__": + dset = MNISTDataDictWrapper( + torchvision.datasets.MNIST( + root=".data/", + train=False, + download=True, + transform=transforms.Compose( + [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)] + ), + ) + ) + ex = dset[0] diff --git a/models/svd/sgm/inference/api.py b/models/svd/sgm/inference/api.py new file mode 100644 index 0000000000000000000000000000000000000000..a359a67bcd9740acc9e320d2f26dc6a3befb36e0 --- /dev/null +++ b/models/svd/sgm/inference/api.py @@ -0,0 +1,385 @@ +import pathlib +from dataclasses import asdict, dataclass +from enum import Enum +from typing import Optional + +from omegaconf import OmegaConf + +from sgm.inference.helpers import (Img2ImgDiscretizationWrapper, do_img2img, + do_sample) +from sgm.modules.diffusionmodules.sampling import (DPMPP2MSampler, + DPMPP2SAncestralSampler, + EulerAncestralSampler, + EulerEDMSampler, + HeunEDMSampler, + LinearMultistepSampler) +from sgm.util import load_model_from_config + + +class ModelArchitecture(str, Enum): + SD_2_1 = "stable-diffusion-v2-1" + SD_2_1_768 = "stable-diffusion-v2-1-768" + SDXL_V0_9_BASE = "stable-diffusion-xl-v0-9-base" + SDXL_V0_9_REFINER = "stable-diffusion-xl-v0-9-refiner" + SDXL_V1_BASE = "stable-diffusion-xl-v1-base" + SDXL_V1_REFINER = "stable-diffusion-xl-v1-refiner" + + +class Sampler(str, Enum): + EULER_EDM = "EulerEDMSampler" + HEUN_EDM = "HeunEDMSampler" + EULER_ANCESTRAL = "EulerAncestralSampler" + DPMPP2S_ANCESTRAL = "DPMPP2SAncestralSampler" + DPMPP2M = "DPMPP2MSampler" + LINEAR_MULTISTEP = "LinearMultistepSampler" + + +class Discretization(str, Enum): + LEGACY_DDPM = "LegacyDDPMDiscretization" + EDM = "EDMDiscretization" + + +class Guider(str, Enum): + VANILLA = "VanillaCFG" + IDENTITY = "IdentityGuider" + + +class Thresholder(str, Enum): + NONE = "None" + + +@dataclass +class SamplingParams: + width: int = 1024 + height: int = 1024 + steps: int = 50 + sampler: Sampler = Sampler.DPMPP2M + discretization: Discretization = Discretization.LEGACY_DDPM + guider: Guider = Guider.VANILLA + thresholder: Thresholder = Thresholder.NONE + scale: float = 6.0 + aesthetic_score: float = 5.0 + negative_aesthetic_score: float = 5.0 + img2img_strength: float = 1.0 + orig_width: int = 1024 + orig_height: int = 1024 + crop_coords_top: int = 0 + crop_coords_left: int = 0 + sigma_min: float = 0.0292 + sigma_max: float = 14.6146 + rho: float = 3.0 + s_churn: float = 0.0 + s_tmin: float = 0.0 + s_tmax: float = 999.0 + s_noise: float = 1.0 + eta: float = 1.0 + order: int = 4 + + +@dataclass +class SamplingSpec: + width: int + height: int + channels: int + factor: int + is_legacy: bool + config: str + ckpt: str + is_guided: bool + + +model_specs = { + ModelArchitecture.SD_2_1: SamplingSpec( + height=512, + width=512, + channels=4, + factor=8, + is_legacy=True, + config="sd_2_1.yaml", + ckpt="v2-1_512-ema-pruned.safetensors", + is_guided=True, + ), + ModelArchitecture.SD_2_1_768: SamplingSpec( + height=768, + width=768, + channels=4, + factor=8, + is_legacy=True, + config="sd_2_1_768.yaml", + ckpt="v2-1_768-ema-pruned.safetensors", + is_guided=True, + ), + ModelArchitecture.SDXL_V0_9_BASE: SamplingSpec( + height=1024, + width=1024, + channels=4, + factor=8, + is_legacy=False, + config="sd_xl_base.yaml", + ckpt="sd_xl_base_0.9.safetensors", + is_guided=True, + ), + ModelArchitecture.SDXL_V0_9_REFINER: SamplingSpec( + height=1024, + width=1024, + channels=4, + factor=8, + is_legacy=True, + config="sd_xl_refiner.yaml", + ckpt="sd_xl_refiner_0.9.safetensors", + is_guided=True, + ), + ModelArchitecture.SDXL_V1_BASE: SamplingSpec( + height=1024, + width=1024, + channels=4, + factor=8, + is_legacy=False, + config="sd_xl_base.yaml", + ckpt="sd_xl_base_1.0.safetensors", + is_guided=True, + ), + ModelArchitecture.SDXL_V1_REFINER: SamplingSpec( + height=1024, + width=1024, + channels=4, + factor=8, + is_legacy=True, + config="sd_xl_refiner.yaml", + ckpt="sd_xl_refiner_1.0.safetensors", + is_guided=True, + ), +} + + +class SamplingPipeline: + def __init__( + self, + model_id: ModelArchitecture, + model_path="checkpoints", + config_path="configs/inference", + device="cuda", + use_fp16=True, + ) -> None: + if model_id not in model_specs: + raise ValueError(f"Model {model_id} not supported") + self.model_id = model_id + self.specs = model_specs[self.model_id] + self.config = str(pathlib.Path(config_path, self.specs.config)) + self.ckpt = str(pathlib.Path(model_path, self.specs.ckpt)) + self.device = device + self.model = self._load_model(device=device, use_fp16=use_fp16) + + def _load_model(self, device="cuda", use_fp16=True): + config = OmegaConf.load(self.config) + model = load_model_from_config(config, self.ckpt) + if model is None: + raise ValueError(f"Model {self.model_id} could not be loaded") + model.to(device) + if use_fp16: + model.conditioner.half() + model.model.half() + return model + + def text_to_image( + self, + params: SamplingParams, + prompt: str, + negative_prompt: str = "", + samples: int = 1, + return_latents: bool = False, + ): + sampler = get_sampler_config(params) + value_dict = asdict(params) + value_dict["prompt"] = prompt + value_dict["negative_prompt"] = negative_prompt + value_dict["target_width"] = params.width + value_dict["target_height"] = params.height + return do_sample( + self.model, + sampler, + value_dict, + samples, + params.height, + params.width, + self.specs.channels, + self.specs.factor, + force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [], + return_latents=return_latents, + filter=None, + ) + + def image_to_image( + self, + params: SamplingParams, + image, + prompt: str, + negative_prompt: str = "", + samples: int = 1, + return_latents: bool = False, + ): + sampler = get_sampler_config(params) + + if params.img2img_strength < 1.0: + sampler.discretization = Img2ImgDiscretizationWrapper( + sampler.discretization, + strength=params.img2img_strength, + ) + height, width = image.shape[2], image.shape[3] + value_dict = asdict(params) + value_dict["prompt"] = prompt + value_dict["negative_prompt"] = negative_prompt + value_dict["target_width"] = width + value_dict["target_height"] = height + return do_img2img( + image, + self.model, + sampler, + value_dict, + samples, + force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [], + return_latents=return_latents, + filter=None, + ) + + def refiner( + self, + params: SamplingParams, + image, + prompt: str, + negative_prompt: Optional[str] = None, + samples: int = 1, + return_latents: bool = False, + ): + sampler = get_sampler_config(params) + value_dict = { + "orig_width": image.shape[3] * 8, + "orig_height": image.shape[2] * 8, + "target_width": image.shape[3] * 8, + "target_height": image.shape[2] * 8, + "prompt": prompt, + "negative_prompt": negative_prompt, + "crop_coords_top": 0, + "crop_coords_left": 0, + "aesthetic_score": 6.0, + "negative_aesthetic_score": 2.5, + } + + return do_img2img( + image, + self.model, + sampler, + value_dict, + samples, + skip_encode=True, + return_latents=return_latents, + filter=None, + ) + + +def get_guider_config(params: SamplingParams): + if params.guider == Guider.IDENTITY: + guider_config = { + "target": "sgm.modules.diffusionmodules.guiders.IdentityGuider" + } + elif params.guider == Guider.VANILLA: + scale = params.scale + + thresholder = params.thresholder + + if thresholder == Thresholder.NONE: + dyn_thresh_config = { + "target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding" + } + else: + raise NotImplementedError + + guider_config = { + "target": "sgm.modules.diffusionmodules.guiders.VanillaCFG", + "params": {"scale": scale, "dyn_thresh_config": dyn_thresh_config}, + } + else: + raise NotImplementedError + return guider_config + + +def get_discretization_config(params: SamplingParams): + if params.discretization == Discretization.LEGACY_DDPM: + discretization_config = { + "target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization", + } + elif params.discretization == Discretization.EDM: + discretization_config = { + "target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization", + "params": { + "sigma_min": params.sigma_min, + "sigma_max": params.sigma_max, + "rho": params.rho, + }, + } + else: + raise ValueError(f"unknown discretization {params.discretization}") + return discretization_config + + +def get_sampler_config(params: SamplingParams): + discretization_config = get_discretization_config(params) + guider_config = get_guider_config(params) + sampler = None + if params.sampler == Sampler.EULER_EDM: + return EulerEDMSampler( + num_steps=params.steps, + discretization_config=discretization_config, + guider_config=guider_config, + s_churn=params.s_churn, + s_tmin=params.s_tmin, + s_tmax=params.s_tmax, + s_noise=params.s_noise, + verbose=True, + ) + if params.sampler == Sampler.HEUN_EDM: + return HeunEDMSampler( + num_steps=params.steps, + discretization_config=discretization_config, + guider_config=guider_config, + s_churn=params.s_churn, + s_tmin=params.s_tmin, + s_tmax=params.s_tmax, + s_noise=params.s_noise, + verbose=True, + ) + if params.sampler == Sampler.EULER_ANCESTRAL: + return EulerAncestralSampler( + num_steps=params.steps, + discretization_config=discretization_config, + guider_config=guider_config, + eta=params.eta, + s_noise=params.s_noise, + verbose=True, + ) + if params.sampler == Sampler.DPMPP2S_ANCESTRAL: + return DPMPP2SAncestralSampler( + num_steps=params.steps, + discretization_config=discretization_config, + guider_config=guider_config, + eta=params.eta, + s_noise=params.s_noise, + verbose=True, + ) + if params.sampler == Sampler.DPMPP2M: + return DPMPP2MSampler( + num_steps=params.steps, + discretization_config=discretization_config, + guider_config=guider_config, + verbose=True, + ) + if params.sampler == Sampler.LINEAR_MULTISTEP: + return LinearMultistepSampler( + num_steps=params.steps, + discretization_config=discretization_config, + guider_config=guider_config, + order=params.order, + verbose=True, + ) + + raise ValueError(f"unknown sampler {params.sampler}!") diff --git a/models/svd/sgm/inference/helpers.py b/models/svd/sgm/inference/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..31b0ec3dc414bf522261e35f73805810cd35582d --- /dev/null +++ b/models/svd/sgm/inference/helpers.py @@ -0,0 +1,305 @@ +import math +import os +from typing import List, Optional, Union + +import numpy as np +import torch +from einops import rearrange +from imwatermark import WatermarkEncoder +from omegaconf import ListConfig +from PIL import Image +from torch import autocast + +from sgm.util import append_dims + + +class WatermarkEmbedder: + def __init__(self, watermark): + self.watermark = watermark + self.num_bits = len(WATERMARK_BITS) + self.encoder = WatermarkEncoder() + self.encoder.set_watermark("bits", self.watermark) + + def __call__(self, image: torch.Tensor) -> torch.Tensor: + """ + Adds a predefined watermark to the input image + + Args: + image: ([N,] B, RGB, H, W) in range [0, 1] + + Returns: + same as input but watermarked + """ + squeeze = len(image.shape) == 4 + if squeeze: + image = image[None, ...] + n = image.shape[0] + image_np = rearrange( + (255 * image).detach().cpu(), "n b c h w -> (n b) h w c" + ).numpy()[:, :, :, ::-1] + # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255] + # watermarking libary expects input as cv2 BGR format + for k in range(image_np.shape[0]): + image_np[k] = self.encoder.encode(image_np[k], "dwtDct") + image = torch.from_numpy( + rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n) + ).to(image.device) + image = torch.clamp(image / 255, min=0.0, max=1.0) + if squeeze: + image = image[0] + return image + + +# A fixed 48-bit message that was choosen at random +# WATERMARK_MESSAGE = 0xB3EC907BB19E +WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110 +# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1 +WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]] +embed_watermark = WatermarkEmbedder(WATERMARK_BITS) + + +def get_unique_embedder_keys_from_conditioner(conditioner): + return list({x.input_key for x in conditioner.embedders}) + + +def perform_save_locally(save_path, samples): + os.makedirs(os.path.join(save_path), exist_ok=True) + base_count = len(os.listdir(os.path.join(save_path))) + samples = embed_watermark(samples) + for sample in samples: + sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c") + Image.fromarray(sample.astype(np.uint8)).save( + os.path.join(save_path, f"{base_count:09}.png") + ) + base_count += 1 + + +class Img2ImgDiscretizationWrapper: + """ + wraps a discretizer, and prunes the sigmas + params: + strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned) + """ + + def __init__(self, discretization, strength: float = 1.0): + self.discretization = discretization + self.strength = strength + assert 0.0 <= self.strength <= 1.0 + + def __call__(self, *args, **kwargs): + # sigmas start large first, and decrease then + sigmas = self.discretization(*args, **kwargs) + print(f"sigmas after discretization, before pruning img2img: ", sigmas) + sigmas = torch.flip(sigmas, (0,)) + sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)] + print("prune index:", max(int(self.strength * len(sigmas)), 1)) + sigmas = torch.flip(sigmas, (0,)) + print(f"sigmas after pruning: ", sigmas) + return sigmas + + +def do_sample( + model, + sampler, + value_dict, + num_samples, + H, + W, + C, + F, + force_uc_zero_embeddings: Optional[List] = None, + batch2model_input: Optional[List] = None, + return_latents=False, + filter=None, + device="cuda", +): + if force_uc_zero_embeddings is None: + force_uc_zero_embeddings = [] + if batch2model_input is None: + batch2model_input = [] + + with torch.no_grad(): + with autocast(device) as precision_scope: + with model.ema_scope(): + num_samples = [num_samples] + batch, batch_uc = get_batch( + get_unique_embedder_keys_from_conditioner(model.conditioner), + value_dict, + num_samples, + ) + for key in batch: + if isinstance(batch[key], torch.Tensor): + print(key, batch[key].shape) + elif isinstance(batch[key], list): + print(key, [len(l) for l in batch[key]]) + else: + print(key, batch[key]) + c, uc = model.conditioner.get_unconditional_conditioning( + batch, + batch_uc=batch_uc, + force_uc_zero_embeddings=force_uc_zero_embeddings, + ) + + for k in c: + if not k == "crossattn": + c[k], uc[k] = map( + lambda y: y[k][: math.prod(num_samples)].to(device), (c, uc) + ) + + additional_model_inputs = {} + for k in batch2model_input: + additional_model_inputs[k] = batch[k] + + shape = (math.prod(num_samples), C, H // F, W // F) + randn = torch.randn(shape).to(device) + + def denoiser(input, sigma, c): + return model.denoiser( + model.model, input, sigma, c, **additional_model_inputs + ) + + samples_z = sampler(denoiser, randn, cond=c, uc=uc) + samples_x = model.decode_first_stage(samples_z) + samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) + + if filter is not None: + samples = filter(samples) + + if return_latents: + return samples, samples_z + return samples + + +def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"): + # Hardcoded demo setups; might undergo some changes in the future + + batch = {} + batch_uc = {} + + for key in keys: + if key == "txt": + batch["txt"] = ( + np.repeat([value_dict["prompt"]], repeats=math.prod(N)) + .reshape(N) + .tolist() + ) + batch_uc["txt"] = ( + np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)) + .reshape(N) + .tolist() + ) + elif key == "original_size_as_tuple": + batch["original_size_as_tuple"] = ( + torch.tensor([value_dict["orig_height"], value_dict["orig_width"]]) + .to(device) + .repeat(*N, 1) + ) + elif key == "crop_coords_top_left": + batch["crop_coords_top_left"] = ( + torch.tensor( + [value_dict["crop_coords_top"], value_dict["crop_coords_left"]] + ) + .to(device) + .repeat(*N, 1) + ) + elif key == "aesthetic_score": + batch["aesthetic_score"] = ( + torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1) + ) + batch_uc["aesthetic_score"] = ( + torch.tensor([value_dict["negative_aesthetic_score"]]) + .to(device) + .repeat(*N, 1) + ) + + elif key == "target_size_as_tuple": + batch["target_size_as_tuple"] = ( + torch.tensor([value_dict["target_height"], value_dict["target_width"]]) + .to(device) + .repeat(*N, 1) + ) + else: + batch[key] = value_dict[key] + + for key in batch.keys(): + if key not in batch_uc and isinstance(batch[key], torch.Tensor): + batch_uc[key] = torch.clone(batch[key]) + return batch, batch_uc + + +def get_input_image_tensor(image: Image.Image, device="cuda"): + w, h = image.size + print(f"loaded input image of size ({w}, {h})") + width, height = map( + lambda x: x - x % 64, (w, h) + ) # resize to integer multiple of 64 + image = image.resize((width, height)) + image_array = np.array(image.convert("RGB")) + image_array = image_array[None].transpose(0, 3, 1, 2) + image_tensor = torch.from_numpy(image_array).to(dtype=torch.float32) / 127.5 - 1.0 + return image_tensor.to(device) + + +def do_img2img( + img, + model, + sampler, + value_dict, + num_samples, + force_uc_zero_embeddings=[], + additional_kwargs={}, + offset_noise_level: float = 0.0, + return_latents=False, + skip_encode=False, + filter=None, + device="cuda", +): + with torch.no_grad(): + with autocast(device) as precision_scope: + with model.ema_scope(): + batch, batch_uc = get_batch( + get_unique_embedder_keys_from_conditioner(model.conditioner), + value_dict, + [num_samples], + ) + c, uc = model.conditioner.get_unconditional_conditioning( + batch, + batch_uc=batch_uc, + force_uc_zero_embeddings=force_uc_zero_embeddings, + ) + + for k in c: + c[k], uc[k] = map(lambda y: y[k][:num_samples].to(device), (c, uc)) + + for k in additional_kwargs: + c[k] = uc[k] = additional_kwargs[k] + if skip_encode: + z = img + else: + z = model.encode_first_stage(img) + noise = torch.randn_like(z) + sigmas = sampler.discretization(sampler.num_steps) + sigma = sigmas[0].to(z.device) + + if offset_noise_level > 0.0: + noise = noise + offset_noise_level * append_dims( + torch.randn(z.shape[0], device=z.device), z.ndim + ) + noised_z = z + noise * append_dims(sigma, z.ndim) + noised_z = noised_z / torch.sqrt( + 1.0 + sigmas[0] ** 2.0 + ) # Note: hardcoded to DDPM-like scaling. need to generalize later. + + def denoiser(x, sigma, c): + return model.denoiser(model.model, x, sigma, c) + + samples_z = sampler(denoiser, noised_z, cond=c, uc=uc) + samples_x = model.decode_first_stage(samples_z) + samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) + + if filter is not None: + samples = filter(samples) + + if return_latents: + return samples, samples_z + return samples diff --git a/models/svd/sgm/lr_scheduler.py b/models/svd/sgm/lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..b2f4d384c1fcaff0df13e0564450d3fa972ace42 --- /dev/null +++ b/models/svd/sgm/lr_scheduler.py @@ -0,0 +1,135 @@ +import numpy as np + + +class LambdaWarmUpCosineScheduler: + """ + note: use with a base_lr of 1.0 + """ + + def __init__( + self, + warm_up_steps, + lr_min, + lr_max, + lr_start, + max_decay_steps, + verbosity_interval=0, + ): + self.lr_warm_up_steps = warm_up_steps + self.lr_start = lr_start + self.lr_min = lr_min + self.lr_max = lr_max + self.lr_max_decay_steps = max_decay_steps + self.last_lr = 0.0 + self.verbosity_interval = verbosity_interval + + def schedule(self, n, **kwargs): + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: + print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") + if n < self.lr_warm_up_steps: + lr = ( + self.lr_max - self.lr_start + ) / self.lr_warm_up_steps * n + self.lr_start + self.last_lr = lr + return lr + else: + t = (n - self.lr_warm_up_steps) / ( + self.lr_max_decay_steps - self.lr_warm_up_steps + ) + t = min(t, 1.0) + lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( + 1 + np.cos(t * np.pi) + ) + self.last_lr = lr + return lr + + def __call__(self, n, **kwargs): + return self.schedule(n, **kwargs) + + +class LambdaWarmUpCosineScheduler2: + """ + supports repeated iterations, configurable via lists + note: use with a base_lr of 1.0. + """ + + def __init__( + self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0 + ): + assert ( + len(warm_up_steps) + == len(f_min) + == len(f_max) + == len(f_start) + == len(cycle_lengths) + ) + self.lr_warm_up_steps = warm_up_steps + self.f_start = f_start + self.f_min = f_min + self.f_max = f_max + self.cycle_lengths = cycle_lengths + self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) + self.last_f = 0.0 + self.verbosity_interval = verbosity_interval + + def find_in_interval(self, n): + interval = 0 + for cl in self.cum_cycles[1:]: + if n <= cl: + return interval + interval += 1 + + def schedule(self, n, **kwargs): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: + print( + f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}" + ) + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[ + cycle + ] * n + self.f_start[cycle] + self.last_f = f + return f + else: + t = (n - self.lr_warm_up_steps[cycle]) / ( + self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle] + ) + t = min(t, 1.0) + f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( + 1 + np.cos(t * np.pi) + ) + self.last_f = f + return f + + def __call__(self, n, **kwargs): + return self.schedule(n, **kwargs) + + +class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): + def schedule(self, n, **kwargs): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: + print( + f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}" + ) + + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[ + cycle + ] * n + self.f_start[cycle] + self.last_f = f + return f + else: + f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * ( + self.cycle_lengths[cycle] - n + ) / (self.cycle_lengths[cycle]) + self.last_f = f + return f diff --git a/models/svd/sgm/models/__init__.py b/models/svd/sgm/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..00d567b82ee7a4450ca8db96d2ffaa9489746890 --- /dev/null +++ b/models/svd/sgm/models/__init__.py @@ -0,0 +1,2 @@ +from models.svd.sgm.models.autoencoder import AutoencodingEngine +from models.svd.sgm.models.diffusion import DiffusionEngine diff --git a/models/svd/sgm/models/autoencoder.py b/models/svd/sgm/models/autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..dba55db09125e023ab8ecfe6a2b7316324540222 --- /dev/null +++ b/models/svd/sgm/models/autoencoder.py @@ -0,0 +1,615 @@ +import logging +import math +import re +from abc import abstractmethod +from contextlib import contextmanager +from typing import Any, Dict, List, Optional, Tuple, Union + +import pytorch_lightning as pl +import torch +import torch.nn as nn +from einops import rearrange +from packaging import version + +from models.svd.sgm.modules.autoencoding.regularizers import AbstractRegularizer +from models.svd.sgm.modules.ema import LitEma +from models.svd.sgm.util import (default, get_nested_attribute, get_obj_from_str, + instantiate_from_config) + +logpy = logging.getLogger(__name__) + + +class AbstractAutoencoder(pl.LightningModule): + """ + This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators, + unCLIP models, etc. Hence, it is fairly general, and specific features + (e.g. discriminator training, encoding, decoding) must be implemented in subclasses. + """ + + def __init__( + self, + ema_decay: Union[None, float] = None, + monitor: Union[None, str] = None, + input_key: str = "jpg", + ): + super().__init__() + + self.input_key = input_key + self.use_ema = ema_decay is not None + if monitor is not None: + self.monitor = monitor + + if self.use_ema: + self.model_ema = LitEma(self, decay=ema_decay) + logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + if version.parse(torch.__version__) >= version.parse("2.0.0"): + self.automatic_optimization = False + + def apply_ckpt(self, ckpt: Union[None, str, dict]): + if ckpt is None: + return + if isinstance(ckpt, str): + ckpt = { + "target": "sgm.modules.checkpoint.CheckpointEngine", + "params": {"ckpt_path": ckpt}, + } + engine = instantiate_from_config(ckpt) + engine(self) + + @abstractmethod + def get_input(self, batch) -> Any: + raise NotImplementedError() + + def on_train_batch_end(self, *args, **kwargs): + # for EMA computation + if self.use_ema: + self.model_ema(self) + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.parameters()) + self.model_ema.copy_to(self) + if context is not None: + logpy.info(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.parameters()) + if context is not None: + logpy.info(f"{context}: Restored training weights") + + @abstractmethod + def encode(self, *args, **kwargs) -> torch.Tensor: + raise NotImplementedError("encode()-method of abstract base class called") + + @abstractmethod + def decode(self, *args, **kwargs) -> torch.Tensor: + raise NotImplementedError("decode()-method of abstract base class called") + + def instantiate_optimizer_from_config(self, params, lr, cfg): + logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config") + return get_obj_from_str(cfg["target"])( + params, lr=lr, **cfg.get("params", dict()) + ) + + def configure_optimizers(self) -> Any: + raise NotImplementedError() + + +class AutoencodingEngine(AbstractAutoencoder): + """ + Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL + (we also restore them explicitly as special cases for legacy reasons). + Regularizations such as KL or VQ are moved to the regularizer class. + """ + + def __init__( + self, + *args, + encoder_config: Dict, + decoder_config: Dict, + loss_config: Dict, + regularizer_config: Dict, + optimizer_config: Union[Dict, None] = None, + lr_g_factor: float = 1.0, + trainable_ae_params: Optional[List[List[str]]] = None, + ae_optimizer_args: Optional[List[dict]] = None, + trainable_disc_params: Optional[List[List[str]]] = None, + disc_optimizer_args: Optional[List[dict]] = None, + disc_start_iter: int = 0, + diff_boost_factor: float = 3.0, + ckpt_engine: Union[None, str, dict] = None, + ckpt_path: Optional[str] = None, + additional_decode_keys: Optional[List[str]] = None, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.automatic_optimization = False # pytorch lightning + + self.encoder: torch.nn.Module = instantiate_from_config(encoder_config) + self.decoder: torch.nn.Module = instantiate_from_config(decoder_config) + self.loss: torch.nn.Module = instantiate_from_config(loss_config) + self.regularization: AbstractRegularizer = instantiate_from_config( + regularizer_config + ) + self.optimizer_config = default( + optimizer_config, {"target": "torch.optim.Adam"} + ) + self.diff_boost_factor = diff_boost_factor + self.disc_start_iter = disc_start_iter + self.lr_g_factor = lr_g_factor + self.trainable_ae_params = trainable_ae_params + if self.trainable_ae_params is not None: + self.ae_optimizer_args = default( + ae_optimizer_args, + [{} for _ in range(len(self.trainable_ae_params))], + ) + assert len(self.ae_optimizer_args) == len(self.trainable_ae_params) + else: + self.ae_optimizer_args = [{}] # makes type consitent + + self.trainable_disc_params = trainable_disc_params + if self.trainable_disc_params is not None: + self.disc_optimizer_args = default( + disc_optimizer_args, + [{} for _ in range(len(self.trainable_disc_params))], + ) + assert len(self.disc_optimizer_args) == len(self.trainable_disc_params) + else: + self.disc_optimizer_args = [{}] # makes type consitent + + if ckpt_path is not None: + assert ckpt_engine is None, "Can't set ckpt_engine and ckpt_path" + logpy.warn("Checkpoint path is deprecated, use `checkpoint_egnine` instead") + self.apply_ckpt(default(ckpt_path, ckpt_engine)) + self.additional_decode_keys = set(default(additional_decode_keys, [])) + + def get_input(self, batch: Dict) -> torch.Tensor: + # assuming unified data format, dataloader returns a dict. + # image tensors should be scaled to -1 ... 1 and in channels-first + # format (e.g., bchw instead if bhwc) + return batch[self.input_key] + + def get_autoencoder_params(self) -> list: + params = [] + if hasattr(self.loss, "get_trainable_autoencoder_parameters"): + params += list(self.loss.get_trainable_autoencoder_parameters()) + if hasattr(self.regularization, "get_trainable_parameters"): + params += list(self.regularization.get_trainable_parameters()) + params = params + list(self.encoder.parameters()) + params = params + list(self.decoder.parameters()) + return params + + def get_discriminator_params(self) -> list: + if hasattr(self.loss, "get_trainable_parameters"): + params = list(self.loss.get_trainable_parameters()) # e.g., discriminator + else: + params = [] + return params + + def get_last_layer(self): + return self.decoder.get_last_layer() + + def encode( + self, + x: torch.Tensor, + return_reg_log: bool = False, + unregularized: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]: + z = self.encoder(x) + if unregularized: + return z, dict() + z, reg_log = self.regularization(z) + if return_reg_log: + return z, reg_log + return z + + def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor: + x = self.decoder(z, **kwargs) + return x + + def forward( + self, x: torch.Tensor, **additional_decode_kwargs + ) -> Tuple[torch.Tensor, torch.Tensor, dict]: + z, reg_log = self.encode(x, return_reg_log=True) + dec = self.decode(z, **additional_decode_kwargs) + return z, dec, reg_log + + def inner_training_step( + self, batch: dict, batch_idx: int, optimizer_idx: int = 0 + ) -> torch.Tensor: + x = self.get_input(batch) + additional_decode_kwargs = { + key: batch[key] for key in self.additional_decode_keys.intersection(batch) + } + z, xrec, regularization_log = self(x, **additional_decode_kwargs) + if hasattr(self.loss, "forward_keys"): + extra_info = { + "z": z, + "optimizer_idx": optimizer_idx, + "global_step": self.global_step, + "last_layer": self.get_last_layer(), + "split": "train", + "regularization_log": regularization_log, + "autoencoder": self, + } + extra_info = {k: extra_info[k] for k in self.loss.forward_keys} + else: + extra_info = dict() + + if optimizer_idx == 0: + # autoencode + out_loss = self.loss(x, xrec, **extra_info) + if isinstance(out_loss, tuple): + aeloss, log_dict_ae = out_loss + else: + # simple loss function + aeloss = out_loss + log_dict_ae = {"train/loss/rec": aeloss.detach()} + + self.log_dict( + log_dict_ae, + prog_bar=False, + logger=True, + on_step=True, + on_epoch=True, + sync_dist=False, + ) + self.log( + "loss", + aeloss.mean().detach(), + prog_bar=True, + logger=False, + on_epoch=False, + on_step=True, + ) + return aeloss + elif optimizer_idx == 1: + # discriminator + discloss, log_dict_disc = self.loss(x, xrec, **extra_info) + # -> discriminator always needs to return a tuple + self.log_dict( + log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True + ) + return discloss + else: + raise NotImplementedError(f"Unknown optimizer {optimizer_idx}") + + def training_step(self, batch: dict, batch_idx: int): + opts = self.optimizers() + if not isinstance(opts, list): + # Non-adversarial case + opts = [opts] + optimizer_idx = batch_idx % len(opts) + if self.global_step < self.disc_start_iter: + optimizer_idx = 0 + opt = opts[optimizer_idx] + opt.zero_grad() + with opt.toggle_model(): + loss = self.inner_training_step( + batch, batch_idx, optimizer_idx=optimizer_idx + ) + self.manual_backward(loss) + opt.step() + + def validation_step(self, batch: dict, batch_idx: int) -> Dict: + log_dict = self._validation_step(batch, batch_idx) + with self.ema_scope(): + log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema") + log_dict.update(log_dict_ema) + return log_dict + + def _validation_step(self, batch: dict, batch_idx: int, postfix: str = "") -> Dict: + x = self.get_input(batch) + + z, xrec, regularization_log = self(x) + if hasattr(self.loss, "forward_keys"): + extra_info = { + "z": z, + "optimizer_idx": 0, + "global_step": self.global_step, + "last_layer": self.get_last_layer(), + "split": "val" + postfix, + "regularization_log": regularization_log, + "autoencoder": self, + } + extra_info = {k: extra_info[k] for k in self.loss.forward_keys} + else: + extra_info = dict() + out_loss = self.loss(x, xrec, **extra_info) + if isinstance(out_loss, tuple): + aeloss, log_dict_ae = out_loss + else: + # simple loss function + aeloss = out_loss + log_dict_ae = {f"val{postfix}/loss/rec": aeloss.detach()} + full_log_dict = log_dict_ae + + if "optimizer_idx" in extra_info: + extra_info["optimizer_idx"] = 1 + discloss, log_dict_disc = self.loss(x, xrec, **extra_info) + full_log_dict.update(log_dict_disc) + self.log( + f"val{postfix}/loss/rec", + log_dict_ae[f"val{postfix}/loss/rec"], + sync_dist=True, + ) + self.log_dict(full_log_dict, sync_dist=True) + return full_log_dict + + def get_param_groups( + self, parameter_names: List[List[str]], optimizer_args: List[dict] + ) -> Tuple[List[Dict[str, Any]], int]: + groups = [] + num_params = 0 + for names, args in zip(parameter_names, optimizer_args): + params = [] + for pattern_ in names: + pattern_params = [] + pattern = re.compile(pattern_) + for p_name, param in self.named_parameters(): + if re.match(pattern, p_name): + pattern_params.append(param) + num_params += param.numel() + if len(pattern_params) == 0: + logpy.warn(f"Did not find parameters for pattern {pattern_}") + params.extend(pattern_params) + groups.append({"params": params, **args}) + return groups, num_params + + def configure_optimizers(self) -> List[torch.optim.Optimizer]: + if self.trainable_ae_params is None: + ae_params = self.get_autoencoder_params() + else: + ae_params, num_ae_params = self.get_param_groups( + self.trainable_ae_params, self.ae_optimizer_args + ) + logpy.info(f"Number of trainable autoencoder parameters: {num_ae_params:,}") + if self.trainable_disc_params is None: + disc_params = self.get_discriminator_params() + else: + disc_params, num_disc_params = self.get_param_groups( + self.trainable_disc_params, self.disc_optimizer_args + ) + logpy.info( + f"Number of trainable discriminator parameters: {num_disc_params:,}" + ) + opt_ae = self.instantiate_optimizer_from_config( + ae_params, + default(self.lr_g_factor, 1.0) * self.learning_rate, + self.optimizer_config, + ) + opts = [opt_ae] + if len(disc_params) > 0: + opt_disc = self.instantiate_optimizer_from_config( + disc_params, self.learning_rate, self.optimizer_config + ) + opts.append(opt_disc) + + return opts + + @torch.no_grad() + def log_images( + self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs + ) -> dict: + log = dict() + additional_decode_kwargs = {} + x = self.get_input(batch) + additional_decode_kwargs.update( + {key: batch[key] for key in self.additional_decode_keys.intersection(batch)} + ) + + _, xrec, _ = self(x, **additional_decode_kwargs) + log["inputs"] = x + log["reconstructions"] = xrec + diff = 0.5 * torch.abs(torch.clamp(xrec, -1.0, 1.0) - x) + diff.clamp_(0, 1.0) + log["diff"] = 2.0 * diff - 1.0 + # diff_boost shows location of small errors, by boosting their + # brightness. + log["diff_boost"] = ( + 2.0 * torch.clamp(self.diff_boost_factor * diff, 0.0, 1.0) - 1 + ) + if hasattr(self.loss, "log_images"): + log.update(self.loss.log_images(x, xrec)) + with self.ema_scope(): + _, xrec_ema, _ = self(x, **additional_decode_kwargs) + log["reconstructions_ema"] = xrec_ema + diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x) + diff_ema.clamp_(0, 1.0) + log["diff_ema"] = 2.0 * diff_ema - 1.0 + log["diff_boost_ema"] = ( + 2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1 + ) + if additional_log_kwargs: + additional_decode_kwargs.update(additional_log_kwargs) + _, xrec_add, _ = self(x, **additional_decode_kwargs) + log_str = "reconstructions-" + "-".join( + [f"{key}={additional_log_kwargs[key]}" for key in additional_log_kwargs] + ) + log[log_str] = xrec_add + return log + + +class AutoencodingEngineLegacy(AutoencodingEngine): + def __init__(self, embed_dim: int, **kwargs): + self.max_batch_size = kwargs.pop("max_batch_size", None) + ddconfig = kwargs.pop("ddconfig") + ckpt_path = kwargs.pop("ckpt_path", None) + ckpt_engine = kwargs.pop("ckpt_engine", None) + super().__init__( + encoder_config={ + "target": "models.svd.sgm.modules.diffusionmodules.model.Encoder", + "params": ddconfig, + }, + decoder_config={ + "target": "models.svd.sgm.modules.diffusionmodules.model.Decoder", + "params": ddconfig, + }, + **kwargs, + ) + self.quant_conv = torch.nn.Conv2d( + (1 + ddconfig["double_z"]) * ddconfig["z_channels"], + (1 + ddconfig["double_z"]) * embed_dim, + 1, + ) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + self.embed_dim = embed_dim + + self.apply_ckpt(default(ckpt_path, ckpt_engine)) + + def get_autoencoder_params(self) -> list: + params = super().get_autoencoder_params() + return params + + def encode( + self, x: torch.Tensor, return_reg_log: bool = False + ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]: + if self.max_batch_size is None: + z = self.encoder(x) + z = self.quant_conv(z) + else: + N = x.shape[0] + bs = self.max_batch_size + n_batches = int(math.ceil(N / bs)) + z = list() + for i_batch in range(n_batches): + z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs]) + z_batch = self.quant_conv(z_batch) + z.append(z_batch) + z = torch.cat(z, 0) + + z, reg_log = self.regularization(z) + if return_reg_log: + return z, reg_log + return z + + def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor: + if self.max_batch_size is None: + dec = self.post_quant_conv(z) + dec = self.decoder(dec, **decoder_kwargs) + else: + N = z.shape[0] + bs = self.max_batch_size + n_batches = int(math.ceil(N / bs)) + dec = list() + for i_batch in range(n_batches): + dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs]) + dec_batch = self.decoder(dec_batch, **decoder_kwargs) + dec.append(dec_batch) + dec = torch.cat(dec, 0) + + return dec + + +class AutoencoderKL(AutoencodingEngineLegacy): + def __init__(self, **kwargs): + if "lossconfig" in kwargs: + kwargs["loss_config"] = kwargs.pop("lossconfig") + super().__init__( + regularizer_config={ + "target": ( + "sgm.modules.autoencoding.regularizers" + ".DiagonalGaussianRegularizer" + ) + }, + **kwargs, + ) + + +class AutoencoderLegacyVQ(AutoencodingEngineLegacy): + def __init__( + self, + embed_dim: int, + n_embed: int, + sane_index_shape: bool = False, + **kwargs, + ): + if "lossconfig" in kwargs: + logpy.warn(f"Parameter `lossconfig` is deprecated, use `loss_config`.") + kwargs["loss_config"] = kwargs.pop("lossconfig") + super().__init__( + regularizer_config={ + "target": ( + "sgm.modules.autoencoding.regularizers.quantize" ".VectorQuantizer" + ), + "params": { + "n_e": n_embed, + "e_dim": embed_dim, + "sane_index_shape": sane_index_shape, + }, + }, + **kwargs, + ) + + +class IdentityFirstStage(AbstractAutoencoder): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def get_input(self, x: Any) -> Any: + return x + + def encode(self, x: Any, *args, **kwargs) -> Any: + return x + + def decode(self, x: Any, *args, **kwargs) -> Any: + return x + + +class AEIntegerWrapper(nn.Module): + def __init__( + self, + model: nn.Module, + shape: Union[None, Tuple[int, int], List[int]] = (16, 16), + regularization_key: str = "regularization", + encoder_kwargs: Optional[Dict[str, Any]] = None, + ): + super().__init__() + self.model = model + assert hasattr(model, "encode") and hasattr( + model, "decode" + ), "Need AE interface" + self.regularization = get_nested_attribute(model, regularization_key) + self.shape = shape + self.encoder_kwargs = default(encoder_kwargs, {"return_reg_log": True}) + + def encode(self, x) -> torch.Tensor: + assert ( + not self.training + ), f"{self.__class__.__name__} only supports inference currently" + _, log = self.model.encode(x, **self.encoder_kwargs) + assert isinstance(log, dict) + inds = log["min_encoding_indices"] + return rearrange(inds, "b ... -> b (...)") + + def decode( + self, inds: torch.Tensor, shape: Union[None, tuple, list] = None + ) -> torch.Tensor: + # expect inds shape (b, s) with s = h*w + shape = default(shape, self.shape) # Optional[(h, w)] + if shape is not None: + assert len(shape) == 2, f"Unhandeled shape {shape}" + inds = rearrange(inds, "b (h w) -> b h w", h=shape[0], w=shape[1]) + h = self.regularization.get_codebook_entry(inds) # (b, h, w, c) + h = rearrange(h, "b h w c -> b c h w") + return self.model.decode(h) + + +class AutoencoderKLModeOnly(AutoencodingEngineLegacy): + def __init__(self, **kwargs): + if "lossconfig" in kwargs: + kwargs["loss_config"] = kwargs.pop("lossconfig") + super().__init__( + regularizer_config={ + "target": ( + "models.svd.sgm.modules.autoencoding.regularizers" + ".DiagonalGaussianRegularizer" + ), + "params": {"sample": False}, + }, + **kwargs, + ) diff --git a/models/svd/sgm/models/diffusion.py b/models/svd/sgm/models/diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..04fe794f54a29d947cbd1470bb9f42060b4376ab --- /dev/null +++ b/models/svd/sgm/models/diffusion.py @@ -0,0 +1,341 @@ +import math +from contextlib import contextmanager +from typing import Any, Dict, List, Optional, Tuple, Union + +import pytorch_lightning as pl +import torch +from omegaconf import ListConfig, OmegaConf +from safetensors.torch import load_file as load_safetensors +from torch.optim.lr_scheduler import LambdaLR + +from models.svd.sgm.modules import UNCONDITIONAL_CONFIG +from models.svd.sgm.modules.autoencoding.temporal_ae import VideoDecoder +from models.svd.sgm.modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER +from models.svd.sgm.modules.ema import LitEma +from models.svd.sgm.util import (default, disabled_train, get_obj_from_str, + instantiate_from_config, log_txt_as_img) + + +class DiffusionEngine(pl.LightningModule): + def __init__( + self, + network_config, + denoiser_config, + first_stage_config, + conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None, + sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None, + optimizer_config: Union[None, Dict, ListConfig, OmegaConf] = None, + scheduler_config: Union[None, Dict, ListConfig, OmegaConf] = None, + loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None, + network_wrapper: Union[None, str] = None, + ckpt_path: Union[None, str] = None, + use_ema: bool = False, + ema_decay_rate: float = 0.9999, + scale_factor: float = 1.0, + disable_first_stage_autocast=False, + input_key: str = "jpg", + log_keys: Union[List, None] = None, + no_cond_log: bool = False, + compile_model: bool = False, + en_and_decode_n_samples_a_time: Optional[int] = None, + ): + super().__init__() + self.log_keys = log_keys + self.input_key = input_key + self.optimizer_config = default( + optimizer_config, {"target": "torch.optim.AdamW"} + ) + model = instantiate_from_config(network_config) + self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))( + model, compile_model=compile_model + ) + + self.denoiser = instantiate_from_config(denoiser_config) + self.sampler = ( + instantiate_from_config(sampler_config) + if sampler_config is not None + else None + ) + self.conditioner = instantiate_from_config( + default(conditioner_config, UNCONDITIONAL_CONFIG) + ) + self.scheduler_config = scheduler_config + self._init_first_stage(first_stage_config) + + self.loss_fn = ( + instantiate_from_config(loss_fn_config) + if loss_fn_config is not None + else None + ) + + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self.model, decay=ema_decay_rate) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + self.scale_factor = scale_factor + self.disable_first_stage_autocast = disable_first_stage_autocast + self.no_cond_log = no_cond_log + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path) + + self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time + + def init_from_ckpt( + self, + path: str, + ) -> None: + if path.endswith("ckpt"): + sd = torch.load(path, map_location="cpu")["state_dict"] + elif path.endswith("safetensors"): + sd = load_safetensors(path) + else: + raise NotImplementedError + + missing, unexpected = self.load_state_dict(sd, strict=False) + print( + f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys" + ) + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + def _init_first_stage(self, config): + model = instantiate_from_config(config).eval() + model.train = disabled_train + for param in model.parameters(): + param.requires_grad = False + self.first_stage_model = model + + def get_input(self, batch): + # assuming unified data format, dataloader returns a dict. + # image tensors should be scaled to -1 ... 1 and in bchw format + return batch[self.input_key] + + @torch.no_grad() + def decode_first_stage(self, z): + z = 1.0 / self.scale_factor * z + n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0]) + + n_rounds = math.ceil(z.shape[0] / n_samples) + all_out = [] + with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast): + for n in range(n_rounds): + if isinstance(self.first_stage_model.decoder, VideoDecoder): + kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])} + else: + kwargs = {} + out = self.first_stage_model.decode( + z[n * n_samples : (n + 1) * n_samples], **kwargs + ) + all_out.append(out) + out = torch.cat(all_out, dim=0) + return out + + @torch.no_grad() + def encode_first_stage(self, x): + n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0]) + n_rounds = math.ceil(x.shape[0] / n_samples) + all_out = [] + with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast): + for n in range(n_rounds): + out = self.first_stage_model.encode( + x[n * n_samples : (n + 1) * n_samples] + ) + all_out.append(out) + z = torch.cat(all_out, dim=0) + z = self.scale_factor * z + return z + + def forward(self, x, batch): + loss = self.loss_fn(self.model, self.denoiser, self.conditioner, x, batch) + loss_mean = loss.mean() + loss_dict = {"loss": loss_mean} + return loss_mean, loss_dict + + def shared_step(self, batch: Dict) -> Any: + x = self.get_input(batch) + x = self.encode_first_stage(x) + batch["global_step"] = self.global_step + loss, loss_dict = self(x, batch) + return loss, loss_dict + + def training_step(self, batch, batch_idx): + loss, loss_dict = self.shared_step(batch) + + self.log_dict( + loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False + ) + + self.log( + "global_step", + self.global_step, + prog_bar=True, + logger=True, + on_step=True, + on_epoch=False, + ) + + if self.scheduler_config is not None: + lr = self.optimizers().param_groups[0]["lr"] + self.log( + "lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False + ) + + return loss + + def on_train_start(self, *args, **kwargs): + if self.sampler is None or self.loss_fn is None: + raise ValueError("Sampler and loss function need to be set for training.") + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self.model) + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.model.parameters()) + self.model_ema.copy_to(self.model) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.model.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + def instantiate_optimizer_from_config(self, params, lr, cfg): + return get_obj_from_str(cfg["target"])( + params, lr=lr, **cfg.get("params", dict()) + ) + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + for embedder in self.conditioner.embedders: + if embedder.is_trainable: + params = params + list(embedder.parameters()) + opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config) + if self.scheduler_config is not None: + scheduler = instantiate_from_config(self.scheduler_config) + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule), + "interval": "step", + "frequency": 1, + } + ] + return [opt], scheduler + return opt + + @torch.no_grad() + def sample( + self, + cond: Dict, + uc: Union[Dict, None] = None, + batch_size: int = 16, + shape: Union[None, Tuple, List] = None, + **kwargs, + ): + randn = torch.randn(batch_size, *shape).to(self.device) + + denoiser = lambda input, sigma, c: self.denoiser( + self.model, input, sigma, c, **kwargs + ) + samples = self.sampler(denoiser, randn, cond, uc=uc) + return samples + + @torch.no_grad() + def log_conditionings(self, batch: Dict, n: int) -> Dict: + """ + Defines heuristics to log different conditionings. + These can be lists of strings (text-to-image), tensors, ints, ... + """ + image_h, image_w = batch[self.input_key].shape[2:] + log = dict() + + for embedder in self.conditioner.embedders: + if ( + (self.log_keys is None) or (embedder.input_key in self.log_keys) + ) and not self.no_cond_log: + x = batch[embedder.input_key][:n] + if isinstance(x, torch.Tensor): + if x.dim() == 1: + # class-conditional, convert integer to string + x = [str(x[i].item()) for i in range(x.shape[0])] + xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4) + elif x.dim() == 2: + # size and crop cond and the like + x = [ + "x".join([str(xx) for xx in x[i].tolist()]) + for i in range(x.shape[0]) + ] + xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20) + else: + raise NotImplementedError() + elif isinstance(x, (List, ListConfig)): + if isinstance(x[0], str): + # strings + xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20) + else: + raise NotImplementedError() + else: + raise NotImplementedError() + log[embedder.input_key] = xc + return log + + @torch.no_grad() + def log_images( + self, + batch: Dict, + N: int = 8, + sample: bool = True, + ucg_keys: List[str] = None, + **kwargs, + ) -> Dict: + conditioner_input_keys = [e.input_key for e in self.conditioner.embedders] + if ucg_keys: + assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), ( + "Each defined ucg key for sampling must be in the provided conditioner input keys," + f"but we have {ucg_keys} vs. {conditioner_input_keys}" + ) + else: + ucg_keys = conditioner_input_keys + log = dict() + + x = self.get_input(batch) + + c, uc = self.conditioner.get_unconditional_conditioning( + batch, + force_uc_zero_embeddings=ucg_keys + if len(self.conditioner.embedders) > 0 + else [], + ) + + sampling_kwargs = {} + + N = min(x.shape[0], N) + x = x.to(self.device)[:N] + log["inputs"] = x + z = self.encode_first_stage(x) + log["reconstructions"] = self.decode_first_stage(z) + log.update(self.log_conditionings(batch, N)) + + for k in c: + if isinstance(c[k], torch.Tensor): + c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc)) + + if sample: + with self.ema_scope("Plotting"): + samples = self.sample( + c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs + ) + samples = self.decode_first_stage(samples) + log["samples"] = samples + return log diff --git a/models/svd/sgm/modules/__init__.py b/models/svd/sgm/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e6e52d7fd95ec7bbd8b5ca4b367405f517f3a319 --- /dev/null +++ b/models/svd/sgm/modules/__init__.py @@ -0,0 +1,6 @@ +from models.svd.sgm.modules.encoders.modules import GeneralConditioner + +UNCONDITIONAL_CONFIG = { + "target": "sgm.modules.GeneralConditioner", + "params": {"emb_models": []}, +} diff --git a/models/svd/sgm/modules/attention.py b/models/svd/sgm/modules/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..3f3e00f0af97b90bf20a5ed3821762a7921cb140 --- /dev/null +++ b/models/svd/sgm/modules/attention.py @@ -0,0 +1,809 @@ +import logging +import math +from inspect import isfunction +from typing import Any, Optional + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from packaging import version +from torch import nn +from torch.utils.checkpoint import checkpoint + +logpy = logging.getLogger(__name__) + +if version.parse(torch.__version__) >= version.parse("2.0.0"): + SDP_IS_AVAILABLE = True + from torch.backends.cuda import SDPBackend, sdp_kernel + + BACKEND_MAP = { + SDPBackend.MATH: { + "enable_math": True, + "enable_flash": False, + "enable_mem_efficient": False, + }, + SDPBackend.FLASH_ATTENTION: { + "enable_math": False, + "enable_flash": True, + "enable_mem_efficient": False, + }, + SDPBackend.EFFICIENT_ATTENTION: { + "enable_math": False, + "enable_flash": False, + "enable_mem_efficient": True, + }, + None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True}, + } +else: + from contextlib import nullcontext + + SDP_IS_AVAILABLE = False + sdp_kernel = nullcontext + BACKEND_MAP = {} + logpy.warn( + f"No SDP backend available, likely because you are running in pytorch " + f"versions < 2.0. In fact, you are using PyTorch {torch.__version__}. " + f"You might want to consider upgrading." + ) + +try: + import xformers + import xformers.ops + + XFORMERS_IS_AVAILABLE = True +except: + XFORMERS_IS_AVAILABLE = False + logpy.warn("no module 'xformers'. Processing without...") + +# from .diffusionmodules.util import mixed_checkpoint as checkpoint + + +def exists(val): + return val is not None + + +def uniq(arr): + return {el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = ( + nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) + if not glu + else GEGLU(dim, inner_dim) + ) + + self.net = nn.Sequential( + project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels): + return torch.nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True + ) + + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange( + qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3 + ) + k = k.softmax(dim=-1) + context = torch.einsum("bhdn,bhen->bhde", k, v) + out = torch.einsum("bhde,bhdn->bhen", context, q) + out = rearrange( + out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w + ) + return self.to_out(out) + + +class SelfAttention(nn.Module): + ATTENTION_MODES = ("xformers", "torch", "math") + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_scale: Optional[float] = None, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + attn_mode: str = "xformers", + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + assert attn_mode in self.ATTENTION_MODES + self.attn_mode = attn_mode + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, L, C = x.shape + + qkv = self.qkv(x) + if self.attn_mode == "torch": + qkv = rearrange( + qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads + ).float() + q, k, v = qkv[0], qkv[1], qkv[2] # B H L D + x = torch.nn.functional.scaled_dot_product_attention(q, k, v) + x = rearrange(x, "B H L D -> B L (H D)") + elif self.attn_mode == "xformers": + qkv = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads) + q, k, v = qkv[0], qkv[1], qkv[2] # B L H D + x = xformers.ops.memory_efficient_attention(q, k, v) + x = rearrange(x, "B L H D -> B L (H D)", H=self.num_heads) + elif self.attn_mode == "math": + qkv = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k, v = qkv[0], qkv[1], qkv[2] # B H L D + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2).reshape(B, L, C) + else: + raise NotImplemented + + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b (h w) c") + k = rearrange(k, "b c h w -> b c (h w)") + w_ = torch.einsum("bij,bjk->bik", q, k) + + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, "b c h w -> b c (h w)") + w_ = rearrange(w_, "b i j -> b j i") + h_ = torch.einsum("bij,bjk->bik", v, w_) + h_ = rearrange(h_, "b c (h w) -> b c h w", h=h) + h_ = self.proj_out(h_) + + return x + h_ + + +class CrossAttention(nn.Module): + def __init__( + self, + query_dim, + context_dim=None, + heads=8, + dim_head=64, + dropout=0.0, + backend=None, + ): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head**-0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) + ) + self.backend = backend + + def forward( + self, + x, + context=None, + mask=None, + additional_tokens=None, + n_times_crossframe_attn_in_self=0, + ): + h = self.heads + + if additional_tokens is not None: + # get the number of masked tokens at the beginning of the output sequence + n_tokens_to_mask = additional_tokens.shape[1] + # add additional token + x = torch.cat([additional_tokens, x], dim=1) + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + if n_times_crossframe_attn_in_self: + # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439 + assert x.shape[0] % n_times_crossframe_attn_in_self == 0 + n_cp = x.shape[0] // n_times_crossframe_attn_in_self + k = repeat( + k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp + ) + v = repeat( + v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp + ) + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) + + ## old + """ + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + del q, k + + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + sim = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', sim, v) + """ + ## new + with sdp_kernel(**BACKEND_MAP[self.backend]): + # print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape) + out = F.scaled_dot_product_attention( + q, k, v, attn_mask=mask + ) # scale is dim_head ** -0.5 per default + + del q, k, v + out = rearrange(out, "b h n d -> b n (h d)", h=h) + + if additional_tokens is not None: + # remove additional token + out = out[:, n_tokens_to_mask:] + return self.to_out(out) + + +class MemoryEfficientCrossAttention(nn.Module): + # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 + def __init__( + self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs + ): + super().__init__() + logpy.debug( + f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, " + f"context_dim is {context_dim} and using {heads} heads with a " + f"dimension of {dim_head}." + ) + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.heads = heads + self.dim_head = dim_head + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) + ) + self.attention_op: Optional[Any] = None + + def forward( + self, + x, + context=None, + mask=None, + additional_tokens=None, + n_times_crossframe_attn_in_self=0, + ): + if additional_tokens is not None: + # get the number of masked tokens at the beginning of the output sequence + n_tokens_to_mask = additional_tokens.shape[1] + # add additional token + x = torch.cat([additional_tokens, x], dim=1) + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + if n_times_crossframe_attn_in_self: + # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439 + assert x.shape[0] % n_times_crossframe_attn_in_self == 0 + # n_cp = x.shape[0]//n_times_crossframe_attn_in_self + k = repeat( + k[::n_times_crossframe_attn_in_self], + "b ... -> (b n) ...", + n=n_times_crossframe_attn_in_self, + ) + v = repeat( + v[::n_times_crossframe_attn_in_self], + "b ... -> (b n) ...", + n=n_times_crossframe_attn_in_self, + ) + + b, _, _ = q.shape + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, t.shape[1], self.heads, self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b * self.heads, t.shape[1], self.dim_head) + .contiguous(), + (q, k, v), + ) + + # actually compute the attention, what we cannot get enough of + if version.parse(xformers.__version__) >= version.parse("0.0.21"): + # NOTE: workaround for + # https://github.com/facebookresearch/xformers/issues/845 + max_bs = 32768 + N = q.shape[0] + n_batches = math.ceil(N / max_bs) + out = list() + for i_batch in range(n_batches): + batch = slice(i_batch * max_bs, (i_batch + 1) * max_bs) + out.append( + xformers.ops.memory_efficient_attention( + q[batch], + k[batch], + v[batch], + attn_bias=None, + op=self.attention_op, + ) + ) + out = torch.cat(out, 0) + else: + out = xformers.ops.memory_efficient_attention( + q, k, v, attn_bias=None, op=self.attention_op + ) + + # TODO: Use this directly in the attention operation, as a bias + if exists(mask): + raise NotImplementedError + out = ( + out.unsqueeze(0) + .reshape(b, self.heads, out.shape[1], self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b, out.shape[1], self.heads * self.dim_head) + ) + if additional_tokens is not None: + # remove additional token + out = out[:, n_tokens_to_mask:] + return self.to_out(out) + + + +class BasicTransformerBlock(nn.Module): + ATTENTION_MODES = { + "softmax": CrossAttention, # vanilla attention + "softmax-xformers": MemoryEfficientCrossAttention, # ampere + } + + def __init__( + self, + dim, + n_heads, + d_head, + dropout=0.0, + context_dim=None, + gated_ff=True, + checkpoint=True, + disable_self_attn=False, + attn_mode="softmax", + sdp_backend=None, + ): + super().__init__() + assert attn_mode in self.ATTENTION_MODES + if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE: + logpy.warn( + f"Attention mode '{attn_mode}' is not available. Falling " + f"back to native attention. This is not a problem in " + f"Pytorch >= 2.0. FYI, you are running with PyTorch " + f"version {torch.__version__}." + ) + attn_mode = "softmax" + elif attn_mode == "softmax" and not SDP_IS_AVAILABLE: + logpy.warn( + "We do not support vanilla attention anymore, as it is too " + "expensive. Sorry." + ) + if not XFORMERS_IS_AVAILABLE: + assert ( + False + ), "Please install xformers via e.g. 'pip install xformers==0.0.16'" + else: + logpy.info("Falling back to xformers efficient attention.") + attn_mode = "softmax-xformers" + attn_cls = self.ATTENTION_MODES[attn_mode] + if version.parse(torch.__version__) >= version.parse("2.0.0"): + assert sdp_backend is None or isinstance(sdp_backend, SDPBackend) + else: + assert sdp_backend is None + self.disable_self_attn = disable_self_attn + self.attn1 = attn_cls( + query_dim=dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + context_dim=context_dim if self.disable_self_attn else None, + backend=sdp_backend, + ) # is a self-attention if not self.disable_self_attn + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = attn_cls( + query_dim=dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + backend=sdp_backend, + ) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + if self.checkpoint: + logpy.debug(f"{self.__class__.__name__} is using checkpointing") + + + def forward( + self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0 + ): + kwargs = {"x": x} + + if context is not None: + kwargs.update({"context": context}) + + if additional_tokens is not None: + kwargs.update({"additional_tokens": additional_tokens}) + + if n_times_crossframe_attn_in_self: + kwargs.update( + {"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self} + ) + + # return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint) + if self.checkpoint: + # inputs = {"x": x, "context": context} + return checkpoint(self._forward, x, context) + # return checkpoint(self._forward, inputs, self.parameters(), self.checkpoint) + else: + return self._forward(**kwargs) + + def _forward( + self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0 + ): + x = ( + self.attn1( + self.norm1(x), + context=context if self.disable_self_attn else None, + additional_tokens=additional_tokens, + n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self + if not self.disable_self_attn + else 0, + ) + + x + ) + x = ( + self.attn2( + self.norm2(x), context=context, additional_tokens=additional_tokens + ) + + x + ) + x = self.ff(self.norm3(x)) + x + return x + + +class BasicTransformerBlockWithAPM(BasicTransformerBlock): + + def __init__(self, dim, n_heads, d_head, dropout=0, context_dim=None, gated_ff=True, checkpoint=True, disable_self_attn=False, attn_mode="softmax", sdp_backend=None,use_apm=False): + super().__init__(dim, n_heads, d_head, dropout, context_dim, gated_ff, checkpoint, disable_self_attn, attn_mode, sdp_backend) + # APM Addition + assert disable_self_attn == False + self.use_apm = use_apm + if use_apm: + tokens_apm_clip = 16+1 + self.apm_conv = torch.nn.Conv1d( + tokens_apm_clip, 1, kernel_size=3, padding="same") + channel_dim_context = 1024 + self.apm_ln = nn.LayerNorm(channel_dim_context) + self.apm_alpha = nn.Parameter(torch.tensor(0.)) + + + def forward(self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0 + ): + if context is not None and context.shape[1]>1 and self.use_apm: + print("using APM CONTEXT !!!!") + context_svd = context[:,:1] + context_mixed = self.apm_conv(context) + context_mixed = self.apm_ln(context_mixed) + context = context_svd + context_mixed * F.silu(self.apm_alpha) + return super().forward(x=x,context=context,additional_tokens=additional_tokens,n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self) + + +class BasicTransformerSingleLayerBlock(nn.Module): + ATTENTION_MODES = { + "softmax": CrossAttention, # vanilla attention + "softmax-xformers": MemoryEfficientCrossAttention # on the A100s not quite as fast as the above version + # (todo might depend on head_dim, check, falls back to semi-optimized kernels for dim!=[16,32,64,128]) + } + + def __init__( + self, + dim, + n_heads, + d_head, + dropout=0.0, + context_dim=None, + gated_ff=True, + checkpoint=True, + attn_mode="softmax", + ): + super().__init__() + assert attn_mode in self.ATTENTION_MODES + attn_cls = self.ATTENTION_MODES[attn_mode] + self.attn1 = attn_cls( + query_dim=dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + context_dim=context_dim, + ) + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None): + # inputs = {"x": x, "context": context} + # return checkpoint(self._forward, inputs, self.parameters(), self.checkpoint) + return checkpoint(self._forward, x, context) + + def _forward(self, x, context=None): + x = self.attn1(self.norm1(x), context=context) + x + x = self.ff(self.norm2(x)) + x + return x + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + NEW: use_linear for more efficiency instead of the 1x1 convs + """ + + def __init__( + self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0.0, + context_dim=None, + disable_self_attn=False, + use_linear=False, + attn_type="softmax", + use_checkpoint=True, + # sdp_backend=SDPBackend.FLASH_ATTENTION + sdp_backend=None, + use_apm:bool =False, + ): + super().__init__() + logpy.debug( + f"constructing {self.__class__.__name__} of depth {depth} w/ " + f"{in_channels} channels and {n_heads} heads." + ) + + if exists(context_dim) and not isinstance(context_dim, list): + context_dim = [context_dim] + if exists(context_dim) and isinstance(context_dim, list): + if depth != len(context_dim): + logpy.warn( + f"{self.__class__.__name__}: Found context dims " + f"{context_dim} of depth {len(context_dim)}, which does not " + f"match the specified 'depth' of {depth}. Setting context_dim " + f"to {depth * [context_dim[0]]} now." + ) + # depth does not match context dims. + assert all( + map(lambda x: x == context_dim[0], context_dim) + ), "need homogenous context_dim to match depth automatically" + context_dim = depth * [context_dim[0]] + elif context_dim is None: + context_dim = [None] * depth + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + if not use_linear: + self.proj_in = nn.Conv2d( + in_channels, inner_dim, kernel_size=1, stride=1, padding=0 + ) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) + + if use_apm: + print("APM TRANSFORMER BLOCK") + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlockWithAPM( + inner_dim, + n_heads, + d_head, + dropout=dropout, + context_dim=context_dim[d], + disable_self_attn=disable_self_attn, + attn_mode=attn_type, + checkpoint=use_checkpoint, + sdp_backend=sdp_backend, + use_apm=use_apm, + ) + for d in range(depth) + ] + ) + else: + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + n_heads, + d_head, + dropout=dropout, + context_dim=context_dim[d], + disable_self_attn=disable_self_attn, + attn_mode=attn_type, + checkpoint=use_checkpoint, + sdp_backend=sdp_backend, + ) + for d in range(depth) + ] + ) + if not use_linear: + self.proj_out = zero_module( + nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + ) + else: + # self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) + self.proj_out = zero_module(nn.Linear(inner_dim, in_channels)) + self.use_linear = use_linear + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + if not isinstance(context, list): + context = [context] + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + if not self.use_linear: + x = self.proj_in(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + if self.use_linear: + x = self.proj_in(x) + for i, block in enumerate(self.transformer_blocks): + if i > 0 and len(context) == 1: + i = 0 # use same context for each block + x = block(x, context=context[i]) + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous() + if not self.use_linear: + x = self.proj_out(x) + return x + x_in + + +class SimpleTransformer(nn.Module): + def __init__( + self, + dim: int, + depth: int, + heads: int, + dim_head: int, + context_dim: Optional[int] = None, + dropout: float = 0.0, + checkpoint: bool = True, + ): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + BasicTransformerBlock( + dim, + heads, + dim_head, + dropout=dropout, + context_dim=context_dim, + attn_mode="softmax-xformers", + checkpoint=checkpoint, + ) + ) + + def forward( + self, + x: torch.Tensor, + context: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + for layer in self.layers: + x = layer(x, context) + return x diff --git a/models/svd/sgm/modules/autoencoding/__init__.py b/models/svd/sgm/modules/autoencoding/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/svd/sgm/modules/autoencoding/losses/__init__.py b/models/svd/sgm/modules/autoencoding/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6b316c7aa6ea1c5e31a58987aa3b37b2933eb7e2 --- /dev/null +++ b/models/svd/sgm/modules/autoencoding/losses/__init__.py @@ -0,0 +1,7 @@ +__all__ = [ + "GeneralLPIPSWithDiscriminator", + "LatentLPIPS", +] + +from .discriminator_loss import GeneralLPIPSWithDiscriminator +from .lpips import LatentLPIPS diff --git a/models/svd/sgm/modules/autoencoding/losses/discriminator_loss.py b/models/svd/sgm/modules/autoencoding/losses/discriminator_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..09b6829267bf8e4d98c3f29abdc19e58dcbcbe64 --- /dev/null +++ b/models/svd/sgm/modules/autoencoding/losses/discriminator_loss.py @@ -0,0 +1,306 @@ +from typing import Dict, Iterator, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torchvision +from einops import rearrange +from matplotlib import colormaps +from matplotlib import pyplot as plt + +from ....util import default, instantiate_from_config +from ..lpips.loss.lpips import LPIPS +from ..lpips.model.model import weights_init +from ..lpips.vqperceptual import hinge_d_loss, vanilla_d_loss + + +class GeneralLPIPSWithDiscriminator(nn.Module): + def __init__( + self, + disc_start: int, + logvar_init: float = 0.0, + disc_num_layers: int = 3, + disc_in_channels: int = 3, + disc_factor: float = 1.0, + disc_weight: float = 1.0, + perceptual_weight: float = 1.0, + disc_loss: str = "hinge", + scale_input_to_tgt_size: bool = False, + dims: int = 2, + learn_logvar: bool = False, + regularization_weights: Union[None, Dict[str, float]] = None, + additional_log_keys: Optional[List[str]] = None, + discriminator_config: Optional[Dict] = None, + ): + super().__init__() + self.dims = dims + if self.dims > 2: + print( + f"running with dims={dims}. This means that for perceptual loss " + f"calculation, the LPIPS loss will be applied to each frame " + f"independently." + ) + self.scale_input_to_tgt_size = scale_input_to_tgt_size + assert disc_loss in ["hinge", "vanilla"] + self.perceptual_loss = LPIPS().eval() + self.perceptual_weight = perceptual_weight + # output log variance + self.logvar = nn.Parameter( + torch.full((), logvar_init), requires_grad=learn_logvar + ) + self.learn_logvar = learn_logvar + + discriminator_config = default( + discriminator_config, + { + "target": "sgm.modules.autoencoding.lpips.model.model.NLayerDiscriminator", + "params": { + "input_nc": disc_in_channels, + "n_layers": disc_num_layers, + "use_actnorm": False, + }, + }, + ) + + self.discriminator = instantiate_from_config(discriminator_config).apply( + weights_init + ) + self.discriminator_iter_start = disc_start + self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss + self.disc_factor = disc_factor + self.discriminator_weight = disc_weight + self.regularization_weights = default(regularization_weights, {}) + + self.forward_keys = [ + "optimizer_idx", + "global_step", + "last_layer", + "split", + "regularization_log", + ] + + self.additional_log_keys = set(default(additional_log_keys, [])) + self.additional_log_keys.update(set(self.regularization_weights.keys())) + + def get_trainable_parameters(self) -> Iterator[nn.Parameter]: + return self.discriminator.parameters() + + def get_trainable_autoencoder_parameters(self) -> Iterator[nn.Parameter]: + if self.learn_logvar: + yield self.logvar + yield from () + + @torch.no_grad() + def log_images( + self, inputs: torch.Tensor, reconstructions: torch.Tensor + ) -> Dict[str, torch.Tensor]: + # calc logits of real/fake + logits_real = self.discriminator(inputs.contiguous().detach()) + if len(logits_real.shape) < 4: + # Non patch-discriminator + return dict() + logits_fake = self.discriminator(reconstructions.contiguous().detach()) + # -> (b, 1, h, w) + + # parameters for colormapping + high = max(logits_fake.abs().max(), logits_real.abs().max()).item() + cmap = colormaps["PiYG"] # diverging colormap + + def to_colormap(logits: torch.Tensor) -> torch.Tensor: + """(b, 1, ...) -> (b, 3, ...)""" + logits = (logits + high) / (2 * high) + logits_np = cmap(logits.cpu().numpy())[..., :3] # truncate alpha channel + # -> (b, 1, ..., 3) + logits = torch.from_numpy(logits_np).to(logits.device) + return rearrange(logits, "b 1 ... c -> b c ...") + + logits_real = torch.nn.functional.interpolate( + logits_real, + size=inputs.shape[-2:], + mode="nearest", + antialias=False, + ) + logits_fake = torch.nn.functional.interpolate( + logits_fake, + size=reconstructions.shape[-2:], + mode="nearest", + antialias=False, + ) + + # alpha value of logits for overlay + alpha_real = torch.abs(logits_real) / high + alpha_fake = torch.abs(logits_fake) / high + # -> (b, 1, h, w) in range [0, 0.5] + # alpha value of lines don't really matter, since the values are the same + # for both images and logits anyway + grid_alpha_real = torchvision.utils.make_grid(alpha_real, nrow=4) + grid_alpha_fake = torchvision.utils.make_grid(alpha_fake, nrow=4) + grid_alpha = 0.8 * torch.cat((grid_alpha_real, grid_alpha_fake), dim=1) + # -> (1, h, w) + # blend logits and images together + + # prepare logits for plotting + logits_real = to_colormap(logits_real) + logits_fake = to_colormap(logits_fake) + # resize logits + # -> (b, 3, h, w) + + # make some grids + # add all logits to one plot + logits_real = torchvision.utils.make_grid(logits_real, nrow=4) + logits_fake = torchvision.utils.make_grid(logits_fake, nrow=4) + # I just love how torchvision calls the number of columns `nrow` + grid_logits = torch.cat((logits_real, logits_fake), dim=1) + # -> (3, h, w) + + grid_images_real = torchvision.utils.make_grid(0.5 * inputs + 0.5, nrow=4) + grid_images_fake = torchvision.utils.make_grid( + 0.5 * reconstructions + 0.5, nrow=4 + ) + grid_images = torch.cat((grid_images_real, grid_images_fake), dim=1) + # -> (3, h, w) in range [0, 1] + + grid_blend = grid_alpha * grid_logits + (1 - grid_alpha) * grid_images + + # Create labeled colorbar + dpi = 100 + height = 128 / dpi + width = grid_logits.shape[2] / dpi + fig, ax = plt.subplots(figsize=(width, height), dpi=dpi) + img = ax.imshow(np.array([[-high, high]]), cmap=cmap) + plt.colorbar( + img, + cax=ax, + orientation="horizontal", + fraction=0.9, + aspect=width / height, + pad=0.0, + ) + img.set_visible(False) + fig.tight_layout() + fig.canvas.draw() + # manually convert figure to numpy + cbar_np = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) + cbar_np = cbar_np.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + cbar = torch.from_numpy(cbar_np.copy()).to(grid_logits.dtype) / 255.0 + cbar = rearrange(cbar, "h w c -> c h w").to(grid_logits.device) + + # Add colorbar to plot + annotated_grid = torch.cat((grid_logits, cbar), dim=1) + blended_grid = torch.cat((grid_blend, cbar), dim=1) + return { + "vis_logits": 2 * annotated_grid[None, ...] - 1, + "vis_logits_blended": 2 * blended_grid[None, ...] - 1, + } + + def calculate_adaptive_weight( + self, nll_loss: torch.Tensor, g_loss: torch.Tensor, last_layer: torch.Tensor + ) -> torch.Tensor: + nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + + d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() + d_weight = d_weight * self.discriminator_weight + return d_weight + + def forward( + self, + inputs: torch.Tensor, + reconstructions: torch.Tensor, + *, # added because I changed the order here + regularization_log: Dict[str, torch.Tensor], + optimizer_idx: int, + global_step: int, + last_layer: torch.Tensor, + split: str = "train", + weights: Union[None, float, torch.Tensor] = None, + ) -> Tuple[torch.Tensor, dict]: + if self.scale_input_to_tgt_size: + inputs = torch.nn.functional.interpolate( + inputs, reconstructions.shape[2:], mode="bicubic", antialias=True + ) + + if self.dims > 2: + inputs, reconstructions = map( + lambda x: rearrange(x, "b c t h w -> (b t) c h w"), + (inputs, reconstructions), + ) + + rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) + if self.perceptual_weight > 0: + p_loss = self.perceptual_loss( + inputs.contiguous(), reconstructions.contiguous() + ) + rec_loss = rec_loss + self.perceptual_weight * p_loss + + nll_loss, weighted_nll_loss = self.get_nll_loss(rec_loss, weights) + + # now the GAN part + if optimizer_idx == 0: + # generator update + if global_step >= self.discriminator_iter_start or not self.training: + logits_fake = self.discriminator(reconstructions.contiguous()) + g_loss = -torch.mean(logits_fake) + if self.training: + d_weight = self.calculate_adaptive_weight( + nll_loss, g_loss, last_layer=last_layer + ) + else: + d_weight = torch.tensor(1.0) + else: + d_weight = torch.tensor(0.0) + g_loss = torch.tensor(0.0, requires_grad=True) + + loss = weighted_nll_loss + d_weight * self.disc_factor * g_loss + log = dict() + for k in regularization_log: + if k in self.regularization_weights: + loss = loss + self.regularization_weights[k] * regularization_log[k] + if k in self.additional_log_keys: + log[f"{split}/{k}"] = regularization_log[k].detach().float().mean() + + log.update( + { + f"{split}/loss/total": loss.clone().detach().mean(), + f"{split}/loss/nll": nll_loss.detach().mean(), + f"{split}/loss/rec": rec_loss.detach().mean(), + f"{split}/loss/g": g_loss.detach().mean(), + f"{split}/scalars/logvar": self.logvar.detach(), + f"{split}/scalars/d_weight": d_weight.detach(), + } + ) + + return loss, log + elif optimizer_idx == 1: + # second pass for discriminator update + logits_real = self.discriminator(inputs.contiguous().detach()) + logits_fake = self.discriminator(reconstructions.contiguous().detach()) + + if global_step >= self.discriminator_iter_start or not self.training: + d_loss = self.disc_factor * self.disc_loss(logits_real, logits_fake) + else: + d_loss = torch.tensor(0.0, requires_grad=True) + + log = { + f"{split}/loss/disc": d_loss.clone().detach().mean(), + f"{split}/logits/real": logits_real.detach().mean(), + f"{split}/logits/fake": logits_fake.detach().mean(), + } + return d_loss, log + else: + raise NotImplementedError(f"Unknown optimizer_idx {optimizer_idx}") + + def get_nll_loss( + self, + rec_loss: torch.Tensor, + weights: Optional[Union[float, torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar + weighted_nll_loss = nll_loss + if weights is not None: + weighted_nll_loss = weights * nll_loss + weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] + nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] + + return nll_loss, weighted_nll_loss diff --git a/models/svd/sgm/modules/autoencoding/losses/lpips.py b/models/svd/sgm/modules/autoencoding/losses/lpips.py new file mode 100644 index 0000000000000000000000000000000000000000..b329fcc2ee9477f0122aa7d066866cdfe71ce521 --- /dev/null +++ b/models/svd/sgm/modules/autoencoding/losses/lpips.py @@ -0,0 +1,73 @@ +import torch +import torch.nn as nn + +from ....util import default, instantiate_from_config +from ..lpips.loss.lpips import LPIPS + + +class LatentLPIPS(nn.Module): + def __init__( + self, + decoder_config, + perceptual_weight=1.0, + latent_weight=1.0, + scale_input_to_tgt_size=False, + scale_tgt_to_input_size=False, + perceptual_weight_on_inputs=0.0, + ): + super().__init__() + self.scale_input_to_tgt_size = scale_input_to_tgt_size + self.scale_tgt_to_input_size = scale_tgt_to_input_size + self.init_decoder(decoder_config) + self.perceptual_loss = LPIPS().eval() + self.perceptual_weight = perceptual_weight + self.latent_weight = latent_weight + self.perceptual_weight_on_inputs = perceptual_weight_on_inputs + + def init_decoder(self, config): + self.decoder = instantiate_from_config(config) + if hasattr(self.decoder, "encoder"): + del self.decoder.encoder + + def forward(self, latent_inputs, latent_predictions, image_inputs, split="train"): + log = dict() + loss = (latent_inputs - latent_predictions) ** 2 + log[f"{split}/latent_l2_loss"] = loss.mean().detach() + image_reconstructions = None + if self.perceptual_weight > 0.0: + image_reconstructions = self.decoder.decode(latent_predictions) + image_targets = self.decoder.decode(latent_inputs) + perceptual_loss = self.perceptual_loss( + image_targets.contiguous(), image_reconstructions.contiguous() + ) + loss = ( + self.latent_weight * loss.mean() + + self.perceptual_weight * perceptual_loss.mean() + ) + log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach() + + if self.perceptual_weight_on_inputs > 0.0: + image_reconstructions = default( + image_reconstructions, self.decoder.decode(latent_predictions) + ) + if self.scale_input_to_tgt_size: + image_inputs = torch.nn.functional.interpolate( + image_inputs, + image_reconstructions.shape[2:], + mode="bicubic", + antialias=True, + ) + elif self.scale_tgt_to_input_size: + image_reconstructions = torch.nn.functional.interpolate( + image_reconstructions, + image_inputs.shape[2:], + mode="bicubic", + antialias=True, + ) + + perceptual_loss2 = self.perceptual_loss( + image_inputs.contiguous(), image_reconstructions.contiguous() + ) + loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean() + log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach() + return loss, log diff --git a/models/svd/sgm/modules/autoencoding/lpips/__init__.py b/models/svd/sgm/modules/autoencoding/lpips/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/svd/sgm/modules/autoencoding/lpips/loss/.gitignore b/models/svd/sgm/modules/autoencoding/lpips/loss/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..a92958a1cd4ffe005e1f5448ab3e6fd9c795a43a --- /dev/null +++ b/models/svd/sgm/modules/autoencoding/lpips/loss/.gitignore @@ -0,0 +1 @@ +vgg.pth \ No newline at end of file diff --git a/models/svd/sgm/modules/autoencoding/lpips/loss/LICENSE b/models/svd/sgm/modules/autoencoding/lpips/loss/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..924cfc85b8d63ef538f5676f830a2a8497932108 --- /dev/null +++ b/models/svd/sgm/modules/autoencoding/lpips/loss/LICENSE @@ -0,0 +1,23 @@ +Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/models/svd/sgm/modules/autoencoding/lpips/loss/__init__.py b/models/svd/sgm/modules/autoencoding/lpips/loss/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/svd/sgm/modules/autoencoding/lpips/loss/lpips.py b/models/svd/sgm/modules/autoencoding/lpips/loss/lpips.py new file mode 100644 index 0000000000000000000000000000000000000000..3e34f3d083674f675a5ca024e9bd27fb77e2b6b5 --- /dev/null +++ b/models/svd/sgm/modules/autoencoding/lpips/loss/lpips.py @@ -0,0 +1,147 @@ +"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" + +from collections import namedtuple + +import torch +import torch.nn as nn +from torchvision import models + +from ..util import get_ckpt_path + + +class LPIPS(nn.Module): + # Learned perceptual metric + def __init__(self, use_dropout=True): + super().__init__() + self.scaling_layer = ScalingLayer() + self.chns = [64, 128, 256, 512, 512] # vg16 features + self.net = vgg16(pretrained=True, requires_grad=False) + self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) + self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) + self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) + self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) + self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) + self.load_from_pretrained() + for param in self.parameters(): + param.requires_grad = False + + def load_from_pretrained(self, name="vgg_lpips"): + ckpt = get_ckpt_path(name, "sgm/modules/autoencoding/lpips/loss") + self.load_state_dict( + torch.load(ckpt, map_location=torch.device("cpu")), strict=False + ) + print("loaded pretrained LPIPS loss from {}".format(ckpt)) + + @classmethod + def from_pretrained(cls, name="vgg_lpips"): + if name != "vgg_lpips": + raise NotImplementedError + model = cls() + ckpt = get_ckpt_path(name) + model.load_state_dict( + torch.load(ckpt, map_location=torch.device("cpu")), strict=False + ) + return model + + def forward(self, input, target): + in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) + outs0, outs1 = self.net(in0_input), self.net(in1_input) + feats0, feats1, diffs = {}, {}, {} + lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] + for kk in range(len(self.chns)): + feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor( + outs1[kk] + ) + diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 + + res = [ + spatial_average(lins[kk].model(diffs[kk]), keepdim=True) + for kk in range(len(self.chns)) + ] + val = res[0] + for l in range(1, len(self.chns)): + val += res[l] + return val + + +class ScalingLayer(nn.Module): + def __init__(self): + super(ScalingLayer, self).__init__() + self.register_buffer( + "shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None] + ) + self.register_buffer( + "scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None] + ) + + def forward(self, inp): + return (inp - self.shift) / self.scale + + +class NetLinLayer(nn.Module): + """A single linear layer which does a 1x1 conv""" + + def __init__(self, chn_in, chn_out=1, use_dropout=False): + super(NetLinLayer, self).__init__() + layers = ( + [ + nn.Dropout(), + ] + if (use_dropout) + else [] + ) + layers += [ + nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), + ] + self.model = nn.Sequential(*layers) + + +class vgg16(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(vgg16, self).__init__() + vgg_pretrained_features = models.vgg16(pretrained=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(4): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(4, 9): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(9, 16): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(16, 23): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(23, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1_2 = h + h = self.slice2(h) + h_relu2_2 = h + h = self.slice3(h) + h_relu3_3 = h + h = self.slice4(h) + h_relu4_3 = h + h = self.slice5(h) + h_relu5_3 = h + vgg_outputs = namedtuple( + "VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"] + ) + out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) + return out + + +def normalize_tensor(x, eps=1e-10): + norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) + return x / (norm_factor + eps) + + +def spatial_average(x, keepdim=True): + return x.mean([2, 3], keepdim=keepdim) diff --git a/models/svd/sgm/modules/autoencoding/lpips/model/LICENSE b/models/svd/sgm/modules/autoencoding/lpips/model/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..4b356e66b5aa689b339f1a80a9f1b5ba378003bb --- /dev/null +++ b/models/svd/sgm/modules/autoencoding/lpips/model/LICENSE @@ -0,0 +1,58 @@ +Copyright (c) 2017, Jun-Yan Zhu and Taesung Park +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +--------------------------- LICENSE FOR pix2pix -------------------------------- +BSD License + +For pix2pix software +Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +----------------------------- LICENSE FOR DCGAN -------------------------------- +BSD License + +For dcgan.torch software + +Copyright (c) 2015, Facebook, Inc. All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/models/svd/sgm/modules/autoencoding/lpips/model/__init__.py b/models/svd/sgm/modules/autoencoding/lpips/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/svd/sgm/modules/autoencoding/lpips/model/model.py b/models/svd/sgm/modules/autoencoding/lpips/model/model.py new file mode 100644 index 0000000000000000000000000000000000000000..66357d4e627f9a69a5abbbad15546c96fcd758fe --- /dev/null +++ b/models/svd/sgm/modules/autoencoding/lpips/model/model.py @@ -0,0 +1,88 @@ +import functools + +import torch.nn as nn + +from ..util import ActNorm + + +def weights_init(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + nn.init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find("BatchNorm") != -1: + nn.init.normal_(m.weight.data, 1.0, 0.02) + nn.init.constant_(m.bias.data, 0) + + +class NLayerDiscriminator(nn.Module): + """Defines a PatchGAN discriminator as in Pix2Pix + --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py + """ + + def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): + """Construct a PatchGAN discriminator + Parameters: + input_nc (int) -- the number of channels in input images + ndf (int) -- the number of filters in the last conv layer + n_layers (int) -- the number of conv layers in the discriminator + norm_layer -- normalization layer + """ + super(NLayerDiscriminator, self).__init__() + if not use_actnorm: + norm_layer = nn.BatchNorm2d + else: + norm_layer = ActNorm + if ( + type(norm_layer) == functools.partial + ): # no need to use bias as BatchNorm2d has affine parameters + use_bias = norm_layer.func != nn.BatchNorm2d + else: + use_bias = norm_layer != nn.BatchNorm2d + + kw = 4 + padw = 1 + sequence = [ + nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), + nn.LeakyReLU(0.2, True), + ] + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, n_layers): # gradually increase the number of filters + nf_mult_prev = nf_mult + nf_mult = min(2**n, 8) + sequence += [ + nn.Conv2d( + ndf * nf_mult_prev, + ndf * nf_mult, + kernel_size=kw, + stride=2, + padding=padw, + bias=use_bias, + ), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True), + ] + + nf_mult_prev = nf_mult + nf_mult = min(2**n_layers, 8) + sequence += [ + nn.Conv2d( + ndf * nf_mult_prev, + ndf * nf_mult, + kernel_size=kw, + stride=1, + padding=padw, + bias=use_bias, + ), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True), + ] + + sequence += [ + nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw) + ] # output 1 channel prediction map + self.main = nn.Sequential(*sequence) + + def forward(self, input): + """Standard forward.""" + return self.main(input) diff --git a/models/svd/sgm/modules/autoencoding/lpips/util.py b/models/svd/sgm/modules/autoencoding/lpips/util.py new file mode 100644 index 0000000000000000000000000000000000000000..49c76e370bf16888ab61f42844b3c9f14ad9014c --- /dev/null +++ b/models/svd/sgm/modules/autoencoding/lpips/util.py @@ -0,0 +1,128 @@ +import hashlib +import os + +import requests +import torch +import torch.nn as nn +from tqdm import tqdm + +URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"} + +CKPT_MAP = {"vgg_lpips": "vgg.pth"} + +MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"} + + +def download(url, local_path, chunk_size=1024): + os.makedirs(os.path.split(local_path)[0], exist_ok=True) + with requests.get(url, stream=True) as r: + total_size = int(r.headers.get("content-length", 0)) + with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: + with open(local_path, "wb") as f: + for data in r.iter_content(chunk_size=chunk_size): + if data: + f.write(data) + pbar.update(chunk_size) + + +def md5_hash(path): + with open(path, "rb") as f: + content = f.read() + return hashlib.md5(content).hexdigest() + + +def get_ckpt_path(name, root, check=False): + assert name in URL_MAP + path = os.path.join(root, CKPT_MAP[name]) + if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): + print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) + download(URL_MAP[name], path) + md5 = md5_hash(path) + assert md5 == MD5_MAP[name], md5 + return path + + +class ActNorm(nn.Module): + def __init__( + self, num_features, logdet=False, affine=True, allow_reverse_init=False + ): + assert affine + super().__init__() + self.logdet = logdet + self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) + self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) + self.allow_reverse_init = allow_reverse_init + + self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8)) + + def initialize(self, input): + with torch.no_grad(): + flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) + mean = ( + flatten.mean(1) + .unsqueeze(1) + .unsqueeze(2) + .unsqueeze(3) + .permute(1, 0, 2, 3) + ) + std = ( + flatten.std(1) + .unsqueeze(1) + .unsqueeze(2) + .unsqueeze(3) + .permute(1, 0, 2, 3) + ) + + self.loc.data.copy_(-mean) + self.scale.data.copy_(1 / (std + 1e-6)) + + def forward(self, input, reverse=False): + if reverse: + return self.reverse(input) + if len(input.shape) == 2: + input = input[:, :, None, None] + squeeze = True + else: + squeeze = False + + _, _, height, width = input.shape + + if self.training and self.initialized.item() == 0: + self.initialize(input) + self.initialized.fill_(1) + + h = self.scale * (input + self.loc) + + if squeeze: + h = h.squeeze(-1).squeeze(-1) + + if self.logdet: + log_abs = torch.log(torch.abs(self.scale)) + logdet = height * width * torch.sum(log_abs) + logdet = logdet * torch.ones(input.shape[0]).to(input) + return h, logdet + + return h + + def reverse(self, output): + if self.training and self.initialized.item() == 0: + if not self.allow_reverse_init: + raise RuntimeError( + "Initializing ActNorm in reverse direction is " + "disabled by default. Use allow_reverse_init=True to enable." + ) + else: + self.initialize(output) + self.initialized.fill_(1) + + if len(output.shape) == 2: + output = output[:, :, None, None] + squeeze = True + else: + squeeze = False + + h = output / self.scale - self.loc + + if squeeze: + h = h.squeeze(-1).squeeze(-1) + return h diff --git a/models/svd/sgm/modules/autoencoding/lpips/vqperceptual.py b/models/svd/sgm/modules/autoencoding/lpips/vqperceptual.py new file mode 100644 index 0000000000000000000000000000000000000000..6195f0a6ed7ee6fd32c1bccea071e6075e95ee43 --- /dev/null +++ b/models/svd/sgm/modules/autoencoding/lpips/vqperceptual.py @@ -0,0 +1,17 @@ +import torch +import torch.nn.functional as F + + +def hinge_d_loss(logits_real, logits_fake): + loss_real = torch.mean(F.relu(1.0 - logits_real)) + loss_fake = torch.mean(F.relu(1.0 + logits_fake)) + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + + +def vanilla_d_loss(logits_real, logits_fake): + d_loss = 0.5 * ( + torch.mean(torch.nn.functional.softplus(-logits_real)) + + torch.mean(torch.nn.functional.softplus(logits_fake)) + ) + return d_loss diff --git a/models/svd/sgm/modules/autoencoding/regularizers/__init__.py b/models/svd/sgm/modules/autoencoding/regularizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..279aba838fe783b220da06544fbd731eaa15ad10 --- /dev/null +++ b/models/svd/sgm/modules/autoencoding/regularizers/__init__.py @@ -0,0 +1,31 @@ +from abc import abstractmethod +from typing import Any, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from models.svd.sgm.modules.distributions.distributions import \ + DiagonalGaussianDistribution +from models.svd.sgm.modules.autoencoding.regularizers.base import AbstractRegularizer + + +class DiagonalGaussianRegularizer(AbstractRegularizer): + def __init__(self, sample: bool = True): + super().__init__() + self.sample = sample + + def get_trainable_parameters(self) -> Any: + yield from () + + def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: + log = dict() + posterior = DiagonalGaussianDistribution(z) + if self.sample: + z = posterior.sample() + else: + z = posterior.mode() + kl_loss = posterior.kl() + kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] + log["kl_loss"] = kl_loss + return z, log diff --git a/models/svd/sgm/modules/autoencoding/regularizers/base.py b/models/svd/sgm/modules/autoencoding/regularizers/base.py new file mode 100644 index 0000000000000000000000000000000000000000..fca681bb3c1f4818b57e956e31b98f76077ccb67 --- /dev/null +++ b/models/svd/sgm/modules/autoencoding/regularizers/base.py @@ -0,0 +1,40 @@ +from abc import abstractmethod +from typing import Any, Tuple + +import torch +import torch.nn.functional as F +from torch import nn + + +class AbstractRegularizer(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: + raise NotImplementedError() + + @abstractmethod + def get_trainable_parameters(self) -> Any: + raise NotImplementedError() + + +class IdentityRegularizer(AbstractRegularizer): + def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: + return z, dict() + + def get_trainable_parameters(self) -> Any: + yield from () + + +def measure_perplexity( + predicted_indices: torch.Tensor, num_centroids: int +) -> Tuple[torch.Tensor, torch.Tensor]: + # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py + # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally + encodings = ( + F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids) + ) + avg_probs = encodings.mean(0) + perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() + cluster_use = torch.sum(avg_probs > 0) + return perplexity, cluster_use diff --git a/models/svd/sgm/modules/autoencoding/regularizers/quantize.py b/models/svd/sgm/modules/autoencoding/regularizers/quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..6ea4cf803e763a8778641cea655fc3c88b012610 --- /dev/null +++ b/models/svd/sgm/modules/autoencoding/regularizers/quantize.py @@ -0,0 +1,487 @@ +import logging +from abc import abstractmethod +from typing import Dict, Iterator, Literal, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch import einsum + +from models.svd.sgm.modules.autoencoding.regularizers.base import AbstractRegularizer, measure_perplexity + +logpy = logging.getLogger(__name__) + + +class AbstractQuantizer(AbstractRegularizer): + def __init__(self): + super().__init__() + # Define these in your init + # shape (N,) + self.used: Optional[torch.Tensor] + self.re_embed: int + self.unknown_index: Union[Literal["random"], int] + + def remap_to_used(self, inds: torch.Tensor) -> torch.Tensor: + assert self.used is not None, "You need to define used indices for remap" + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + match = (inds[:, :, None] == used[None, None, ...]).long() + new = match.argmax(-1) + unknown = match.sum(2) < 1 + if self.unknown_index == "random": + new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to( + device=new.device + ) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds: torch.Tensor) -> torch.Tensor: + assert self.used is not None, "You need to define used indices for remap" + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds >= self.used.shape[0]] = 0 # simply set to zero + back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) + return back.reshape(ishape) + + @abstractmethod + def get_codebook_entry( + self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None + ) -> torch.Tensor: + raise NotImplementedError() + + def get_trainable_parameters(self) -> Iterator[torch.nn.Parameter]: + yield from self.parameters() + + +class GumbelQuantizer(AbstractQuantizer): + """ + credit to @karpathy: + https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!) + Gumbel Softmax trick quantizer + Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016 + https://arxiv.org/abs/1611.01144 + """ + + def __init__( + self, + num_hiddens: int, + embedding_dim: int, + n_embed: int, + straight_through: bool = True, + kl_weight: float = 5e-4, + temp_init: float = 1.0, + remap: Optional[str] = None, + unknown_index: str = "random", + loss_key: str = "loss/vq", + ) -> None: + super().__init__() + + self.loss_key = loss_key + self.embedding_dim = embedding_dim + self.n_embed = n_embed + + self.straight_through = straight_through + self.temperature = temp_init + self.kl_weight = kl_weight + + self.proj = nn.Conv2d(num_hiddens, n_embed, 1) + self.embed = nn.Embedding(n_embed, embedding_dim) + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + else: + self.used = None + self.re_embed = n_embed + if unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed + 1 + else: + assert unknown_index == "random" or isinstance( + unknown_index, int + ), "unknown index needs to be 'random', 'extra' or any integer" + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.remap is not None: + logpy.info( + f"Remapping {self.n_embed} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices." + ) + + def forward( + self, z: torch.Tensor, temp: Optional[float] = None, return_logits: bool = False + ) -> Tuple[torch.Tensor, Dict]: + # force hard = True when we are in eval mode, as we must quantize. + # actually, always true seems to work + hard = self.straight_through if self.training else True + temp = self.temperature if temp is None else temp + out_dict = {} + logits = self.proj(z) + if self.remap is not None: + # continue only with used logits + full_zeros = torch.zeros_like(logits) + logits = logits[:, self.used, ...] + + soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard) + if self.remap is not None: + # go back to all entries but unused set to zero + full_zeros[:, self.used, ...] = soft_one_hot + soft_one_hot = full_zeros + z_q = einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight) + + # + kl divergence to the prior loss + qy = F.softmax(logits, dim=1) + diff = ( + self.kl_weight + * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean() + ) + out_dict[self.loss_key] = diff + + ind = soft_one_hot.argmax(dim=1) + out_dict["indices"] = ind + if self.remap is not None: + ind = self.remap_to_used(ind) + + if return_logits: + out_dict["logits"] = logits + + return z_q, out_dict + + def get_codebook_entry(self, indices, shape): + # TODO: shape not yet optional + b, h, w, c = shape + assert b * h * w == indices.shape[0] + indices = rearrange(indices, "(b h w) -> b h w", b=b, h=h, w=w) + if self.remap is not None: + indices = self.unmap_to_all(indices) + one_hot = ( + F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float() + ) + z_q = einsum("b n h w, n d -> b d h w", one_hot, self.embed.weight) + return z_q + + +class VectorQuantizer(AbstractQuantizer): + """ + ____________________________________________ + Discretization bottleneck part of the VQ-VAE. + Inputs: + - n_e : number of embeddings + - e_dim : dimension of embedding + - beta : commitment cost used in loss term, + beta * ||z_e(x)-sg[e]||^2 + _____________________________________________ + """ + + def __init__( + self, + n_e: int, + e_dim: int, + beta: float = 0.25, + remap: Optional[str] = None, + unknown_index: str = "random", + sane_index_shape: bool = False, + log_perplexity: bool = False, + embedding_weight_norm: bool = False, + loss_key: str = "loss/vq", + ): + super().__init__() + self.n_e = n_e + self.e_dim = e_dim + self.beta = beta + self.loss_key = loss_key + + if not embedding_weight_norm: + self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + else: + self.embedding = torch.nn.utils.weight_norm( + nn.Embedding(self.n_e, self.e_dim), dim=1 + ) + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + else: + self.used = None + self.re_embed = n_e + if unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed + 1 + else: + assert unknown_index == "random" or isinstance( + unknown_index, int + ), "unknown index needs to be 'random', 'extra' or any integer" + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.remap is not None: + logpy.info( + f"Remapping {self.n_e} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices." + ) + + self.sane_index_shape = sane_index_shape + self.log_perplexity = log_perplexity + + def forward( + self, + z: torch.Tensor, + ) -> Tuple[torch.Tensor, Dict]: + do_reshape = z.ndim == 4 + if do_reshape: + # # reshape z -> (batch, height, width, channel) and flatten + z = rearrange(z, "b c h w -> b h w c").contiguous() + + else: + assert z.ndim < 4, "No reshaping strategy for inputs > 4 dimensions defined" + z = z.contiguous() + + z_flattened = z.view(-1, self.e_dim) + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + + d = ( + torch.sum(z_flattened**2, dim=1, keepdim=True) + + torch.sum(self.embedding.weight**2, dim=1) + - 2 + * torch.einsum( + "bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n") + ) + ) + + min_encoding_indices = torch.argmin(d, dim=1) + z_q = self.embedding(min_encoding_indices).view(z.shape) + loss_dict = {} + if self.log_perplexity: + perplexity, cluster_usage = measure_perplexity( + min_encoding_indices.detach(), self.n_e + ) + loss_dict.update({"perplexity": perplexity, "cluster_usage": cluster_usage}) + + # compute loss for embedding + loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean( + (z_q - z.detach()) ** 2 + ) + loss_dict[self.loss_key] = loss + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + if do_reshape: + z_q = rearrange(z_q, "b h w c -> b c h w").contiguous() + + if self.remap is not None: + min_encoding_indices = min_encoding_indices.reshape( + z.shape[0], -1 + ) # add batch axis + min_encoding_indices = self.remap_to_used(min_encoding_indices) + min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten + + if self.sane_index_shape: + if do_reshape: + min_encoding_indices = min_encoding_indices.reshape( + z_q.shape[0], z_q.shape[2], z_q.shape[3] + ) + else: + min_encoding_indices = rearrange( + min_encoding_indices, "(b s) 1 -> b s", b=z_q.shape[0] + ) + + loss_dict["min_encoding_indices"] = min_encoding_indices + + return z_q, loss_dict + + def get_codebook_entry( + self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None + ) -> torch.Tensor: + # shape specifying (batch, height, width, channel) + if self.remap is not None: + assert shape is not None, "Need to give shape for remap" + indices = indices.reshape(shape[0], -1) # add batch axis + indices = self.unmap_to_all(indices) + indices = indices.reshape(-1) # flatten again + + # get quantized latent vectors + z_q = self.embedding(indices) + + if shape is not None: + z_q = z_q.view(shape) + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q + + +class EmbeddingEMA(nn.Module): + def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5): + super().__init__() + self.decay = decay + self.eps = eps + weight = torch.randn(num_tokens, codebook_dim) + self.weight = nn.Parameter(weight, requires_grad=False) + self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad=False) + self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False) + self.update = True + + def forward(self, embed_id): + return F.embedding(embed_id, self.weight) + + def cluster_size_ema_update(self, new_cluster_size): + self.cluster_size.data.mul_(self.decay).add_( + new_cluster_size, alpha=1 - self.decay + ) + + def embed_avg_ema_update(self, new_embed_avg): + self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay) + + def weight_update(self, num_tokens): + n = self.cluster_size.sum() + smoothed_cluster_size = ( + (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n + ) + # normalize embedding average with smoothed cluster size + embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1) + self.weight.data.copy_(embed_normalized) + + +class EMAVectorQuantizer(AbstractQuantizer): + def __init__( + self, + n_embed: int, + embedding_dim: int, + beta: float, + decay: float = 0.99, + eps: float = 1e-5, + remap: Optional[str] = None, + unknown_index: str = "random", + loss_key: str = "loss/vq", + ): + super().__init__() + self.codebook_dim = embedding_dim + self.num_tokens = n_embed + self.beta = beta + self.loss_key = loss_key + + self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps) + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + else: + self.used = None + self.re_embed = n_embed + if unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed + 1 + else: + assert unknown_index == "random" or isinstance( + unknown_index, int + ), "unknown index needs to be 'random', 'extra' or any integer" + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.remap is not None: + logpy.info( + f"Remapping {self.n_embed} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices." + ) + + def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Dict]: + # reshape z -> (batch, height, width, channel) and flatten + # z, 'b c h w -> b h w c' + z = rearrange(z, "b c h w -> b h w c") + z_flattened = z.reshape(-1, self.codebook_dim) + + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + d = ( + z_flattened.pow(2).sum(dim=1, keepdim=True) + + self.embedding.weight.pow(2).sum(dim=1) + - 2 * torch.einsum("bd,nd->bn", z_flattened, self.embedding.weight) + ) # 'n d -> d n' + + encoding_indices = torch.argmin(d, dim=1) + + z_q = self.embedding(encoding_indices).view(z.shape) + encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype) + avg_probs = torch.mean(encodings, dim=0) + perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) + + if self.training and self.embedding.update: + # EMA cluster size + encodings_sum = encodings.sum(0) + self.embedding.cluster_size_ema_update(encodings_sum) + # EMA embedding average + embed_sum = encodings.transpose(0, 1) @ z_flattened + self.embedding.embed_avg_ema_update(embed_sum) + # normalize embed_avg and update weight + self.embedding.weight_update(self.num_tokens) + + # compute loss for embedding + loss = self.beta * F.mse_loss(z_q.detach(), z) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + # z_q, 'b h w c -> b c h w' + z_q = rearrange(z_q, "b h w c -> b c h w") + + out_dict = { + self.loss_key: loss, + "encodings": encodings, + "encoding_indices": encoding_indices, + "perplexity": perplexity, + } + + return z_q, out_dict + + +class VectorQuantizerWithInputProjection(VectorQuantizer): + def __init__( + self, + input_dim: int, + n_codes: int, + codebook_dim: int, + beta: float = 1.0, + output_dim: Optional[int] = None, + **kwargs, + ): + super().__init__(n_codes, codebook_dim, beta, **kwargs) + self.proj_in = nn.Linear(input_dim, codebook_dim) + self.output_dim = output_dim + if output_dim is not None: + self.proj_out = nn.Linear(codebook_dim, output_dim) + else: + self.proj_out = nn.Identity() + + def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Dict]: + rearr = False + in_shape = z.shape + + if z.ndim > 3: + rearr = self.output_dim is not None + z = rearrange(z, "b c ... -> b (...) c") + z = self.proj_in(z) + z_q, loss_dict = super().forward(z) + + z_q = self.proj_out(z_q) + if rearr: + if len(in_shape) == 4: + z_q = rearrange(z_q, "b (h w) c -> b c h w ", w=in_shape[-1]) + elif len(in_shape) == 5: + z_q = rearrange( + z_q, "b (t h w) c -> b c t h w ", w=in_shape[-1], h=in_shape[-2] + ) + else: + raise NotImplementedError( + f"rearranging not available for {len(in_shape)}-dimensional input." + ) + + return z_q, loss_dict diff --git a/models/svd/sgm/modules/autoencoding/temporal_ae.py b/models/svd/sgm/modules/autoencoding/temporal_ae.py new file mode 100644 index 0000000000000000000000000000000000000000..647291b7ffc6562469ca570fd8dac779809f38f9 --- /dev/null +++ b/models/svd/sgm/modules/autoencoding/temporal_ae.py @@ -0,0 +1,347 @@ +from typing import Callable, Iterable, Union + +import torch +from einops import rearrange, repeat + +from models.svd.sgm.modules.diffusionmodules.model import (XFORMERS_IS_AVAILABLE, + AttnBlock, Decoder, + MemoryEfficientAttnBlock, + ResnetBlock) +from models.svd.sgm.modules.diffusionmodules.openaimodel import (ResBlock, + timestep_embedding) +from models.svd.sgm.modules.video_attention import VideoTransformerBlock +from models.svd.sgm.util import partialclass + + +class VideoResBlock(ResnetBlock): + def __init__( + self, + out_channels, + *args, + dropout=0.0, + video_kernel_size=3, + alpha=0.0, + merge_strategy="learned", + **kwargs, + ): + super().__init__(out_channels=out_channels, dropout=dropout, *args, **kwargs) + if video_kernel_size is None: + video_kernel_size = [3, 1, 1] + self.time_stack = ResBlock( + channels=out_channels, + emb_channels=0, + dropout=dropout, + dims=3, + use_scale_shift_norm=False, + use_conv=False, + up=False, + down=False, + kernel_size=video_kernel_size, + use_checkpoint=False, + skip_t_emb=True, + ) + + self.merge_strategy = merge_strategy + if self.merge_strategy == "fixed": + self.register_buffer("mix_factor", torch.Tensor([alpha])) + elif self.merge_strategy == "learned": + self.register_parameter( + "mix_factor", torch.nn.Parameter(torch.Tensor([alpha])) + ) + else: + raise ValueError(f"unknown merge strategy {self.merge_strategy}") + + def get_alpha(self, bs): + if self.merge_strategy == "fixed": + return self.mix_factor + elif self.merge_strategy == "learned": + return torch.sigmoid(self.mix_factor) + else: + raise NotImplementedError() + + def forward(self, x, temb, skip_video=False, timesteps=None): + if timesteps is None: + timesteps = self.timesteps + + b, c, h, w = x.shape + + x = super().forward(x, temb) + + if not skip_video: + x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) + + x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) + + x = self.time_stack(x, temb) + + alpha = self.get_alpha(bs=b // timesteps) + x = alpha * x + (1.0 - alpha) * x_mix + + x = rearrange(x, "b c t h w -> (b t) c h w") + return x + + +class AE3DConv(torch.nn.Conv2d): + def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs): + super().__init__(in_channels, out_channels, *args, **kwargs) + if isinstance(video_kernel_size, Iterable): + padding = [int(k // 2) for k in video_kernel_size] + else: + padding = int(video_kernel_size // 2) + + self.time_mix_conv = torch.nn.Conv3d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=video_kernel_size, + padding=padding, + ) + + def forward(self, input, timesteps, skip_video=False): + x = super().forward(input) + if skip_video: + return x + x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) + x = self.time_mix_conv(x) + return rearrange(x, "b c t h w -> (b t) c h w") + + +class VideoBlock(AttnBlock): + def __init__( + self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned" + ): + super().__init__(in_channels) + # no context, single headed, as in base class + self.time_mix_block = VideoTransformerBlock( + dim=in_channels, + n_heads=1, + d_head=in_channels, + checkpoint=False, + ff_in=True, + attn_mode="softmax", + ) + + time_embed_dim = self.in_channels * 4 + self.video_time_embed = torch.nn.Sequential( + torch.nn.Linear(self.in_channels, time_embed_dim), + torch.nn.SiLU(), + torch.nn.Linear(time_embed_dim, self.in_channels), + ) + + self.merge_strategy = merge_strategy + if self.merge_strategy == "fixed": + self.register_buffer("mix_factor", torch.Tensor([alpha])) + elif self.merge_strategy == "learned": + self.register_parameter( + "mix_factor", torch.nn.Parameter(torch.Tensor([alpha])) + ) + else: + raise ValueError(f"unknown merge strategy {self.merge_strategy}") + + def forward(self, x, timesteps, skip_video=False): + if skip_video: + return super().forward(x) + + x_in = x + x = self.attention(x) + h, w = x.shape[2:] + x = rearrange(x, "b c h w -> b (h w) c") + + x_mix = x + num_frames = torch.arange(timesteps, device=x.device) + num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps) + num_frames = rearrange(num_frames, "b t -> (b t)") + t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False) + emb = self.video_time_embed(t_emb) # b, n_channels + emb = emb[:, None, :] + x_mix = x_mix + emb + + alpha = self.get_alpha() + x_mix = self.time_mix_block(x_mix, timesteps=timesteps) + x = alpha * x + (1.0 - alpha) * x_mix # alpha merge + + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) + x = self.proj_out(x) + + return x_in + x + + def get_alpha( + self, + ): + if self.merge_strategy == "fixed": + return self.mix_factor + elif self.merge_strategy == "learned": + return torch.sigmoid(self.mix_factor) + else: + raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}") + + +class MemoryEfficientVideoBlock(MemoryEfficientAttnBlock): + def __init__( + self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned" + ): + super().__init__(in_channels) + # no context, single headed, as in base class + self.time_mix_block = VideoTransformerBlock( + dim=in_channels, + n_heads=1, + d_head=in_channels, + checkpoint=False, + ff_in=True, + attn_mode="softmax-xformers", + ) + + time_embed_dim = self.in_channels * 4 + self.video_time_embed = torch.nn.Sequential( + torch.nn.Linear(self.in_channels, time_embed_dim), + torch.nn.SiLU(), + torch.nn.Linear(time_embed_dim, self.in_channels), + ) + + self.merge_strategy = merge_strategy + if self.merge_strategy == "fixed": + self.register_buffer("mix_factor", torch.Tensor([alpha])) + elif self.merge_strategy == "learned": + self.register_parameter( + "mix_factor", torch.nn.Parameter(torch.Tensor([alpha])) + ) + else: + raise ValueError(f"unknown merge strategy {self.merge_strategy}") + + def forward(self, x, timesteps, skip_time_block=False): + if skip_time_block: + return super().forward(x) + + x_in = x + x = self.attention(x) + h, w = x.shape[2:] + x = rearrange(x, "b c h w -> b (h w) c") + + x_mix = x + num_frames = torch.arange(timesteps, device=x.device) + num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps) + num_frames = rearrange(num_frames, "b t -> (b t)") + t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False) + emb = self.video_time_embed(t_emb) # b, n_channels + emb = emb[:, None, :] + x_mix = x_mix + emb + + alpha = self.get_alpha() + x_mix = self.time_mix_block(x_mix, timesteps=timesteps) + x = alpha * x + (1.0 - alpha) * x_mix # alpha merge + + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) + x = self.proj_out(x) + + return x_in + x + + def get_alpha( + self, + ): + if self.merge_strategy == "fixed": + return self.mix_factor + elif self.merge_strategy == "learned": + return torch.sigmoid(self.mix_factor) + else: + raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}") + + +def make_time_attn( + in_channels, + attn_type="vanilla", + attn_kwargs=None, + alpha: float = 0, + merge_strategy: str = "learned", +): + assert attn_type in [ + "vanilla", + "vanilla-xformers", + ], f"attn_type {attn_type} not supported for spatio-temporal attention" + print( + f"making spatial and temporal attention of type '{attn_type}' with {in_channels} in_channels" + ) + if not XFORMERS_IS_AVAILABLE and attn_type == "vanilla-xformers": + print( + f"Attention mode '{attn_type}' is not available. Falling back to vanilla attention. " + f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}" + ) + attn_type = "vanilla" + + if attn_type == "vanilla": + assert attn_kwargs is None + return partialclass( + VideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy + ) + elif attn_type == "vanilla-xformers": + print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...") + return partialclass( + MemoryEfficientVideoBlock, + in_channels, + alpha=alpha, + merge_strategy=merge_strategy, + ) + else: + return NotImplementedError() + + +class Conv2DWrapper(torch.nn.Conv2d): + def forward(self, input: torch.Tensor, **kwargs) -> torch.Tensor: + return super().forward(input) + + +class VideoDecoder(Decoder): + available_time_modes = ["all", "conv-only", "attn-only"] + + def __init__( + self, + *args, + video_kernel_size: Union[int, list] = 3, + alpha: float = 0.0, + merge_strategy: str = "learned", + time_mode: str = "conv-only", + **kwargs, + ): + self.video_kernel_size = video_kernel_size + self.alpha = alpha + self.merge_strategy = merge_strategy + self.time_mode = time_mode + assert ( + self.time_mode in self.available_time_modes + ), f"time_mode parameter has to be in {self.available_time_modes}" + super().__init__(*args, **kwargs) + + def get_last_layer(self, skip_time_mix=False, **kwargs): + if self.time_mode == "attn-only": + raise NotImplementedError("TODO") + else: + return ( + self.conv_out.time_mix_conv.weight + if not skip_time_mix + else self.conv_out.weight + ) + + def _make_attn(self) -> Callable: + if self.time_mode not in ["conv-only", "only-last-conv"]: + return partialclass( + make_time_attn, + alpha=self.alpha, + merge_strategy=self.merge_strategy, + ) + else: + return super()._make_attn() + + def _make_conv(self) -> Callable: + if self.time_mode != "attn-only": + return partialclass(AE3DConv, video_kernel_size=self.video_kernel_size) + else: + return Conv2DWrapper + + def _make_resblock(self) -> Callable: + if self.time_mode not in ["attn-only", "only-last-conv"]: + return partialclass( + VideoResBlock, + video_kernel_size=self.video_kernel_size, + alpha=self.alpha, + merge_strategy=self.merge_strategy, + ) + else: + return super()._make_resblock() diff --git a/models/svd/sgm/modules/diffusionmodules/__init__.py b/models/svd/sgm/modules/diffusionmodules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/svd/sgm/modules/diffusionmodules/denoiser.py b/models/svd/sgm/modules/diffusionmodules/denoiser.py new file mode 100644 index 0000000000000000000000000000000000000000..36ed66443401d590cb87810b863f28444a8cb9d4 --- /dev/null +++ b/models/svd/sgm/modules/diffusionmodules/denoiser.py @@ -0,0 +1,75 @@ +from typing import Dict, Union + +import torch +import torch.nn as nn + +from models.svd.sgm.util import append_dims, instantiate_from_config +from models.svd.sgm.modules.diffusionmodules.denoiser_scaling import DenoiserScaling +from models.svd.sgm.modules.diffusionmodules.discretizer import Discretization + + +class Denoiser(nn.Module): + def __init__(self, scaling_config: Dict): + super().__init__() + + self.scaling: DenoiserScaling = instantiate_from_config(scaling_config) + + def possibly_quantize_sigma(self, sigma: torch.Tensor) -> torch.Tensor: + return sigma + + def possibly_quantize_c_noise(self, c_noise: torch.Tensor) -> torch.Tensor: + return c_noise + + def forward( + self, + network: nn.Module, + input: torch.Tensor, + sigma: torch.Tensor, + cond: Dict, + **additional_model_inputs, + ) -> torch.Tensor: + sigma = self.possibly_quantize_sigma(sigma) + sigma_shape = sigma.shape + sigma = append_dims(sigma, input.ndim) + c_skip, c_out, c_in, c_noise = self.scaling(sigma) + c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape)) + return ( + network(input * c_in, c_noise, cond, **additional_model_inputs) * c_out + + input * c_skip + ) + + +class DiscreteDenoiser(Denoiser): + def __init__( + self, + scaling_config: Dict, + num_idx: int, + discretization_config: Dict, + do_append_zero: bool = False, + quantize_c_noise: bool = True, + flip: bool = True, + ): + super().__init__(scaling_config) + self.discretization: Discretization = instantiate_from_config( + discretization_config + ) + sigmas = self.discretization(num_idx, do_append_zero=do_append_zero, flip=flip) + self.register_buffer("sigmas", sigmas) + self.quantize_c_noise = quantize_c_noise + self.num_idx = num_idx + + def sigma_to_idx(self, sigma: torch.Tensor) -> torch.Tensor: + dists = sigma - self.sigmas[:, None] + return dists.abs().argmin(dim=0).view(sigma.shape) + + def idx_to_sigma(self, idx: Union[torch.Tensor, int]) -> torch.Tensor: + return self.sigmas[idx] + + def possibly_quantize_sigma(self, sigma: torch.Tensor) -> torch.Tensor: + return self.idx_to_sigma(self.sigma_to_idx(sigma)) + + def possibly_quantize_c_noise(self, c_noise: torch.Tensor) -> torch.Tensor: + if self.quantize_c_noise: + return self.sigma_to_idx(c_noise) + else: + return c_noise diff --git a/models/svd/sgm/modules/diffusionmodules/denoiser_scaling.py b/models/svd/sgm/modules/diffusionmodules/denoiser_scaling.py new file mode 100644 index 0000000000000000000000000000000000000000..f4e287bfe8a82839a9a12fbd25c3446f43ab493b --- /dev/null +++ b/models/svd/sgm/modules/diffusionmodules/denoiser_scaling.py @@ -0,0 +1,59 @@ +from abc import ABC, abstractmethod +from typing import Tuple + +import torch + + +class DenoiserScaling(ABC): + @abstractmethod + def __call__( + self, sigma: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + pass + + +class EDMScaling: + def __init__(self, sigma_data: float = 0.5): + self.sigma_data = sigma_data + + def __call__( + self, sigma: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) + c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5 + c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5 + c_noise = 0.25 * sigma.log() + return c_skip, c_out, c_in, c_noise + + +class EpsScaling: + def __call__( + self, sigma: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + c_skip = torch.ones_like(sigma, device=sigma.device) + c_out = -sigma + c_in = 1 / (sigma**2 + 1.0) ** 0.5 + c_noise = sigma.clone() + return c_skip, c_out, c_in, c_noise + + +class VScaling: + def __call__( + self, sigma: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + c_skip = 1.0 / (sigma**2 + 1.0) + c_out = -sigma / (sigma**2 + 1.0) ** 0.5 + c_in = 1.0 / (sigma**2 + 1.0) ** 0.5 + c_noise = sigma.clone() + return c_skip, c_out, c_in, c_noise + + +class VScalingWithEDMcNoise(DenoiserScaling): + def __call__( + self, sigma: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + c_skip = 1.0 / (sigma**2 + 1.0) + c_out = -sigma / (sigma**2 + 1.0) ** 0.5 + c_in = 1.0 / (sigma**2 + 1.0) ** 0.5 + c_noise = 0.25 * sigma.log() + return c_skip, c_out, c_in, c_noise diff --git a/models/svd/sgm/modules/diffusionmodules/denoiser_weighting.py b/models/svd/sgm/modules/diffusionmodules/denoiser_weighting.py new file mode 100644 index 0000000000000000000000000000000000000000..b8b03ca58f17ea3d7374f4bbb7bf1d2994755e00 --- /dev/null +++ b/models/svd/sgm/modules/diffusionmodules/denoiser_weighting.py @@ -0,0 +1,24 @@ +import torch + + +class UnitWeighting: + def __call__(self, sigma): + return torch.ones_like(sigma, device=sigma.device) + + +class EDMWeighting: + def __init__(self, sigma_data=0.5): + self.sigma_data = sigma_data + + def __call__(self, sigma): + return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 + + +class VWeighting(EDMWeighting): + def __init__(self): + super().__init__(sigma_data=1.0) + + +class EpsWeighting: + def __call__(self, sigma): + return sigma**-2.0 diff --git a/models/svd/sgm/modules/diffusionmodules/discretizer.py b/models/svd/sgm/modules/diffusionmodules/discretizer.py new file mode 100644 index 0000000000000000000000000000000000000000..d4d14421744af47250b5080a3dd25194d8477ccf --- /dev/null +++ b/models/svd/sgm/modules/diffusionmodules/discretizer.py @@ -0,0 +1,69 @@ +from abc import abstractmethod +from functools import partial + +import numpy as np +import torch + +from models.svd.sgm.modules.diffusionmodules.util import make_beta_schedule +from models.svd.sgm.util import append_zero + + +def generate_roughly_equally_spaced_steps( + num_substeps: int, max_step: int +) -> np.ndarray: + return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1] + + +class Discretization: + def __call__(self, n, do_append_zero=True, device="cpu", flip=False): + sigmas = self.get_sigmas(n, device=device) + sigmas = append_zero(sigmas) if do_append_zero else sigmas + return sigmas if not flip else torch.flip(sigmas, (0,)) + + @abstractmethod + def get_sigmas(self, n, device): + pass + + +class EDMDiscretization(Discretization): + def __init__(self, sigma_min=0.002, sigma_max=80.0, rho=7.0): + self.sigma_min = sigma_min + self.sigma_max = sigma_max + self.rho = rho + + def get_sigmas(self, n, device="cpu"): + ramp = torch.linspace(0, 1, n, device=device) + min_inv_rho = self.sigma_min ** (1 / self.rho) + max_inv_rho = self.sigma_max ** (1 / self.rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho + return sigmas + + +class LegacyDDPMDiscretization(Discretization): + def __init__( + self, + linear_start=0.00085, + linear_end=0.0120, + num_timesteps=1000, + ): + super().__init__() + self.num_timesteps = num_timesteps + betas = make_beta_schedule( + "linear", num_timesteps, linear_start=linear_start, linear_end=linear_end + ) + alphas = 1.0 - betas + self.alphas_cumprod = np.cumprod(alphas, axis=0) + self.to_torch = partial(torch.tensor, dtype=torch.float32) + + def get_sigmas(self, n, device="cpu"): + if n < self.num_timesteps: + timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps) + alphas_cumprod = self.alphas_cumprod[timesteps] + elif n == self.num_timesteps: + alphas_cumprod = self.alphas_cumprod + else: + raise ValueError + + to_torch = partial(torch.tensor, dtype=torch.float32, device=device) + sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 + return torch.flip(sigmas, (0,)) diff --git a/models/svd/sgm/modules/diffusionmodules/guiders.py b/models/svd/sgm/modules/diffusionmodules/guiders.py new file mode 100644 index 0000000000000000000000000000000000000000..b1507e0eab3aac0b6b510f9c19844fd8c5a1a7f4 --- /dev/null +++ b/models/svd/sgm/modules/diffusionmodules/guiders.py @@ -0,0 +1,131 @@ +import logging +from abc import ABC, abstractmethod +from typing import Dict, List, Literal, Optional, Tuple, Union + +import torch +from einops import rearrange, repeat + +from models.svd.sgm.util import append_dims, default + +logpy = logging.getLogger(__name__) + + +class Guider(ABC): + @abstractmethod + def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor: + pass + + def prepare_inputs( + self, x: torch.Tensor, s: float, c: Dict, uc: Dict + ) -> Tuple[torch.Tensor, float, Dict]: + pass + + +class VanillaCFG(Guider): + def __init__(self, scale: float): + self.scale = scale + + def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: + x_u, x_c = x.chunk(2) + x_pred = x_u + self.scale * (x_c - x_u) + return x_pred + + def prepare_inputs(self, x, s, c, uc): + c_out = dict() + + for k in c: + if k in ["vector", "crossattn", "concat"]: + c_out[k] = torch.cat((uc[k], c[k]), 0) + else: + assert c[k] == uc[k] + c_out[k] = c[k] + return torch.cat([x] * 2), torch.cat([s] * 2), c_out + + +class IdentityGuider(Guider): + def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor: + return x + + def prepare_inputs( + self, x: torch.Tensor, s: float, c: Dict, uc: Dict + ) -> Tuple[torch.Tensor, float, Dict]: + c_out = dict() + + for k in c: + c_out[k] = c[k] + + return x, s, c_out + + +class LinearPredictionGuider(Guider): + def __init__( + self, + max_scale: float, + num_frames: int, + min_scale: float = 1.0, + additional_cond_keys: Optional[Union[List[str], str]] = None, + ): + self.min_scale = min_scale + self.max_scale = max_scale + self.num_frames = num_frames + self.scale = torch.linspace(min_scale, max_scale, num_frames).unsqueeze(0) + + additional_cond_keys = default(additional_cond_keys, []) + if isinstance(additional_cond_keys, str): + additional_cond_keys = [additional_cond_keys] + self.additional_cond_keys = additional_cond_keys + + def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: + x_u, x_c = x.chunk(2) + + x_u = rearrange(x_u, "(b t) ... -> b t ...", t=self.num_frames) + x_c = rearrange(x_c, "(b t) ... -> b t ...", t=self.num_frames) + scale = repeat(self.scale, "1 t -> b t", b=x_u.shape[0]) + scale = append_dims(scale, x_u.ndim).to(x_u.device) + scale = scale.to(x_u.dtype) + return rearrange(x_u + scale * (x_c - x_u), "b t ... -> (b t) ...") + + def prepare_inputs( + self, x: torch.Tensor, s: torch.Tensor, c: dict, uc: dict + ) -> Tuple[torch.Tensor, torch.Tensor, dict]: + c_out = dict() + + for k in c: + if k in ["vector", "crossattn", "concat"] + self.additional_cond_keys: + c_out[k] = torch.cat((uc[k], c[k]), 0) + else: + assert c[k] == uc[k] + c_out[k] = c[k] + return torch.cat([x] * 2), torch.cat([s] * 2), c_out + + +class TrianglePredictionGuider(LinearPredictionGuider): + def __init__( + self, + max_scale: float, + num_frames: int, + min_scale: float = 1.0, + period: Union[float,List[float]] = 1.0, + period_fusing: Literal["mean", "multiply", "max"] = "max", + additional_cond_keys: Optional[Union[List[str], str]] = None, + ): + super().__init__(max_scale, num_frames, min_scale, additional_cond_keys) + values = torch.linspace(0, 1, num_frames) + # Constructs a triangle wave + if isinstance(period, float): + period = [period] + + scales = [] + for p in period: + scales.append(self.triangle_wave(values, p)) + + if period_fusing == "mean": + scale = sum(scales) / len(period) + elif period_fusing == "multiply": + scale = torch.prod(torch.stack(scales), dim=0) + elif period_fusing == "max": + scale = torch.max(torch.stack(scales), dim=0).values + self.scale = (scale * (max_scale - min_scale) + min_scale).unsqueeze(0) + + def triangle_wave(self, values: torch.Tensor, period) -> torch.Tensor: + return 2 * (values / period - torch.floor(values / period + 0.5)).abs() diff --git a/models/svd/sgm/modules/diffusionmodules/loss.py b/models/svd/sgm/modules/diffusionmodules/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..8f585f62f420fe4bf5978005b8181f23d442f013 --- /dev/null +++ b/models/svd/sgm/modules/diffusionmodules/loss.py @@ -0,0 +1,105 @@ +from typing import Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from models.svd.sgm.modules.autoencoding.lpips.loss.lpips import LPIPS +from models.svd.sgm.modules.encoders.modules import GeneralConditioner +from models.svd.sgm.util import append_dims, instantiate_from_config +from models.svd.sgm.modules.diffusionmodules.denoiser import Denoiser + + +class StandardDiffusionLoss(nn.Module): + def __init__( + self, + sigma_sampler_config: dict, + loss_weighting_config: dict, + loss_type: str = "l2", + offset_noise_level: float = 0.0, + batch2model_keys: Optional[Union[str, List[str]]] = None, + ): + super().__init__() + + assert loss_type in ["l2", "l1", "lpips"] + + self.sigma_sampler = instantiate_from_config(sigma_sampler_config) + self.loss_weighting = instantiate_from_config(loss_weighting_config) + + self.loss_type = loss_type + self.offset_noise_level = offset_noise_level + + if loss_type == "lpips": + self.lpips = LPIPS().eval() + + if not batch2model_keys: + batch2model_keys = [] + + if isinstance(batch2model_keys, str): + batch2model_keys = [batch2model_keys] + + self.batch2model_keys = set(batch2model_keys) + + def get_noised_input( + self, sigmas_bc: torch.Tensor, noise: torch.Tensor, input: torch.Tensor + ) -> torch.Tensor: + noised_input = input + noise * sigmas_bc + return noised_input + + def forward( + self, + network: nn.Module, + denoiser: Denoiser, + conditioner: GeneralConditioner, + input: torch.Tensor, + batch: Dict, + ) -> torch.Tensor: + cond = conditioner(batch) + return self._forward(network, denoiser, cond, input, batch) + + def _forward( + self, + network: nn.Module, + denoiser: Denoiser, + cond: Dict, + input: torch.Tensor, + batch: Dict, + ) -> Tuple[torch.Tensor, Dict]: + additional_model_inputs = { + key: batch[key] for key in self.batch2model_keys.intersection(batch) + } + sigmas = self.sigma_sampler(input.shape[0]).to(input) + + noise = torch.randn_like(input) + if self.offset_noise_level > 0.0: + offset_shape = ( + (input.shape[0], 1, input.shape[2]) + if self.n_frames is not None + else (input.shape[0], input.shape[1]) + ) + noise = noise + self.offset_noise_level * append_dims( + torch.randn(offset_shape, device=input.device), + input.ndim, + ) + sigmas_bc = append_dims(sigmas, input.ndim) + noised_input = self.get_noised_input(sigmas_bc, noise, input) + + model_output = denoiser( + network, noised_input, sigmas, cond, **additional_model_inputs + ) + w = append_dims(self.loss_weighting(sigmas), input.ndim) + return self.get_loss(model_output, input, w) + + def get_loss(self, model_output, target, w): + if self.loss_type == "l2": + return torch.mean( + (w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1 + ) + elif self.loss_type == "l1": + return torch.mean( + (w * (model_output - target).abs()).reshape(target.shape[0], -1), 1 + ) + elif self.loss_type == "lpips": + loss = self.lpips(model_output, target).reshape(-1) + return loss + else: + raise NotImplementedError(f"Unknown loss type {self.loss_type}") diff --git a/models/svd/sgm/modules/diffusionmodules/loss_weighting.py b/models/svd/sgm/modules/diffusionmodules/loss_weighting.py new file mode 100644 index 0000000000000000000000000000000000000000..e12c0a76635435babd1af33969e82fa284525af8 --- /dev/null +++ b/models/svd/sgm/modules/diffusionmodules/loss_weighting.py @@ -0,0 +1,32 @@ +from abc import ABC, abstractmethod + +import torch + + +class DiffusionLossWeighting(ABC): + @abstractmethod + def __call__(self, sigma: torch.Tensor) -> torch.Tensor: + pass + + +class UnitWeighting(DiffusionLossWeighting): + def __call__(self, sigma: torch.Tensor) -> torch.Tensor: + return torch.ones_like(sigma, device=sigma.device) + + +class EDMWeighting(DiffusionLossWeighting): + def __init__(self, sigma_data: float = 0.5): + self.sigma_data = sigma_data + + def __call__(self, sigma: torch.Tensor) -> torch.Tensor: + return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 + + +class VWeighting(EDMWeighting): + def __init__(self): + super().__init__(sigma_data=1.0) + + +class EpsWeighting(DiffusionLossWeighting): + def __call__(self, sigma: torch.Tensor) -> torch.Tensor: + return sigma**-2.0 diff --git a/models/svd/sgm/modules/diffusionmodules/model.py b/models/svd/sgm/modules/diffusionmodules/model.py new file mode 100644 index 0000000000000000000000000000000000000000..4cf9d92140dee8443a0ea6b5cf218f2879ad88f4 --- /dev/null +++ b/models/svd/sgm/modules/diffusionmodules/model.py @@ -0,0 +1,748 @@ +# pytorch_diffusion + derived encoder decoder +import logging +import math +from typing import Any, Callable, Optional + +import numpy as np +import torch +import torch.nn as nn +from einops import rearrange +from packaging import version + +logpy = logging.getLogger(__name__) + +try: + import xformers + import xformers.ops + + XFORMERS_IS_AVAILABLE = True +except: + XFORMERS_IS_AVAILABLE = False + logpy.warning("no module 'xformers'. Processing without...") + +from ...modules.attention import LinearAttention, MemoryEfficientCrossAttention + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm( + num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True + ) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, padding=0 + ) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout, + temb_channels=512, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + else: + self.nin_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +class LinAttnBlock(LinearAttention): + """to match AttnBlock usage""" + + def __init__(self, in_channels): + super().__init__(dim=in_channels, heads=1, dim_head=in_channels) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def attention(self, h_: torch.Tensor) -> torch.Tensor: + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + b, c, h, w = q.shape + q, k, v = map( + lambda x: rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v) + ) + h_ = torch.nn.functional.scaled_dot_product_attention( + q, k, v + ) # scale is dim ** -0.5 per default + # compute attention + + return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) + + def forward(self, x, **kwargs): + h_ = x + h_ = self.attention(h_) + h_ = self.proj_out(h_) + return x + h_ + + +class MemoryEfficientAttnBlock(nn.Module): + """ + Uses xformers efficient implementation, + see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 + Note: this is a single-head self-attention operation + """ + + # + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.attention_op: Optional[Any] = None + + def attention(self, h_: torch.Tensor) -> torch.Tensor: + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + B, C, H, W = q.shape + q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v)) + + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(B, t.shape[1], 1, C) + .permute(0, 2, 1, 3) + .reshape(B * 1, t.shape[1], C) + .contiguous(), + (q, k, v), + ) + out = xformers.ops.memory_efficient_attention( + q, k, v, attn_bias=None, op=self.attention_op + ) + + out = ( + out.unsqueeze(0) + .reshape(B, 1, out.shape[1], C) + .permute(0, 2, 1, 3) + .reshape(B, out.shape[1], C) + ) + return rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C) + + def forward(self, x, **kwargs): + h_ = x + h_ = self.attention(h_) + h_ = self.proj_out(h_) + return x + h_ + + +class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention): + def forward(self, x, context=None, mask=None, **unused_kwargs): + b, c, h, w = x.shape + x = rearrange(x, "b c h w -> b (h w) c") + out = super().forward(x, context=context, mask=mask) + out = rearrange(out, "b (h w) c -> b c h w", h=h, w=w, c=c) + return x + out + + +def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): + assert attn_type in [ + "vanilla", + "vanilla-xformers", + "memory-efficient-cross-attn", + "linear", + "none", + ], f"attn_type {attn_type} unknown" + if ( + version.parse(torch.__version__) < version.parse("2.0.0") + and attn_type != "none" + ): + assert XFORMERS_IS_AVAILABLE, ( + f"We do not support vanilla attention in {torch.__version__} anymore, " + f"as it is too expensive. Please install xformers via e.g. 'pip install xformers==0.0.16'" + ) + attn_type = "vanilla-xformers" + logpy.info(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + assert attn_kwargs is None + return AttnBlock(in_channels) + elif attn_type == "vanilla-xformers": + logpy.info( + f"building MemoryEfficientAttnBlock with {in_channels} in_channels..." + ) + return MemoryEfficientAttnBlock(in_channels) + elif type == "memory-efficient-cross-attn": + attn_kwargs["query_dim"] = in_channels + return MemoryEfficientCrossAttentionWrapper(**attn_kwargs) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + return LinAttnBlock(in_channels) + + +class Model(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + use_timestep=True, + use_linear_attn=False, + attn_type="vanilla", + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = self.ch * 4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList( + [ + torch.nn.Linear(self.ch, self.temb_ch), + torch.nn.Linear(self.temb_ch, self.temb_ch), + ] + ) + + # downsampling + self.conv_in = torch.nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1 + ) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + skip_in = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + if i_block == self.num_res_blocks: + skip_in = ch * in_ch_mult[i_level] + block.append( + ResnetBlock( + in_channels=block_in + skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_ch, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x, t=None, context=None): + # assert x.shape[2] == x.shape[3] == self.resolution + if context is not None: + # assume aligned context, cat along channel axis + x = torch.cat((x, context), dim=1) + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb + ) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + def get_last_layer(self): + return self.conv_out.weight + + +class Encoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + use_linear_attn=False, + attn_type="vanilla", + **ignore_kwargs, + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1 + ) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, + 2 * z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + def forward(self, x): + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + tanh_out=False, + use_linear_attn=False, + attn_type="vanilla", + **ignorekwargs, + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,) + tuple(ch_mult) + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + logpy.info( + "Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape) + ) + ) + + make_attn_cls = self._make_attn() + make_resblock_cls = self._make_resblock() + make_conv_cls = self._make_conv() + # z to block_in + self.conv_in = torch.nn.Conv2d( + z_channels, block_in, kernel_size=3, stride=1, padding=1 + ) + + # middle + self.mid = nn.Module() + self.mid.block_1 = make_resblock_cls( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn_cls(block_in, attn_type=attn_type) + self.mid.block_2 = make_resblock_cls( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + make_resblock_cls( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn_cls(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = make_conv_cls( + block_in, out_ch, kernel_size=3, stride=1, padding=1 + ) + + def _make_attn(self) -> Callable: + return make_attn + + def _make_resblock(self) -> Callable: + return ResnetBlock + + def _make_conv(self) -> Callable: + return torch.nn.Conv2d + + def get_last_layer(self, **kwargs): + return self.conv_out.weight + + def forward(self, z, **kwargs): + # assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb, **kwargs) + h = self.mid.attn_1(h, **kwargs) + h = self.mid.block_2(h, temb, **kwargs) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h, temb, **kwargs) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h, **kwargs) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h, **kwargs) + if self.tanh_out: + h = torch.tanh(h) + return h diff --git a/models/svd/sgm/modules/diffusionmodules/openaimodel.py b/models/svd/sgm/modules/diffusionmodules/openaimodel.py new file mode 100644 index 0000000000000000000000000000000000000000..514d0832d872c41bd85523639ebcfbc708a2c3c8 --- /dev/null +++ b/models/svd/sgm/modules/diffusionmodules/openaimodel.py @@ -0,0 +1,853 @@ +import logging +import math +from abc import abstractmethod +from typing import Iterable, List, Optional, Tuple, Union + +import torch as th +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.utils.checkpoint import checkpoint + +from models.svd.sgm.modules.attention import SpatialTransformer +from models.svd.sgm.modules.diffusionmodules.util import (avg_pool_nd, conv_nd, linear, + normalization, + timestep_embedding, zero_module) +from models.svd.sgm.modules.video_attention import SpatialVideoTransformer +from models.svd.sgm.util import exists + +logpy = logging.getLogger(__name__) + + +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: Optional[int] = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter( + th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5 + ) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x: th.Tensor) -> th.Tensor: + b, c, _ = x.shape + x = x.reshape(b, c, -1) + x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x: th.Tensor, emb: th.Tensor): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward( + self, + x: th.Tensor, + emb: th.Tensor, + context: Optional[th.Tensor] = None, + image_only_indicator: Optional[th.Tensor] = None, + time_context: Optional[int] = None, + num_video_frames: Optional[int] = None, + ): + from models.diffusion.video_model import VideoResBlock + + for layer in self: + module = layer + + if isinstance(module, TimestepBlock) and not isinstance( + module, VideoResBlock + ): + x = layer(x, emb) + elif isinstance(module, VideoResBlock): + x = layer(x, emb, num_video_frames, image_only_indicator) + elif isinstance(module, SpatialVideoTransformer): + x = layer( + x, + context, + time_context, + num_video_frames, + image_only_indicator, + ) + elif isinstance(module, SpatialTransformer): + x = layer(x, context) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__( + self, + channels: int, + use_conv: bool, + dims: int = 2, + out_channels: Optional[int] = None, + padding: int = 1, + third_up: bool = False, + kernel_size: int = 3, + scale_factor: int = 2, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + self.third_up = third_up + self.scale_factor = scale_factor + if use_conv: + self.conv = conv_nd( + dims, self.channels, self.out_channels, kernel_size, padding=padding + ) + + def forward(self, x: th.Tensor) -> th.Tensor: + assert x.shape[1] == self.channels + + if self.dims == 3: + t_factor = 1 if not self.third_up else self.scale_factor + x = F.interpolate( + x, + ( + t_factor * x.shape[2], + x.shape[3] * self.scale_factor, + x.shape[4] * self.scale_factor, + ), + mode="nearest", + ) + else: + x = F.interpolate(x, scale_factor=self.scale_factor, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__( + self, + channels: int, + use_conv: bool, + dims: int = 2, + out_channels: Optional[int] = None, + padding: int = 1, + third_down: bool = False, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else ((1, 2, 2) if not third_down else (2, 2, 2)) + if use_conv: + logpy.info(f"Building a Downsample layer with {dims} dims.") + logpy.info( + f" --> settings are: \n in-chn: {self.channels}, out-chn: {self.out_channels}, " + f"kernel-size: 3, stride: {stride}, padding: {padding}" + ) + if dims == 3: + logpy.info(f" --> Downsampling third axis (time): {third_down}") + self.op = conv_nd( + dims, + self.channels, + self.out_channels, + 3, + stride=stride, + padding=padding, + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x: th.Tensor) -> th.Tensor: + assert x.shape[1] == self.channels + + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels: int, + emb_channels: int, + dropout: float, + out_channels: Optional[int] = None, + use_conv: bool = False, + use_scale_shift_norm: bool = False, + dims: int = 2, + use_checkpoint: bool = False, + up: bool = False, + down: bool = False, + kernel_size: int = 3, + exchange_temb_dims: bool = False, + skip_t_emb: bool = False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + self.exchange_temb_dims = exchange_temb_dims + + if isinstance(kernel_size, Iterable): + padding = [k // 2 for k in kernel_size] + else: + padding = kernel_size // 2 + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.skip_t_emb = skip_t_emb + self.emb_out_channels = ( + 2 * self.out_channels if use_scale_shift_norm else self.out_channels + ) + if self.skip_t_emb: + logpy.info(f"Skipping timestep embedding in {self.__class__.__name__}") + assert not self.use_scale_shift_norm + self.emb_layers = None + self.exchange_temb_dims = False + else: + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + self.emb_out_channels, + ), + ) + + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd( + dims, + self.out_channels, + self.out_channels, + kernel_size, + padding=padding, + ) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, kernel_size, padding=padding + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x: th.Tensor, emb: th.Tensor) -> th.Tensor: + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + if self.use_checkpoint: + return checkpoint(self._forward, x, emb) + else: + return self._forward(x, emb) + + def _forward(self, x: th.Tensor, emb: th.Tensor) -> th.Tensor: + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + + if self.skip_t_emb: + emb_out = th.zeros_like(h) + else: + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + if self.exchange_temb_dims: + emb_out = rearrange(emb_out, "b t c ... -> b c t ...") + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels: int, + num_heads: int = 1, + num_head_channels: int = -1, + use_checkpoint: bool = False, + use_new_attention_order: bool = False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x: th.Tensor, **kwargs) -> th.Tensor: + return checkpoint(self._forward, x) + + def _forward(self, x: th.Tensor) -> th.Tensor: + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads: int): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv: th.Tensor) -> th.Tensor: + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads: int): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv: th.Tensor) -> th.Tensor: + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length) + + +class Timestep(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.dim = dim + + def forward(self, t: th.Tensor) -> th.Tensor: + return timestep_embedding(t, self.dim) + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + in_channels: int, + model_channels: int, + out_channels: int, + num_res_blocks: int, + attention_resolutions: int, + dropout: float = 0.0, + channel_mult: Union[List, Tuple] = (1, 2, 4, 8), + conv_resample: bool = True, + dims: int = 2, + num_classes: Optional[Union[int, str]] = None, + use_checkpoint: bool = False, + num_heads: int = -1, + num_head_channels: int = -1, + num_heads_upsample: int = -1, + use_scale_shift_norm: bool = False, + resblock_updown: bool = False, + transformer_depth: int = 1, + context_dim: Optional[int] = None, + disable_self_attentions: Optional[List[bool]] = None, + num_attention_blocks: Optional[List[int]] = None, + disable_middle_self_attn: bool = False, + disable_middle_transformer: bool = False, + use_linear_in_transformer: bool = False, + spatial_transformer_attn_type: str = "softmax", + adm_in_channels: Optional[int] = None, + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert ( + num_head_channels != -1 + ), "Either num_heads or num_head_channels has to be set" + + if num_head_channels == -1: + assert ( + num_heads != -1 + ), "Either num_heads or num_head_channels has to be set" + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + if isinstance(transformer_depth, int): + transformer_depth = len(channel_mult) * [transformer_depth] + transformer_depth_middle = transformer_depth[-1] + + if isinstance(num_res_blocks, int): + self.num_res_blocks = len(channel_mult) * [num_res_blocks] + else: + if len(num_res_blocks) != len(channel_mult): + raise ValueError( + "provide num_res_blocks either as an int (globally constant) or " + "as a list/tuple (per-level) with the same length as channel_mult" + ) + self.num_res_blocks = num_res_blocks + + if disable_self_attentions is not None: + assert len(disable_self_attentions) == len(channel_mult) + if num_attention_blocks is not None: + assert len(num_attention_blocks) == len(self.num_res_blocks) + assert all( + map( + lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], + range(len(num_attention_blocks)), + ) + ) + logpy.info( + f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " + f"This option has LESS priority than attention_resolutions {attention_resolutions}, " + f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " + f"attention will still not be set." + ) + + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + if isinstance(self.num_classes, int): + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + elif self.num_classes == "continuous": + logpy.info("setting up linear c_adm embedding layer") + self.label_emb = nn.Linear(1, time_embed_dim) + elif self.num_classes == "timestep": + self.label_emb = nn.Sequential( + Timestep(model_channels), + nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ), + ) + elif self.num_classes == "sequential": + assert adm_in_channels is not None + self.label_emb = nn.Sequential( + nn.Sequential( + linear(adm_in_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + ) + else: + raise ValueError + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for nr in range(self.num_res_blocks[level]): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + + if context_dim is not None and exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if ( + not exists(num_attention_blocks) + or nr < num_attention_blocks[level] + ): + layers.append( + SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth[level], + context_dim=context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_transformer, + attn_type=spatial_transformer_attn_type, + use_checkpoint=use_checkpoint, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth_middle, + context_dim=context_dim, + disable_self_attn=disable_middle_self_attn, + use_linear=use_linear_in_transformer, + attn_type=spatial_transformer_attn_type, + use_checkpoint=use_checkpoint, + ) + if not disable_middle_transformer + else th.nn.Identity(), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(self.num_res_blocks[level] + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if ( + not exists(num_attention_blocks) + or i < num_attention_blocks[level] + ): + layers.append( + SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth[level], + context_dim=context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_transformer, + attn_type=spatial_transformer_attn_type, + use_checkpoint=use_checkpoint, + ) + ) + if level and i == self.num_res_blocks[level]: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + + def forward( + self, + x: th.Tensor, + timesteps: Optional[th.Tensor] = None, + context: Optional[th.Tensor] = None, + y: Optional[th.Tensor] = None, + **kwargs, + ) -> th.Tensor: + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y.shape[0] == x.shape[0] + emb = emb + self.label_emb(y) + + h = x + for module in self.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + h = h.type(x.dtype) + + return self.out(h) diff --git a/models/svd/sgm/modules/diffusionmodules/sampling.py b/models/svd/sgm/modules/diffusionmodules/sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..d4d4ff03e389edccc25285773cdbe61ae3cd4525 --- /dev/null +++ b/models/svd/sgm/modules/diffusionmodules/sampling.py @@ -0,0 +1,362 @@ +""" + Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py +""" + + +from typing import Dict, Union + +import torch +from omegaconf import ListConfig, OmegaConf +from tqdm import tqdm + +from models.svd.sgm.modules.diffusionmodules.sampling_utils import (get_ancestral_step, + linear_multistep_coeff, + to_d, to_neg_log_sigma, + to_sigma) +from models.svd.sgm.util import append_dims, default, instantiate_from_config + +DEFAULT_GUIDER = {"target": "models.svd.sgm.modules.diffusionmodules.guiders.IdentityGuider"} + + +class BaseDiffusionSampler: + def __init__( + self, + discretization_config: Union[Dict, ListConfig, OmegaConf], + num_steps: Union[int, None] = None, + guider_config: Union[Dict, ListConfig, OmegaConf, None] = None, + verbose: bool = False, + device: str = "cuda", + ): + self.num_steps = num_steps + self.discretization = instantiate_from_config(discretization_config) + self.guider = instantiate_from_config( + default( + guider_config, + DEFAULT_GUIDER, + ) + ) + self.verbose = verbose + self.device = device + + def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None): + sigmas = self.discretization( + self.num_steps if num_steps is None else num_steps, device=self.device + ) + uc = default(uc, cond) + + x *= torch.sqrt(1.0 + sigmas[0] ** 2.0) + num_sigmas = len(sigmas) + + s_in = x.new_ones([x.shape[0]]) + + return x, s_in, sigmas, num_sigmas, cond, uc + + def denoise(self, x, denoiser, sigma, cond, uc): + denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc)) + denoised = self.guider(denoised, sigma) + return denoised + + def get_sigma_gen(self, num_sigmas): + sigma_generator = range(num_sigmas - 1) + if self.verbose: + print("#" * 30, " Sampling setting ", "#" * 30) + print(f"Sampler: {self.__class__.__name__}") + print(f"Discretization: {self.discretization.__class__.__name__}") + print(f"Guider: {self.guider.__class__.__name__}") + sigma_generator = tqdm( + sigma_generator, + total=num_sigmas, + desc=f"Sampling with {self.__class__.__name__} for {num_sigmas} steps", + ) + return sigma_generator + + +class SingleStepDiffusionSampler(BaseDiffusionSampler): + def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc, *args, **kwargs): + raise NotImplementedError + + def euler_step(self, x, d, dt): + return x + dt * d + + +class EDMSampler(SingleStepDiffusionSampler): + def __init__( + self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs + ): + super().__init__(*args, **kwargs) + + self.s_churn = s_churn + self.s_tmin = s_tmin + self.s_tmax = s_tmax + self.s_noise = s_noise + + def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, gamma=0.0): + sigma_hat = sigma * (gamma + 1.0) + if gamma > 0: + eps = torch.randn_like(x) * self.s_noise + x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5 + + denoised = self.denoise(x, denoiser, sigma_hat, cond, uc) + d = to_d(x, sigma_hat, denoised) + dt = append_dims(next_sigma - sigma_hat, x.ndim) + + euler_step = self.euler_step(x, d, dt) + x = self.possible_correction_step( + euler_step, x, d, dt, next_sigma, denoiser, cond, uc + ) + return x + + def __call__(self, denoiser, x, cond, uc=None, num_steps=None): + x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( + x, cond, uc, num_steps + ) + + for i in self.get_sigma_gen(num_sigmas): + gamma = ( + min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) + if self.s_tmin <= sigmas[i] <= self.s_tmax + else 0.0 + ) + x = self.sampler_step( + s_in * sigmas[i], + s_in * sigmas[i + 1], + denoiser, + x, + cond, + uc, + gamma, + ) + + return x + + +class AncestralSampler(SingleStepDiffusionSampler): + def __init__(self, eta=1.0, s_noise=1.0, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.eta = eta + self.s_noise = s_noise + self.noise_sampler = lambda x: torch.randn_like(x) + + def ancestral_euler_step(self, x, denoised, sigma, sigma_down): + d = to_d(x, sigma, denoised) + dt = append_dims(sigma_down - sigma, x.ndim) + + return self.euler_step(x, d, dt) + + def ancestral_step(self, x, sigma, next_sigma, sigma_up): + x = torch.where( + append_dims(next_sigma, x.ndim) > 0.0, + x + self.noise_sampler(x) * self.s_noise * append_dims(sigma_up, x.ndim), + x, + ) + return x + + def __call__(self, denoiser, x, cond, uc=None, num_steps=None): + x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( + x, cond, uc, num_steps + ) + + for i in self.get_sigma_gen(num_sigmas): + x = self.sampler_step( + s_in * sigmas[i], + s_in * sigmas[i + 1], + denoiser, + x, + cond, + uc, + ) + + return x + + +class LinearMultistepSampler(BaseDiffusionSampler): + def __init__( + self, + order=4, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + self.order = order + + def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs): + x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( + x, cond, uc, num_steps + ) + + ds = [] + sigmas_cpu = sigmas.detach().cpu().numpy() + for i in self.get_sigma_gen(num_sigmas): + sigma = s_in * sigmas[i] + denoised = denoiser( + *self.guider.prepare_inputs(x, sigma, cond, uc), **kwargs + ) + denoised = self.guider(denoised, sigma) + d = to_d(x, sigma, denoised) + ds.append(d) + if len(ds) > self.order: + ds.pop(0) + cur_order = min(i + 1, self.order) + coeffs = [ + linear_multistep_coeff(cur_order, sigmas_cpu, i, j) + for j in range(cur_order) + ] + x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds))) + + return x + + +class EulerEDMSampler(EDMSampler): + def possible_correction_step( + self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc + ): + return euler_step + + +class HeunEDMSampler(EDMSampler): + def possible_correction_step( + self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc + ): + if torch.sum(next_sigma) < 1e-14: + # Save a network evaluation if all noise levels are 0 + return euler_step + else: + denoised = self.denoise(euler_step, denoiser, next_sigma, cond, uc) + d_new = to_d(euler_step, next_sigma, denoised) + d_prime = (d + d_new) / 2.0 + + # apply correction if noise level is not 0 + x = torch.where( + append_dims(next_sigma, x.ndim) > 0.0, x + d_prime * dt, euler_step + ) + return x + + +class EulerAncestralSampler(AncestralSampler): + def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc): + sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta) + denoised = self.denoise(x, denoiser, sigma, cond, uc) + x = self.ancestral_euler_step(x, denoised, sigma, sigma_down) + x = self.ancestral_step(x, sigma, next_sigma, sigma_up) + + return x + + +class DPMPP2SAncestralSampler(AncestralSampler): + def get_variables(self, sigma, sigma_down): + t, t_next = [to_neg_log_sigma(s) for s in (sigma, sigma_down)] + h = t_next - t + s = t + 0.5 * h + return h, s, t, t_next + + def get_mult(self, h, s, t, t_next): + mult1 = to_sigma(s) / to_sigma(t) + mult2 = (-0.5 * h).expm1() + mult3 = to_sigma(t_next) / to_sigma(t) + mult4 = (-h).expm1() + + return mult1, mult2, mult3, mult4 + + def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, **kwargs): + sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta) + denoised = self.denoise(x, denoiser, sigma, cond, uc) + x_euler = self.ancestral_euler_step(x, denoised, sigma, sigma_down) + + if torch.sum(sigma_down) < 1e-14: + # Save a network evaluation if all noise levels are 0 + x = x_euler + else: + h, s, t, t_next = self.get_variables(sigma, sigma_down) + mult = [ + append_dims(mult, x.ndim) for mult in self.get_mult(h, s, t, t_next) + ] + + x2 = mult[0] * x - mult[1] * denoised + denoised2 = self.denoise(x2, denoiser, to_sigma(s), cond, uc) + x_dpmpp2s = mult[2] * x - mult[3] * denoised2 + + # apply correction if noise level is not 0 + x = torch.where(append_dims(sigma_down, x.ndim) > 0.0, x_dpmpp2s, x_euler) + + x = self.ancestral_step(x, sigma, next_sigma, sigma_up) + return x + + +class DPMPP2MSampler(BaseDiffusionSampler): + def get_variables(self, sigma, next_sigma, previous_sigma=None): + t, t_next = [to_neg_log_sigma(s) for s in (sigma, next_sigma)] + h = t_next - t + + if previous_sigma is not None: + h_last = t - to_neg_log_sigma(previous_sigma) + r = h_last / h + return h, r, t, t_next + else: + return h, None, t, t_next + + def get_mult(self, h, r, t, t_next, previous_sigma): + mult1 = to_sigma(t_next) / to_sigma(t) + mult2 = (-h).expm1() + + if previous_sigma is not None: + mult3 = 1 + 1 / (2 * r) + mult4 = 1 / (2 * r) + return mult1, mult2, mult3, mult4 + else: + return mult1, mult2 + + def sampler_step( + self, + old_denoised, + previous_sigma, + sigma, + next_sigma, + denoiser, + x, + cond, + uc=None, + ): + denoised = self.denoise(x, denoiser, sigma, cond, uc) + + h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma) + mult = [ + append_dims(mult, x.ndim) + for mult in self.get_mult(h, r, t, t_next, previous_sigma) + ] + + x_standard = mult[0] * x - mult[1] * denoised + if old_denoised is None or torch.sum(next_sigma) < 1e-14: + # Save a network evaluation if all noise levels are 0 or on the first step + return x_standard, denoised + else: + denoised_d = mult[2] * denoised - mult[3] * old_denoised + x_advanced = mult[0] * x - mult[1] * denoised_d + + # apply correction if noise level is not 0 and not first step + x = torch.where( + append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard + ) + + return x, denoised + + def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs): + x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( + x, cond, uc, num_steps + ) + + old_denoised = None + for i in self.get_sigma_gen(num_sigmas): + x, old_denoised = self.sampler_step( + old_denoised, + None if i == 0 else s_in * sigmas[i - 1], + s_in * sigmas[i], + s_in * sigmas[i + 1], + denoiser, + x, + cond, + uc=uc, + ) + + return x diff --git a/models/svd/sgm/modules/diffusionmodules/sampling_utils.py b/models/svd/sgm/modules/diffusionmodules/sampling_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4e12506e3f3817035eb8b138616783a1291b564d --- /dev/null +++ b/models/svd/sgm/modules/diffusionmodules/sampling_utils.py @@ -0,0 +1,43 @@ +import torch +from scipy import integrate + +from models.svd.sgm.util import append_dims + + +def linear_multistep_coeff(order, t, i, j, epsrel=1e-4): + if order - 1 > i: + raise ValueError(f"Order {order} too high for step {i}") + + def fn(tau): + prod = 1.0 + for k in range(order): + if j == k: + continue + prod *= (tau - t[i - k]) / (t[i - j] - t[i - k]) + return prod + + return integrate.quad(fn, t[i], t[i + 1], epsrel=epsrel)[0] + + +def get_ancestral_step(sigma_from, sigma_to, eta=1.0): + if not eta: + return sigma_to, 0.0 + sigma_up = torch.minimum( + sigma_to, + eta + * (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5, + ) + sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 + return sigma_down, sigma_up + + +def to_d(x, sigma, denoised): + return (x - denoised) / append_dims(sigma, x.ndim) + + +def to_neg_log_sigma(sigma): + return sigma.log().neg() + + +def to_sigma(neg_log_sigma): + return neg_log_sigma.neg().exp() diff --git a/models/svd/sgm/modules/diffusionmodules/sigma_sampling.py b/models/svd/sgm/modules/diffusionmodules/sigma_sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..0166c99ce36b3e4b652383f49bec216bfab777ec --- /dev/null +++ b/models/svd/sgm/modules/diffusionmodules/sigma_sampling.py @@ -0,0 +1,31 @@ +import torch + +from models.svd.sgm.util import default, instantiate_from_config + + +class EDMSampling: + def __init__(self, p_mean=-1.2, p_std=1.2): + self.p_mean = p_mean + self.p_std = p_std + + def __call__(self, n_samples, rand=None): + log_sigma = self.p_mean + self.p_std * default(rand, torch.randn((n_samples,))) + return log_sigma.exp() + + +class DiscreteSampling: + def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True): + self.num_idx = num_idx + self.sigmas = instantiate_from_config(discretization_config)( + num_idx, do_append_zero=do_append_zero, flip=flip + ) + + def idx_to_sigma(self, idx): + return self.sigmas[idx] + + def __call__(self, n_samples, rand=None): + idx = default( + rand, + torch.randint(0, self.num_idx, (n_samples,)), + ) + return self.idx_to_sigma(idx) diff --git a/models/svd/sgm/modules/diffusionmodules/util.py b/models/svd/sgm/modules/diffusionmodules/util.py new file mode 100644 index 0000000000000000000000000000000000000000..88565b0c33c9eb8e46bc251630e5a22e86f3182a --- /dev/null +++ b/models/svd/sgm/modules/diffusionmodules/util.py @@ -0,0 +1,375 @@ +""" +partially adopted from +https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +and +https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +and +https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py + +thanks! +""" +import os +import math +from typing import Optional + +import torch +import torch.nn as nn +from einops import rearrange, repeat + + +def make_beta_schedule( + schedule, + n_timestep, + linear_start=1e-4, + linear_end=2e-2, +): + if schedule == "linear": + betas = ( + torch.linspace( + linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64 + ) + ** 2 + ) + return betas.numpy() + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def mixed_checkpoint(func, inputs: dict, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. This differs from the original checkpoint function + borrowed from https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py in that + it also works with non-tensor inputs + :param func: the function to evaluate. + :param inputs: the argument dictionary to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + tensor_keys = [key for key in inputs if isinstance(inputs[key], torch.Tensor)] + tensor_inputs = [ + inputs[key] for key in inputs if isinstance(inputs[key], torch.Tensor) + ] + non_tensor_keys = [ + key for key in inputs if not isinstance(inputs[key], torch.Tensor) + ] + non_tensor_inputs = [ + inputs[key] for key in inputs if not isinstance(inputs[key], torch.Tensor) + ] + args = tuple(tensor_inputs) + tuple(non_tensor_inputs) + tuple(params) + return MixedCheckpointFunction.apply( + func, + len(tensor_inputs), + len(non_tensor_inputs), + tensor_keys, + non_tensor_keys, + *args, + ) + else: + return func(**inputs) + + +class MixedCheckpointFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + run_function, + length_tensors, + length_non_tensors, + tensor_keys, + non_tensor_keys, + *args, + ): + ctx.end_tensors = length_tensors + ctx.end_non_tensors = length_tensors + length_non_tensors + ctx.gpu_autocast_kwargs = { + "enabled": torch.is_autocast_enabled(), + "dtype": torch.get_autocast_gpu_dtype(), + "cache_enabled": torch.is_autocast_cache_enabled(), + } + assert ( + len(tensor_keys) == length_tensors + and len(non_tensor_keys) == length_non_tensors + ) + + ctx.input_tensors = { + key: val for (key, val) in zip(tensor_keys, list(args[: ctx.end_tensors])) + } + ctx.input_non_tensors = { + key: val + for (key, val) in zip( + non_tensor_keys, list(args[ctx.end_tensors : ctx.end_non_tensors]) + ) + } + ctx.run_function = run_function + ctx.input_params = list(args[ctx.end_non_tensors :]) + + with torch.no_grad(): + output_tensors = ctx.run_function( + **ctx.input_tensors, **ctx.input_non_tensors + ) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + # additional_args = {key: ctx.input_tensors[key] for key in ctx.input_tensors if not isinstance(ctx.input_tensors[key],torch.Tensor)} + ctx.input_tensors = { + key: ctx.input_tensors[key].detach().requires_grad_(True) + for key in ctx.input_tensors + } + + with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = { + key: ctx.input_tensors[key].view_as(ctx.input_tensors[key]) + for key in ctx.input_tensors + } + # shallow_copies.update(additional_args) + output_tensors = ctx.run_function(**shallow_copies, **ctx.input_non_tensors) + input_grads = torch.autograd.grad( + output_tensors, + list(ctx.input_tensors.values()) + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return ( + (None, None, None, None, None) + + input_grads[: ctx.end_tensors] + + (None,) * (ctx.end_non_tensors - ctx.end_tensors) + + input_grads[ctx.end_tensors :] + ) + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + ctx.gpu_autocast_kwargs = { + "enabled": torch.is_autocast_enabled(), + "dtype": torch.get_autocast_gpu_dtype(), + "cache_enabled": torch.is_autocast_cache_enabled(), + } + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + else: + embedding = repeat(timesteps, "b -> b d", d=dim) + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_fp16 = os.getenv("STREAMING_USE_FP16", "False") == "True" + def forward(self, x): + if self.is_fp16: + return super().forward(x).type(x.dtype) + else: + return super().forward(x.float()).type(x.dtype) + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +class AlphaBlender(nn.Module): + strategies = ["learned", "fixed", "learned_with_images"] + + def __init__( + self, + alpha: float, + merge_strategy: str = "learned_with_images", + rearrange_pattern: str = "b t -> (b t) 1 1", + ): + super().__init__() + self.merge_strategy = merge_strategy + self.rearrange_pattern = rearrange_pattern + + assert ( + merge_strategy in self.strategies + ), f"merge_strategy needs to be in {self.strategies}" + + if self.merge_strategy == "fixed": + self.register_buffer("mix_factor", torch.Tensor([alpha])) + elif ( + self.merge_strategy == "learned" + or self.merge_strategy == "learned_with_images" + ): + self.register_parameter( + "mix_factor", torch.nn.Parameter(torch.Tensor([alpha])) + ) + else: + raise ValueError(f"unknown merge strategy {self.merge_strategy}") + + def get_alpha(self, image_only_indicator: torch.Tensor) -> torch.Tensor: + if self.merge_strategy == "fixed": + alpha = self.mix_factor + elif self.merge_strategy == "learned": + alpha = torch.sigmoid(self.mix_factor) + elif self.merge_strategy == "learned_with_images": + assert image_only_indicator is not None, "need image_only_indicator ..." + alpha = torch.where( + image_only_indicator.bool(), + torch.ones(1, 1, device=image_only_indicator.device), + rearrange(torch.sigmoid(self.mix_factor), "... -> ... 1"), + ) + alpha = rearrange(alpha, self.rearrange_pattern) + else: + raise NotImplementedError + return alpha + + def forward( + self, + x_spatial: torch.Tensor, + x_temporal: torch.Tensor, + image_only_indicator: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + alpha = self.get_alpha(image_only_indicator) + + x = ( + alpha.to(x_spatial.dtype) * x_spatial + + (1.0 - alpha).to(x_spatial.dtype) * x_temporal + ) + return x diff --git a/models/svd/sgm/modules/diffusionmodules/wrappers.py b/models/svd/sgm/modules/diffusionmodules/wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..fe211fbc971eb168e866dce4cc4e3d7cc962594e --- /dev/null +++ b/models/svd/sgm/modules/diffusionmodules/wrappers.py @@ -0,0 +1,33 @@ +import torch +import torch.nn as nn +from packaging import version +OPENAIUNETWRAPPER = "models.svd.sgm.modules.diffusionmodules.wrappers.OpenAIWrapper" + + +class IdentityWrapper(nn.Module): + def __init__(self, diffusion_model, compile_model: bool = False): + super().__init__() + compile = ( + torch.compile + if (version.parse(torch.__version__) >= version.parse("2.0.0")) + and compile_model + else lambda x: x + ) + self.diffusion_model = compile(diffusion_model) + + def forward(self, *args, **kwargs): + return self.diffusion_model(*args, **kwargs) + + +class OpenAIWrapper(IdentityWrapper): + def forward( + self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs + ) -> torch.Tensor: + x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1) + return self.diffusion_model( + x, + timesteps=t, + context=c.get("crossattn", None), + y=c.get("vector", None), + **kwargs, + ) \ No newline at end of file diff --git a/models/svd/sgm/modules/distributions/__init__.py b/models/svd/sgm/modules/distributions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/svd/sgm/modules/distributions/distributions.py b/models/svd/sgm/modules/distributions/distributions.py new file mode 100644 index 0000000000000000000000000000000000000000..016be35523187ea366db9ade391fe8ee276db60b --- /dev/null +++ b/models/svd/sgm/modules/distributions/distributions.py @@ -0,0 +1,102 @@ +import numpy as np +import torch + + +class AbstractDistribution: + def sample(self): + raise NotImplementedError() + + def mode(self): + raise NotImplementedError() + + +class DiracDistribution(AbstractDistribution): + def __init__(self, value): + self.value = value + + def sample(self): + return self.value + + def mode(self): + return self.value + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to( + device=self.parameters.device + ) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to( + device=self.parameters.device + ) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.0]) + else: + if other is None: + return 0.5 * torch.sum( + torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, + dim=[1, 2, 3], + ) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2, 3], + ) + + def nll(self, sample, dims=[1, 2, 3]): + if self.deterministic: + return torch.Tensor([0.0]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims, + ) + + def mode(self): + return self.mean + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [ + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) diff --git a/models/svd/sgm/modules/ema.py b/models/svd/sgm/modules/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..97b5ae2b230f89b4dba57e44c4f851478ad86f68 --- /dev/null +++ b/models/svd/sgm/modules/ema.py @@ -0,0 +1,86 @@ +import torch +from torch import nn + + +class LitEma(nn.Module): + def __init__(self, model, decay=0.9999, use_num_upates=True): + super().__init__() + if decay < 0.0 or decay > 1.0: + raise ValueError("Decay must be between 0 and 1") + + self.m_name2s_name = {} + self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) + self.register_buffer( + "num_updates", + torch.tensor(0, dtype=torch.int) + if use_num_upates + else torch.tensor(-1, dtype=torch.int), + ) + + for name, p in model.named_parameters(): + if p.requires_grad: + # remove as '.'-character is not allowed in buffers + s_name = name.replace(".", "") + self.m_name2s_name.update({name: s_name}) + self.register_buffer(s_name, p.clone().detach().data) + + self.collected_params = [] + + def reset_num_updates(self): + del self.num_updates + self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int)) + + def forward(self, model): + decay = self.decay + + if self.num_updates >= 0: + self.num_updates += 1 + decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) + + one_minus_decay = 1.0 - decay + + with torch.no_grad(): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + + for key in m_param: + if m_param[key].requires_grad: + sname = self.m_name2s_name[key] + shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) + shadow_params[sname].sub_( + one_minus_decay * (shadow_params[sname] - m_param[key]) + ) + else: + assert not key in self.m_name2s_name + + def copy_to(self, model): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + for key in m_param: + if m_param[key].requires_grad: + m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) + else: + assert not key in self.m_name2s_name + + def store(self, parameters): + """ + Save the current parameters for restoring later. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.collected_params = [param.clone() for param in parameters] + + def restore(self, parameters): + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. + """ + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) diff --git a/models/svd/sgm/modules/encoders/__init__.py b/models/svd/sgm/modules/encoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/svd/sgm/modules/encoders/modules.py b/models/svd/sgm/modules/encoders/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..bfeb1175f562c82c5ff5e43c7ce4ac9288e0ad47 --- /dev/null +++ b/models/svd/sgm/modules/encoders/modules.py @@ -0,0 +1,1050 @@ +import math +from contextlib import nullcontext +from functools import partial +from typing import Dict, List, Optional, Tuple, Union + +import kornia +import numpy as np +import open_clip +import torch +import torch.nn as nn +from einops import rearrange, repeat +from omegaconf import ListConfig +from torch.utils.checkpoint import checkpoint +from transformers import (ByT5Tokenizer, CLIPTextModel, CLIPTokenizer, + T5EncoderModel, T5Tokenizer) + +from models.svd.sgm.modules.autoencoding.regularizers import DiagonalGaussianRegularizer +from models.svd.sgm.modules.diffusionmodules.model import Encoder +from models.svd.sgm.modules.diffusionmodules.openaimodel import Timestep +from models.svd.sgm.modules.diffusionmodules.util import (extract_into_tensor, + make_beta_schedule) +from models.svd.sgm.modules.distributions.distributions import DiagonalGaussianDistribution +from models.svd.sgm.util import (append_dims, autocast, count_params, default, + disabled_train, expand_dims_like, instantiate_from_config) + + +class AbstractEmbModel(nn.Module): + def __init__(self): + super().__init__() + self._is_trainable = None + self._ucg_rate = None + self._input_key = None + + @property + def is_trainable(self) -> bool: + return self._is_trainable + + @property + def ucg_rate(self) -> Union[float, torch.Tensor]: + return self._ucg_rate + + @property + def input_key(self) -> str: + return self._input_key + + @is_trainable.setter + def is_trainable(self, value: bool): + self._is_trainable = value + + @ucg_rate.setter + def ucg_rate(self, value: Union[float, torch.Tensor]): + self._ucg_rate = value + + @input_key.setter + def input_key(self, value: str): + self._input_key = value + + @is_trainable.deleter + def is_trainable(self): + del self._is_trainable + + @ucg_rate.deleter + def ucg_rate(self): + del self._ucg_rate + + @input_key.deleter + def input_key(self): + del self._input_key + + +class GeneralConditioner(nn.Module): + OUTPUT_DIM2KEYS = {2: "vector", 3: "crossattn", 4: "concat", 5: "concat"} + KEY2CATDIM = {"vector": 1, "crossattn": 2, "concat": 1} + + def __init__(self, emb_models: Union[List, ListConfig]): + super().__init__() + embedders = [] + for n, embconfig in enumerate(emb_models): + embedder = instantiate_from_config(embconfig) + assert isinstance( + embedder, AbstractEmbModel + ), f"embedder model {embedder.__class__.__name__} has to inherit from AbstractEmbModel" + embedder.is_trainable = embconfig.get("is_trainable", False) + embedder.ucg_rate = embconfig.get("ucg_rate", 0.0) + if not embedder.is_trainable: + embedder.train = disabled_train + for param in embedder.parameters(): + param.requires_grad = False + embedder.eval() + print( + f"Initialized embedder #{n}: {embedder.__class__.__name__} " + f"with {count_params(embedder, False)} params. Trainable: {embedder.is_trainable}" + ) + + if "input_key" in embconfig: + embedder.input_key = embconfig["input_key"] + elif "input_keys" in embconfig: + embedder.input_keys = embconfig["input_keys"] + else: + raise KeyError( + f"need either 'input_key' or 'input_keys' for embedder {embedder.__class__.__name__}" + ) + + embedder.legacy_ucg_val = embconfig.get("legacy_ucg_value", None) + if embedder.legacy_ucg_val is not None: + embedder.ucg_prng = np.random.RandomState() + + embedders.append(embedder) + + self.embedders = nn.ModuleList(embedders) + + def possibly_get_ucg_val(self, embedder: AbstractEmbModel, batch: Dict) -> Dict: + assert embedder.legacy_ucg_val is not None + p = embedder.ucg_rate + val = embedder.legacy_ucg_val + for i in range(len(batch[embedder.input_key])): + if embedder.ucg_prng.choice(2, p=[1 - p, p]): + batch[embedder.input_key][i] = val + return batch + + def forward( + self, batch: Dict, force_zero_embeddings: Optional[List] = None + ) -> Dict: + output = dict() + if force_zero_embeddings is None: + force_zero_embeddings = [] + for embedder in self.embedders: + embedding_context = nullcontext if embedder.is_trainable else torch.no_grad + with embedding_context(): + if hasattr(embedder, "input_key") and (embedder.input_key is not None): + if embedder.legacy_ucg_val is not None: + batch = self.possibly_get_ucg_val(embedder, batch) + emb_out = embedder(batch[embedder.input_key]) + elif hasattr(embedder, "input_keys"): + emb_out = embedder(*[batch[k] for k in embedder.input_keys]) + assert isinstance( + emb_out, (torch.Tensor, list, tuple) + ), f"encoder outputs must be tensors or a sequence, but got {type(emb_out)}" + if not isinstance(emb_out, (list, tuple)): + emb_out = [emb_out] + for emb in emb_out: + out_key = self.OUTPUT_DIM2KEYS[emb.dim()] + # print(f"{embedder.input_key} -> {out_key}") + if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None: + emb = ( + expand_dims_like( + torch.bernoulli( + (1.0 - embedder.ucg_rate) + * torch.ones(emb.shape[0], device=emb.device) + ), + emb, + ) + * emb + ) + if ( + hasattr(embedder, "input_key") + and embedder.input_key in force_zero_embeddings + ): + emb = torch.zeros_like(emb) + if out_key in output: + #print(f"Embedder {embedder.input_key} -> {out_key}") + output[out_key] = torch.cat( + (output[out_key], emb), self.KEY2CATDIM[out_key] + ) + else: + #print(f"Embedder {embedder.input_key} -> {out_key}") + output[out_key] = emb + return output + + def get_unconditional_conditioning( + self, + batch_c: Dict, + batch_uc: Optional[Dict] = None, + force_uc_zero_embeddings: Optional[List[str]] = None, + force_cond_zero_embeddings: Optional[List[str]] = None, + ): + if force_uc_zero_embeddings is None: + force_uc_zero_embeddings = [] + ucg_rates = list() + for embedder in self.embedders: + ucg_rates.append(embedder.ucg_rate) + embedder.ucg_rate = 0.0 + c = self(batch_c, force_cond_zero_embeddings) + uc = self(batch_c if batch_uc is None else batch_uc, force_uc_zero_embeddings) + + for embedder, rate in zip(self.embedders, ucg_rates): + embedder.ucg_rate = rate + return c, uc + + +class InceptionV3(nn.Module): + """Wrapper around the https://github.com/mseitzer/pytorch-fid inception + port with an additional squeeze at the end""" + + def __init__(self, normalize_input=False, **kwargs): + super().__init__() + from pytorch_fid import inception + + kwargs["resize_input"] = True + self.model = inception.InceptionV3(normalize_input=normalize_input, **kwargs) + + def forward(self, inp): + outp = self.model(inp) + + if len(outp) == 1: + return outp[0].squeeze() + + return outp + + +class IdentityEncoder(AbstractEmbModel): + def encode(self, x): + return x + + def forward(self, x): + return x + + +class ClassEmbedder(AbstractEmbModel): + def __init__(self, embed_dim, n_classes=1000, add_sequence_dim=False): + super().__init__() + self.embedding = nn.Embedding(n_classes, embed_dim) + self.n_classes = n_classes + self.add_sequence_dim = add_sequence_dim + + def forward(self, c): + c = self.embedding(c) + if self.add_sequence_dim: + c = c[:, None, :] + return c + + def get_unconditional_conditioning(self, bs, device="cuda"): + uc_class = ( + self.n_classes - 1 + ) # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000) + uc = torch.ones((bs,), device=device) * uc_class + uc = {self.key: uc.long()} + return uc + + +class ClassEmbedderForMultiCond(ClassEmbedder): + def forward(self, batch, key=None, disable_dropout=False): + out = batch + key = default(key, self.key) + islist = isinstance(batch[key], list) + if islist: + batch[key] = batch[key][0] + c_out = super().forward(batch, key, disable_dropout) + out[key] = [c_out] if islist else c_out + return out + + +class FrozenT5Embedder(AbstractEmbModel): + """Uses the T5 transformer encoder for text""" + + def __init__( + self, version="google/t5-v1_1-xxl", device="cuda", max_length=77, freeze=True + ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl + super().__init__() + self.tokenizer = T5Tokenizer.from_pretrained(version) + self.transformer = T5EncoderModel.from_pretrained(version) + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + + def freeze(self): + self.transformer = self.transformer.eval() + + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) + tokens = batch_encoding["input_ids"].to(self.device) + with torch.autocast("cuda", enabled=False): + outputs = self.transformer(input_ids=tokens) + z = outputs.last_hidden_state + return z + + def encode(self, text): + return self(text) + + +class FrozenByT5Embedder(AbstractEmbModel): + """ + Uses the ByT5 transformer encoder for text. Is character-aware. + """ + + def __init__( + self, version="google/byt5-base", device="cuda", max_length=77, freeze=True + ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl + super().__init__() + self.tokenizer = ByT5Tokenizer.from_pretrained(version) + self.transformer = T5EncoderModel.from_pretrained(version) + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + + def freeze(self): + self.transformer = self.transformer.eval() + + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) + tokens = batch_encoding["input_ids"].to(self.device) + with torch.autocast("cuda", enabled=False): + outputs = self.transformer(input_ids=tokens) + z = outputs.last_hidden_state + return z + + def encode(self, text): + return self(text) + + +class FrozenCLIPEmbedder(AbstractEmbModel): + """Uses the CLIP transformer encoder for text (from huggingface)""" + + LAYERS = ["last", "pooled", "hidden"] + + def __init__( + self, + version="openai/clip-vit-large-patch14", + device="cuda", + max_length=77, + freeze=True, + layer="last", + layer_idx=None, + always_return_pooled=False, + ): # clip-vit-base-patch32 + super().__init__() + assert layer in self.LAYERS + self.tokenizer = CLIPTokenizer.from_pretrained(version) + self.transformer = CLIPTextModel.from_pretrained(version) + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + self.layer = layer + self.layer_idx = layer_idx + self.return_pooled = always_return_pooled + if layer == "hidden": + assert layer_idx is not None + assert 0 <= abs(layer_idx) <= 12 + + def freeze(self): + self.transformer = self.transformer.eval() + + for param in self.parameters(): + param.requires_grad = False + + @autocast + def forward(self, text): + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer( + input_ids=tokens, output_hidden_states=self.layer == "hidden" + ) + if self.layer == "last": + z = outputs.last_hidden_state + elif self.layer == "pooled": + z = outputs.pooler_output[:, None, :] + else: + z = outputs.hidden_states[self.layer_idx] + if self.return_pooled: + return z, outputs.pooler_output + return z + + def encode(self, text): + return self(text) + + +class FrozenOpenCLIPEmbedder2(AbstractEmbModel): + """ + Uses the OpenCLIP transformer encoder for text + """ + + LAYERS = ["pooled", "last", "penultimate"] + + def __init__( + self, + arch="ViT-H-14", + version="laion2b_s32b_b79k", + device="cuda", + max_length=77, + freeze=True, + layer="last", + always_return_pooled=False, + legacy=True, + ): + super().__init__() + assert layer in self.LAYERS + model, _, _ = open_clip.create_model_and_transforms( + arch, + device=torch.device("cpu"), + pretrained=version, + ) + del model.visual + self.model = model + + self.device = device + self.max_length = max_length + self.return_pooled = always_return_pooled + if freeze: + self.freeze() + self.layer = layer + if self.layer == "last": + self.layer_idx = 0 + elif self.layer == "penultimate": + self.layer_idx = 1 + else: + raise NotImplementedError() + self.legacy = legacy + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + @autocast + def forward(self, text): + tokens = open_clip.tokenize(text) + z = self.encode_with_transformer(tokens.to(self.device)) + if not self.return_pooled and self.legacy: + return z + if self.return_pooled: + assert not self.legacy + return z[self.layer], z["pooled"] + return z[self.layer] + + def encode_with_transformer(self, text): + x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] + x = x + self.model.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) + if self.legacy: + x = x[self.layer] + x = self.model.ln_final(x) + return x + else: + # x is a dict and will stay a dict + o = x["last"] + o = self.model.ln_final(o) + pooled = self.pool(o, text) + x["pooled"] = pooled + return x + + def pool(self, x, text): + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = ( + x[torch.arange(x.shape[0]), text.argmax(dim=-1)] + @ self.model.text_projection + ) + return x + + def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): + outputs = {} + for i, r in enumerate(self.model.transformer.resblocks): + if i == len(self.model.transformer.resblocks) - 1: + outputs["penultimate"] = x.permute(1, 0, 2) # LND -> NLD + if ( + self.model.transformer.grad_checkpointing + and not torch.jit.is_scripting() + ): + x = checkpoint(r, x, attn_mask) + else: + x = r(x, attn_mask=attn_mask) + outputs["last"] = x.permute(1, 0, 2) # LND -> NLD + return outputs + + def encode(self, text): + return self(text) + + +class FrozenOpenCLIPEmbedder(AbstractEmbModel): + LAYERS = [ + # "pooled", + "last", + "penultimate", + ] + + def __init__( + self, + arch="ViT-H-14", + version="laion2b_s32b_b79k", + device="cuda", + max_length=77, + freeze=True, + layer="last", + ): + super().__init__() + assert layer in self.LAYERS + model, _, _ = open_clip.create_model_and_transforms( + arch, device=torch.device("cpu"), pretrained=version + ) + del model.visual + self.model = model + + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + self.layer = layer + if self.layer == "last": + self.layer_idx = 0 + elif self.layer == "penultimate": + self.layer_idx = 1 + else: + raise NotImplementedError() + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + tokens = open_clip.tokenize(text) + z = self.encode_with_transformer(tokens.to(self.device)) + return z + + def encode_with_transformer(self, text): + x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] + x = x + self.model.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.model.ln_final(x) + return x + + def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): + for i, r in enumerate(self.model.transformer.resblocks): + if i == len(self.model.transformer.resblocks) - self.layer_idx: + break + if ( + self.model.transformer.grad_checkpointing + and not torch.jit.is_scripting() + ): + x = checkpoint(r, x, attn_mask) + else: + x = r(x, attn_mask=attn_mask) + return x + + def encode(self, text): + return self(text) + + +class FrozenOpenCLIPImageEmbedder(AbstractEmbModel): + """ + Uses the OpenCLIP vision transformer encoder for images + """ + + def __init__( + self, + arch="ViT-H-14", + version="laion2b_s32b_b79k", + device="cuda", + max_length=77, + freeze=True, + antialias=True, + ucg_rate=0.0, + unsqueeze_dim=False, + repeat_to_max_len=False, + num_image_crops=0, + output_tokens=False, + init_device=None, + ): + super().__init__() + model, _, _ = open_clip.create_model_and_transforms( + arch, + device=torch.device(default(init_device, "cpu")), + pretrained=version, + ) + del model.transformer + self.model = model + self.max_crops = num_image_crops + self.pad_to_max_len = self.max_crops > 0 + self.repeat_to_max_len = repeat_to_max_len and (not self.pad_to_max_len) + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + + self.antialias = antialias + + self.register_buffer( + "mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False + ) + self.register_buffer( + "std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False + ) + self.ucg_rate = ucg_rate + self.unsqueeze_dim = unsqueeze_dim + self.stored_batch = None + self.model.visual.output_tokens = output_tokens + self.output_tokens = output_tokens + + def preprocess(self, x): + # normalize to [0,1] + x = kornia.geometry.resize( + x, + (224, 224), + interpolation="bicubic", + align_corners=True, + antialias=self.antialias, + ) + x = (x + 1.0) / 2.0 + # renormalize according to clip + x = kornia.enhance.normalize(x, self.mean, self.std) + return x + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + @autocast + def forward(self, image, no_dropout=False): + z = self.encode_with_vision_transformer(image) + tokens = None + if self.output_tokens: + z, tokens = z[0], z[1] + z = z.to(image.dtype) + if self.ucg_rate > 0.0 and not no_dropout and not (self.max_crops > 0): + z = ( + torch.bernoulli( + (1.0 - self.ucg_rate) * torch.ones(z.shape[0], device=z.device) + )[:, None] + * z + ) + if tokens is not None: + tokens = ( + expand_dims_like( + torch.bernoulli( + (1.0 - self.ucg_rate) + * torch.ones(tokens.shape[0], device=tokens.device) + ), + tokens, + ) + * tokens + ) + if self.unsqueeze_dim: + z = z[:, None, :] + if self.output_tokens: + assert not self.repeat_to_max_len + assert not self.pad_to_max_len + return tokens, z + if self.repeat_to_max_len: + if z.dim() == 2: + z_ = z[:, None, :] + else: + z_ = z + return repeat(z_, "b 1 d -> b n d", n=self.max_length), z + elif self.pad_to_max_len: + assert z.dim() == 3 + z_pad = torch.cat( + ( + z, + torch.zeros( + z.shape[0], + self.max_length - z.shape[1], + z.shape[2], + device=z.device, + ), + ), + 1, + ) + return z_pad, z_pad[:, 0, ...] + return z + + def encode_with_vision_transformer(self, img): + # if self.max_crops > 0: + # img = self.preprocess_by_cropping(img) + if img.dim() == 5: + assert self.max_crops == img.shape[1] + img = rearrange(img, "b n c h w -> (b n) c h w") + img = self.preprocess(img) + if not self.output_tokens: + assert not self.model.visual.output_tokens + x = self.model.visual(img) + tokens = None + else: + assert self.model.visual.output_tokens + x, tokens = self.model.visual(img) + if self.max_crops > 0: + x = rearrange(x, "(b n) d -> b n d", n=self.max_crops) + # drop out between 0 and all along the sequence axis + x = ( + torch.bernoulli( + (1.0 - self.ucg_rate) + * torch.ones(x.shape[0], x.shape[1], 1, device=x.device) + ) + * x + ) + if tokens is not None: + tokens = rearrange(tokens, "(b n) t d -> b t (n d)", n=self.max_crops) + print( + f"You are running very experimental token-concat in {self.__class__.__name__}. " + f"Check what you are doing, and then remove this message." + ) + if self.output_tokens: + return x, tokens + return x + + def encode(self, text): + return self(text) + + +class FrozenCLIPT5Encoder(AbstractEmbModel): + def __init__( + self, + clip_version="openai/clip-vit-large-patch14", + t5_version="google/t5-v1_1-xl", + device="cuda", + clip_max_length=77, + t5_max_length=77, + ): + super().__init__() + self.clip_encoder = FrozenCLIPEmbedder( + clip_version, device, max_length=clip_max_length + ) + self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) + print( + f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, " + f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params." + ) + + def encode(self, text): + return self(text) + + def forward(self, text): + clip_z = self.clip_encoder.encode(text) + t5_z = self.t5_encoder.encode(text) + return [clip_z, t5_z] + + +class SpatialRescaler(nn.Module): + def __init__( + self, + n_stages=1, + method="bilinear", + multiplier=0.5, + in_channels=3, + out_channels=None, + bias=False, + wrap_video=False, + kernel_size=1, + remap_output=False, + ): + super().__init__() + self.n_stages = n_stages + assert self.n_stages >= 0 + assert method in [ + "nearest", + "linear", + "bilinear", + "trilinear", + "bicubic", + "area", + ] + self.multiplier = multiplier + self.interpolator = partial(torch.nn.functional.interpolate, mode=method) + self.remap_output = out_channels is not None or remap_output + if self.remap_output: + print( + f"Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing." + ) + self.channel_mapper = nn.Conv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + bias=bias, + padding=kernel_size // 2, + ) + self.wrap_video = wrap_video + + def forward(self, x): + if self.wrap_video and x.ndim == 5: + B, C, T, H, W = x.shape + x = rearrange(x, "b c t h w -> b t c h w") + x = rearrange(x, "b t c h w -> (b t) c h w") + + for stage in range(self.n_stages): + x = self.interpolator(x, scale_factor=self.multiplier) + + if self.wrap_video: + x = rearrange(x, "(b t) c h w -> b t c h w", b=B, t=T, c=C) + x = rearrange(x, "b t c h w -> b c t h w") + if self.remap_output: + x = self.channel_mapper(x) + return x + + def encode(self, x): + return self(x) + + +class LowScaleEncoder(nn.Module): + def __init__( + self, + model_config, + linear_start, + linear_end, + timesteps=1000, + max_noise_level=250, + output_size=64, + scale_factor=1.0, + ): + super().__init__() + self.max_noise_level = max_noise_level + self.model = instantiate_from_config(model_config) + self.augmentation_schedule = self.register_schedule( + timesteps=timesteps, linear_start=linear_start, linear_end=linear_end + ) + self.out_size = output_size + self.scale_factor = scale_factor + + def register_schedule( + self, + beta_schedule="linear", + timesteps=1000, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + ): + betas = make_beta_schedule( + beta_schedule, + timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s, + ) + alphas = 1.0 - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) + + (timesteps,) = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert ( + alphas_cumprod.shape[0] == self.num_timesteps + ), "alphas have to be defined for each timestep" + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer("betas", to_torch(betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer( + "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod)) + ) + self.register_buffer( + "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod)) + ) + self.register_buffer( + "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod)) + ) + self.register_buffer( + "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1)) + ) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) + * noise + ) + + def forward(self, x): + z = self.model.encode(x) + if isinstance(z, DiagonalGaussianDistribution): + z = z.sample() + z = z * self.scale_factor + noise_level = torch.randint( + 0, self.max_noise_level, (x.shape[0],), device=x.device + ).long() + z = self.q_sample(z, noise_level) + if self.out_size is not None: + z = torch.nn.functional.interpolate(z, size=self.out_size, mode="nearest") + return z, noise_level + + def decode(self, z): + z = z / self.scale_factor + return self.model.decode(z) + + +class ConcatTimestepEmbedderND(AbstractEmbModel): + """embeds each dimension independently and concatenates them""" + + def __init__(self, outdim): + super().__init__() + self.timestep = Timestep(outdim) + self.outdim = outdim + + def forward(self, x): + if x.ndim == 1: + x = x[:, None] + assert len(x.shape) == 2 + b, dims = x.shape[0], x.shape[1] + x = rearrange(x, "b d -> (b d)") + emb = self.timestep(x) + emb = rearrange(emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim) + return emb + + +class GaussianEncoder(Encoder, AbstractEmbModel): + def __init__( + self, weight: float = 1.0, flatten_output: bool = True, *args, **kwargs + ): + super().__init__(*args, **kwargs) + self.posterior = DiagonalGaussianRegularizer() + self.weight = weight + self.flatten_output = flatten_output + + def forward(self, x) -> Tuple[Dict, torch.Tensor]: + z = super().forward(x) + z, log = self.posterior(z) + log["loss"] = log["kl_loss"] + log["weight"] = self.weight + if self.flatten_output: + z = rearrange(z, "b c h w -> b (h w ) c") + return log, z + + +class VideoPredictionEmbedderWithEncoder(AbstractEmbModel): + def __init__( + self, + n_cond_frames: int, + n_copies: int, + encoder_config: dict, + sigma_sampler_config: Optional[dict] = None, + sigma_cond_config: Optional[dict] = None, + is_ae: bool = False, + scale_factor: float = 1.0, + disable_encoder_autocast: bool = False, + en_and_decode_n_samples_a_time: Optional[int] = None, + ): + super().__init__() + + self.n_cond_frames = n_cond_frames + self.n_copies = n_copies + self.encoder = instantiate_from_config(encoder_config) + self.sigma_sampler = ( + instantiate_from_config(sigma_sampler_config) + if sigma_sampler_config is not None + else None + ) + self.sigma_cond = ( + instantiate_from_config(sigma_cond_config) + if sigma_cond_config is not None + else None + ) + self.is_ae = is_ae + self.scale_factor = scale_factor + self.disable_encoder_autocast = disable_encoder_autocast + self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time + + def forward( + self, vid: torch.Tensor + ) -> Union[ + torch.Tensor, + Tuple[torch.Tensor, torch.Tensor], + Tuple[torch.Tensor, dict], + Tuple[Tuple[torch.Tensor, torch.Tensor], dict], + ]: + if self.sigma_sampler is not None: + b = vid.shape[0] // self.n_cond_frames + sigmas = self.sigma_sampler(b).to(vid.device) + if self.sigma_cond is not None: + sigma_cond = self.sigma_cond(sigmas) + sigma_cond = repeat(sigma_cond, "b d -> (b t) d", t=self.n_copies) + sigmas = repeat(sigmas, "b -> (b t)", t=self.n_cond_frames) + noise = torch.randn_like(vid) + vid = vid + noise * append_dims(sigmas, vid.ndim) + + with torch.autocast("cuda", enabled=not self.disable_encoder_autocast): + n_samples = ( + self.en_and_decode_n_samples_a_time + if self.en_and_decode_n_samples_a_time is not None + else vid.shape[0] + ) + n_rounds = math.ceil(vid.shape[0] / n_samples) + all_out = [] + for n in range(n_rounds): + if self.is_ae: + out = self.encoder.encode(vid[n * n_samples : (n + 1) * n_samples]) + else: + out = self.encoder(vid[n * n_samples : (n + 1) * n_samples]) + all_out.append(out) + + vid = torch.cat(all_out, dim=0) + vid *= self.scale_factor + + vid = rearrange(vid, "(b t) c h w -> b () (t c) h w", t=self.n_cond_frames) + vid = repeat(vid, "b 1 c h w -> (b t) c h w", t=self.n_copies) + + return_val = (vid, sigma_cond) if self.sigma_cond is not None else vid + + return return_val + + +class FrozenOpenCLIPImagePredictionEmbedder(AbstractEmbModel): + def __init__( + self, + open_clip_embedding_config: Dict, + n_cond_frames: int, + n_copies: int, + ): + super().__init__() + + self.n_cond_frames = n_cond_frames + self.n_copies = n_copies + self.open_clip = instantiate_from_config(open_clip_embedding_config) + + def forward(self, vid): + vid = self.open_clip(vid) + vid = rearrange(vid, "(b t) d -> b t d", t=self.n_cond_frames) + vid = repeat(vid, "b t d -> (b s) t d", s=self.n_copies) + + return vid diff --git a/models/svd/sgm/modules/video_attention.py b/models/svd/sgm/modules/video_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..e2f7d006a60f952dd095ccb95ea0eab697fe5b36 --- /dev/null +++ b/models/svd/sgm/modules/video_attention.py @@ -0,0 +1,312 @@ +import os + +import torch + +from models.svd.sgm.modules.attention import * +from models.svd.sgm.modules.diffusionmodules.util import (AlphaBlender, linear, + timestep_embedding) + + +class TimeMixSequential(nn.Sequential): + def forward(self, x, context=None, timesteps=None): + for layer in self: + x = layer(x, context, timesteps) + + return x + + +class VideoTransformerBlock(nn.Module): + ATTENTION_MODES = { + "softmax": CrossAttention, + "softmax-xformers": MemoryEfficientCrossAttention, + } + + def __init__( + self, + dim, + n_heads, + d_head, + dropout=0.0, + context_dim=None, + gated_ff=True, + checkpoint=True, + timesteps=None, + ff_in=False, + inner_dim=None, + attn_mode="softmax", + disable_self_attn=False, + disable_temporal_crossattention=False, + switch_temporal_ca_to_sa=False, + ): + super().__init__() + + attn_cls = self.ATTENTION_MODES[attn_mode] + + self.ff_in = ff_in or inner_dim is not None + if inner_dim is None: + inner_dim = dim + + assert int(n_heads * d_head) == inner_dim + + self.is_res = inner_dim == dim + + if self.ff_in: + self.norm_in = nn.LayerNorm(dim) + self.ff_in = FeedForward( + dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff + ) + + self.timesteps = timesteps + self.disable_self_attn = disable_self_attn + if self.disable_self_attn: + self.attn1 = attn_cls( + query_dim=inner_dim, + heads=n_heads, + dim_head=d_head, + context_dim=context_dim, + dropout=dropout, + ) # is a cross-attention + else: + self.attn1 = attn_cls( + query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout + ) # is a self-attention + + self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff) + + if disable_temporal_crossattention: + if switch_temporal_ca_to_sa: + raise ValueError + else: + self.attn2 = None + else: + self.norm2 = nn.LayerNorm(inner_dim) + if switch_temporal_ca_to_sa: + self.attn2 = attn_cls( + query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout + ) # is a self-attention + else: + self.attn2 = attn_cls( + query_dim=inner_dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + ) # is self-attn if context is none + + self.norm1 = nn.LayerNorm(inner_dim) + self.norm3 = nn.LayerNorm(inner_dim) + self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa + + self.checkpoint = checkpoint + if self.checkpoint: + print(f"{self.__class__.__name__} is using checkpointing") + + def forward( + self, x: torch.Tensor, context: torch.Tensor = None, timesteps: int = None + ) -> torch.Tensor: + if self.checkpoint: + return checkpoint(self._forward, x, context, timesteps) + else: + return self._forward(x, context, timesteps=timesteps) + + def _forward(self, x, context=None, timesteps=None): + + assert self.timesteps or timesteps + assert not (self.timesteps and timesteps) or self.timesteps == timesteps + timesteps = self.timesteps or timesteps + B, S, C = x.shape + x = rearrange(x, "(b t) s c -> (b s) t c", t=timesteps) + + if self.ff_in: + x_skip = x + x = self.ff_in(self.norm_in(x)) + if self.is_res: + x += x_skip + #import pdb + #pdb.set_trace() + if self.disable_self_attn: + x = self.attn1(self.norm1(x), context=context) + x + else: + x = self.attn1(self.norm1(x)) + x + + if self.attn2 is not None: + if self.switch_temporal_ca_to_sa: + x = self.attn2(self.norm2(x)) + x + else: + x = self.attn2(self.norm2(x), context=context) + x + x_skip = x + x = self.ff(self.norm3(x)) + if self.is_res: + x += x_skip + + x = rearrange( + x, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps + ) + return x + + def get_last_layer(self): + return self.ff.net[-1].weight + + +class SpatialVideoTransformer(SpatialTransformer): + def __init__( + self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0.0, + use_linear=False, + context_dim=None, + use_spatial_context=False, + timesteps=None, + merge_strategy: str = "fixed", + merge_factor: float = 0.5, + time_context_dim=None, + ff_in=False, + checkpoint=False, + time_depth=1, + attn_mode="softmax", + disable_self_attn=False, + disable_temporal_crossattention=False, + max_time_embed_period: int = 10000, + use_apm: bool =False, + ): + super().__init__( + in_channels, + n_heads, + d_head, + depth=depth, + dropout=dropout, + attn_type=attn_mode, + use_checkpoint=checkpoint, + context_dim=context_dim, + use_linear=use_linear, + disable_self_attn=disable_self_attn, + use_apm=use_apm + ) + self.time_depth = time_depth + self.depth = depth + self.max_time_embed_period = max_time_embed_period + + time_mix_d_head = d_head + n_time_mix_heads = n_heads + + time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads) + + inner_dim = n_heads * d_head + if use_spatial_context: + time_context_dim = context_dim + + self.time_stack = nn.ModuleList( + [ + VideoTransformerBlock( + inner_dim, + n_time_mix_heads, + time_mix_d_head, + dropout=dropout, + context_dim=time_context_dim, + timesteps=timesteps, + checkpoint=checkpoint, + ff_in=ff_in, + inner_dim=time_mix_inner_dim, + attn_mode=attn_mode, + disable_self_attn=disable_self_attn, + disable_temporal_crossattention=disable_temporal_crossattention, + ) + for _ in range(self.depth) + ] + ) + + assert len(self.time_stack) == len(self.transformer_blocks) + + self.use_spatial_context = use_spatial_context + self.in_channels = in_channels + + time_embed_dim = self.in_channels * 4 + self.time_pos_embed = nn.Sequential( + linear(self.in_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, self.in_channels), + ) + + self.time_mixer = AlphaBlender( + alpha=merge_factor, merge_strategy=merge_strategy + ) + + def forward( + self, + x: torch.Tensor, + context: Optional[torch.Tensor] = None, + time_context: Optional[torch.Tensor] = None, + timesteps: Optional[int] = None, + image_only_indicator: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + _, _, h, w = x.shape + + + x_in = x + spatial_context = None + if exists(context): + spatial_context = context + + if self.use_spatial_context: + assert ( + context.ndim == 3 + ), f"n dims of spatial context should be 3 but are {context.ndim}" + + time_context = context + time_context_first_timestep = time_context[::timesteps] + time_context = repeat( + time_context_first_timestep, "b ... -> (b n) ...", n=h * w + ) + elif time_context is not None and not self.use_spatial_context: + time_context = repeat(time_context, "b ... -> (b n) ...", n=h * w) + if time_context.ndim == 2: + time_context = rearrange(time_context, "b c -> b 1 c") + + x = self.norm(x) + if not self.use_linear: + x = self.proj_in(x) + x = rearrange(x, "b c h w -> b (h w) c") + if self.use_linear: + x = self.proj_in(x) + + num_frames = torch.arange(timesteps, device=x.device) + num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps) + num_frames = rearrange(num_frames, "b t -> (b t)") + t_emb = timestep_embedding( + num_frames, + self.in_channels, + repeat_only=False, + max_period=self.max_time_embed_period, + ) + if os.getenv("STREAMING_USE_FP16", "False") == "True": + t_emb = t_emb.to(x.dtype) + emb = self.time_pos_embed(t_emb) + emb = emb[:, None, :] + + for it_, (block, mix_block) in enumerate( + zip(self.transformer_blocks, self.time_stack) + ): + x = block( + x, + context=spatial_context, + ) + + x_mix = x + x_mix = x_mix + emb + + x_mix = mix_block(x_mix, context=time_context, timesteps=timesteps) + x = self.time_mixer( + x_spatial=x, + x_temporal=x_mix, + image_only_indicator=image_only_indicator, + ) + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) + if not self.use_linear: + x = self.proj_out(x) + out = x + x_in + return out diff --git a/models/svd/sgm/util.py b/models/svd/sgm/util.py new file mode 100644 index 0000000000000000000000000000000000000000..79d4557f99731bd13cd1a43e2a37707065d2210a --- /dev/null +++ b/models/svd/sgm/util.py @@ -0,0 +1,275 @@ +import functools +import importlib +import os +from functools import partial +from inspect import isfunction + +import fsspec +import numpy as np +import torch +from PIL import Image, ImageDraw, ImageFont +from safetensors.torch import load_file as load_safetensors + + +def disabled_train(self, mode=True): + """Overwrite models.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def get_string_from_tuple(s): + try: + # Check if the string starts and ends with parentheses + if s[0] == "(" and s[-1] == ")": + # Convert the string to a tuple + t = eval(s) + # Check if the type of t is tuple + if type(t) == tuple: + return t[0] + else: + pass + except: + pass + return s + + +def is_power_of_two(n): + """ + chat.openai.com/chat + Return True if n is a power of 2, otherwise return False. + + The function is_power_of_two takes an integer n as input and returns True if n is a power of 2, otherwise it returns False. + The function works by first checking if n is less than or equal to 0. If n is less than or equal to 0, it can't be a power of 2, so the function returns False. + If n is greater than 0, the function checks whether n is a power of 2 by using a bitwise AND operation between n and n-1. If n is a power of 2, then it will have only one bit set to 1 in its binary representation. When we subtract 1 from a power of 2, all the bits to the right of that bit become 1, and the bit itself becomes 0. So, when we perform a bitwise AND between n and n-1, we get 0 if n is a power of 2, and a non-zero value otherwise. + Thus, if the result of the bitwise AND operation is 0, then n is a power of 2 and the function returns True. Otherwise, the function returns False. + + """ + if n <= 0: + return False + return (n & (n - 1)) == 0 + + +def autocast(f, enabled=True): + def do_autocast(*args, **kwargs): + with torch.cuda.amp.autocast( + enabled=enabled, + dtype=torch.get_autocast_gpu_dtype(), + cache_enabled=torch.is_autocast_cache_enabled(), + ): + return f(*args, **kwargs) + + return do_autocast + + +def load_partial_from_config(config): + return partial(get_obj_from_str(config["target"]), **config.get("params", dict())) + + +def log_txt_as_img(wh, xc, size=10): + # wh a tuple of (width, height) + # xc a list of captions to plot + b = len(xc) + txts = list() + for bi in range(b): + txt = Image.new("RGB", wh, color="white") + draw = ImageDraw.Draw(txt) + font = ImageFont.truetype("data/DejaVuSans.ttf", size=size) + nc = int(40 * (wh[0] / 256)) + if isinstance(xc[bi], list): + text_seq = xc[bi][0] + else: + text_seq = xc[bi] + lines = "\n".join( + text_seq[start : start + nc] for start in range(0, len(text_seq), nc) + ) + + try: + draw.text((0, 0), lines, fill="black", font=font) + except UnicodeEncodeError: + print("Cant encode string for logging. Skipping.") + + txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 + txts.append(txt) + txts = np.stack(txts) + txts = torch.tensor(txts) + return txts + + +def partialclass(cls, *args, **kwargs): + class NewCls(cls): + __init__ = functools.partialmethod(cls.__init__, *args, **kwargs) + + return NewCls + + +def make_path_absolute(path): + fs, p = fsspec.core.url_to_fs(path) + if fs.protocol == "file": + return os.path.abspath(p) + return path + + +def ismap(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] > 3) + + +def isimage(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) + + +def isheatmap(x): + if not isinstance(x, torch.Tensor): + return False + + return x.ndim == 2 + + +def isneighbors(x): + if not isinstance(x, torch.Tensor): + return False + return x.ndim == 5 and (x.shape[2] == 3 or x.shape[2] == 1) + + +def exists(x): + return x is not None + + +def expand_dims_like(x, y): + while x.dim() != y.dim(): + x = x.unsqueeze(-1) + return x + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def mean_flat(tensor): + """ + https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") + return total_params + + +def instantiate_from_config(config): + if not "target" in config: + if config == "__is_first_stage__": + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def get_obj_from_str(string, reload=False, invalidate_cache=True): + module, cls = string.rsplit(".", 1) + if invalidate_cache: + importlib.invalidate_caches() + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def append_zero(x): + return torch.cat([x, x.new_zeros([1])]) + + +def append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError( + f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" + ) + return x[(...,) + (None,) * dims_to_append] + + +def load_model_from_config(config, ckpt, verbose=True, freeze=True): + print(f"Loading model from {ckpt}") + if ckpt.endswith("ckpt"): + pl_sd = torch.load(ckpt, map_location="cpu") + if "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") + sd = pl_sd["state_dict"] + elif ckpt.endswith("safetensors"): + sd = load_safetensors(ckpt) + else: + raise NotImplementedError + + model = instantiate_from_config(config.model) + + m, u = model.load_state_dict(sd, strict=False) + + if len(m) > 0 and verbose: + print("missing keys:") + print(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + print(u) + + if freeze: + for param in model.parameters(): + param.requires_grad = False + + model.eval() + return model + + +def get_configs_path() -> str: + """ + Get the `configs` directory. + For a working copy, this is the one in the root of the repository, + but for an installed copy, it's in the `sgm` package (see pyproject.toml). + """ + this_dir = os.path.dirname(__file__) + candidates = ( + os.path.join(this_dir, "configs"), + os.path.join(this_dir, "..", "configs"), + ) + for candidate in candidates: + candidate = os.path.abspath(candidate) + if os.path.isdir(candidate): + return candidate + raise FileNotFoundError(f"Could not find SGM configs in {candidates}") + + +def get_nested_attribute(obj, attribute_path, depth=None, return_key=False): + """ + Will return the result of a recursive get attribute call. + E.g.: + a.b.c + = getattr(getattr(a, "b"), "c") + = get_nested_attribute(a, "b.c") + If any part of the attribute call is an integer x with current obj a, will + try to call a[x] instead of a.x first. + """ + attributes = attribute_path.split(".") + if depth is not None and depth > 0: + attributes = attributes[:depth] + assert len(attributes) > 0, "At least one attribute should be selected" + current_attribute = obj + current_key = None + for level, attribute in enumerate(attributes): + current_key = ".".join(attributes[: level + 1]) + try: + id_ = int(attribute) + current_attribute = current_attribute[id_] + except ValueError: + current_attribute = getattr(current_attribute, attribute) + + return (current_attribute, current_key) if return_key else current_attribute diff --git a/modules/__init__.py b/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/modules/loader/__init__.py b/modules/loader/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/modules/loader/module_loader.py b/modules/loader/module_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..bf8dd4fd3ac241c277dec45fc3d975d479ade936 --- /dev/null +++ b/modules/loader/module_loader.py @@ -0,0 +1,232 @@ +from diffusers import DDPMScheduler, DiffusionPipeline +from typing import List, Any, Union, Type +from utils.loader import get_class +from copy import deepcopy +from modules.loader.module_loader_config import ModuleLoaderConfig +import torch +import pytorch_lightning as pl +import jsonargparse + + +class bcolors: + HEADER = '\033[95m' + OKBLUE = '\033[94m' + OKCYAN = '\033[96m' + OKGREEN = '\033[92m' + WARNING = '\033[93m' + FAIL = '\033[91m' + ENDC = '\033[0m' + BOLD = '\033[1m' + UNDERLINE = '\033[4m' + + +class GenericModuleLoader(): + + def __init__(self, + pipeline_repo: str = None, + pipeline_obj: str = None, + set_prediction_type: str = "", + module_names: List[str] = [ + "scheduler", "text_encoder", "tokenizer", "vae", "unet",], + module_config: dict[str, + Union[ModuleLoaderConfig, torch.nn.Module, Any]] = None, + fast_dev_run: Union[int, bool] = False, + root_cls: Type[Any] = None, + ) -> None: + self.module_config = module_config + self.pipeline_repo = pipeline_repo + self.pipeline_obj = pipeline_obj + self.set_prediction_type = set_prediction_type + self.module_names = module_names + self.fast_dev_run = fast_dev_run + self.root_cls = root_cls + + def load_custom_scheduler(self): + module_obj = DDPMScheduler.from_pretrained( + self.pipeline_repo, subfolder="scheduler") + + if len(self.set_prediction_type) > 0: + scheduler_config = module_obj.load_config( + self.pipeline_repo, subfolder="scheduler") + scheduler_config["prediction_type"] = self.set_prediction_type + module_obj = module_obj.from_config(scheduler_config) + return module_obj + + def load_pipeline(self): + return DiffusionPipeline.from_pretrained(self.pipeline_repo) if self.pipeline_repo is not None else None + + def __call__(self, trainer: pl.LightningModule, diff_trainer_params): + # load diffusers pipeline object if set + if self.pipeline_obj is not None: + pipe = self.load_pipeline() + else: + pipe = None + + if pipe is not None and self.pipeline_obj is not None: + # store the entire diffusers pipeline object under the name given by pipeline_obj + setattr(trainer, self.pipeline_obj, self.load_pipeline()) + + for module_name in self.module_names: + print(f" --- START: Loading module: {module_name} ---") + if module_name not in self.module_config.keys() and pipe is not None: + # stores models from already loaded diffusers pipeline + module_obj = getattr(pipe, module_name) + if module_name == "scheduler": + module_obj = self.load_custom_scheduler() + setattr(trainer, module_name, module_obj) + else: + if not isinstance(self.module_config[module_name], ModuleLoaderConfig): + # instantiate model by jsonargparse and store it + module = self.module_config[module_name] + # TODO we want to be able to load ckpt still. + config_obj = None + else: + # instantiate object from class method (as used by Diffusers, e.g. DiffusionPipeline.load_from_pretrained) + config_obj = self.module_config[module_name] + # retrieve loader class + loader_cls = get_class( + config_obj.loader_cls_path) + + # retrieve loader method + if config_obj.cls_func != "": + # we allow to specify a method for fast loading (e.g. in diffusers, from_config instead of from_pretrained) + # makes loading faster for quick testing + if not self.fast_dev_run or config_obj.cls_func_fast_dev_run == "": + cls_func = getattr( + loader_cls, config_obj.cls_func) + else: + print( + f"Model {module_name}: loading fast_dev_run class loader") + cls_func = getattr( + loader_cls, config_obj.cls_func_fast_dev_run) + else: + cls_func = loader_cls + + # retrieve parameters + # load parameters specified in diff_trainer_params (so it links them) + kwargs_trainer_params = config_obj.kwargs_diff_trainer_params + + kwargs_diffusers = config_obj.kwargs_diffusers + + # names of dependent modules that we need as input + dependent_modules = config_obj.dependent_modules + + # names of dependent modules that we need as input. Modules will be cloned + dependent_modules_cloned = config_obj.dependent_modules_cloned + + # model kwargs. Can be just a dict, or a parameter class (derived from modules.params.params_mixin.AsDictMixin) so we have verification of inputs + model_params = config_obj.model_params + + # kwargs used only if on fast_dev_run mode + model_params_fast_dev_run = config_obj.model_params_fast_dev_run + + if model_params is not None: + if isinstance(model_params, dict): + model_dict = model_params + else: + model_dict = model_params.to_dict() + else: + model_dict = {} + + if (model_params_fast_dev_run is None) or (not self.fast_dev_run): + model_params_fast_dev_run = {} + else: + print( + f"Module {module_name}: loading fast_dev_run params") + + loaded_modules_dict = {} + if dependent_modules is not None: + for key, dependent_module in dependent_modules.items(): + assert hasattr( + trainer, dependent_module), f"Module {dependent_module} not available. Set {dependent_module} before module {module_name} in module_loader.module_names. Current order: {self.module_names}" + loaded_modules_dict[key] = getattr( + trainer, dependent_module) + + if dependent_modules_cloned is not None: + for key, dependent_module in dependent_modules_cloned.items(): + assert hasattr( + trainer, dependent_module), f"Module {dependent_module} not available. Set {dependent_module} before module {module_name} in module_loader.module_names. Current order: {self.module_names}" + loaded_modules_dict[key] = getattr( + trainer, deepcopy(dependent_module)) + if kwargs_trainer_params is not None: + for key, param in kwargs_trainer_params.items(): + if param is not None: + kwargs_trainer_params[key] = getattr( + diff_trainer_params, param) + else: + kwargs_trainer_params[key] = diff_trainer_params + else: + kwargs_trainer_params = {} + + if kwargs_diffusers is None: + kwargs_diffusers = {} + else: + for key, value in kwargs_diffusers.items(): + if key == "torch_dtype": + if value == "torch.float16": + kwargs_diffusers[key] = torch.float16 + + kwargs = kwargs_diffusers | loaded_modules_dict | kwargs_trainer_params | model_dict | model_params_fast_dev_run + args = config_obj.args + # instantiate object + module = cls_func(*args, **kwargs) + module: torch.nn.Module + if self.root_cls is not None: + assert isinstance(module, self.root_cls) + + if config_obj is not None and config_obj.state_dict_path != "" and not self.fast_dev_run: + # TODO extend loading to hf spaces + print( + f" * Loading checkpoint {config_obj.state_dict_path} - STARTED") + module_state_dict = torch.load( + config_obj.state_dict_path, map_location=torch.device("cpu")) + module_state_dict = module_state_dict["state_dict"] + + if len(config_obj.state_dict_filters) > 0: + assert not config_obj.strict_loading + ckpt_params_dict = {} + for name, param in module.named_parameters(prefix=module_name): + for filter_str in config_obj.state_dict_filters: + filter_groups = filter_str.split("*") + has_all_parts = True + for filter_group in filter_groups: + has_all_parts = has_all_parts and filter_group in name + + if has_all_parts: + validate_name = name + for filter_group in filter_groups: + if filter_group in validate_name: + shift = validate_name.index( + filter_group) + validate_name = validate_name[shift+len( + filter_group):] + else: + has_all_parts = False + break + if has_all_parts: + ckpt_params_dict[name[len( + module_name+"."):]] = param + else: + ckpt_params_dict = dict(filter(lambda x: x[0].startswith( + module_name), module_state_dict.items())) + ckpt_params_dict = { + k.split(module_name+".")[1]: v for (k, v) in ckpt_params_dict.items()} + + if len(ckpt_params_dict) > 0: + miss, unex = module.load_state_dict( + ckpt_params_dict, strict=config_obj.strict_loading) + ckpt_params_dict = {} + assert len( + unex) == 0, f"Unexpected parameters in checkpoint: {unex}" + if len(miss) > 0: + print( + f"Checkpoint {config_obj.state_dict_path} is missing parameters for module {module_name}.") + print(miss) + print( + f" * Loading checkpoint {config_obj.state_dict_path} - FINISHED") + if isinstance(module, jsonargparse.Namespace) or isinstance(module, dict): + print(bcolors.WARNING + + f"Warning: Seems object {module_name} was not build correct." + bcolors.ENDC) + + setattr(trainer, module_name, module) + print(f" --- FINSHED: Loading module: {module_name} ---") diff --git a/modules/loader/module_loader_config.py b/modules/loader/module_loader_config.py new file mode 100644 index 0000000000000000000000000000000000000000..54911cf12003b82a0045ced5cf2d201c30a72d4f --- /dev/null +++ b/modules/loader/module_loader_config.py @@ -0,0 +1,39 @@ +from typing import Any, Union, List + + +class ModuleLoaderConfig: + + def __init__(self, + loader_cls_path: str, + cls_func: str = "", + cls_func_fast_dev_run: str = "", + kwargs_diffusers: dict[str, Any] = None, + # model kwargs. Can be just a dict, or a parameter class (derived from modules.params.params_mixin.AsDictMixin) so we have verification of inputs + model_params: Any = None, + # kwargs activated only if on fast_dev_run mode + model_params_fast_dev_run: Any = None, + # load parameters specified in diff_trainer_params (so it links them) + kwargs_diff_trainer_params: dict[str, + Union[str, None]] = None, + args: List[Any] = [], + # names of dependent modules that we need as input + dependent_modules: dict[str, str] = None, + # names of dependent modules that we need as input. Modules will be cloned + dependent_modules_cloned: dict[str, str] = None, + state_dict_path: str = "", + strict_loading: bool = True, + state_dict_filters: List[str] = [] + ) -> None: + self.loader_cls_path = loader_cls_path + self.cls_func = cls_func + self.cls_func_fast_dev_run = cls_func_fast_dev_run + self.kwargs_diffusers = kwargs_diffusers + self.dependent_modules = dependent_modules + self.dependent_modules_cloned = dependent_modules_cloned + self.kwargs_diff_trainer_params = kwargs_diff_trainer_params + self.model_params = model_params + self.state_dict_path = state_dict_path + self.strict_loading = strict_loading + self.state_dict_filters = state_dict_filters + self.model_params_fast_dev_run = model_params_fast_dev_run + self.args = args diff --git a/modules/params/diffusion/inference_params.py b/modules/params/diffusion/inference_params.py new file mode 100644 index 0000000000000000000000000000000000000000..9d01631924d11303f27722cf715b01dbf62dc51f --- /dev/null +++ b/modules/params/diffusion/inference_params.py @@ -0,0 +1,25 @@ +from typing import List +from pathlib import Path +from modules.params.params_mixin import AsDictMixin + + +class InferenceParams(AsDictMixin): + def __init__(self, + # reset seed (only for inference) at every start + reset_seed_per_generation: bool = True, + ): + super().__init__() + self.reset_seed_per_generation = reset_seed_per_generation + +class T2VInferenceParams(InferenceParams): + def __init__(self, + n_autoregressive_generations: int = 1, + num_conditional_frames: int = 8, # during GENERATION, take the last frames,i.e. [:-num_conditional_frames] + # can be "15", i.e. take always the 16th frame of the entire video, or a range "-8:-1", take always frames -8:-1 of the last chunk + anchor_frames: str = "15", + **kwargs + ): + super().__init__(**kwargs) + self.n_autoregressive_generations = n_autoregressive_generations + self.num_conditional_frames = num_conditional_frames + self.anchor_frames = anchor_frames diff --git a/modules/params/diffusion_trainer/params_streaming_diff_trainer.py b/modules/params/diffusion_trainer/params_streaming_diff_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..bf66800d92ac00b70049affe64e4fad6295ca89a --- /dev/null +++ b/modules/params/diffusion_trainer/params_streaming_diff_trainer.py @@ -0,0 +1,24 @@ +from modules.params.params_mixin import AsDictMixin + + +class CheckpointDescriptor(): + + def __init__(self, + ckpt_path_local: str = None, + ckpt_path_global: str = None): + self.ckpt_path_local = ckpt_path_local + self.ckpt_path_global = ckpt_path_global + + +class DiffusionTrainerParams(AsDictMixin): + + def __init__(self, + scale_factor: float = 0.18215, + streamingsvd_ckpt: CheckpointDescriptor = None, + disable_first_stage_autocast: bool = False, + **kwargs, + ): + super().__init__(**kwargs) + self.scale_factor = scale_factor + self.streamingsvd_ckpt = streamingsvd_ckpt + self.disable_first_stage_autocast = disable_first_stage_autocast diff --git a/modules/params/i2v_enhance.py b/modules/params/i2v_enhance.py new file mode 100644 index 0000000000000000000000000000000000000000..824bade1dddaaebe13aeb724a85594c0d7e58e94 --- /dev/null +++ b/modules/params/i2v_enhance.py @@ -0,0 +1,12 @@ +from modules.params.params_mixin import AsDictMixin + + +class I2VEnhanceParams(AsDictMixin): + + def __init__(self, + ckpt_path_local: str = "", + ckpt_path_global: str = "", + ) -> None: + super().__init__() + self.ckpt_path_local = ckpt_path_local + self.ckpt_path_global = ckpt_path_global diff --git a/modules/params/params_mixin.py b/modules/params/params_mixin.py new file mode 100644 index 0000000000000000000000000000000000000000..b8fb4d614822047f28806298552f70bcdbd8d2dc --- /dev/null +++ b/modules/params/params_mixin.py @@ -0,0 +1,14 @@ + +class AsDictMixin: + def to_dict(self): + + keys = [entry for entry in dir(self) if not callable(getattr( + self, entry)) and not entry.startswith("__")] + + result_dict = {} + for key in keys: + result_dict[key] = getattr(self, key) + return result_dict + + def __str__(self) -> str: + return self.to_dict().__str__() \ No newline at end of file diff --git a/modules/params/vfi.py b/modules/params/vfi.py new file mode 100644 index 0000000000000000000000000000000000000000..4b2927088ea95b73234016d36b1dfce1d15a8e07 --- /dev/null +++ b/modules/params/vfi.py @@ -0,0 +1,11 @@ +from modules.params.params_mixin import AsDictMixin + +class VFIParams(AsDictMixin): + + def __init__(self, + ckpt_path_local: str = "", + ckpt_path_global: str = "", + ) -> None: + super().__init__() + self.ckpt_path_local = ckpt_path_local + self.ckpt_path_global = ckpt_path_global \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..45a9162b2c2c8982a61b82a09d1fa2f49847c2cb --- /dev/null +++ b/requirements.txt @@ -0,0 +1,28 @@ +torch==2.0.1 +torchvision==0.15.2 +diffusers==0.30.2 +imageio==2.34.0 +imageio-ffmpeg==0.4.9 +transformers==4.40.0 +numpy==1.26.4 +einops==0.7.0 +pillow==10.3.0 +timm==1.0.7 +tokenizers==0.19.1 +triton==2.0.0 +typing_extensions==4.11.0 +torchsde==0.2.6 +tqdm==4.66.2 +pytorch-lightning==2.2.2 +jsonargparse==4.28.0 +matplotlib==3.8.4 +omegaconf==2.3.0 +typeshed_client==2.5.1 +docstring_parser==0.16 +kornia==0.7.2 +open-clip-torch==2.24.0 +xformers==0.0.20 +opencv-python==4.10.0.84 +accelerate==0.29.3 +gdown==5.2.0 +gradio==4.43.0 \ No newline at end of file diff --git a/streaming_svd_inference.py b/streaming_svd_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..472ddb5c254d7104819034be1a810cdb609644c0 --- /dev/null +++ b/streaming_svd_inference.py @@ -0,0 +1,313 @@ +import numpy as np +from lib.farancia import IImage +from PIL import Image +from i2v_enhance import i2v_enhance_interface +from dataloader.dataset_factory import SingleImageDatasetFactory +from pytorch_lightning import Trainer, LightningDataModule, seed_everything +import math +from diffusion_trainer import streaming_svd as streaming_svd_model +import torch +from safetensors.torch import load_file as load_safetensors +from utils.loader import download_ckpt +from functools import partial +from dataloader.video_data_module import VideoDataModule +from pathlib import Path +from pytorch_lightning.cli import LightningCLI, LightningArgumentParser +from pytorch_lightning import LightningModule +import sys +import os +from copy import deepcopy +from utils.aux import ensure_annotation_class +from diffusers import FluxPipeline +from typing import Union + + +class CustomCLI(LightningCLI): + + def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: + parser.add_argument("--image", type=Path, + help="Path to the input image(s)") + parser.add_argument("--output", type=Path, + help="Path to the output folder") + parser.add_argument("--num_frames", type=int, default=200, + help="Number of frames to generate.") + parser.add_argument("--out_fps", type=int, default=24, + help="Framerate of the generated video.") + parser.add_argument("--chunk_size", type=int, default=38, + help="Chunk size used in randomized blending.") + parser.add_argument("--overlap_size", type=int, default=12, + help="Overlap size used in randomized blending.") + parser.add_argument("--use_randomized_blending", action="store_true", + help="Wether to use randomized blending.") + parser.add_argument("--use_fp16", action="store_true", + help="Wether to use float16 quantization.") + parser.add_argument("--prompt", type=str, default = "") + + return parser + + +class StreamingSVD(): + + def __init__(self, load_argv = True) -> None: + + call_fol = Path(os.getcwd()).resolve() + + code_fol = Path(__file__).resolve().parent + code_fol = os.path.relpath(code_fol, call_fol) + argv_backup = deepcopy(sys.argv) + + if "--use_fp16" in sys.argv: + os.environ["STREAMING_USE_FP16"] = "True" + sys.argv = [__file__] + sys.argv.extend(self.__config_call(argv_backup[1:] if load_argv else [], code_fol)) + cli = CustomCLI(LightningModule, run=False, subclass_mode_model=True, parser_kwargs={ + "parser_mode": "omegaconf"}, save_config_callback=None) + self.__init_models(cli) + self.__init_fields(cli) + + sys.argv = argv_backup + + def __init_models(self, cli): + model = cli.model + trainer = cli.trainer + + path = download_ckpt( + local_path=model.diff_trainer_params.streamingsvd_ckpt.ckpt_path_local, + global_path=model.diff_trainer_params.streamingsvd_ckpt.ckpt_path_global + ) + if path.endswith(".safetensors"): + ckpt = load_safetensors(path) + else: + ckpt = torch.load(path, map_location="cpu")["state_dict"] + + model.load_state_dict(ckpt) # load trained model + trainer = cli.trainer + data_module_loader = partial(VideoDataModule, workers=2) + vfi = i2v_enhance_interface.vfi_init(model.vfi) + + enhance_pipeline, enhance_generator = i2v_enhance_interface.i2v_enhance_init( + model.i2v_enhance) + + flux_pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) + flux_pipe.enable_model_cpu_offload() + + # store of objects + model: streaming_svd_model + data_module_loader: LightningDataModule + trainer: Trainer + + self.model = model + self.vfi = vfi + self.data_module_loader = data_module_loader + self.enhance_pipeline = enhance_pipeline + self.enhance_generator = enhance_generator + self.trainer = trainer + self.flux_pipe = flux_pipe + + def __init_fields(self, cli): + self.input_path = cli.config["image"] + self.output_path = cli.config["output"] + self.num_frames = cli.config["num_frames"] + self.fps = cli.config["out_fps"] + self.use_randomized_blending = cli.config["use_randomized_blending"] + self.chunk_size = cli.config["chunk_size"] + self.overlap_size = cli.config["overlap_size"] + self.prompt = cli.config["prompt"] + + def __config_call(self, config_cmds, code_fol): + cmds = [cmd for cmd in config_cmds if len(cmd) > 0] + cmd_init = [] + cmd_init.append(f"--config") + cmd_init.append(f"{code_fol}/config.yaml") + if "--use_fp16" in config_cmds: + cmd_init.append(f"--trainer.precision=16-true") + cmd_init.extend(cmds) + return cmd_init + + # interfaces + + def streaming_t2v(self, prompt, num_frames: int, use_randomized_blending: bool = False, chunk_size: int = 38, overlap_size: int = 12, seed=33): + image = self.text_to_image(prompt=prompt) + return self.streaming_i2v(image, num_frames=num_frames, use_randomized_blending=use_randomized_blending, chunk_size=chunk_size, overlap_size=overlap_size, seed=seed) + + def streaming_i2v(self, image, num_frames: int, use_randomized_blending: bool = False, chunk_size: int = 38, overlap_size: int = 12, seed=33) -> np.array: + video, scaled_outpainted_image, expanded_size = self.image_to_video( + image, num_frames=(num_frames+1)//2, seed=seed) + max_memory_allocated = torch.cuda.max_memory_allocated() + print( + f"max_memory_allocated at image_to_video: {max_memory_allocated}") + video = self.enhance_video(image=IImage(scaled_outpainted_image).numpy(), video=video, chunk_size=chunk_size, overlap_size=overlap_size, + use_randomized_blending=use_randomized_blending, seed=seed) + video = self.interpolate_video(video, dest_num_frames=num_frames) + + # scale/crop back to input size + if image.shape[0] == 1: + image = image[0] + video = IImage(video, vmin=0, vmax=255).resize(expanded_size[::-1]).crop((0, 0, image.shape[1], image.shape[0])).numpy() + + print( + f"max_memory_allocated at interpolate_video: {max_memory_allocated}") + return video + + # StreamingSVD pipeline + def streaming(self, image: np.ndarray): + + datamodule = self.data_module_loader(predict_dataset_factory=SingleImageDatasetFactory( + file=image)) + self.trainer.predict(model=self.model, datamodule=datamodule) + video = self.trainer.generated_video + expanded_size = self.trainer.expanded_size + scaled_outpainted_image = self.trainer.scaled_outpainted_image + + return video, scaled_outpainted_image, expanded_size + + def image_to_video(self, image: Union[np.ndarray, str], num_frames: int, seed=33) -> tuple[np.ndarray,Image,list[int]]: + seed_everything(seed) + if isinstance(image, str): + image = IImage.open(image).numpy() + + if image.shape[0] == 1 and image.ndim == 4: + image = image[0] + + assert image.shape[-1] == 3 and image.shape[0] > 1, "Wrong image format. Assuming shape [H W C], with C = 3." + assert image.dtype == "uint8", "Wrong dtype for input image. Must be uint8." + # compute necessary number of chunks + n_cond_frames = self.model.inference_params.num_conditional_frames + n_frames_per_gen = self.model.sampler.guider.num_frames + n_autoregressive_generations = math.ceil( + (num_frames - n_frames_per_gen) / (n_frames_per_gen - n_cond_frames)) + self.model.inference_params.n_autoregressive_generations = int( + n_autoregressive_generations) + + print(" --- STREAMING ----- [START]") + video, scaled_outpainted_image, expanded_size = self.streaming( + image=image) + print(f" --- STREAMING ----- [FINISHED]: {video.shape}") + + video = video[:num_frames] + return video, scaled_outpainted_image, expanded_size + + def enhance_video(self, video: Union[np.ndarray, str], image: np.ndarray = None, chunk_size = 38, overlap_size=12, strength=0.97, use_randomized_blending=False, seed=33,num_frames = None): + + seed_everything(seed) + if isinstance(video, str): + video = IImage.open(video).numpy() + if image is None: + image = video[0] + print("ATTENTION: We take first frame of previous stage as input frame for enhance. ") + + if num_frames is not None: + video = video[:num_frames, ...] + + if not use_randomized_blending: + chunk_size = video.shape[0] + overlap_size = 0 + if image.ndim == 3: + image = image[None] + image = [Image.fromarray( + IImage(image, vmin=0, vmax=255).resize((720, 1280)).numpy()[0])] + + video = np.split(video, video.shape[0]) + video = [Image.fromarray(frame[0]).resize((1280, 720)) + for frame in video] + + print( + f"---- ENHANCE ---- [START]. Video length = {len(video)}. Randomized Blending = {use_randomized_blending}. Chunk size = {chunk_size}. Overlap size = {overlap_size}.") + video_enhanced = i2v_enhance_interface.i2v_enhance_process( + image=image, video=video, pipeline=self.enhance_pipeline, generator=self.enhance_generator, + chunk_size=chunk_size, overlap_size=overlap_size, strength=strength, use_randomized_blending=use_randomized_blending) + video_enhanced = np.stack([np.asarray(frame) + for frame in video_enhanced], axis=0) + print("---- ENHANCE ---- [FINISHED].") + return video_enhanced + + def interpolate_video(self, video: np.ndarray, dest_num_frames: int): + video = np.split(video, len(video)) + video = [frame[0] for frame in video] + + print(" ---- VFI ---- [START]") + self.vfi.device() + video_vfi = i2v_enhance_interface.vfi_process( + video=video, vfi=self.vfi, video_len=dest_num_frames) + video_vfi = np.stack([np.asarray(frame) + for frame in video_vfi], axis=0) + self.vfi.unload() + print(f"---- VFI ---- [FINISHED]. Video length = {len(video_vfi)}") + return video_vfi + + # T2I method + + def text_to_image(self, prompt, seed=33): + # FLUX + print("[FLUX] Generating image from text prompt") + out = self.flux_pipe( + prompt=prompt, + guidance_scale=0, + height=720, + width=1280, + num_inference_steps=4, + max_sequence_length=256, + generator=torch.Generator( + device=self.model.device).manual_seed(seed), + ).images[0] + print("[FLUX] Finished") + return np.array(out) + + +if __name__ == "__main__": + + @ensure_annotation_class + def get_input_data(input_path: Path = None): + if input_path.is_file(): + inputs = [input_path] + else: + suffixes = ["*.[jJ][pP][gG]", "*.[pP][nN][gG]", + "*.[jJ][pP][eE][gG]", "*.[bB][mM][pP]"] # loading png, jpg and bmp images + inputs = [] + for suffix in suffixes: + inputs.extend(list(input_path.glob(suffix))) + assert len( + inputs) > 0, "No images found. Please make sure the input path is correct." + + img_as_np = [IImage.open(input).numpy() for input in inputs] + return zip(img_as_np, inputs) + + streaming_svd = StreamingSVD() + num_frames = streaming_svd.num_frames + chunk_size = streaming_svd.chunk_size + overlap_size = streaming_svd.overlap_size + use_randomized_blending = streaming_svd.use_randomized_blending + if not use_randomized_blending: + chunk_size = (num_frames + 1)//2 + overlap_size = 0 + result_path = Path(streaming_svd.output_path) + seed = 33 + + assert result_path.exists() is False or result_path.is_dir( + ), "Output path must be the path to a folder." + prompt = streaming_svd.prompt + if len(prompt) == 0: + for img, img_path in get_input_data(streaming_svd.input_path): + video = streaming_svd.streaming_i2v( + image=img, num_frames=num_frames, use_randomized_blending=use_randomized_blending, chunk_size=chunk_size, overlap_size=overlap_size, seed=33) + if not result_path.exists(): + result_path.mkdir(parents=True) + result_file = result_path / (img_path.stem+".mp4") + result_file = result_file.as_posix() + IImage(video, vmin=0, vmax=255).setFps( + streaming_svd.fps).save(result_file) + print(f"Video created at: {result_file}") + else: + video = streaming_svd.streaming_t2v( + prompt=prompt, num_frames=num_frames, use_randomized_blending=use_randomized_blending, chunk_size=chunk_size, overlap_size=overlap_size, seed=33) + prompt_file = prompt.replace(" ", "_").replace( + ".", "_").replace("/", "_").replace(":", "_") + prompt_file = prompt_file[:15] + if not result_path.exists(): + result_path.mkdir(parents=True) + result_file = result_path / (prompt_file+".mp4") + result_file = result_file.as_posix() + IImage(video, vmin=0, vmax=255).setFps( + streaming_svd.fps).save(result_file) + print(f"Video created at: {result_file}") diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/utils/aux.py b/utils/aux.py new file mode 100644 index 0000000000000000000000000000000000000000..4557b6976a7f429c53d1652c2920df8e487bca18 --- /dev/null +++ b/utils/aux.py @@ -0,0 +1,25 @@ +from types import FunctionType + + +def ensure_annotation_class(f: FunctionType): + + def wrapper(*args, **kwargs): + keys = tuple(f.__annotations__.keys()) + args_converted = () + for ar in enumerate(args): + expected_class = f.__annotations__.get(keys[ar[0]]) + if not isinstance(ar[1], expected_class): + args_converted += (expected_class(ar[1]),) + else: + args_converted += (ar[1],) + + kwargs_ensured_class = {} + for k, v in kwargs.items(): + expected_class = f.__annotations__.get(k) + if not isinstance(v, expected_class): + v = expected_class(v) + kwargs_ensured_class[k] = v + + return f(*args_converted, **kwargs_ensured_class) + + return wrapper diff --git a/utils/gradio_utils.py b/utils/gradio_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3a252d4e4ec9286a77d1b7ece4c7fcae1c4b6014 --- /dev/null +++ b/utils/gradio_utils.py @@ -0,0 +1,201 @@ +from streaming_svd_inference import StreamingSVD +from lib.farancia import IImage + +import datetime +from pathlib import Path +import os +import ast +from typing import Tuple +import numpy as np +from PIL import Image + + +def get_uuid(asset: str,cache: str) -> Path: + """ + Generate a unique filename based on the current timestamp and save it in the specified root folder. + Root Folder will be under environment variable GRADIO_TEMP_DIR, if specified, otherwise the current working directory. + Args: + root_fol (str): The root folder where the file will be saved. + + Returns: + Path: The path to the saved file. + + """ + file_name = "_".join( + "_".join(str(datetime.datetime.now()).split('.')).split(" "))+".mp4" + file = Path(cache) / asset / file_name + if not file.parent.exists(): + file.parent.mkdir(parents=True) + print(f"Saving file to {file}") + return file + + +def retrieve_intermediate_data(video_file: str) -> Tuple[list[int],list[int],Image.Image]: + """ + Retrieve intermediate data related to a video file, including expansion size, original size, and outpainted image. + + Args: + video_file (str): The path to the video file with "__cropped__" in its name. + + Returns: + Tuple[list[int], list[int], Image.Image]: A tuple containing the expansion size, original size, and outpainted image. + + Raises: + AssertionError: If the video file path is not a string or does not contain "__cropped__" in its name. + + """ + assert isinstance(video_file,str) and "__cropped__" in video_file,f"File {video_file} is missing __cropped__ keyword" + + video_file_expanded = video_file.replace( + "__cropped__", "__expanded__") + + # get the expansion size to obtain 16:9 aspect ratio + expanded_size = ast.literal_eval(Path(video_file_expanded.replace( + "__expanded__", "__meta_expanded_size__").replace("mp4", "txt")).read_text()) + # get the original size + orig_size = ast.literal_eval(Path(video_file_expanded.replace( + "__expanded__", "__meta_orig_size__").replace("mp4", "txt")).read_text()) + # get the outpainted image + scaled_outpainted_image = IImage.open(video_file_expanded.replace( + "__expanded__", "__anchor__").replace("mp4", "png")).numpy() + return expanded_size, orig_size, scaled_outpainted_image + + +def save_intermediate_data(video: np.ndarray, user_image: np.ndarray, video_path: Path, expanded_size: list[int], fps: int, scaled_outpainted_image: Image.Image): + """ + Save intermediate data related to the generated video, including resolution information and scaled outpainted image. + + Args: + video (np.ndarray): The generated video. + user_image (np.ndarray): The user image used for generating the video. + video_path (Path): The path to the generated video file. + expanded_size (list[int]): The expansion size information. + fps (int): The frames per second of the video. + scaled_outpainted_image (Image.Image): The scaled outpainted image. + + """ + # save resolution of outpainting (before scaling) + meta = video_path.parent / \ + ("__meta_expanded_size__"+video_path.name.replace("mp4", "txt")) + meta.write_text(str(expanded_size)) + + # save original resolution of user image + meta = video_path.parent / \ + ("__meta_orig_size__"+video_path.name.replace("mp4", "txt")) + meta.write_text(str([user_image.shape[1], user_image.shape[0]])) + + # save scaled outpainted first frame + anchor = video_path.parent / \ + ("__anchor__"+video_path.name.replace("mp4", "png")) + IImage(scaled_outpainted_image).save(anchor) + + # save video generated from outpainted image + video_path_expanded = video_path.parent / \ + ("__expanded__" + video_path.name) + IImage(video, vmin=0, vmax=255).setFps(fps).save(video_path_expanded) + + +def image_to_video_gradio(img: np.ndarray, streaming_svd: StreamingSVD, gradio_cache: str, fps: int =24, asset: str="first_stage", **kwargs: dict) -> str: + """ + Convert an image to a video using the provided streaming_svd object and perform additional processing steps. + + Args: + img: The input image to convert to video. + streaming_svd: The object used for converting the image to video. + fps (int, optional): The frames per second of the output video (default is 24). + root_fol (str, optional): The root folder where the video will be saved (default is "first_stage"). + **kwargs: Additional keyword arguments to pass to the streaming_svd object. + + Returns: + str: The path to the saved cropped video file. + + Note: We save several additional files to hard-drive using a path derived from the cropped video file. + * image-to-video result using outpainted image (key = __cropped__ ) + * the size of the outpainted image (key = __meta_expanded_size__ ) + * the size of the input image (key = __meta_orig_size__ ) + * the input image (key = __anchor__ ) + """ + + video, scaled_outpainted_image, expanded_size = streaming_svd.image_to_video(img, **kwargs) + video_path = get_uuid(asset,cache=gradio_cache) + video_path_cropped = video_path.parent / ("__cropped__" + video_path.name) + IImage(video, vmin=0, vmax=255).resize(expanded_size[::-1]).crop( + (0, 0, img.shape[1], img.shape[0])).setFps(fps).save(video_path_cropped) + save_intermediate_data(video=video, video_path=video_path, expanded_size=expanded_size, + fps=fps, user_image=img, scaled_outpainted_image=scaled_outpainted_image) + return video_path_cropped.as_posix() + + +def image_to_video_vfi_gradio(img: np.ndarray, streaming_svd: StreamingSVD, gradio_cache: str, fps: int =24, asset: str="first_stage", num_frames: int=None, **kwargs: dict) -> str: + """ + Convert an image to a video using the provided streaming_svd object and perform additional processing steps. Then applies VFI + + Args: + img: The input image to convert to video. + streaming_svd: The object used for converting the image to video. + fps (int, optional): The frames per second of the output video (default is 24). + root_fol (str, optional): The root folder where the video will be saved (default is "first_stage"). + **kwargs: Additional keyword arguments to pass to the streaming_svd object. + + Returns: + str: The path to the saved cropped video file. + + Note: We save several additional files to hard-drive using a path derived from the cropped video file. + * image-to-video result using outpainted image (key = __cropped__ ) + * the size of the outpainted image (key = __meta_expanded_size__ ) + * the size of the input image (key = __meta_orig_size__ ) + * the input image (key = __anchor__ ) + """ + + video, scaled_outpainted_image, expanded_size = streaming_svd.image_to_video(img, num_frames = (num_frames+1) // 2, **kwargs) + video = streaming_svd.interpolate_video(video, dest_num_frames=num_frames) + + video_path = get_uuid(asset,cache=gradio_cache) + video_path_cropped = video_path.parent / ("__cropped__" + video_path.name) + IImage(video, vmin=0, vmax=255).resize(expanded_size[::-1]).crop( + (0, 0, img.shape[1], img.shape[0])).setFps(fps).save(video_path_cropped) + save_intermediate_data(video=video, video_path=video_path, expanded_size=expanded_size, + fps=fps, user_image=img, scaled_outpainted_image=scaled_outpainted_image) + return video_path_cropped.as_posix() + + +def text_to_image_gradio(prompt: str, streaming_svd: StreamingSVD, **kwargs: dict) -> np.ndarray: + """ + Generate an image from the provided text prompt using the specified streaming_svd object. + + Args: + prompt (str): The text prompt used to generate the image. + streaming_svd (StreamingSVD): The object used for converting the text to an image. + **kwargs (dict): Additional keyword arguments to pass to the streaming_svd object. + + Returns: + np.ndarray: The generated image based on the text prompt. + + """ + return streaming_svd.text_to_image(prompt, **kwargs) + + +def enhance_video_vfi_gradio(img: np.ndarray, video : str, expanded_size: list[int], num_frames: int,gradio_cache:str, streaming_svd: StreamingSVD, fps: int = 24, asset="second_stage", orig_size: list[int] = None, **kwargs: dict) -> str: + """ + Enhance a video by applying our proposed enhancement (including randomized blending) to the video. + + Args: + img (np.ndarray): The input image used for enhancing the video. + video (str): The path to the input video to be enhanced. + expanded_size (list[int]): The size to which the video will be expanded. + streaming_svd (StreamingSVD): The object used for enhancing the video. + fps (int, optional): The frames per second of the output video (default is 24). + root_fol (str, optional): The root folder where the enhanced video will be saved (default is "second_stage_preview"). + orig_size (list[int], optional): The original size of the image (default is None). + **kwargs (dict): Additional keyword arguments to pass to the streaming_svd object for enhancement. + + Returns: + str: The path to the saved enhanced video file. + + """ + video_enh = streaming_svd.enhance_video(image=img, video=video, num_frames=(num_frames+1) // 2, **kwargs) + video_int = streaming_svd.interpolate_video(video_enh, dest_num_frames=num_frames) + video_path = get_uuid(asset, cache=gradio_cache) + IImage(video_int, vmin=0, vmax=255).resize( + expanded_size[::-1]).crop((0, 0, orig_size[0], orig_size[1])).setFps(fps).save(video_path) + return video_path.as_posix() \ No newline at end of file diff --git a/utils/inference_utils.py b/utils/inference_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1881dee0c9d77bf1545690ee7729b942c15fd8a0 --- /dev/null +++ b/utils/inference_utils.py @@ -0,0 +1,131 @@ +from math import ceil, floor +from PIL import Image +import torchvision.transforms as transforms +import numpy as np +pil_to_torch = transforms.Compose([ + transforms.PILToTensor() +]) +from typing import Tuple + +def get_padding_for_aspect_ratio(img: Image.Image, target_aspect_ratio: float = 16/9) -> list[int]: + aspect_ratio = img.width / img.height + + if aspect_ratio != target_aspect_ratio: + w_target = ceil(target_aspect_ratio*img.height) # r = w /h = w_i / h_i + h_target = floor(img.width * (1/target_aspect_ratio)) + + if w_target >= img.width: + w_scale = w_target / img.width + else: + w_scale = np.inf + + if h_target >= img.height: + h_scale = h_target / img.height + else: + h_scale = np.inf + + if min([h_scale, w_scale]) == h_scale: + scale_axis = 1 + target_size = h_target + else: + scale_axis = 0 + target_size = w_target + + pad_size = [0, 0, 0, 0] + img_size = img.size + pad_size[2+scale_axis] = int(target_size - img_size[scale_axis]) + return pad_size + else: + return None + + +def get_padding_for_aspect_ratio(img: Image, target_aspect_ratio: float = 16/9): + aspect_ratio = img.width / img.height + + if aspect_ratio != target_aspect_ratio: + w_target = ceil(target_aspect_ratio*img.height) # r = w /h = w_i / h_i + h_target = floor(img.width * (1/target_aspect_ratio)) + + if w_target >= img.width: + w_scale = w_target / img.width + else: + w_scale = np.inf + + if h_target >= img.height: + h_scale = h_target / img.height + else: + h_scale = np.inf + + if min([h_scale, w_scale]) == h_scale: + scale_axis = 1 + target_size = h_target + else: + scale_axis = 0 + target_size = w_target + + pad_size = [0, 0, 0, 0] + img_size = img.size + pad_size[2+scale_axis] = int(target_size - img_size[scale_axis]) + return pad_size + else: + return None + + +def add_margin(pil_img, top, right, bottom, left, color): + width, height = pil_img.size + new_width = width + right + left + new_height = height + top + bottom + result = Image.new(pil_img.mode, (new_width, new_height), color) + result.paste(pil_img, (left, top)) + return result + + +def resize_to_fit(image, size): + W, H = size + w, h = image.size + if H / h > W / w: + H_ = int(h * W / w) + W_ = W + else: + W_ = int(w * H / h) + H_ = H + return image.resize((W_, H_)) + + +def pad_to_fit(image, size): + W, H = size + w, h = image.size + pad_h = (H - h) // 2 + pad_w = (W - w) // 2 + return add_margin(image, pad_h, pad_w, pad_h, pad_w, (0, 0, 0)) + + +def resize_and_keep(pil_img): + expanded_size = [pil_img.width, pil_img.height] + myheight = 576 + hpercent = (myheight/float(pil_img.size[1])) + wsize = int((float(pil_img.size[0])*float(hpercent))) + pil_img = pil_img.resize((wsize, myheight)) + + return pil_img, expanded_size + + +def resize_and_crop(pil_img: Image.Image) -> Tuple[Image.Image, Tuple[int, int]]: + img, expanded_size = resize_and_keep(pil_img) + assert img.width >= 1024 and img.height >= 576,f"Got {img.width} and {img.height}" + return img.crop((0, 0, 1024, 576)), expanded_size + + +def center_crop(pil_img): + width, height = pil_img.size + new_width = 576 + new_height = 576 + + left = (width - new_width)/2 + top = (height - new_height)/2 + right = (width + new_width)/2 + bottom = (height + new_height)/2 + + # Crop the center of the image + pil_img = pil_img.crop((left, top, right, bottom)) + return pil_img diff --git a/utils/loader.py b/utils/loader.py new file mode 100644 index 0000000000000000000000000000000000000000..8dc159037d56410f5d56f233749a9aa19f085afc --- /dev/null +++ b/utils/loader.py @@ -0,0 +1,52 @@ +import importlib +from functools import partialmethod +from pathlib import Path +from torchvision.datasets.utils import download_url +import gdown +from utils.aux import ensure_annotation_class + + +def get_class(cls_path: str, *args, **kwargs): + module_name = ".".join(cls_path.split(".")[:-1]) + module = importlib.import_module(module_name) + + class_ = getattr(module, cls_path.split(".")[-1]) + class_.__init__ = partialmethod(class_.__init__, *args, **kwargs) + return class_ + + +@ensure_annotation_class +def download_ckpt(local_path: Path, global_path: str) -> str: + + if local_path.exists(): + return local_path.as_posix() + else: + if not local_path.parent.exists(): + local_path.parent.mkdir(parents=True) + + if "drive.google.com" in global_path and "file" in global_path: + url = global_path + dest = local_path.as_posix() + gdown.download(url=url, output=dest, fuzzy=True) + + elif "drive.google.com" in global_path and "folder" in global_path: + url = global_path + dest = local_path.parent.as_posix() + gdown.download_folder(url=url, output=dest) + + elif local_path.suffix == ".safetensors" or "." not in local_path.as_posix(): + ckpt_url = f"https://huggingface.co./{global_path}" + try: + download_url(ckpt_url, local_path.parent.as_posix(), + local_path.name) + except Exception as e: + print( + f"Error: Failed to download model from {ckpt_url} to {local_path}") + raise e + else: + raise NotImplementedError( + f"Download model file {global_path} not supported") + + assert local_path.exists(), f"Missing checkpoint {local_path}" + + return local_path.as_posix() diff --git a/utils/result_processor.py b/utils/result_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..dc3318f4ab11ed900ec5090b708d95b8a30338a7 --- /dev/null +++ b/utils/result_processor.py @@ -0,0 +1,29 @@ +from lib.farancia import IImage + + +def convert_range(video, output_range, input_range=None): + if input_range is None: + if video.min() < 0: + input_range = [-1, 1] + elif video.max() > 1: + input_range = [0, 255] + else: + input_range = [0, 1] + video = (video-input_range[0])/(input_range[1]-input_range[0]) # [0,1] + video = video * (output_range[1]-output_range[0]) + output_range[0] + return video + + +def concat_chunks(result_chunks): + if not isinstance(result_chunks, list): + result_chunks = [result_chunks] + concatenated_result = None + + for chunk in result_chunks: + assert chunk.min() >= 0 + chunk = IImage(chunk, vmin=0, vmax=255) + if concatenated_result is None: + concatenated_result = chunk + else: + concatenated_result &= chunk + return concatenated_result