import logging from os import PathLike from pathlib import Path from typing import Optional from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline from huggingface_hub import hf_hub_download, snapshot_download from tqdm.rich import tqdm from animatediff import HF_HUB_CACHE, HF_LIB_NAME, HF_LIB_VER, get_dir from animatediff.utils.util import path_from_cwd logger = logging.getLogger(__name__) data_dir = get_dir("data") checkpoint_dir = data_dir.joinpath("models/sd") pipeline_dir = data_dir.joinpath("models/huggingface") IGNORE_TF = ["*.git*", "*.h5", "tf_*"] IGNORE_FLAX = ["*.git*", "flax_*", "*.msgpack"] IGNORE_TF_FLAX = IGNORE_TF + IGNORE_FLAX class DownloadTqdm(tqdm): def __init__(self, *args, **kwargs): kwargs.update( { "ncols": 100, "dynamic_ncols": False, "disable": None, } ) super().__init__(*args, **kwargs) def get_hf_file( repo_id: Path, filename: str, target_dir: Path, subfolder: Optional[PathLike] = None, revision: Optional[str] = None, force: bool = False, ) -> Path: target_path = target_dir.joinpath(filename) if target_path.exists() and force is not True: raise FileExistsError( f"File {path_from_cwd(target_path)} already exists! Pass force=True to overwrite" ) target_dir.mkdir(exist_ok=True, parents=True) save_path = hf_hub_download( repo_id=str(repo_id), filename=filename, revision=revision or "main", subfolder=subfolder, local_dir=target_dir, local_dir_use_symlinks=False, cache_dir=HF_HUB_CACHE, resume_download=True, ) return Path(save_path) def get_hf_repo( repo_id: Path, target_dir: Path, subfolder: Optional[PathLike] = None, revision: Optional[str] = None, force: bool = False, ) -> Path: if target_dir.exists() and force is not True: raise FileExistsError( f"Target dir {path_from_cwd(target_dir)} already exists! Pass force=True to overwrite" ) target_dir.mkdir(exist_ok=True, parents=True) save_path = snapshot_download( repo_id=str(repo_id), revision=revision or "main", subfolder=subfolder, library_name=HF_LIB_NAME, library_version=HF_LIB_VER, local_dir=target_dir, local_dir_use_symlinks=False, ignore_patterns=IGNORE_TF_FLAX, cache_dir=HF_HUB_CACHE, tqdm_class=DownloadTqdm, max_workers=2, resume_download=True, ) return Path(save_path) def get_hf_pipeline( repo_id: Path, target_dir: Path, save: bool = True, force_download: bool = False, ) -> StableDiffusionPipeline: pipeline_exists = target_dir.joinpath("model_index.json").exists() if pipeline_exists and force_download is not True: pipeline = StableDiffusionPipeline.from_pretrained( pretrained_model_name_or_path=target_dir, local_files_only=True, ) else: target_dir.mkdir(exist_ok=True, parents=True) pipeline = StableDiffusionPipeline.from_pretrained( pretrained_model_name_or_path=str(repo_id).lstrip("./").replace("\\", "/"), cache_dir=HF_HUB_CACHE, resume_download=True, ) if save and force_download: logger.warning(f"Pipeline already exists at {path_from_cwd(target_dir)}. Overwriting!") pipeline.save_pretrained(target_dir, safe_serialization=True) elif save and not pipeline_exists: logger.info(f"Saving pipeline to {path_from_cwd(target_dir)}") pipeline.save_pretrained(target_dir, safe_serialization=True) return pipeline def get_hf_pipeline_sdxl( repo_id: Path, target_dir: Path, save: bool = True, force_download: bool = False, ) -> StableDiffusionXLPipeline: import torch pipeline_exists = target_dir.joinpath("model_index.json").exists() if pipeline_exists and force_download is not True: pipeline = StableDiffusionXLPipeline.from_pretrained( pretrained_model_name_or_path=target_dir, local_files_only=True, torch_dtype=torch.float16, use_safetensors=True, variant="fp16" ) else: target_dir.mkdir(exist_ok=True, parents=True) pipeline = StableDiffusionXLPipeline.from_pretrained( pretrained_model_name_or_path=str(repo_id).lstrip("./").replace("\\", "/"), cache_dir=HF_HUB_CACHE, resume_download=True, torch_dtype=torch.float16, use_safetensors=True, variant="fp16" ) if save and force_download: logger.warning(f"Pipeline already exists at {path_from_cwd(target_dir)}. Overwriting!") pipeline.save_pretrained(target_dir, safe_serialization=True) elif save and not pipeline_exists: logger.info(f"Saving pipeline to {path_from_cwd(target_dir)}") pipeline.save_pretrained(target_dir, safe_serialization=True) return pipeline