import logging from typing import Optional import torch import torch._dynamo as dynamo from diffusers import (DiffusionPipeline, StableDiffusionPipeline, StableDiffusionXLPipeline) from einops._torch_specific import allow_ops_in_compiled_graph from animatediff.utils.device import get_memory_format, get_model_dtypes from animatediff.utils.model import nop_train logger = logging.getLogger(__name__) def send_to_device( pipeline: DiffusionPipeline, device: torch.device, freeze: bool = True, force_half: bool = False, compile: bool = False, is_sdxl: bool = False, ) -> DiffusionPipeline: if is_sdxl: return send_to_device_sdxl( pipeline=pipeline, device=device, freeze=freeze, force_half=force_half, compile=compile, ) logger.info(f"Sending pipeline to device \"{device.type}{device.index if device.index else ''}\"") unet_dtype, tenc_dtype, vae_dtype = get_model_dtypes(device, force_half) model_memory_format = get_memory_format(device) if hasattr(pipeline, 'controlnet'): unet_dtype = tenc_dtype = vae_dtype logger.info(f"-> Selected data types: {unet_dtype=},{tenc_dtype=},{vae_dtype=}") if hasattr(pipeline.controlnet, 'nets'): for i in range(len(pipeline.controlnet.nets)): pipeline.controlnet.nets[i] = pipeline.controlnet.nets[i].to(device=device, dtype=vae_dtype, memory_format=model_memory_format) else: if pipeline.controlnet: pipeline.controlnet = pipeline.controlnet.to(device=device, dtype=vae_dtype, memory_format=model_memory_format) if hasattr(pipeline, 'controlnet_map'): if pipeline.controlnet_map: for c in pipeline.controlnet_map: #pipeline.controlnet_map[c] = pipeline.controlnet_map[c].to(device=device, dtype=unet_dtype, memory_format=model_memory_format) pipeline.controlnet_map[c] = pipeline.controlnet_map[c].to(dtype=unet_dtype, memory_format=model_memory_format) if hasattr(pipeline, 'lora_map'): if pipeline.lora_map: pipeline.lora_map.to(device=device, dtype=unet_dtype) if hasattr(pipeline, 'lcm'): if pipeline.lcm: pipeline.lcm.to(device=device, dtype=unet_dtype) pipeline.unet = pipeline.unet.to(device=device, dtype=unet_dtype, memory_format=model_memory_format) pipeline.text_encoder = pipeline.text_encoder.to(device=device, dtype=tenc_dtype) pipeline.vae = pipeline.vae.to(device=device, dtype=vae_dtype, memory_format=model_memory_format) # Compile model if enabled if compile: if not isinstance(pipeline.unet, dynamo.OptimizedModule): allow_ops_in_compiled_graph() # make einops behave logger.warn("Enabling model compilation with TorchDynamo, this may take a while...") logger.warn("Model compilation is experimental and may not work as expected!") pipeline.unet = torch.compile( pipeline.unet, backend="inductor", mode="reduce-overhead", ) else: logger.debug("Skipping model compilation, already compiled!") return pipeline def send_to_device_sdxl( pipeline: StableDiffusionXLPipeline, device: torch.device, freeze: bool = True, force_half: bool = False, compile: bool = False, ) -> StableDiffusionXLPipeline: logger.info(f"Sending pipeline to device \"{device.type}{device.index if device.index else ''}\"") pipeline.unet = pipeline.unet.half() pipeline.text_encoder = pipeline.text_encoder.half() pipeline.text_encoder_2 = pipeline.text_encoder_2.half() if False: pipeline.to(device) else: pipeline.enable_model_cpu_offload() pipeline.enable_xformers_memory_efficient_attention() pipeline.enable_vae_slicing() pipeline.enable_vae_tiling() return pipeline def get_context_params( length: int, context: Optional[int] = None, overlap: Optional[int] = None, stride: Optional[int] = None, ): if context is None: context = min(length, 16) if overlap is None: overlap = context // 4 if stride is None: stride = 0 return context, overlap, stride