TheNetherWatcher's picture
Upload folder using huggingface_hub
d0ffe9c verified
import logging
from os import PathLike
from pathlib import Path
from typing import Optional
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
from huggingface_hub import hf_hub_download, snapshot_download
from tqdm.rich import tqdm
from animatediff import HF_HUB_CACHE, HF_LIB_NAME, HF_LIB_VER, get_dir
from animatediff.utils.util import path_from_cwd
logger = logging.getLogger(__name__)
data_dir = get_dir("data")
checkpoint_dir = data_dir.joinpath("models/sd")
pipeline_dir = data_dir.joinpath("models/huggingface")
IGNORE_TF = ["*.git*", "*.h5", "tf_*"]
IGNORE_FLAX = ["*.git*", "flax_*", "*.msgpack"]
IGNORE_TF_FLAX = IGNORE_TF + IGNORE_FLAX
class DownloadTqdm(tqdm):
def __init__(self, *args, **kwargs):
kwargs.update(
{
"ncols": 100,
"dynamic_ncols": False,
"disable": None,
}
)
super().__init__(*args, **kwargs)
def get_hf_file(
repo_id: Path,
filename: str,
target_dir: Path,
subfolder: Optional[PathLike] = None,
revision: Optional[str] = None,
force: bool = False,
) -> Path:
target_path = target_dir.joinpath(filename)
if target_path.exists() and force is not True:
raise FileExistsError(
f"File {path_from_cwd(target_path)} already exists! Pass force=True to overwrite"
)
target_dir.mkdir(exist_ok=True, parents=True)
save_path = hf_hub_download(
repo_id=str(repo_id),
filename=filename,
revision=revision or "main",
subfolder=subfolder,
local_dir=target_dir,
local_dir_use_symlinks=False,
cache_dir=HF_HUB_CACHE,
resume_download=True,
)
return Path(save_path)
def get_hf_repo(
repo_id: Path,
target_dir: Path,
subfolder: Optional[PathLike] = None,
revision: Optional[str] = None,
force: bool = False,
) -> Path:
if target_dir.exists() and force is not True:
raise FileExistsError(
f"Target dir {path_from_cwd(target_dir)} already exists! Pass force=True to overwrite"
)
target_dir.mkdir(exist_ok=True, parents=True)
save_path = snapshot_download(
repo_id=str(repo_id),
revision=revision or "main",
subfolder=subfolder,
library_name=HF_LIB_NAME,
library_version=HF_LIB_VER,
local_dir=target_dir,
local_dir_use_symlinks=False,
ignore_patterns=IGNORE_TF_FLAX,
cache_dir=HF_HUB_CACHE,
tqdm_class=DownloadTqdm,
max_workers=2,
resume_download=True,
)
return Path(save_path)
def get_hf_pipeline(
repo_id: Path,
target_dir: Path,
save: bool = True,
force_download: bool = False,
) -> StableDiffusionPipeline:
pipeline_exists = target_dir.joinpath("model_index.json").exists()
if pipeline_exists and force_download is not True:
pipeline = StableDiffusionPipeline.from_pretrained(
pretrained_model_name_or_path=target_dir,
local_files_only=True,
)
else:
target_dir.mkdir(exist_ok=True, parents=True)
pipeline = StableDiffusionPipeline.from_pretrained(
pretrained_model_name_or_path=str(repo_id).lstrip("./").replace("\\", "/"),
cache_dir=HF_HUB_CACHE,
resume_download=True,
)
if save and force_download:
logger.warning(f"Pipeline already exists at {path_from_cwd(target_dir)}. Overwriting!")
pipeline.save_pretrained(target_dir, safe_serialization=True)
elif save and not pipeline_exists:
logger.info(f"Saving pipeline to {path_from_cwd(target_dir)}")
pipeline.save_pretrained(target_dir, safe_serialization=True)
return pipeline
def get_hf_pipeline_sdxl(
repo_id: Path,
target_dir: Path,
save: bool = True,
force_download: bool = False,
) -> StableDiffusionXLPipeline:
import torch
pipeline_exists = target_dir.joinpath("model_index.json").exists()
if pipeline_exists and force_download is not True:
pipeline = StableDiffusionXLPipeline.from_pretrained(
pretrained_model_name_or_path=target_dir,
local_files_only=True,
torch_dtype=torch.float16, use_safetensors=True, variant="fp16"
)
else:
target_dir.mkdir(exist_ok=True, parents=True)
pipeline = StableDiffusionXLPipeline.from_pretrained(
pretrained_model_name_or_path=str(repo_id).lstrip("./").replace("\\", "/"),
cache_dir=HF_HUB_CACHE,
resume_download=True,
torch_dtype=torch.float16, use_safetensors=True, variant="fp16"
)
if save and force_download:
logger.warning(f"Pipeline already exists at {path_from_cwd(target_dir)}. Overwriting!")
pipeline.save_pretrained(target_dir, safe_serialization=True)
elif save and not pipeline_exists:
logger.info(f"Saving pipeline to {path_from_cwd(target_dir)}")
pipeline.save_pretrained(target_dir, safe_serialization=True)
return pipeline