import glob import logging import os import re from functools import partial from itertools import chain from os import PathLike from pathlib import Path from typing import Any, Callable, Dict, List, Union import numpy as np import torch from controlnet_aux import LineartAnimeDetector from controlnet_aux.processor import MODELS from controlnet_aux.processor import Processor as ControlnetPreProcessor from controlnet_aux.util import HWC3, ade_palette from controlnet_aux.util import resize_image as aux_resize_image from diffusers import (AutoencoderKL, ControlNetModel, DiffusionPipeline, EulerDiscreteScheduler, StableDiffusionControlNetImg2ImgPipeline, StableDiffusionPipeline, StableDiffusionXLPipeline) from PIL import Image from torchvision.datasets.folder import IMG_EXTENSIONS from tqdm.rich import tqdm from transformers import (AutoImageProcessor, CLIPImageProcessor, CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, UperNetForSemanticSegmentation) from animatediff import get_dir from animatediff.dwpose import DWposeDetector from animatediff.models.clip import CLIPSkipTextModel from animatediff.models.unet import UNet3DConditionModel from animatediff.pipelines import AnimationPipeline, load_text_embeddings from animatediff.pipelines.lora import load_lcm_lora, load_lora_map from animatediff.pipelines.pipeline_controlnet_img2img_reference import \ StableDiffusionControlNetImg2ImgReferencePipeline from animatediff.schedulers import DiffusionScheduler, get_scheduler from animatediff.settings import InferenceConfig, ModelConfig from animatediff.utils.control_net_lllite import (ControlNetLLLite, load_controlnet_lllite) from animatediff.utils.convert_from_ckpt import convert_ldm_vae_checkpoint from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora from animatediff.utils.model import (ensure_motion_modules, get_checkpoint_weights, get_checkpoint_weights_sdxl) from animatediff.utils.util import (get_resized_image, get_resized_image2, get_resized_images, get_tensor_interpolation_method, prepare_dwpose, prepare_extra_controlnet, prepare_ip_adapter, prepare_ip_adapter_sdxl, prepare_lcm_lora, prepare_lllite, prepare_motion_module, save_frames, save_imgs, save_video) controlnet_address_table={ "controlnet_tile" : ['lllyasviel/control_v11f1e_sd15_tile'], "controlnet_lineart_anime" : ['lllyasviel/control_v11p_sd15s2_lineart_anime'], "controlnet_ip2p" : ['lllyasviel/control_v11e_sd15_ip2p'], "controlnet_openpose" : ['lllyasviel/control_v11p_sd15_openpose'], "controlnet_softedge" : ['lllyasviel/control_v11p_sd15_softedge'], "controlnet_shuffle" : ['lllyasviel/control_v11e_sd15_shuffle'], "controlnet_depth" : ['lllyasviel/control_v11f1p_sd15_depth'], "controlnet_canny" : ['lllyasviel/control_v11p_sd15_canny'], "controlnet_inpaint" : ['lllyasviel/control_v11p_sd15_inpaint'], "controlnet_lineart" : ['lllyasviel/control_v11p_sd15_lineart'], "controlnet_mlsd" : ['lllyasviel/control_v11p_sd15_mlsd'], "controlnet_normalbae" : ['lllyasviel/control_v11p_sd15_normalbae'], "controlnet_scribble" : ['lllyasviel/control_v11p_sd15_scribble'], "controlnet_seg" : ['lllyasviel/control_v11p_sd15_seg'], "qr_code_monster_v1" : ['monster-labs/control_v1p_sd15_qrcode_monster'], "qr_code_monster_v2" : ['monster-labs/control_v1p_sd15_qrcode_monster', 'v2'], "controlnet_mediapipe_face" : ['CrucibleAI/ControlNetMediaPipeFace', "diffusion_sd15"], "animatediff_controlnet" : [None, "data/models/controlnet/animatediff_controlnet/controlnet_checkpoint.ckpt"] } # Edit this table if you want to change to another controlnet checkpoint controlnet_address_table_sdxl={ # "controlnet_openpose" : ['thibaud/controlnet-openpose-sdxl-1.0'], # "controlnet_softedge" : ['SargeZT/controlnet-sd-xl-1.0-softedge-dexined'], # "controlnet_depth" : ['diffusers/controlnet-depth-sdxl-1.0-small'], # "controlnet_canny" : ['diffusers/controlnet-canny-sdxl-1.0-small'], # "controlnet_seg" : ['SargeZT/sdxl-controlnet-seg'], "qr_code_monster_v1" : ['monster-labs/control_v1p_sdxl_qrcode_monster'], } # Edit this table if you want to change to another lllite checkpoint lllite_address_table_sdxl={ "controlnet_tile" : ['models/lllite/bdsqlsz_controlllite_xl_tile_anime_β.safetensors'], "controlnet_lineart_anime" : ['models/lllite/bdsqlsz_controlllite_xl_lineart_anime_denoise.safetensors'], # "controlnet_ip2p" : ('lllyasviel/control_v11e_sd15_ip2p'), "controlnet_openpose" : ['models/lllite/bdsqlsz_controlllite_xl_dw_openpose.safetensors'], # "controlnet_openpose" : ['models/lllite/controllllite_v01032064e_sdxl_pose_anime.safetensors'], "controlnet_softedge" : ['models/lllite/bdsqlsz_controlllite_xl_softedge.safetensors'], "controlnet_shuffle" : ['models/lllite/bdsqlsz_controlllite_xl_t2i-adapter_color_shuffle.safetensors'], "controlnet_depth" : ['models/lllite/bdsqlsz_controlllite_xl_depth.safetensors'], "controlnet_canny" : ['models/lllite/bdsqlsz_controlllite_xl_canny.safetensors'], # "controlnet_canny" : ['models/lllite/controllllite_v01032064e_sdxl_canny.safetensors'], # "controlnet_inpaint" : ('lllyasviel/control_v11p_sd15_inpaint'), # "controlnet_lineart" : ('lllyasviel/control_v11p_sd15_lineart'), "controlnet_mlsd" : ['models/lllite/bdsqlsz_controlllite_xl_mlsd_V2.safetensors'], "controlnet_normalbae" : ['models/lllite/bdsqlsz_controlllite_xl_normal.safetensors'], "controlnet_scribble" : ['models/lllite/bdsqlsz_controlllite_xl_sketch.safetensors'], "controlnet_seg" : ['models/lllite/bdsqlsz_controlllite_xl_segment_animeface_V2.safetensors'], # "qr_code_monster_v1" : ['monster-labs/control_v1p_sdxl_qrcode_monster'], # "qr_code_monster_v2" : ('monster-labs/control_v1p_sd15_qrcode_monster', 'v2'), # "controlnet_mediapipe_face" : ('CrucibleAI/ControlNetMediaPipeFace', "diffusion_sd15"), } try: import onnxruntime onnxruntime_installed = True except: onnxruntime_installed = False logger = logging.getLogger(__name__) data_dir = get_dir("data") default_base_path = data_dir.joinpath("models/huggingface/stable-diffusion-v1-5") re_clean_prompt = re.compile(r"[^\w\-, ]") controlnet_preprocessor = {} def load_safetensors_lora(text_encoder, unet, lora_path, alpha=0.75, is_animatediff=True): from safetensors.torch import load_file from animatediff.utils.lora_diffusers import (LoRANetwork, create_network_from_weights) sd = load_file(lora_path) print(f"create LoRA network") lora_network: LoRANetwork = create_network_from_weights(text_encoder, unet, sd, multiplier=alpha, is_animatediff=is_animatediff) print(f"load LoRA network weights") lora_network.load_state_dict(sd, False) #lora_network.merge_to(alpha) lora_network.apply_to(alpha) return lora_network def load_safetensors_lora2(text_encoder, unet, lora_path, alpha=0.75, is_animatediff=True): from safetensors.torch import load_file from animatediff.utils.lora_diffusers import (LoRANetwork, create_network_from_weights) sd = load_file(lora_path) print(f"create LoRA network") lora_network: LoRANetwork = create_network_from_weights(text_encoder, unet, sd, multiplier=alpha, is_animatediff=is_animatediff) print(f"load LoRA network weights") lora_network.load_state_dict(sd, False) lora_network.merge_to(alpha) def load_tensors(path:Path,framework="pt",device="cpu"): tensors = {} if path.suffix == ".safetensors": from safetensors import safe_open with safe_open(path, framework=framework, device=device) as f: for k in f.keys(): tensors[k] = f.get_tensor(k) # loads the full tensor given a key else: from torch import load tensors = load(path, device) if "state_dict" in tensors: tensors = tensors["state_dict"] return tensors def load_motion_lora(unet, lora_path:Path, alpha=1.0): state_dict = load_tensors(lora_path) # directly update weight in diffusers model for key in state_dict: # only process lora down key if "up." in key: continue up_key = key.replace(".down.", ".up.") model_key = key.replace("processor.", "").replace("_lora", "").replace("down.", "").replace("up.", "") model_key = model_key.replace("to_out.", "to_out.0.") layer_infos = model_key.split(".")[:-1] curr_layer = unet try: while len(layer_infos) > 0: temp_name = layer_infos.pop(0) curr_layer = curr_layer.__getattr__(temp_name) except: logger.info(f"{model_key} not found") continue weight_down = state_dict[key] weight_up = state_dict[up_key] curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device) class SegPreProcessor: def __init__(self): self.image_processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-convnext-small") self.processor = UperNetForSemanticSegmentation.from_pretrained("openmmlab/upernet-convnext-small") def __call__(self, input_image, detect_resolution=512, image_resolution=512, output_type="pil", **kwargs): input_array = np.array(input_image, dtype=np.uint8) input_array = HWC3(input_array) input_array = aux_resize_image(input_array, detect_resolution) pixel_values = self.image_processor(input_array, return_tensors="pt").pixel_values with torch.no_grad(): outputs = self.processor(pixel_values.to(self.processor.device)) outputs.loss = outputs.loss.to("cpu") if outputs.loss is not None else outputs.loss outputs.logits = outputs.logits.to("cpu") if outputs.logits is not None else outputs.logits outputs.hidden_states = outputs.hidden_states.to("cpu") if outputs.hidden_states is not None else outputs.hidden_states outputs.attentions = outputs.attentions.to("cpu") if outputs.attentions is not None else outputs.attentions seg = self.image_processor.post_process_semantic_segmentation(outputs, target_sizes=[input_image.size[::-1]])[0] color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3 for label, color in enumerate(ade_palette()): color_seg[seg == label, :] = color color_seg = color_seg.astype(np.uint8) color_seg = aux_resize_image(color_seg, image_resolution) color_seg = Image.fromarray(color_seg) return color_seg class NullPreProcessor: def __call__(self, input_image, **kwargs): return input_image class BlurPreProcessor: def __call__(self, input_image, sigma=5.0, **kwargs): import cv2 input_array = np.array(input_image, dtype=np.uint8) input_array = HWC3(input_array) dst = cv2.GaussianBlur(input_array, (0, 0), sigma) return Image.fromarray(dst) class TileResamplePreProcessor: def resize(self, input_image, resolution): import cv2 H, W, C = input_image.shape H = float(H) W = float(W) k = float(resolution) / min(H, W) H *= k W *= k img = cv2.resize(input_image, (int(W), int(H)), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA) return img def __call__(self, input_image, down_sampling_rate = 1.0, **kwargs): input_array = np.array(input_image, dtype=np.uint8) input_array = HWC3(input_array) H, W, C = input_array.shape target_res = min(H,W) / down_sampling_rate dst = self.resize(input_array, target_res) return Image.fromarray(dst) def is_valid_controlnet_type(type_str, is_sdxl): if not is_sdxl: return type_str in controlnet_address_table else: return (type_str in controlnet_address_table_sdxl) or (type_str in lllite_address_table_sdxl) def load_controlnet_from_file(file_path, torch_dtype): from safetensors.torch import load_file prepare_extra_controlnet() file_path = Path(file_path) if file_path.exists() and file_path.is_file(): if file_path.suffix.lower() in [".pth", ".pt", ".ckpt"]: controlnet_state_dict = torch.load(file_path, map_location="cpu", weights_only=True) elif file_path.suffix.lower() == ".safetensors": controlnet_state_dict = load_file(file_path, device="cpu") else: raise RuntimeError( f"unknown file format for controlnet weights: {file_path.suffix}" ) else: raise FileNotFoundError(f"no controlnet weights found in {file_path}") if file_path.parent.name == "animatediff_controlnet": model = ControlNetModel(cross_attention_dim=768) else: model = ControlNetModel() missing, _ = model.load_state_dict(controlnet_state_dict["state_dict"], strict=False) if len(missing) > 0: logger.info(f"ControlNetModel has missing keys: {missing}") return model.to(dtype=torch_dtype) def create_controlnet_model(pipe, type_str, is_sdxl): if not is_sdxl: if type_str in controlnet_address_table: addr = controlnet_address_table[type_str] if addr[0] != None: if len(addr) == 1: return ControlNetModel.from_pretrained(addr[0], torch_dtype=torch.float16) else: return ControlNetModel.from_pretrained(addr[0], subfolder=addr[1], torch_dtype=torch.float16) else: return load_controlnet_from_file(addr[1],torch_dtype=torch.float16) else: raise ValueError(f"unknown controlnet type {type_str}") else: if type_str in controlnet_address_table_sdxl: addr = controlnet_address_table_sdxl[type_str] if len(addr) == 1: return ControlNetModel.from_pretrained(addr[0], torch_dtype=torch.float16) else: return ControlNetModel.from_pretrained(addr[0], subfolder=addr[1], torch_dtype=torch.float16) elif type_str in lllite_address_table_sdxl: addr = lllite_address_table_sdxl[type_str] model_path = data_dir.joinpath(addr[0]) return load_controlnet_lllite(model_path, pipe, torch_dtype=torch.float16) else: raise ValueError(f"unknown controlnet type {type_str}") default_preprocessor_table={ "controlnet_lineart_anime":"lineart_anime", "controlnet_openpose": "openpose_full" if onnxruntime_installed==False else "dwpose", "controlnet_softedge":"softedge_hedsafe", "controlnet_shuffle":"shuffle", "controlnet_depth":"depth_midas", "controlnet_canny":"canny", "controlnet_lineart":"lineart_realistic", "controlnet_mlsd":"mlsd", "controlnet_normalbae":"normal_bae", "controlnet_scribble":"scribble_pidsafe", "controlnet_seg":"upernet_seg", "controlnet_mediapipe_face":"mediapipe_face", "qr_code_monster_v1":"depth_midas", "qr_code_monster_v2":"depth_midas", } def create_preprocessor_from_name(pre_type): if pre_type == "dwpose": prepare_dwpose() return DWposeDetector() elif pre_type == "upernet_seg": return SegPreProcessor() elif pre_type == "blur": return BlurPreProcessor() elif pre_type == "tile_resample": return TileResamplePreProcessor() elif pre_type == "none": return NullPreProcessor() elif pre_type in MODELS: return ControlnetPreProcessor(pre_type) else: raise ValueError(f"unknown controlnet preprocessor type {pre_type}") def create_default_preprocessor(type_str): if type_str in default_preprocessor_table: pre_type = default_preprocessor_table[type_str] else: pre_type = "none" return create_preprocessor_from_name(pre_type) def get_preprocessor(type_str, device_str, preprocessor_map): if type_str not in controlnet_preprocessor: if preprocessor_map: controlnet_preprocessor[type_str] = create_preprocessor_from_name(preprocessor_map["type"]) if type_str not in controlnet_preprocessor: controlnet_preprocessor[type_str] = create_default_preprocessor(type_str) if hasattr(controlnet_preprocessor[type_str], "processor"): if hasattr(controlnet_preprocessor[type_str].processor, "to"): if device_str: controlnet_preprocessor[type_str].processor.to(device_str) elif hasattr(controlnet_preprocessor[type_str], "to"): if device_str: controlnet_preprocessor[type_str].to(device_str) return controlnet_preprocessor[type_str] def clear_controlnet_preprocessor(type_str = None): global controlnet_preprocessor if type_str == None: for t in controlnet_preprocessor: controlnet_preprocessor[t] = None controlnet_preprocessor={} torch.cuda.empty_cache() else: controlnet_preprocessor[type_str] = None torch.cuda.empty_cache() def get_preprocessed_img(type_str, img, use_preprocessor, device_str, preprocessor_map): if use_preprocessor: param = {} if preprocessor_map: param = preprocessor_map["param"] if "param" in preprocessor_map else {} return get_preprocessor(type_str, device_str, preprocessor_map)(img, **param) else: return img def create_pipeline_sdxl( base_model: Union[str, PathLike] = default_base_path, model_config: ModelConfig = ..., infer_config: InferenceConfig = ..., use_xformers: bool = True, video_length: int = 16, motion_module_path = ..., ): from animatediff.pipelines.sdxl_animation import AnimationPipeline from animatediff.sdxl_models.unet import UNet3DConditionModel logger.info("Loading tokenizer...") tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(base_model, subfolder="tokenizer") logger.info("Loading text encoder...") text_encoder: CLIPTextModel = CLIPTextModel.from_pretrained(base_model, subfolder="text_encoder", torch_dtype=torch.float16) logger.info("Loading VAE...") vae: AutoencoderKL = AutoencoderKL.from_pretrained(base_model, subfolder="vae") logger.info("Loading tokenizer two...") tokenizer_two = CLIPTokenizer.from_pretrained(base_model, subfolder="tokenizer_2") logger.info("Loading text encoder two...") text_encoder_two = CLIPTextModelWithProjection.from_pretrained(base_model, subfolder="text_encoder_2", torch_dtype=torch.float16) logger.info("Loading UNet...") unet: UNet3DConditionModel = UNet3DConditionModel.from_pretrained_2d( pretrained_model_path=base_model, motion_module_path=motion_module_path, subfolder="unet", unet_additional_kwargs=infer_config.unet_additional_kwargs, ) # set up scheduler sched_kwargs = infer_config.noise_scheduler_kwargs scheduler = get_scheduler(model_config.scheduler, sched_kwargs) logger.info(f'Using scheduler "{model_config.scheduler}" ({scheduler.__class__.__name__})') if model_config.gradual_latent_hires_fix_map: if "enable" in model_config.gradual_latent_hires_fix_map: if model_config.gradual_latent_hires_fix_map["enable"]: if model_config.scheduler not in (DiffusionScheduler.euler_a, DiffusionScheduler.lcm): logger.warn("gradual_latent_hires_fix enable") logger.warn(f"{model_config.scheduler=}") logger.warn("If you are forced to exit with an error, change to euler_a or lcm") # Load the checkpoint weights into the pipeline if model_config.path is not None: model_path = data_dir.joinpath(model_config.path) logger.info(f"Loading weights from {model_path}") if model_path.is_file(): logger.debug("Loading from single checkpoint file") unet_state_dict, tenc_state_dict, tenc2_state_dict, vae_state_dict = get_checkpoint_weights_sdxl(model_path) elif model_path.is_dir(): logger.debug("Loading from Diffusers model directory") temp_pipeline = StableDiffusionXLPipeline.from_pretrained(model_path) unet_state_dict, tenc_state_dict, tenc2_state_dict, vae_state_dict = ( temp_pipeline.unet.state_dict(), temp_pipeline.text_encoder.state_dict(), temp_pipeline.text_encoder_2.state_dict(), temp_pipeline.vae.state_dict(), ) del temp_pipeline else: raise FileNotFoundError(f"model_path {model_path} is not a file or directory") # Load into the unet, TE, and VAE logger.info("Merging weights into UNet...") _, unet_unex = unet.load_state_dict(unet_state_dict, strict=False) if len(unet_unex) > 0: raise ValueError(f"UNet has unexpected keys: {unet_unex}") tenc_missing, _ = text_encoder.load_state_dict(tenc_state_dict, strict=False) if len(tenc_missing) > 0: raise ValueError(f"TextEncoder has missing keys: {tenc_missing}") tenc2_missing, _ = text_encoder_two.load_state_dict(tenc2_state_dict, strict=False) if len(tenc2_missing) > 0: raise ValueError(f"TextEncoder2 has missing keys: {tenc2_missing}") vae_missing, _ = vae.load_state_dict(vae_state_dict, strict=False) if len(vae_missing) > 0: raise ValueError(f"VAE has missing keys: {vae_missing}") else: logger.info("Using base model weights (no checkpoint/LoRA)") if model_config.vae_path: vae_path = data_dir.joinpath(model_config.vae_path) logger.info(f"Loading vae from {vae_path}") if vae_path.is_dir(): vae = AutoencoderKL.from_pretrained(vae_path) else: tensors = load_tensors(vae_path) tensors = convert_ldm_vae_checkpoint(tensors, vae.config) vae.load_state_dict(tensors) unet.to(torch.float16) text_encoder.to(torch.float16) text_encoder_two.to(torch.float16) del unet_state_dict del tenc_state_dict del tenc2_state_dict del vae_state_dict # enable xformers if available if use_xformers: logger.info("Enabling xformers memory-efficient attention") unet.enable_xformers_memory_efficient_attention() # motion lora for l in model_config.motion_lora_map: lora_path = data_dir.joinpath(l) logger.info(f"loading motion lora {lora_path=}") if lora_path.is_file(): logger.info(f"Loading motion lora {lora_path}") logger.info(f"alpha = {model_config.motion_lora_map[l]}") load_motion_lora(unet, lora_path, alpha=model_config.motion_lora_map[l]) else: raise ValueError(f"{lora_path=} not found") logger.info("Creating AnimationPipeline...") pipeline = AnimationPipeline( vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder_two, tokenizer=tokenizer, tokenizer_2=tokenizer_two, unet=unet, scheduler=scheduler, controlnet_map=None, ) del vae del text_encoder del text_encoder_two del tokenizer del tokenizer_two del unet torch.cuda.empty_cache() pipeline.lcm = None if model_config.lcm_map: if model_config.lcm_map["enable"]: prepare_lcm_lora() load_lcm_lora(pipeline, model_config.lcm_map, is_sdxl=True) load_lora_map(pipeline, model_config.lora_map, video_length, is_sdxl=True) pipeline.unet = pipeline.unet.half() pipeline.text_encoder = pipeline.text_encoder.half() pipeline.text_encoder_2 = pipeline.text_encoder_2.half() # Load TI embeddings pipeline.text_encoder = pipeline.text_encoder.to("cuda") pipeline.text_encoder_2 = pipeline.text_encoder_2.to("cuda") load_text_embeddings(pipeline, is_sdxl=True) pipeline.text_encoder = pipeline.text_encoder.to("cpu") pipeline.text_encoder_2 = pipeline.text_encoder_2.to("cpu") return pipeline def create_pipeline( base_model: Union[str, PathLike] = default_base_path, model_config: ModelConfig = ..., infer_config: InferenceConfig = ..., use_xformers: bool = True, video_length: int = 16, is_sdxl:bool = False, ) -> DiffusionPipeline: """Create an AnimationPipeline from a pretrained model. Uses the base_model argument to load or download the pretrained reference pipeline model.""" # make sure motion_module is a Path and exists logger.info("Checking motion module...") motion_module = data_dir.joinpath(model_config.motion_module) if not (motion_module.exists() and motion_module.is_file()): prepare_motion_module() if not (motion_module.exists() and motion_module.is_file()): # check for safetensors version motion_module = motion_module.with_suffix(".safetensors") if not (motion_module.exists() and motion_module.is_file()): # download from HuggingFace Hub if not found ensure_motion_modules() if not (motion_module.exists() and motion_module.is_file()): # this should never happen, but just in case... raise FileNotFoundError(f"Motion module {motion_module} does not exist or is not a file!") if is_sdxl: return create_pipeline_sdxl( base_model=base_model, model_config=model_config, infer_config=infer_config, use_xformers=use_xformers, video_length=video_length, motion_module_path=motion_module, ) logger.info("Loading tokenizer...") tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(base_model, subfolder="tokenizer") logger.info("Loading text encoder...") text_encoder: CLIPSkipTextModel = CLIPSkipTextModel.from_pretrained(base_model, subfolder="text_encoder") logger.info("Loading VAE...") vae: AutoencoderKL = AutoencoderKL.from_pretrained(base_model, subfolder="vae") logger.info("Loading UNet...") unet: UNet3DConditionModel = UNet3DConditionModel.from_pretrained_2d( pretrained_model_path=base_model, motion_module_path=motion_module, subfolder="unet", unet_additional_kwargs=infer_config.unet_additional_kwargs, ) feature_extractor = CLIPImageProcessor.from_pretrained(base_model, subfolder="feature_extractor") # set up scheduler if model_config.gradual_latent_hires_fix_map: if "enable" in model_config.gradual_latent_hires_fix_map: if model_config.gradual_latent_hires_fix_map["enable"]: if model_config.scheduler not in (DiffusionScheduler.euler_a, DiffusionScheduler.lcm): logger.warn("gradual_latent_hires_fix enable") logger.warn(f"{model_config.scheduler=}") logger.warn("If you are forced to exit with an error, change to euler_a or lcm") sched_kwargs = infer_config.noise_scheduler_kwargs scheduler = get_scheduler(model_config.scheduler, sched_kwargs) logger.info(f'Using scheduler "{model_config.scheduler}" ({scheduler.__class__.__name__})') # Load the checkpoint weights into the pipeline if model_config.path is not None: model_path = data_dir.joinpath(model_config.path) logger.info(f"Loading weights from {model_path}") if model_path.is_file(): logger.debug("Loading from single checkpoint file") unet_state_dict, tenc_state_dict, vae_state_dict = get_checkpoint_weights(model_path) elif model_path.is_dir(): logger.debug("Loading from Diffusers model directory") temp_pipeline = StableDiffusionPipeline.from_pretrained(model_path) unet_state_dict, tenc_state_dict, vae_state_dict = ( temp_pipeline.unet.state_dict(), temp_pipeline.text_encoder.state_dict(), temp_pipeline.vae.state_dict(), ) del temp_pipeline else: raise FileNotFoundError(f"model_path {model_path} is not a file or directory") # Load into the unet, TE, and VAE logger.info("Merging weights into UNet...") _, unet_unex = unet.load_state_dict(unet_state_dict, strict=False) if len(unet_unex) > 0: raise ValueError(f"UNet has unexpected keys: {unet_unex}") tenc_missing, _ = text_encoder.load_state_dict(tenc_state_dict, strict=False) if len(tenc_missing) > 0: raise ValueError(f"TextEncoder has missing keys: {tenc_missing}") vae_missing, _ = vae.load_state_dict(vae_state_dict, strict=False) if len(vae_missing) > 0: raise ValueError(f"VAE has missing keys: {vae_missing}") else: logger.info("Using base model weights (no checkpoint/LoRA)") if model_config.vae_path: vae_path = data_dir.joinpath(model_config.vae_path) logger.info(f"Loading vae from {vae_path}") if vae_path.is_dir(): vae = AutoencoderKL.from_pretrained(vae_path) else: tensors = load_tensors(vae_path) tensors = convert_ldm_vae_checkpoint(tensors, vae.config) vae.load_state_dict(tensors) # enable xformers if available if use_xformers: logger.info("Enabling xformers memory-efficient attention") unet.enable_xformers_memory_efficient_attention() if False: # lora for l in model_config.lora_map: lora_path = data_dir.joinpath(l) if lora_path.is_file(): logger.info(f"Loading lora {lora_path}") logger.info(f"alpha = {model_config.lora_map[l]}") load_safetensors_lora(text_encoder, unet, lora_path, alpha=model_config.lora_map[l]) # motion lora for l in model_config.motion_lora_map: lora_path = data_dir.joinpath(l) logger.info(f"loading motion lora {lora_path=}") if lora_path.is_file(): logger.info(f"Loading motion lora {lora_path}") logger.info(f"alpha = {model_config.motion_lora_map[l]}") load_motion_lora(unet, lora_path, alpha=model_config.motion_lora_map[l]) else: raise ValueError(f"{lora_path=} not found") logger.info("Creating AnimationPipeline...") pipeline = AnimationPipeline( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, feature_extractor=feature_extractor, controlnet_map=None, ) pipeline.lcm = None if model_config.lcm_map: if model_config.lcm_map["enable"]: prepare_lcm_lora() load_lcm_lora(pipeline, model_config.lcm_map, is_sdxl=False) load_lora_map(pipeline, model_config.lora_map, video_length) # Load TI embeddings pipeline.unet = pipeline.unet.half() pipeline.text_encoder = pipeline.text_encoder.half() pipeline.text_encoder = pipeline.text_encoder.to("cuda") load_text_embeddings(pipeline) pipeline.text_encoder = pipeline.text_encoder.to("cpu") return pipeline def load_controlnet_models(pipe: DiffusionPipeline, model_config: ModelConfig = ..., is_sdxl:bool = False): # controlnet if is_sdxl: prepare_lllite() controlnet_map={} if model_config.controlnet_map: c_image_dir = data_dir.joinpath( model_config.controlnet_map["input_image_dir"] ) for c in model_config.controlnet_map: item = model_config.controlnet_map[c] if type(item) is dict: if item["enable"] == True: if is_valid_controlnet_type(c, is_sdxl): img_dir = c_image_dir.joinpath( c ) cond_imgs = sorted(glob.glob( os.path.join(img_dir, "[0-9]*.png"), recursive=False)) if len(cond_imgs) > 0: logger.info(f"loading {c=} model") controlnet_map[c] = create_controlnet_model(pipe, c , is_sdxl) else: logger.info(f"invalid controlnet type for {'sdxl' if is_sdxl else 'sd15'} : {c}") if not controlnet_map: controlnet_map = None pipe.controlnet_map = controlnet_map def unload_controlnet_models(pipe: AnimationPipeline): from animatediff.utils.util import show_gpu if pipe.controlnet_map: for c in pipe.controlnet_map: controlnet = pipe.controlnet_map[c] if isinstance(controlnet, ControlNetLLLite): controlnet.unapply_to() del controlnet #show_gpu("before uload controlnet") pipe.controlnet_map = None torch.cuda.empty_cache() #show_gpu("after unload controlnet") def create_us_pipeline( model_config: ModelConfig = ..., infer_config: InferenceConfig = ..., use_xformers: bool = True, use_controlnet_ref: bool = False, use_controlnet_tile: bool = False, use_controlnet_line_anime: bool = False, use_controlnet_ip2p: bool = False, ) -> DiffusionPipeline: # set up scheduler sched_kwargs = infer_config.noise_scheduler_kwargs scheduler = get_scheduler(model_config.scheduler, sched_kwargs) logger.info(f'Using scheduler "{model_config.scheduler}" ({scheduler.__class__.__name__})') controlnet = [] if use_controlnet_tile: controlnet.append( ControlNetModel.from_pretrained('lllyasviel/control_v11f1e_sd15_tile') ) if use_controlnet_line_anime: controlnet.append( ControlNetModel.from_pretrained('lllyasviel/control_v11p_sd15s2_lineart_anime') ) if use_controlnet_ip2p: controlnet.append( ControlNetModel.from_pretrained('lllyasviel/control_v11e_sd15_ip2p') ) if len(controlnet) == 1: controlnet = controlnet[0] elif len(controlnet) == 0: controlnet = None # Load the checkpoint weights into the pipeline pipeline:DiffusionPipeline if model_config.path is not None: model_path = data_dir.joinpath(model_config.path) logger.info(f"Loading weights from {model_path}") if model_path.is_file(): def is_empty_dir(path): import os return len(os.listdir(path)) == 0 save_path = data_dir.joinpath("models/huggingface/" + model_path.stem + "_" + str(model_path.stat().st_size)) save_path.mkdir(exist_ok=True) if save_path.is_dir() and is_empty_dir(save_path): # StableDiffusionControlNetImg2ImgPipeline.from_single_file does not exist in version 18.2 logger.debug("Loading from single checkpoint file") tmp_pipeline = StableDiffusionPipeline.from_single_file( pretrained_model_link_or_path=str(model_path.absolute()) ) tmp_pipeline.save_pretrained(save_path, safe_serialization=True) del tmp_pipeline if use_controlnet_ref: pipeline = StableDiffusionControlNetImg2ImgReferencePipeline.from_pretrained( save_path, controlnet=controlnet, local_files_only=False, load_safety_checker=False, safety_checker=None, ) else: pipeline = StableDiffusionControlNetImg2ImgPipeline.from_pretrained( save_path, controlnet=controlnet, local_files_only=False, load_safety_checker=False, safety_checker=None, ) elif model_path.is_dir(): logger.debug("Loading from Diffusers model directory") if use_controlnet_ref: pipeline = StableDiffusionControlNetImg2ImgReferencePipeline.from_pretrained( model_path, controlnet=controlnet, local_files_only=True, load_safety_checker=False, safety_checker=None, ) else: pipeline = StableDiffusionControlNetImg2ImgPipeline.from_pretrained( model_path, controlnet=controlnet, local_files_only=True, load_safety_checker=False, safety_checker=None, ) else: raise FileNotFoundError(f"model_path {model_path} is not a file or directory") else: raise ValueError("model_config.path is invalid") pipeline.scheduler = scheduler # enable xformers if available if use_xformers: logger.info("Enabling xformers memory-efficient attention") pipeline.enable_xformers_memory_efficient_attention() # lora for l in model_config.lora_map: lora_path = data_dir.joinpath(l) if lora_path.is_file(): alpha = model_config.lora_map[l] if isinstance(alpha, dict): alpha = 0.75 logger.info(f"Loading lora {lora_path}") logger.info(f"alpha = {alpha}") load_safetensors_lora2(pipeline.text_encoder, pipeline.unet, lora_path, alpha=alpha,is_animatediff=False) # Load TI embeddings pipeline.unet = pipeline.unet.half() pipeline.text_encoder = pipeline.text_encoder.half() pipeline.text_encoder = pipeline.text_encoder.to("cuda") load_text_embeddings(pipeline) pipeline.text_encoder = pipeline.text_encoder.to("cpu") return pipeline def seed_everything(seed): import random import numpy as np torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed % (2**32)) random.seed(seed) def controlnet_preprocess( controlnet_map: Dict[str, Any] = None, width: int = 512, height: int = 512, duration: int = 16, out_dir: PathLike = ..., device_str:str=None, is_sdxl:bool = False, ): if not controlnet_map: return None, None, None, None out_dir = Path(out_dir) # ensure out_dir is a Path # { 0 : { "type_str" : IMAGE, "type_str2" : IMAGE } } controlnet_image_map={} controlnet_type_map={} c_image_dir = data_dir.joinpath( controlnet_map["input_image_dir"] ) save_detectmap = controlnet_map["save_detectmap"] if "save_detectmap" in controlnet_map else True preprocess_on_gpu = controlnet_map["preprocess_on_gpu"] if "preprocess_on_gpu" in controlnet_map else True device_str = device_str if preprocess_on_gpu else None for c in controlnet_map: if c == "controlnet_ref": continue item = controlnet_map[c] processed = False if type(item) is dict: if item["enable"] == True: if is_valid_controlnet_type(c, is_sdxl): preprocessor_map = item["preprocessor"] if "preprocessor" in item else {} img_dir = c_image_dir.joinpath( c ) cond_imgs = sorted(glob.glob( os.path.join(img_dir, "[0-9]*.png"), recursive=False)) if len(cond_imgs) > 0: controlnet_type_map[c] = { "controlnet_conditioning_scale" : item["controlnet_conditioning_scale"], "control_guidance_start" : item["control_guidance_start"], "control_guidance_end" : item["control_guidance_end"], "control_scale_list" : item["control_scale_list"], "guess_mode" : item["guess_mode"] if "guess_mode" in item else False, "control_region_list" : item["control_region_list"] if "control_region_list" in item else [] } use_preprocessor = item["use_preprocessor"] if "use_preprocessor" in item else True for img_path in tqdm(cond_imgs, desc=f"Preprocessing images ({c})"): frame_no = int(Path(img_path).stem) if frame_no < duration: if frame_no not in controlnet_image_map: controlnet_image_map[frame_no] = {} controlnet_image_map[frame_no][c] = get_preprocessed_img( c, get_resized_image2(img_path, 512) , use_preprocessor, device_str, preprocessor_map) processed = True else: logger.info(f"invalid controlnet type for {'sdxl' if is_sdxl else 'sd15'} : {c}") if save_detectmap and processed: det_dir = out_dir.joinpath(f"{0:02d}_detectmap/{c}") det_dir.mkdir(parents=True, exist_ok=True) for frame_no in tqdm(controlnet_image_map, desc=f"Saving Preprocessed images ({c})"): save_path = det_dir.joinpath(f"{frame_no:08d}.png") if c in controlnet_image_map[frame_no]: controlnet_image_map[frame_no][c].save(save_path) clear_controlnet_preprocessor(c) clear_controlnet_preprocessor() controlnet_ref_map = None if "controlnet_ref" in controlnet_map: r = controlnet_map["controlnet_ref"] if r["enable"] == True: org_name = data_dir.joinpath( r["ref_image"]).stem # ref_image = get_resized_image( data_dir.joinpath( r["ref_image"] ) , width, height) ref_image = get_resized_image2( data_dir.joinpath( r["ref_image"] ) , 512) if ref_image is not None: controlnet_ref_map = { "ref_image" : ref_image, "style_fidelity" : r["style_fidelity"], "attention_auto_machine_weight" : r["attention_auto_machine_weight"], "gn_auto_machine_weight" : r["gn_auto_machine_weight"], "reference_attn" : r["reference_attn"], "reference_adain" : r["reference_adain"], "scale_pattern" : r["scale_pattern"] } if save_detectmap: det_dir = out_dir.joinpath(f"{0:02d}_detectmap/controlnet_ref") det_dir.mkdir(parents=True, exist_ok=True) save_path = det_dir.joinpath(f"{org_name}.png") ref_image.save(save_path) controlnet_no_shrink = ["controlnet_tile","animatediff_controlnet","controlnet_canny","controlnet_normalbae","controlnet_depth","controlnet_lineart","controlnet_lineart_anime","controlnet_scribble","controlnet_seg","controlnet_softedge","controlnet_mlsd"] if "no_shrink_list" in controlnet_map: controlnet_no_shrink = controlnet_map["no_shrink_list"] return controlnet_image_map, controlnet_type_map, controlnet_ref_map, controlnet_no_shrink def ip_adapter_preprocess( ip_adapter_config_map: Dict[str, Any] = None, width: int = 512, height: int = 512, duration: int = 16, out_dir: PathLike = ..., is_sdxl: bool = False, ): ip_adapter_map={} processed = False if ip_adapter_config_map: if ip_adapter_config_map["enable"] == True: resized_to_square = ip_adapter_config_map["resized_to_square"] if "resized_to_square" in ip_adapter_config_map else False image_dir = data_dir.joinpath( ip_adapter_config_map["input_image_dir"] ) imgs = sorted(chain.from_iterable([glob.glob(os.path.join(image_dir, f"[0-9]*{ext}")) for ext in IMG_EXTENSIONS])) if len(imgs) > 0: prepare_ip_adapter_sdxl() if is_sdxl else prepare_ip_adapter() ip_adapter_map["images"] = {} for img_path in tqdm(imgs, desc=f"Preprocessing images (ip_adapter)"): frame_no = int(Path(img_path).stem) if frame_no < duration: if resized_to_square: ip_adapter_map["images"][frame_no] = get_resized_image(img_path, 256, 256) else: ip_adapter_map["images"][frame_no] = get_resized_image2(img_path, 256) processed = True if processed: ip_adapter_config_map["prompt_fixed_ratio"] = max(min(1.0, ip_adapter_config_map["prompt_fixed_ratio"]),0) prompt_fixed_ratio = ip_adapter_config_map["prompt_fixed_ratio"] prompt_map = ip_adapter_map["images"] prompt_map = dict(sorted(prompt_map.items())) key_list = list(prompt_map.keys()) for k0,k1 in zip(key_list,key_list[1:]+[duration]): k05 = k0 + round((k1-k0) * prompt_fixed_ratio) if k05 == k1: k05 -= 1 if k05 != k0: prompt_map[k05] = prompt_map[k0] ip_adapter_map["images"] = prompt_map if (ip_adapter_config_map["save_input_image"] == True) and processed: det_dir = out_dir.joinpath(f"{0:02d}_ip_adapter/") det_dir.mkdir(parents=True, exist_ok=True) for frame_no in tqdm(ip_adapter_map["images"], desc=f"Saving Preprocessed images (ip_adapter)"): save_path = det_dir.joinpath(f"{frame_no:08d}.png") ip_adapter_map["images"][frame_no].save(save_path) return ip_adapter_map if processed else None def prompt_preprocess( prompt_config_map: Dict[str, Any], head_prompt: str, tail_prompt: str, prompt_fixed_ratio: float, video_length: int, ): prompt_map = {} for k in prompt_config_map.keys(): if int(k) < video_length: pr = prompt_config_map[k] if head_prompt: pr = head_prompt + "," + pr if tail_prompt: pr = pr + "," + tail_prompt prompt_map[int(k)]=pr prompt_map = dict(sorted(prompt_map.items())) key_list = list(prompt_map.keys()) for k0,k1 in zip(key_list,key_list[1:]+[video_length]): k05 = k0 + round((k1-k0) * prompt_fixed_ratio) if k05 == k1: k05 -= 1 if k05 != k0: prompt_map[k05] = prompt_map[k0] return prompt_map def region_preprocess( model_config: ModelConfig = ..., width: int = 512, height: int = 512, duration: int = 16, out_dir: PathLike = ..., is_init_img_exist: bool = False, is_sdxl:bool = False, ): is_bg_init_img = False if is_init_img_exist: if model_config.region_map: if "background" in model_config.region_map: is_bg_init_img = model_config.region_map["background"]["is_init_img"] region_condi_list=[] region2index={} condi_index = 0 prev_ip_map = None if not is_bg_init_img: ip_map = ip_adapter_preprocess( model_config.ip_adapter_map, width, height, duration, out_dir, is_sdxl ) if ip_map: prev_ip_map = ip_map condition_map = { "prompt_map": prompt_preprocess( model_config.prompt_map, model_config.head_prompt, model_config.tail_prompt, model_config.prompt_fixed_ratio, duration ), "ip_adapter_map": ip_map } region_condi_list.append( condition_map ) bg_src = condi_index condi_index += 1 else: bg_src = -1 region_list=[ { "mask_images": None, "src" : bg_src, "crop_generation_rate" : 0 } ] region2index["background"]=bg_src if model_config.region_map: for r in model_config.region_map: if r == "background": continue if model_config.region_map[r]["enable"] != True: continue region_dir = out_dir.joinpath(f"region_{int(r):05d}/") region_dir.mkdir(parents=True, exist_ok=True) mask_map = mask_preprocess( model_config.region_map[r], width, height, duration, region_dir ) if not mask_map: continue if model_config.region_map[r]["is_init_img"] == False: ip_map = ip_adapter_preprocess( model_config.region_map[r]["condition"]["ip_adapter_map"], width, height, duration, region_dir, is_sdxl ) if ip_map: prev_ip_map = ip_map condition_map={ "prompt_map": prompt_preprocess( model_config.region_map[r]["condition"]["prompt_map"], model_config.region_map[r]["condition"]["head_prompt"], model_config.region_map[r]["condition"]["tail_prompt"], model_config.region_map[r]["condition"]["prompt_fixed_ratio"], duration ), "ip_adapter_map": ip_map } region_condi_list.append( condition_map ) src = condi_index condi_index += 1 else: if is_init_img_exist == False: logger.warn("'is_init_img' : true / BUT init_img is not exist -> ignore region") continue src = -1 region_list.append( { "mask_images": mask_map, "src" : src, "crop_generation_rate" : model_config.region_map[r]["crop_generation_rate"] if "crop_generation_rate" in model_config.region_map[r] else 0 } ) region2index[r]=src ip_adapter_config_map = None if prev_ip_map is not None: ip_adapter_config_map={} ip_adapter_config_map["scale"] = model_config.ip_adapter_map["scale"] ip_adapter_config_map["is_plus"] = model_config.ip_adapter_map["is_plus"] ip_adapter_config_map["is_plus_face"] = model_config.ip_adapter_map["is_plus_face"] if "is_plus_face" in model_config.ip_adapter_map else False ip_adapter_config_map["is_light"] = model_config.ip_adapter_map["is_light"] if "is_light" in model_config.ip_adapter_map else False ip_adapter_config_map["is_full_face"] = model_config.ip_adapter_map["is_full_face"] if "is_full_face" in model_config.ip_adapter_map else False for c in region_condi_list: if c["ip_adapter_map"] == None: logger.info(f"fill map") c["ip_adapter_map"] = prev_ip_map #for c in region_condi_list: # logger.info(f"{c['prompt_map']=}") if not region_condi_list: raise ValueError("erro! There is not a single valid region") return region_condi_list, region_list, ip_adapter_config_map, region2index def img2img_preprocess( img2img_config_map: Dict[str, Any] = None, width: int = 512, height: int = 512, duration: int = 16, out_dir: PathLike = ..., ): img2img_map={} processed = False if img2img_config_map: if img2img_config_map["enable"] == True: image_dir = data_dir.joinpath( img2img_config_map["init_img_dir"] ) imgs = sorted(glob.glob( os.path.join(image_dir, "[0-9]*.png"), recursive=False)) if len(imgs) > 0: img2img_map["images"] = {} img2img_map["denoising_strength"] = img2img_config_map["denoising_strength"] for img_path in tqdm(imgs, desc=f"Preprocessing images (img2img)"): frame_no = int(Path(img_path).stem) if frame_no < duration: img2img_map["images"][frame_no] = get_resized_image(img_path, width, height) processed = True if (img2img_config_map["save_init_image"] == True) and processed: det_dir = out_dir.joinpath(f"{0:02d}_img2img_init_img/") det_dir.mkdir(parents=True, exist_ok=True) for frame_no in tqdm(img2img_map["images"], desc=f"Saving Preprocessed images (img2img)"): save_path = det_dir.joinpath(f"{frame_no:08d}.png") img2img_map["images"][frame_no].save(save_path) return img2img_map if processed else None def mask_preprocess( region_config_map: Dict[str, Any] = None, width: int = 512, height: int = 512, duration: int = 16, out_dir: PathLike = ..., ): mask_map={} processed = False size = None mode = None if region_config_map: image_dir = data_dir.joinpath( region_config_map["mask_dir"] ) imgs = sorted(glob.glob( os.path.join(image_dir, "[0-9]*.png"), recursive=False)) if len(imgs) > 0: for img_path in tqdm(imgs, desc=f"Preprocessing images (mask)"): frame_no = int(Path(img_path).stem) if frame_no < duration: mask_map[frame_no] = get_resized_image(img_path, width, height) if size is None: size = mask_map[frame_no].size mode = mask_map[frame_no].mode processed = True if processed: if 0 in mask_map: prev_img = mask_map[0] else: prev_img = Image.new(mode, size, color=0) for i in range(duration): if i in mask_map: prev_img = mask_map[i] else: mask_map[i] = prev_img if (region_config_map["save_mask"] == True) and processed: det_dir = out_dir.joinpath(f"mask/") det_dir.mkdir(parents=True, exist_ok=True) for frame_no in tqdm(mask_map, desc=f"Saving Preprocessed images (mask)"): save_path = det_dir.joinpath(f"{frame_no:08d}.png") mask_map[frame_no].save(save_path) return mask_map if processed else None def wild_card_conversion(model_config: ModelConfig = ...,): from animatediff.utils.wild_card import replace_wild_card wild_card_dir = get_dir("wildcards") for k in model_config.prompt_map.keys(): model_config.prompt_map[k] = replace_wild_card(model_config.prompt_map[k], wild_card_dir) if model_config.head_prompt: model_config.head_prompt = replace_wild_card(model_config.head_prompt, wild_card_dir) if model_config.tail_prompt: model_config.tail_prompt = replace_wild_card(model_config.tail_prompt, wild_card_dir) model_config.prompt_fixed_ratio = max(min(1.0, model_config.prompt_fixed_ratio),0) if model_config.region_map: for r in model_config.region_map: if r == "background": continue if "condition" in model_config.region_map[r]: c = model_config.region_map[r]["condition"] for k in c["prompt_map"].keys(): c["prompt_map"][k] = replace_wild_card(c["prompt_map"][k], wild_card_dir) if "head_prompt" in c: c["head_prompt"] = replace_wild_card(c["head_prompt"], wild_card_dir) if "tail_prompt" in c: c["tail_prompt"] = replace_wild_card(c["tail_prompt"], wild_card_dir) if "prompt_fixed_ratio" in c: c["prompt_fixed_ratio"] = max(min(1.0, c["prompt_fixed_ratio"]),0) def save_output( pipeline_output, frame_dir:str, out_file:str, output_map : Dict[str,Any] = {}, no_frames : bool = False, save_frames=save_frames, save_video=None, ): output_format = "gif" output_fps = 8 if output_map: output_format = output_map["format"] if "format" in output_map else output_format output_fps = output_map["fps"] if "fps" in output_map else output_fps if output_format == "mp4": output_format = "h264" if output_format == "gif": out_file = out_file.with_suffix(".gif") if no_frames is not True: if save_frames: save_frames(pipeline_output,frame_dir) # generate the output filename and save the video if save_video: save_video(pipeline_output, out_file, output_fps) else: pipeline_output[0].save( fp=out_file, format="GIF", append_images=pipeline_output[1:], save_all=True, duration=(1 / output_fps * 1000), loop=0 ) else: if save_frames: save_frames(pipeline_output,frame_dir) from animatediff.rife.ffmpeg import (FfmpegEncoder, VideoCodec, codec_extn) out_file = out_file.with_suffix( f".{codec_extn(output_format)}" ) logger.info("Creating ffmpeg encoder...") encoder = FfmpegEncoder( frames_dir=frame_dir, out_file=out_file, codec=output_format, in_fps=output_fps, out_fps=output_fps, lossless=False, param= output_map["encode_param"] if "encode_param" in output_map else {} ) logger.info("Encoding interpolated frames with ffmpeg...") result = encoder.encode() logger.debug(f"ffmpeg result: {result}") def run_inference( pipeline: DiffusionPipeline, n_prompt: str = ..., seed: int = -1, steps: int = 25, guidance_scale: float = 7.5, unet_batch_size: int = 1, width: int = 512, height: int = 512, duration: int = 16, idx: int = 0, out_dir: PathLike = ..., context_frames: int = -1, context_stride: int = 3, context_overlap: int = 4, context_schedule: str = "uniform", clip_skip: int = 1, controlnet_map: Dict[str, Any] = None, controlnet_image_map: Dict[str,Any] = None, controlnet_type_map: Dict[str,Any] = None, controlnet_ref_map: Dict[str,Any] = None, controlnet_no_shrink:List[str]=None, no_frames :bool = False, img2img_map: Dict[str,Any] = None, ip_adapter_config_map: Dict[str,Any] = None, region_list: List[Any] = None, region_condi_list: List[Any] = None, output_map: Dict[str,Any] = None, is_single_prompt_mode: bool = False, is_sdxl:bool=False, apply_lcm_lora:bool=False, gradual_latent_map: Dict[str,Any] = None, ): out_dir = Path(out_dir) # ensure out_dir is a Path # Trim and clean up the prompt for filename use prompt_map = region_condi_list[0]["prompt_map"] prompt_tags = [re_clean_prompt.sub("", tag).strip().replace(" ", "-") for tag in prompt_map[list(prompt_map.keys())[0]].split(",")] prompt_str = "_".join((prompt_tags[:6]))[:50] frame_dir = out_dir.joinpath(f"{idx:02d}-{seed}") out_file = out_dir.joinpath(f"{idx:02d}_{seed}_{prompt_str}") def preview_callback(i: int, video: torch.Tensor, save_fn: Callable[[torch.Tensor], None], out_file: str) -> None: save_fn(video, out_file=Path(f"{out_file}_preview@{i}")) save_fn = partial( save_output, frame_dir=frame_dir, output_map=output_map, no_frames=no_frames, save_frames=partial(save_frames, show_progress=False), save_video=save_video ) callback = partial(preview_callback, save_fn=save_fn, out_file=out_file) seed_everything(seed) logger.info(f"{len( region_condi_list )=}") logger.info(f"{len( region_list )=}") pipeline_output = pipeline( negative_prompt=n_prompt, num_inference_steps=steps, guidance_scale=guidance_scale, unet_batch_size=unet_batch_size, width=width, height=height, video_length=duration, return_dict=False, context_frames=context_frames, context_stride=context_stride + 1, context_overlap=context_overlap, context_schedule=context_schedule, clip_skip=clip_skip, controlnet_type_map=controlnet_type_map, controlnet_image_map=controlnet_image_map, controlnet_ref_map=controlnet_ref_map, controlnet_no_shrink=controlnet_no_shrink, controlnet_max_samples_on_vram=controlnet_map["max_samples_on_vram"] if "max_samples_on_vram" in controlnet_map else 999, controlnet_max_models_on_vram=controlnet_map["max_models_on_vram"] if "max_models_on_vram" in controlnet_map else 99, controlnet_is_loop = controlnet_map["is_loop"] if "is_loop" in controlnet_map else True, img2img_map=img2img_map, ip_adapter_config_map=ip_adapter_config_map, region_list=region_list, region_condi_list=region_condi_list, interpolation_factor=1, is_single_prompt_mode=is_single_prompt_mode, apply_lcm_lora=apply_lcm_lora, gradual_latent_map=gradual_latent_map, callback=callback, callback_steps=output_map.get("preview_steps"), ) logger.info("Generation complete, saving...") save_fn(pipeline_output, out_file=out_file) logger.info(f"Saved sample to {out_file}") return pipeline_output def run_upscale( org_imgs: List[str], pipeline: DiffusionPipeline, prompt_map: Dict[int, str] = None, n_prompt: str = ..., seed: int = -1, steps: int = 25, strength: float = 0.5, guidance_scale: float = 7.5, clip_skip: int = 1, us_width: int = 512, us_height: int = 512, idx: int = 0, out_dir: PathLike = ..., upscale_config:Dict[str, Any]=None, use_controlnet_ref: bool = False, use_controlnet_tile: bool = False, use_controlnet_line_anime: bool = False, use_controlnet_ip2p: bool = False, no_frames:bool = False, output_map: Dict[str,Any] = None, ): from animatediff.utils.lpw_stable_diffusion import lpw_encode_prompt pipeline.set_progress_bar_config(disable=True) images = get_resized_images(org_imgs, us_width, us_height) steps = steps if "steps" not in upscale_config else upscale_config["steps"] scheduler = scheduler if "scheduler" not in upscale_config else upscale_config["scheduler"] guidance_scale = guidance_scale if "guidance_scale" not in upscale_config else upscale_config["guidance_scale"] clip_skip = clip_skip if "clip_skip" not in upscale_config else upscale_config["clip_skip"] strength = strength if "strength" not in upscale_config else upscale_config["strength"] controlnet_conditioning_scale = [] guess_mode = [] control_guidance_start = [] control_guidance_end = [] # for controlnet tile if use_controlnet_tile: controlnet_conditioning_scale.append(upscale_config["controlnet_tile"]["controlnet_conditioning_scale"]) guess_mode.append(upscale_config["controlnet_tile"]["guess_mode"]) control_guidance_start.append(upscale_config["controlnet_tile"]["control_guidance_start"]) control_guidance_end.append(upscale_config["controlnet_tile"]["control_guidance_end"]) # for controlnet line_anime if use_controlnet_line_anime: controlnet_conditioning_scale.append(upscale_config["controlnet_line_anime"]["controlnet_conditioning_scale"]) guess_mode.append(upscale_config["controlnet_line_anime"]["guess_mode"]) control_guidance_start.append(upscale_config["controlnet_line_anime"]["control_guidance_start"]) control_guidance_end.append(upscale_config["controlnet_line_anime"]["control_guidance_end"]) # for controlnet ip2p if use_controlnet_ip2p: controlnet_conditioning_scale.append(upscale_config["controlnet_ip2p"]["controlnet_conditioning_scale"]) guess_mode.append(upscale_config["controlnet_ip2p"]["guess_mode"]) control_guidance_start.append(upscale_config["controlnet_ip2p"]["control_guidance_start"]) control_guidance_end.append(upscale_config["controlnet_ip2p"]["control_guidance_end"]) # for controlnet ref ref_image = None if use_controlnet_ref: if not upscale_config["controlnet_ref"]["use_frame_as_ref_image"] and not upscale_config["controlnet_ref"]["use_1st_frame_as_ref_image"]: ref_image = get_resized_images([ data_dir.joinpath( upscale_config["controlnet_ref"]["ref_image"] ) ], us_width, us_height)[0] generator = torch.manual_seed(seed) seed_everything(seed) prompt_embeds_map = {} prompt_map = dict(sorted(prompt_map.items())) negative = None do_classifier_free_guidance=guidance_scale > 1.0 prompt_list = [prompt_map[key_frame] for key_frame in prompt_map.keys()] prompt_embeds,neg_embeds = lpw_encode_prompt( pipe=pipeline, prompt=prompt_list, do_classifier_free_guidance=do_classifier_free_guidance, negative_prompt=n_prompt, ) if do_classifier_free_guidance: negative = neg_embeds.chunk(neg_embeds.shape[0], 0) positive = prompt_embeds.chunk(prompt_embeds.shape[0], 0) else: negative = [None] positive = prompt_embeds.chunk(prompt_embeds.shape[0], 0) for i, key_frame in enumerate(prompt_map): prompt_embeds_map[key_frame] = positive[i] key_first =list(prompt_map.keys())[0] key_last =list(prompt_map.keys())[-1] def get_current_prompt_embeds( center_frame: int = 0, video_length : int = 0 ): key_prev = key_last key_next = key_first for p in prompt_map.keys(): if p > center_frame: key_next = p break key_prev = p dist_prev = center_frame - key_prev if dist_prev < 0: dist_prev += video_length dist_next = key_next - center_frame if dist_next < 0: dist_next += video_length if key_prev == key_next or dist_prev + dist_next == 0: return prompt_embeds_map[key_prev] rate = dist_prev / (dist_prev + dist_next) return get_tensor_interpolation_method()(prompt_embeds_map[key_prev],prompt_embeds_map[key_next], rate) line_anime_processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators") out_images=[] logger.info(f"{use_controlnet_tile=}") logger.info(f"{use_controlnet_line_anime=}") logger.info(f"{use_controlnet_ip2p=}") logger.info(f"{controlnet_conditioning_scale=}") logger.info(f"{guess_mode=}") logger.info(f"{control_guidance_start=}") logger.info(f"{control_guidance_end=}") for i, org_image in enumerate(tqdm(images, desc=f"Upscaling...")): cur_positive = get_current_prompt_embeds(i, len(images)) # logger.info(f"w {condition_image.size[0]}") # logger.info(f"h {condition_image.size[1]}") condition_image = [] if use_controlnet_tile: condition_image.append( org_image ) if use_controlnet_line_anime: condition_image.append( line_anime_processor(org_image) ) if use_controlnet_ip2p: condition_image.append( org_image ) if not use_controlnet_ref: out_image = pipeline( prompt_embeds=cur_positive, negative_prompt_embeds=negative[0], image=org_image, control_image=condition_image, width=org_image.size[0], height=org_image.size[1], strength=strength, num_inference_steps=steps, guidance_scale=guidance_scale, generator=generator, controlnet_conditioning_scale= controlnet_conditioning_scale if len(controlnet_conditioning_scale) > 1 else controlnet_conditioning_scale[0], guess_mode= guess_mode[0], control_guidance_start= control_guidance_start if len(control_guidance_start) > 1 else control_guidance_start[0], control_guidance_end= control_guidance_end if len(control_guidance_end) > 1 else control_guidance_end[0], ).images[0] else: if upscale_config["controlnet_ref"]["use_1st_frame_as_ref_image"]: if i == 0: ref_image = org_image elif upscale_config["controlnet_ref"]["use_frame_as_ref_image"]: ref_image = org_image out_image = pipeline( prompt_embeds=cur_positive, negative_prompt_embeds=negative[0], image=org_image, control_image=condition_image, width=org_image.size[0], height=org_image.size[1], strength=strength, num_inference_steps=steps, guidance_scale=guidance_scale, generator=generator, controlnet_conditioning_scale= controlnet_conditioning_scale if len(controlnet_conditioning_scale) > 1 else controlnet_conditioning_scale[0], guess_mode= guess_mode[0], # control_guidance_start= control_guidance_start, # control_guidance_end= control_guidance_end, ### for controlnet ref ref_image=ref_image, attention_auto_machine_weight = upscale_config["controlnet_ref"]["attention_auto_machine_weight"], gn_auto_machine_weight = upscale_config["controlnet_ref"]["gn_auto_machine_weight"], style_fidelity = upscale_config["controlnet_ref"]["style_fidelity"], reference_attn= upscale_config["controlnet_ref"]["reference_attn"], reference_adain= upscale_config["controlnet_ref"]["reference_adain"], ).images[0] out_images.append(out_image) # Trim and clean up the prompt for filename use prompt_tags = [re_clean_prompt.sub("", tag).strip().replace(" ", "-") for tag in prompt_map[list(prompt_map.keys())[0]].split(",")] prompt_str = "_".join((prompt_tags[:6]))[:50] # generate the output filename and save the video out_file = out_dir.joinpath(f"{idx:02d}_{seed}_{prompt_str}") frame_dir = out_dir.joinpath(f"{idx:02d}-{seed}-upscaled") save_output( out_images, frame_dir, out_file, output_map, no_frames, save_imgs, None ) logger.info(f"Saved sample to {out_file}") return out_images