TheNetherWatcher's picture
Upload folder using huggingface_hub
d0ffe9c verified
import glob
import logging
import os.path
from datetime import datetime
from pathlib import Path
from typing import Annotated, Optional
if False:
if 'PYTORCH_CUDA_ALLOC_CONF' in os.environ:
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = ",backend:cudaMallocAsync"
else:
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "backend:cudaMallocAsync"
#"garbage_collection_threshold:0.6"
# max_split_size_mb:1024"
# "backend:cudaMallocAsync"
# roundup_power2_divisions:4
print(f"{os.environ['PYTORCH_CUDA_ALLOC_CONF']=}")
if False:
os.environ['PYTORCH_NO_CUDA_MEMORY_CACHING']="1"
import torch
import typer
from diffusers import DiffusionPipeline
from diffusers.utils.logging import \
set_verbosity_error as set_diffusers_verbosity_error
from rich.logging import RichHandler
from animatediff import __version__, console, get_dir
from animatediff.generate import (controlnet_preprocess, create_pipeline,
create_us_pipeline, img2img_preprocess,
ip_adapter_preprocess,
load_controlnet_models, prompt_preprocess,
region_preprocess, run_inference,
run_upscale, save_output,
unload_controlnet_models,
wild_card_conversion)
from animatediff.pipelines import AnimationPipeline, load_text_embeddings
from animatediff.settings import (CKPT_EXTENSIONS, InferenceConfig,
ModelConfig, get_infer_config,
get_model_config)
from animatediff.utils.civitai2config import generate_config_from_civitai_info
from animatediff.utils.model import (checkpoint_to_pipeline,
fix_checkpoint_if_needed, get_base_model)
from animatediff.utils.pipeline import get_context_params, send_to_device
from animatediff.utils.util import (extract_frames, is_sdxl_checkpoint,
is_v2_motion_module, path_from_cwd,
save_frames, save_imgs, save_video,
set_tensor_interpolation_method, show_gpu)
from animatediff.utils.wild_card import replace_wild_card
cli: typer.Typer = typer.Typer(
context_settings=dict(help_option_names=["-h", "--help"]),
rich_markup_mode="rich",
no_args_is_help=True,
pretty_exceptions_show_locals=False,
)
data_dir = get_dir("data")
checkpoint_dir = data_dir.joinpath("models/sd")
pipeline_dir = data_dir.joinpath("models/huggingface")
try:
import google.colab
IN_COLAB = True
except:
IN_COLAB = False
if IN_COLAB:
import sys
logging.basicConfig(
level=logging.INFO,
stream=sys.stdout,
format="%(message)s",
datefmt="%H:%M:%S",
force=True,
)
else:
logging.basicConfig(
level=logging.INFO,
format="%(message)s",
handlers=[
RichHandler(console=console, rich_tracebacks=True),
],
datefmt="%H:%M:%S",
force=True,
)
logger = logging.getLogger(__name__)
from importlib.metadata import version as meta_version
from packaging import version
diffuser_ver = meta_version('diffusers')
logger.info(f"{diffuser_ver=}")
if version.parse(diffuser_ver) < version.parse('0.23.0'):
logger.error(f"The version of diffusers is out of date")
logger.error(f"python -m pip install diffusers==0.23.0")
raise ImportError("Please update diffusers to 0.23.0")
try:
from animatediff.rife import app as rife_app
cli.add_typer(rife_app, name="rife")
except ImportError:
logger.debug("RIFE not available, skipping...", exc_info=True)
rife_app = None
from animatediff.stylize import stylize
cli.add_typer(stylize, name="stylize")
# mildly cursed globals to allow for reuse of the pipeline if we're being called as a module
g_pipeline: Optional[DiffusionPipeline] = None
last_model_path: Optional[Path] = None
def version_callback(value: bool):
if value:
console.print(f"AnimateDiff v{__version__}")
raise typer.Exit()
def get_random():
import sys
import numpy as np
return int(np.random.randint(sys.maxsize, dtype=np.int64))
@cli.command()
def generate(
config_path: Annotated[
Path,
typer.Option(
"--config-path",
"-c",
path_type=Path,
exists=True,
readable=True,
dir_okay=False,
help="Path to a prompt configuration JSON file",
),
] = Path("config/prompts/01-ToonYou.json"),
width: Annotated[
int,
typer.Option(
"--width",
"-W",
min=64,
max=3840,
help="Width of generated frames",
rich_help_panel="Generation",
),
] = 512,
height: Annotated[
int,
typer.Option(
"--height",
"-H",
min=64,
max=2160,
help="Height of generated frames",
rich_help_panel="Generation",
),
] = 512,
length: Annotated[
int,
typer.Option(
"--length",
"-L",
min=1,
max=9999,
help="Number of frames to generate",
rich_help_panel="Generation",
),
] = 16,
context: Annotated[
Optional[int],
typer.Option(
"--context",
"-C",
min=1,
max=32,
help="Number of frames to condition on (default: 16)",
show_default=False,
rich_help_panel="Generation",
),
] = 16,
overlap: Annotated[
Optional[int],
typer.Option(
"--overlap",
"-O",
min=0,
max=12,
help="Number of frames to overlap in context (default: context//4)",
show_default=False,
rich_help_panel="Generation",
),
] = None,
stride: Annotated[
Optional[int],
typer.Option(
"--stride",
"-S",
min=0,
max=8,
help="Max motion stride as a power of 2 (default: 0)",
show_default=False,
rich_help_panel="Generation",
),
] = None,
repeats: Annotated[
int,
typer.Option(
"--repeats",
"-r",
min=1,
max=99,
help="Number of times to repeat the prompt (default: 1)",
show_default=False,
rich_help_panel="Generation",
),
] = 1,
device: Annotated[
str,
typer.Option(
"--device", "-d", help="Device to run on (cpu, cuda, cuda:id)", rich_help_panel="Advanced"
),
] = "cuda",
use_xformers: Annotated[
bool,
typer.Option(
"--xformers",
"-x",
is_flag=True,
help="Use XFormers instead of SDP Attention",
rich_help_panel="Advanced",
),
] = False,
force_half_vae: Annotated[
bool,
typer.Option(
"--half-vae",
is_flag=True,
help="Force VAE to use fp16 (not recommended)",
rich_help_panel="Advanced",
),
] = False,
out_dir: Annotated[
Path,
typer.Option(
"--out-dir",
"-o",
path_type=Path,
file_okay=False,
help="Directory for output folders (frames, gifs, etc)",
rich_help_panel="Output",
),
] = Path("output/"),
no_frames: Annotated[
bool,
typer.Option(
"--no-frames",
"-N",
is_flag=True,
help="Don't save frames, only the animation",
rich_help_panel="Output",
),
] = False,
save_merged: Annotated[
bool,
typer.Option(
"--save-merged",
"-m",
is_flag=True,
help="Save a merged animation of all prompts",
rich_help_panel="Output",
),
] = False,
version: Annotated[
Optional[bool],
typer.Option(
"--version",
"-v",
callback=version_callback,
is_eager=True,
is_flag=True,
help="Show version",
),
] = None,
):
"""
Do the thing. Make the animation happen. Waow.
"""
# be quiet, diffusers. we care not for your safety checker
set_diffusers_verbosity_error()
#torch.set_flush_denormal(True)
config_path = config_path.absolute()
logger.info(f"Using generation config: {path_from_cwd(config_path)}")
model_config: ModelConfig = get_model_config(config_path)
is_sdxl = is_sdxl_checkpoint(data_dir.joinpath(model_config.path))
if is_sdxl:
is_v2 = False
else:
is_v2 = is_v2_motion_module(data_dir.joinpath(model_config.motion_module))
infer_config: InferenceConfig = get_infer_config(is_v2, is_sdxl)
set_tensor_interpolation_method( model_config.tensor_interpolation_slerp )
# set sane defaults for context, overlap, and stride if not supplied
context, overlap, stride = get_context_params(length, context, overlap, stride)
if (not is_v2) and (not is_sdxl) and (context > 24):
logger.warning( "For motion module v1, the maximum value of context is 24. Set to 24" )
context = 24
# turn the device string into a torch.device
device: torch.device = torch.device(device)
model_name_or_path = Path("runwayml/stable-diffusion-v1-5") if not is_sdxl else Path("stabilityai/stable-diffusion-xl-base-1.0")
# Get the base model if we don't have it already
logger.info(f"Using base model: {model_name_or_path}")
base_model_path: Path = get_base_model(model_name_or_path, local_dir=get_dir("data/models/huggingface"), is_sdxl=is_sdxl)
# get a timestamp for the output directory
time_str = datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
# make the output directory
save_dir = out_dir.joinpath(f"{time_str}-{model_config.save_name}")
save_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Will save outputs to ./{path_from_cwd(save_dir)}")
controlnet_image_map, controlnet_type_map, controlnet_ref_map, controlnet_no_shrink = controlnet_preprocess(model_config.controlnet_map, width, height, length, save_dir, device, is_sdxl)
img2img_map = img2img_preprocess(model_config.img2img_map, width, height, length, save_dir)
# beware the pipeline
global g_pipeline
global last_model_path
pipeline_already_loaded = False
if g_pipeline is None or last_model_path != model_config.path.resolve():
g_pipeline = create_pipeline(
base_model=base_model_path,
model_config=model_config,
infer_config=infer_config,
use_xformers=use_xformers,
video_length=length,
is_sdxl=is_sdxl
)
last_model_path = model_config.path.resolve()
else:
logger.info("Pipeline already loaded, skipping initialization")
# reload TIs; create_pipeline does this for us, but they may have changed
# since load time if we're being called from another package
#load_text_embeddings(g_pipeline, is_sdxl=is_sdxl)
pipeline_already_loaded = True
load_controlnet_models(pipe=g_pipeline, model_config=model_config, is_sdxl=is_sdxl)
# if g_pipeline.device == device:
if pipeline_already_loaded:
logger.info("Pipeline already on the correct device, skipping device transfer")
else:
g_pipeline = send_to_device(
g_pipeline, device, freeze=True, force_half=force_half_vae, compile=model_config.compile, is_sdxl=is_sdxl
)
torch.cuda.empty_cache()
apply_lcm_lora = False
if model_config.lcm_map:
if "enable" in model_config.lcm_map:
apply_lcm_lora = model_config.lcm_map["enable"]
# save raw config to output directory
save_config_path = save_dir.joinpath("raw_prompt.json")
save_config_path.write_text(model_config.json(indent=4), encoding="utf-8")
# fix seed
for i, s in enumerate(model_config.seed):
if s == -1:
model_config.seed[i] = get_random()
# wildcard conversion
wild_card_conversion(model_config)
is_init_img_exist = img2img_map != None
region_condi_list, region_list, ip_adapter_config_map, region2index = region_preprocess(model_config, width, height, length, save_dir, is_init_img_exist, is_sdxl)
if controlnet_type_map:
for c in controlnet_type_map:
tmp_r = [region2index[r] for r in controlnet_type_map[c]["control_region_list"]]
controlnet_type_map[c]["control_region_list"] = [r for r in tmp_r if r != -1]
logger.info(f"{c=} / {controlnet_type_map[c]['control_region_list']}")
# save config to output directory
logger.info("Saving prompt config to output directory")
save_config_path = save_dir.joinpath("prompt.json")
save_config_path.write_text(model_config.json(indent=4), encoding="utf-8")
num_negatives = len(model_config.n_prompt)
num_seeds = len(model_config.seed)
gen_total = repeats # total number of generations
logger.info("Initialization complete!")
logger.info(f"Generating {gen_total} animations")
outputs = []
gen_num = 0 # global generation index
# repeat the prompts if we're doing multiple runs
for _ in range(repeats):
if model_config.prompt_map:
# get the index of the prompt, negative, and seed
idx = gen_num
logger.info(f"Running generation {gen_num + 1} of {gen_total}")
# allow for reusing the same negative prompt(s) and seed(s) for multiple prompts
n_prompt = model_config.n_prompt[idx % num_negatives]
seed = model_config.seed[idx % num_seeds]
logger.info(f"Generation seed: {seed}")
output = run_inference(
pipeline=g_pipeline,
n_prompt=n_prompt,
seed=seed,
steps=model_config.steps,
guidance_scale=model_config.guidance_scale,
unet_batch_size=model_config.unet_batch_size,
width=width,
height=height,
duration=length,
idx=gen_num,
out_dir=save_dir,
context_schedule=model_config.context_schedule,
context_frames=context,
context_overlap=overlap,
context_stride=stride,
clip_skip=model_config.clip_skip,
controlnet_map=model_config.controlnet_map,
controlnet_image_map=controlnet_image_map,
controlnet_type_map=controlnet_type_map,
controlnet_ref_map=controlnet_ref_map,
controlnet_no_shrink=controlnet_no_shrink,
no_frames=no_frames,
img2img_map=img2img_map,
ip_adapter_config_map=ip_adapter_config_map,
region_list=region_list,
region_condi_list=region_condi_list,
output_map = model_config.output,
is_single_prompt_mode=model_config.is_single_prompt_mode,
is_sdxl=is_sdxl,
apply_lcm_lora=apply_lcm_lora,
gradual_latent_map=model_config.gradual_latent_hires_fix_map
)
outputs.append(output)
torch.cuda.empty_cache()
# increment the generation number
gen_num += 1
unload_controlnet_models(pipe=g_pipeline)
logger.info("Generation complete!")
if save_merged:
logger.info("Output merged output video...")
merged_output = torch.concat(outputs, dim=0)
save_video(merged_output, save_dir.joinpath("final.gif"))
logger.info("Done, exiting...")
cli.info
return save_dir
@cli.command()
def tile_upscale(
frames_dir: Annotated[
Path,
typer.Argument(path_type=Path, file_okay=False, exists=True, help="Path to source frames directory"),
] = ...,
model_name_or_path: Annotated[
Path,
typer.Option(
...,
"--model-path",
"-m",
path_type=Path,
help="Base model to use (path or HF repo ID). You probably don't need to change this.",
),
] = Path("runwayml/stable-diffusion-v1-5"),
config_path: Annotated[
Path,
typer.Option(
"--config-path",
"-c",
path_type=Path,
exists=True,
readable=True,
dir_okay=False,
help="Path to a prompt configuration JSON file. default is frames_dir/../prompt.json",
),
] = None,
width: Annotated[
int,
typer.Option(
"--width",
"-W",
min=-1,
max=3840,
help="Width of generated frames",
rich_help_panel="Generation",
),
] = -1,
height: Annotated[
int,
typer.Option(
"--height",
"-H",
min=-1,
max=2160,
help="Height of generated frames",
rich_help_panel="Generation",
),
] = -1,
device: Annotated[
str,
typer.Option(
"--device", "-d", help="Device to run on (cpu, cuda, cuda:id)", rich_help_panel="Advanced"
),
] = "cuda",
use_xformers: Annotated[
bool,
typer.Option(
"--xformers",
"-x",
is_flag=True,
help="Use XFormers instead of SDP Attention",
rich_help_panel="Advanced",
),
] = False,
force_half_vae: Annotated[
bool,
typer.Option(
"--half-vae",
is_flag=True,
help="Force VAE to use fp16 (not recommended)",
rich_help_panel="Advanced",
),
] = False,
out_dir: Annotated[
Path,
typer.Option(
"--out-dir",
"-o",
path_type=Path,
file_okay=False,
help="Directory for output folders (frames, gifs, etc)",
rich_help_panel="Output",
),
] = Path("upscaled/"),
no_frames: Annotated[
bool,
typer.Option(
"--no-frames",
"-N",
is_flag=True,
help="Don't save frames, only the animation",
rich_help_panel="Output",
),
] = False,
):
"""Upscale frames using controlnet tile"""
# be quiet, diffusers. we care not for your safety checker
set_diffusers_verbosity_error()
if width < 0 and height < 0:
raise ValueError(f"invalid width,height: {width},{height} \n At least one of them must be specified.")
if not config_path:
tmp = frames_dir.parent.joinpath("prompt.json")
if tmp.is_file():
config_path = tmp
config_path = config_path.absolute()
logger.info(f"Using generation config: {path_from_cwd(config_path)}")
model_config: ModelConfig = get_model_config(config_path)
is_sdxl = is_sdxl_checkpoint(data_dir.joinpath(model_config.path))
if is_sdxl:
raise ValueError("Currently SDXL model is not available for this command.")
infer_config: InferenceConfig = get_infer_config(is_v2_motion_module(data_dir.joinpath(model_config.motion_module)), is_sdxl)
frames_dir = frames_dir.absolute()
set_tensor_interpolation_method( model_config.tensor_interpolation_slerp )
# turn the device string into a torch.device
device: torch.device = torch.device(device)
# get a timestamp for the output directory
time_str = datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
# make the output directory
save_dir = out_dir.joinpath(f"{time_str}-{model_config.save_name}")
save_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Will save outputs to ./{path_from_cwd(save_dir)}")
if "controlnet_tile" not in model_config.upscale_config:
model_config.upscale_config["controlnet_tile"] = {
"enable": True,
"controlnet_conditioning_scale": 1.0,
"guess_mode": False,
"control_guidance_start": 0.0,
"control_guidance_end": 1.0,
}
use_controlnet_ref = False
use_controlnet_tile = False
use_controlnet_line_anime = False
use_controlnet_ip2p = False
if model_config.upscale_config:
use_controlnet_ref = model_config.upscale_config["controlnet_ref"]["enable"] if "controlnet_ref" in model_config.upscale_config else False
use_controlnet_tile = model_config.upscale_config["controlnet_tile"]["enable"] if "controlnet_tile" in model_config.upscale_config else False
use_controlnet_line_anime = model_config.upscale_config["controlnet_line_anime"]["enable"] if "controlnet_line_anime" in model_config.upscale_config else False
use_controlnet_ip2p = model_config.upscale_config["controlnet_ip2p"]["enable"] if "controlnet_ip2p" in model_config.upscale_config else False
if use_controlnet_tile == False:
if use_controlnet_line_anime==False:
if use_controlnet_ip2p == False:
raise ValueError(f"At least one of them should be enabled. {use_controlnet_tile=}, {use_controlnet_line_anime=}, {use_controlnet_ip2p=}")
# beware the pipeline
us_pipeline = create_us_pipeline(
model_config=model_config,
infer_config=infer_config,
use_xformers=use_xformers,
use_controlnet_ref=use_controlnet_ref,
use_controlnet_tile=use_controlnet_tile,
use_controlnet_line_anime=use_controlnet_line_anime,
use_controlnet_ip2p=use_controlnet_ip2p,
)
if us_pipeline.device == device:
logger.info("Pipeline already on the correct device, skipping device transfer")
else:
us_pipeline = send_to_device(
us_pipeline, device, freeze=True, force_half=force_half_vae, compile=model_config.compile
)
model_config.result = { "original_frames": str(frames_dir) }
# save config to output directory
logger.info("Saving prompt config to output directory")
save_config_path = save_dir.joinpath("prompt.json")
save_config_path.write_text(model_config.json(indent=4), encoding="utf-8")
num_prompts = 1
num_negatives = len(model_config.n_prompt)
num_seeds = len(model_config.seed)
logger.info("Initialization complete!")
gen_num = 0 # global generation index
org_images = sorted(glob.glob( os.path.join(frames_dir, "[0-9]*.png"), recursive=False))
length = len(org_images)
if model_config.prompt_map:
# get the index of the prompt, negative, and seed
idx = gen_num % num_prompts
logger.info(f"Running generation {gen_num + 1} of {1} (prompt {idx + 1})")
# allow for reusing the same negative prompt(s) and seed(s) for multiple prompts
n_prompt = model_config.n_prompt[idx % num_negatives]
seed = seed = model_config.seed[idx % num_seeds]
if seed == -1:
seed = get_random()
logger.info(f"Generation seed: {seed}")
prompt_map = {}
for k in model_config.prompt_map.keys():
if int(k) < length:
pr = model_config.prompt_map[k]
if model_config.head_prompt:
pr = model_config.head_prompt + "," + pr
if model_config.tail_prompt:
pr = pr + "," + model_config.tail_prompt
prompt_map[int(k)]=pr
if model_config.upscale_config:
upscaled_output = run_upscale(
org_imgs=org_images,
pipeline=us_pipeline,
prompt_map=prompt_map,
n_prompt=n_prompt,
seed=seed,
steps=model_config.steps,
guidance_scale=model_config.guidance_scale,
clip_skip=model_config.clip_skip,
us_width=width,
us_height=height,
idx=gen_num,
out_dir=save_dir,
upscale_config=model_config.upscale_config,
use_controlnet_ref=use_controlnet_ref,
use_controlnet_tile=use_controlnet_tile,
use_controlnet_line_anime=use_controlnet_line_anime,
use_controlnet_ip2p=use_controlnet_ip2p,
no_frames = no_frames,
output_map = model_config.output,
)
torch.cuda.empty_cache()
# increment the generation number
gen_num += 1
logger.info("Generation complete!")
logger.info("Done, exiting...")
cli.info
return save_dir
@cli.command()
def civitai2config(
lora_dir: Annotated[
Path,
typer.Argument(path_type=Path, file_okay=False, exists=True, help="Path to loras directory"),
] = ...,
config_org: Annotated[
Path,
typer.Option(
"--config-org",
"-c",
path_type=Path,
dir_okay=False,
exists=True,
help="Path to original config file",
),
] = Path("config/prompts/prompt_travel.json"),
out_dir: Annotated[
Optional[Path],
typer.Option(
"--out-dir",
"-o",
path_type=Path,
file_okay=False,
help="Target directory for generated configs",
),
] = Path("config/prompts/converted/"),
lora_weight: Annotated[
float,
typer.Option(
"--lora_weight",
"-l",
min=0.0,
max=3.0,
help="Lora weight",
),
] = 0.75,
):
"""Generate config file from *.civitai.info"""
out_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Generate config files from: {lora_dir}")
generate_config_from_civitai_info(lora_dir,config_org,out_dir, lora_weight)
logger.info(f"saved at: {out_dir.absolute()}")
@cli.command()
def convert(
checkpoint: Annotated[
Path,
typer.Option(
"--checkpoint",
"-i",
path_type=Path,
dir_okay=False,
exists=True,
help="Path to a model checkpoint file",
),
] = ...,
out_dir: Annotated[
Optional[Path],
typer.Option(
"--out-dir",
"-o",
path_type=Path,
file_okay=False,
help="Target directory for converted model",
),
] = None,
):
"""Convert a StableDiffusion checkpoint into a Diffusers pipeline"""
logger.info(f"Converting checkpoint: {checkpoint}")
_, pipeline_dir = checkpoint_to_pipeline(checkpoint, target_dir=out_dir)
logger.info(f"Converted to HuggingFace pipeline at {pipeline_dir}")
@cli.command()
def fix_checkpoint(
checkpoint: Annotated[
Path,
typer.Argument(path_type=Path, dir_okay=False, exists=True, help="Path to a model checkpoint file"),
] = ...,
debug: Annotated[
bool,
typer.Option(
"--debug",
"-d",
is_flag=True,
rich_help_panel="Debug",
),
] = False,
):
"""Fix checkpoint with error "AttributeError: 'Attention' object has no attribute 'to_to_k'" on loading"""
set_diffusers_verbosity_error()
logger.info(f"Converting checkpoint: {checkpoint}")
fix_checkpoint_if_needed(checkpoint, debug)
@cli.command()
def merge(
checkpoint: Annotated[
Path,
typer.Option(
"--checkpoint",
"-i",
path_type=Path,
dir_okay=False,
exists=True,
help="Path to a model checkpoint file",
),
] = ...,
out_dir: Annotated[
Optional[Path],
typer.Option(
"--out-dir",
"-o",
path_type=Path,
file_okay=False,
help="Target directory for converted model",
),
] = None,
):
"""Convert a StableDiffusion checkpoint into an AnimationPipeline"""
raise NotImplementedError("Sorry, haven't implemented this yet!")
# if we have a checkpoint, convert it to HF automagically
if checkpoint.is_file() and checkpoint.suffix in CKPT_EXTENSIONS:
logger.info(f"Loading model from checkpoint: {checkpoint}")
# check if we've already converted this model
model_dir = pipeline_dir.joinpath(checkpoint.stem)
if model_dir.joinpath("model_index.json").exists():
# we have, so just use that
logger.info("Found converted model in {model_dir}, will not convert")
logger.info("Delete the output directory to re-run conversion.")
else:
# we haven't, so convert it
logger.info("Converting checkpoint to HuggingFace pipeline...")
g_pipeline, model_dir = checkpoint_to_pipeline(checkpoint)
logger.info("Done!")
@cli.command(no_args_is_help=True)
def refine(
frames_dir: Annotated[
Path,
typer.Argument(path_type=Path, file_okay=False, exists=True, help="Path to source frames directory"),
] = ...,
config_path: Annotated[
Path,
typer.Option(
"--config-path",
"-c",
path_type=Path,
exists=True,
readable=True,
dir_okay=False,
help="Path to a prompt configuration JSON file. default is frames_dir/../prompt.json",
),
] = None,
interpolation_multiplier: Annotated[
int,
typer.Option(
"--interpolation-multiplier",
"-M",
min=1,
max=10,
help="Interpolate with RIFE before generation. (I'll leave it as is, but I think interpolation after generation is sufficient).",
rich_help_panel="Generation",
),
] = 1,
tile_conditioning_scale: Annotated[
float,
typer.Option(
"--tile",
"-t",
min= 0,
max= 1.0,
help="controlnet_tile conditioning scale",
rich_help_panel="Generation",
),
] = 0.75,
width: Annotated[
int,
typer.Option(
"--width",
"-W",
min=-1,
max=3840,
help="Width of generated frames",
rich_help_panel="Generation",
),
] = -1,
height: Annotated[
int,
typer.Option(
"--height",
"-H",
min=-1,
max=2160,
help="Height of generated frames",
rich_help_panel="Generation",
),
] = -1,
length: Annotated[
int,
typer.Option(
"--length",
"-L",
min=-1,
max=9999,
help="Number of frames to generate. -1 means using all frames in frames_dir.",
rich_help_panel="Generation",
),
] = -1,
context: Annotated[
Optional[int],
typer.Option(
"--context",
"-C",
min=1,
max=32,
help="Number of frames to condition on (default: 16)",
show_default=False,
rich_help_panel="Generation",
),
] = 16,
overlap: Annotated[
Optional[int],
typer.Option(
"--overlap",
"-O",
min=1,
max=12,
help="Number of frames to overlap in context (default: context//4)",
show_default=False,
rich_help_panel="Generation",
),
] = None,
stride: Annotated[
Optional[int],
typer.Option(
"--stride",
"-S",
min=0,
max=8,
help="Max motion stride as a power of 2 (default: 0)",
show_default=False,
rich_help_panel="Generation",
),
] = None,
repeats: Annotated[
int,
typer.Option(
"--repeats",
"-r",
min=1,
max=99,
help="Number of times to repeat the refine (default: 1)",
show_default=False,
rich_help_panel="Generation",
),
] = 1,
device: Annotated[
str,
typer.Option(
"--device", "-d", help="Device to run on (cpu, cuda, cuda:id)", rich_help_panel="Advanced"
),
] = "cuda",
use_xformers: Annotated[
bool,
typer.Option(
"--xformers",
"-x",
is_flag=True,
help="Use XFormers instead of SDP Attention",
rich_help_panel="Advanced",
),
] = False,
force_half_vae: Annotated[
bool,
typer.Option(
"--half-vae",
is_flag=True,
help="Force VAE to use fp16 (not recommended)",
rich_help_panel="Advanced",
),
] = False,
out_dir: Annotated[
Path,
typer.Option(
"--out-dir",
"-o",
path_type=Path,
file_okay=False,
help="Directory for output folders (frames, gifs, etc)",
rich_help_panel="Output",
),
] = Path("refine/"),
):
"""Create upscaled or improved video using pre-generated frames"""
import shutil
from PIL import Image
from animatediff.rife.rife import rife_interpolate
if not config_path:
tmp = frames_dir.parent.joinpath("prompt.json")
if tmp.is_file():
config_path = tmp
else:
raise ValueError(f"config_path invalid.")
org_frames = sorted(glob.glob( os.path.join(frames_dir, "[0-9]*.png"), recursive=False))
W,H = Image.open(org_frames[0]).size
if width == -1 and height == -1:
width = W
height = H
elif width == -1:
width = int(height * W / H) //8 * 8
elif height == -1:
height = int(width * H / W) //8 * 8
else:
pass
if length == -1:
length = len(org_frames)
else:
length = min(length, len(org_frames))
config_path = config_path.absolute()
logger.info(f"Using generation config: {path_from_cwd(config_path)}")
model_config: ModelConfig = get_model_config(config_path)
# get a timestamp for the output directory
time_str = datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
# make the output directory
save_dir = out_dir.joinpath(f"{time_str}-{model_config.save_name}")
save_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Will save outputs to ./{path_from_cwd(save_dir)}")
seeds = [get_random() for i in range(repeats)]
rife_img_dir = None
for repeat_count in range(repeats):
if interpolation_multiplier > 1:
rife_img_dir = save_dir.joinpath(f"{repeat_count:02d}_rife_frame")
rife_img_dir.mkdir(parents=True, exist_ok=True)
rife_interpolate(frames_dir, rife_img_dir, interpolation_multiplier)
length *= interpolation_multiplier
if model_config.output:
model_config.output["fps"] *= interpolation_multiplier
if model_config.prompt_map:
model_config.prompt_map = { str(int(i)*interpolation_multiplier): model_config.prompt_map[i] for i in model_config.prompt_map }
frames_dir = rife_img_dir
controlnet_img_dir = save_dir.joinpath(f"{repeat_count:02d}_controlnet_image")
for c in ["controlnet_canny","controlnet_depth","controlnet_inpaint","controlnet_ip2p","controlnet_lineart","controlnet_lineart_anime","controlnet_mlsd","controlnet_normalbae","controlnet_openpose","controlnet_scribble","controlnet_seg","controlnet_shuffle","controlnet_softedge","controlnet_tile"]:
c_dir = controlnet_img_dir.joinpath(c)
c_dir.mkdir(parents=True, exist_ok=True)
shutil.copytree(frames_dir, controlnet_img_dir.joinpath("controlnet_tile"), dirs_exist_ok=True)
model_config.controlnet_map["input_image_dir"] = os.path.relpath(controlnet_img_dir.absolute(), data_dir)
model_config.controlnet_map["is_loop"] = False
if "controlnet_tile" in model_config.controlnet_map:
model_config.controlnet_map["controlnet_tile"]["enable"] = True
model_config.controlnet_map["controlnet_tile"]["control_scale_list"] = []
model_config.controlnet_map["controlnet_tile"]["controlnet_conditioning_scale"] = tile_conditioning_scale
else:
model_config.controlnet_map["controlnet_tile"] = {
"enable": True,
"use_preprocessor":True,
"guess_mode":False,
"controlnet_conditioning_scale": tile_conditioning_scale,
"control_guidance_start": 0.0,
"control_guidance_end": 1.0,
"control_scale_list":[]
}
model_config.seed = [seeds[repeat_count]]
config_path = save_dir.joinpath(f"{repeat_count:02d}_prompt.json")
config_path.write_text(model_config.json(indent=4), encoding="utf-8")
generated_dir = generate(
config_path=config_path,
width=width,
height=height,
length=length,
context=context,
overlap=overlap,
stride=stride,
device=device,
use_xformers=use_xformers,
force_half_vae=force_half_vae,
out_dir=save_dir,
)
interpolation_multiplier = 1
torch.cuda.empty_cache()
generated_dir = generated_dir.rename(generated_dir.parent / f"{time_str}_{repeat_count:02d}")
frames_dir = glob.glob( os.path.join(generated_dir, "00-[0-9]*"), recursive=False)[0]
if rife_img_dir:
frames = sorted(glob.glob( os.path.join(rife_img_dir, "[0-9]*.png"), recursive=False))
out_images = []
for f in frames:
out_images.append(Image.open(f))
out_file = save_dir.joinpath(f"rife_only_for_comparison")
save_output(out_images,rife_img_dir,out_file,model_config.output,True,save_frames=None,save_video=None)
logger.info(f"Refined results are output to {generated_dir}")