File size: 4,335 Bytes
d0ffe9c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import logging
from typing import Optional
import torch
import torch._dynamo as dynamo
from diffusers import (DiffusionPipeline, StableDiffusionPipeline,
StableDiffusionXLPipeline)
from einops._torch_specific import allow_ops_in_compiled_graph
from animatediff.utils.device import get_memory_format, get_model_dtypes
from animatediff.utils.model import nop_train
logger = logging.getLogger(__name__)
def send_to_device(
pipeline: DiffusionPipeline,
device: torch.device,
freeze: bool = True,
force_half: bool = False,
compile: bool = False,
is_sdxl: bool = False,
) -> DiffusionPipeline:
if is_sdxl:
return send_to_device_sdxl(
pipeline=pipeline,
device=device,
freeze=freeze,
force_half=force_half,
compile=compile,
)
logger.info(f"Sending pipeline to device \"{device.type}{device.index if device.index else ''}\"")
unet_dtype, tenc_dtype, vae_dtype = get_model_dtypes(device, force_half)
model_memory_format = get_memory_format(device)
if hasattr(pipeline, 'controlnet'):
unet_dtype = tenc_dtype = vae_dtype
logger.info(f"-> Selected data types: {unet_dtype=},{tenc_dtype=},{vae_dtype=}")
if hasattr(pipeline.controlnet, 'nets'):
for i in range(len(pipeline.controlnet.nets)):
pipeline.controlnet.nets[i] = pipeline.controlnet.nets[i].to(device=device, dtype=vae_dtype, memory_format=model_memory_format)
else:
if pipeline.controlnet:
pipeline.controlnet = pipeline.controlnet.to(device=device, dtype=vae_dtype, memory_format=model_memory_format)
if hasattr(pipeline, 'controlnet_map'):
if pipeline.controlnet_map:
for c in pipeline.controlnet_map:
#pipeline.controlnet_map[c] = pipeline.controlnet_map[c].to(device=device, dtype=unet_dtype, memory_format=model_memory_format)
pipeline.controlnet_map[c] = pipeline.controlnet_map[c].to(dtype=unet_dtype, memory_format=model_memory_format)
if hasattr(pipeline, 'lora_map'):
if pipeline.lora_map:
pipeline.lora_map.to(device=device, dtype=unet_dtype)
if hasattr(pipeline, 'lcm'):
if pipeline.lcm:
pipeline.lcm.to(device=device, dtype=unet_dtype)
pipeline.unet = pipeline.unet.to(device=device, dtype=unet_dtype, memory_format=model_memory_format)
pipeline.text_encoder = pipeline.text_encoder.to(device=device, dtype=tenc_dtype)
pipeline.vae = pipeline.vae.to(device=device, dtype=vae_dtype, memory_format=model_memory_format)
# Compile model if enabled
if compile:
if not isinstance(pipeline.unet, dynamo.OptimizedModule):
allow_ops_in_compiled_graph() # make einops behave
logger.warn("Enabling model compilation with TorchDynamo, this may take a while...")
logger.warn("Model compilation is experimental and may not work as expected!")
pipeline.unet = torch.compile(
pipeline.unet,
backend="inductor",
mode="reduce-overhead",
)
else:
logger.debug("Skipping model compilation, already compiled!")
return pipeline
def send_to_device_sdxl(
pipeline: StableDiffusionXLPipeline,
device: torch.device,
freeze: bool = True,
force_half: bool = False,
compile: bool = False,
) -> StableDiffusionXLPipeline:
logger.info(f"Sending pipeline to device \"{device.type}{device.index if device.index else ''}\"")
pipeline.unet = pipeline.unet.half()
pipeline.text_encoder = pipeline.text_encoder.half()
pipeline.text_encoder_2 = pipeline.text_encoder_2.half()
if False:
pipeline.to(device)
else:
pipeline.enable_model_cpu_offload()
pipeline.enable_xformers_memory_efficient_attention()
pipeline.enable_vae_slicing()
pipeline.enable_vae_tiling()
return pipeline
def get_context_params(
length: int,
context: Optional[int] = None,
overlap: Optional[int] = None,
stride: Optional[int] = None,
):
if context is None:
context = min(length, 16)
if overlap is None:
overlap = context // 4
if stride is None:
stride = 0
return context, overlap, stride
|