|
import logging |
|
from os import PathLike |
|
from pathlib import Path |
|
from typing import Optional |
|
|
|
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline |
|
from huggingface_hub import hf_hub_download, snapshot_download |
|
from tqdm.rich import tqdm |
|
|
|
from animatediff import HF_HUB_CACHE, HF_LIB_NAME, HF_LIB_VER, get_dir |
|
from animatediff.utils.util import path_from_cwd |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
data_dir = get_dir("data") |
|
checkpoint_dir = data_dir.joinpath("models/sd") |
|
pipeline_dir = data_dir.joinpath("models/huggingface") |
|
|
|
IGNORE_TF = ["*.git*", "*.h5", "tf_*"] |
|
IGNORE_FLAX = ["*.git*", "flax_*", "*.msgpack"] |
|
IGNORE_TF_FLAX = IGNORE_TF + IGNORE_FLAX |
|
|
|
|
|
class DownloadTqdm(tqdm): |
|
def __init__(self, *args, **kwargs): |
|
kwargs.update( |
|
{ |
|
"ncols": 100, |
|
"dynamic_ncols": False, |
|
"disable": None, |
|
} |
|
) |
|
super().__init__(*args, **kwargs) |
|
|
|
|
|
def get_hf_file( |
|
repo_id: Path, |
|
filename: str, |
|
target_dir: Path, |
|
subfolder: Optional[PathLike] = None, |
|
revision: Optional[str] = None, |
|
force: bool = False, |
|
) -> Path: |
|
target_path = target_dir.joinpath(filename) |
|
if target_path.exists() and force is not True: |
|
raise FileExistsError( |
|
f"File {path_from_cwd(target_path)} already exists! Pass force=True to overwrite" |
|
) |
|
|
|
target_dir.mkdir(exist_ok=True, parents=True) |
|
save_path = hf_hub_download( |
|
repo_id=str(repo_id), |
|
filename=filename, |
|
revision=revision or "main", |
|
subfolder=subfolder, |
|
local_dir=target_dir, |
|
local_dir_use_symlinks=False, |
|
cache_dir=HF_HUB_CACHE, |
|
resume_download=True, |
|
) |
|
return Path(save_path) |
|
|
|
|
|
def get_hf_repo( |
|
repo_id: Path, |
|
target_dir: Path, |
|
subfolder: Optional[PathLike] = None, |
|
revision: Optional[str] = None, |
|
force: bool = False, |
|
) -> Path: |
|
if target_dir.exists() and force is not True: |
|
raise FileExistsError( |
|
f"Target dir {path_from_cwd(target_dir)} already exists! Pass force=True to overwrite" |
|
) |
|
|
|
target_dir.mkdir(exist_ok=True, parents=True) |
|
save_path = snapshot_download( |
|
repo_id=str(repo_id), |
|
revision=revision or "main", |
|
subfolder=subfolder, |
|
library_name=HF_LIB_NAME, |
|
library_version=HF_LIB_VER, |
|
local_dir=target_dir, |
|
local_dir_use_symlinks=False, |
|
ignore_patterns=IGNORE_TF_FLAX, |
|
cache_dir=HF_HUB_CACHE, |
|
tqdm_class=DownloadTqdm, |
|
max_workers=2, |
|
resume_download=True, |
|
) |
|
return Path(save_path) |
|
|
|
|
|
def get_hf_pipeline( |
|
repo_id: Path, |
|
target_dir: Path, |
|
save: bool = True, |
|
force_download: bool = False, |
|
) -> StableDiffusionPipeline: |
|
pipeline_exists = target_dir.joinpath("model_index.json").exists() |
|
if pipeline_exists and force_download is not True: |
|
pipeline = StableDiffusionPipeline.from_pretrained( |
|
pretrained_model_name_or_path=target_dir, |
|
local_files_only=True, |
|
) |
|
else: |
|
target_dir.mkdir(exist_ok=True, parents=True) |
|
pipeline = StableDiffusionPipeline.from_pretrained( |
|
pretrained_model_name_or_path=str(repo_id).lstrip("./").replace("\\", "/"), |
|
cache_dir=HF_HUB_CACHE, |
|
resume_download=True, |
|
) |
|
if save and force_download: |
|
logger.warning(f"Pipeline already exists at {path_from_cwd(target_dir)}. Overwriting!") |
|
pipeline.save_pretrained(target_dir, safe_serialization=True) |
|
elif save and not pipeline_exists: |
|
logger.info(f"Saving pipeline to {path_from_cwd(target_dir)}") |
|
pipeline.save_pretrained(target_dir, safe_serialization=True) |
|
return pipeline |
|
|
|
def get_hf_pipeline_sdxl( |
|
repo_id: Path, |
|
target_dir: Path, |
|
save: bool = True, |
|
force_download: bool = False, |
|
) -> StableDiffusionXLPipeline: |
|
import torch |
|
pipeline_exists = target_dir.joinpath("model_index.json").exists() |
|
if pipeline_exists and force_download is not True: |
|
pipeline = StableDiffusionXLPipeline.from_pretrained( |
|
pretrained_model_name_or_path=target_dir, |
|
local_files_only=True, |
|
torch_dtype=torch.float16, use_safetensors=True, variant="fp16" |
|
) |
|
else: |
|
target_dir.mkdir(exist_ok=True, parents=True) |
|
pipeline = StableDiffusionXLPipeline.from_pretrained( |
|
pretrained_model_name_or_path=str(repo_id).lstrip("./").replace("\\", "/"), |
|
cache_dir=HF_HUB_CACHE, |
|
resume_download=True, |
|
torch_dtype=torch.float16, use_safetensors=True, variant="fp16" |
|
) |
|
if save and force_download: |
|
logger.warning(f"Pipeline already exists at {path_from_cwd(target_dir)}. Overwriting!") |
|
pipeline.save_pretrained(target_dir, safe_serialization=True) |
|
elif save and not pipeline_exists: |
|
logger.info(f"Saving pipeline to {path_from_cwd(target_dir)}") |
|
pipeline.save_pretrained(target_dir, safe_serialization=True) |
|
return pipeline |
|
|