|
import logging |
|
import subprocess |
|
from math import ceil |
|
from pathlib import Path |
|
from typing import Annotated, Optional |
|
|
|
import typer |
|
|
|
from animatediff import get_dir |
|
|
|
from .ffmpeg import FfmpegEncoder, VideoCodec, codec_extn |
|
from .ncnn import RifeNCNNOptions |
|
|
|
rife_dir = get_dir("data/rife") |
|
rife_ncnn_vulkan = rife_dir.joinpath("rife-ncnn-vulkan") |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
app: typer.Typer = typer.Typer( |
|
name="rife", |
|
context_settings=dict(help_option_names=["-h", "--help"]), |
|
rich_markup_mode="rich", |
|
pretty_exceptions_show_locals=False, |
|
help="RIFE motion flow interpolation (MORE FPS!)", |
|
) |
|
|
|
def rife_interpolate( |
|
input_frames_dir:str, |
|
output_frames_dir:str, |
|
frame_multiplier:int = 2, |
|
rife_model:str = "rife-v4.6", |
|
spatial_tta:bool = False, |
|
temporal_tta:bool = False, |
|
uhd:bool = False, |
|
): |
|
|
|
rife_model_dir = rife_dir.joinpath(rife_model) |
|
if not rife_model_dir.joinpath("flownet.bin").exists(): |
|
raise FileNotFoundError(f"RIFE model dir {rife_model_dir} does not have a model in it!") |
|
|
|
|
|
rife_opts = RifeNCNNOptions( |
|
model_path=rife_model_dir, |
|
input_path=input_frames_dir, |
|
output_path=output_frames_dir, |
|
time_step=1 / frame_multiplier, |
|
spatial_tta=spatial_tta, |
|
temporal_tta=temporal_tta, |
|
uhd=uhd, |
|
) |
|
rife_args = rife_opts.get_args(frame_multiplier=frame_multiplier) |
|
|
|
|
|
logger.info("Running RIFE, this may take a little while...") |
|
with subprocess.Popen( |
|
[rife_ncnn_vulkan, *rife_args], stdout=subprocess.PIPE, stderr=subprocess.PIPE |
|
) as proc: |
|
errs = [] |
|
for line in proc.stderr: |
|
line = line.decode("utf-8").strip() |
|
if line: |
|
logger.debug(line) |
|
stdout, _ = proc.communicate() |
|
if proc.returncode != 0: |
|
raise RuntimeError(f"RIFE failed with code {proc.returncode}:\n" + "\n".join(errs)) |
|
|
|
import glob |
|
import os |
|
org_images = sorted(glob.glob( os.path.join(output_frames_dir, "[0-9]*.png"), recursive=False)) |
|
for o in org_images: |
|
p = Path(o) |
|
new_no = int(p.stem) - 1 |
|
new_p = p.with_stem(f"{new_no:08d}") |
|
p.rename(new_p) |
|
|
|
|
|
|
|
@app.command(no_args_is_help=True) |
|
def interpolate( |
|
rife_model: Annotated[ |
|
str, |
|
typer.Option("--rife-model", "-m", help="RIFE model to use (subdirectory of data/rife/)"), |
|
] = "rife-v4.6", |
|
in_fps: Annotated[ |
|
int, |
|
typer.Option("--in-fps", "-I", help="Input frame FPS (8 for AnimateDiff)", show_default=True), |
|
] = 8, |
|
frame_multiplier: Annotated[ |
|
int, |
|
typer.Option( |
|
"--frame-multiplier", "-M", help="Multiply total frame count by this", show_default=True |
|
), |
|
] = 8, |
|
out_fps: Annotated[ |
|
int, |
|
typer.Option("--out-fps", "-F", help="Target FPS", show_default=True), |
|
] = 50, |
|
codec: Annotated[ |
|
VideoCodec, |
|
typer.Option("--codec", "-c", help="Output video codec", show_default=True), |
|
] = VideoCodec.webm, |
|
lossless: Annotated[ |
|
bool, |
|
typer.Option("--lossless", "-L", is_flag=True, help="Use lossless encoding (WebP only)"), |
|
] = False, |
|
spatial_tta: Annotated[ |
|
bool, |
|
typer.Option("--spatial-tta", "-x", is_flag=True, help="Enable RIFE Spatial TTA mode"), |
|
] = False, |
|
temporal_tta: Annotated[ |
|
bool, |
|
typer.Option("--temporal-tta", "-z", is_flag=True, help="Enable RIFE Temporal TTA mode"), |
|
] = False, |
|
uhd: Annotated[ |
|
bool, |
|
typer.Option("--uhd", "-u", is_flag=True, help="Enable RIFE UHD mode"), |
|
] = False, |
|
frames_dir: Annotated[ |
|
Path, |
|
typer.Argument(path_type=Path, file_okay=False, exists=True, help="Path to source frames directory"), |
|
] = ..., |
|
out_file: Annotated[ |
|
Optional[Path], |
|
typer.Argument( |
|
dir_okay=False, |
|
help="Path to output file (default: frames_dir/rife-output.<out_type>)", |
|
show_default=False, |
|
), |
|
] = None, |
|
): |
|
rife_model_dir = rife_dir.joinpath(rife_model) |
|
if not rife_model_dir.joinpath("flownet.bin").exists(): |
|
raise FileNotFoundError(f"RIFE model dir {rife_model_dir} does not have a model in it!") |
|
|
|
if not frames_dir.exists(): |
|
raise FileNotFoundError(f"Frames directory {frames_dir} does not exist!") |
|
|
|
|
|
|
|
rife_frames_dir = frames_dir.parent.joinpath(f"{frames_dir.name}-rife") |
|
rife_frames_dir.mkdir(exist_ok=True, parents=True) |
|
|
|
|
|
file_extn = codec_extn(codec) |
|
if out_file is None: |
|
out_file = frames_dir.parent.joinpath(f"{frames_dir.name}-rife.{file_extn}") |
|
elif out_file.suffix != file_extn: |
|
logger.warn("Output file extension does not match codec, changing extension") |
|
out_file = out_file.with_suffix(file_extn) |
|
|
|
|
|
|
|
rife_opts = RifeNCNNOptions( |
|
model_path=rife_model_dir, |
|
input_path=frames_dir, |
|
output_path=rife_frames_dir, |
|
time_step=1 / in_fps, |
|
spatial_tta=spatial_tta, |
|
temporal_tta=temporal_tta, |
|
uhd=uhd, |
|
) |
|
rife_args = rife_opts.get_args(frame_multiplier=frame_multiplier) |
|
|
|
|
|
logger.info("Running RIFE, this may take a little while...") |
|
with subprocess.Popen( |
|
[rife_ncnn_vulkan, *rife_args], stdout=subprocess.PIPE, stderr=subprocess.PIPE |
|
) as proc: |
|
errs = [] |
|
for line in proc.stderr: |
|
line = line.decode("utf-8").strip() |
|
if line: |
|
logger.debug(line) |
|
stdout, _ = proc.communicate() |
|
if proc.returncode != 0: |
|
raise RuntimeError(f"RIFE failed with code {proc.returncode}:\n" + "\n".join(errs)) |
|
|
|
|
|
logger.info("Creating ffmpeg encoder...") |
|
encoder = FfmpegEncoder( |
|
frames_dir=rife_frames_dir, |
|
out_file=out_file, |
|
codec=codec, |
|
in_fps=min(out_fps, in_fps * frame_multiplier), |
|
out_fps=out_fps, |
|
lossless=lossless, |
|
) |
|
logger.info("Encoding interpolated frames with ffmpeg...") |
|
result = encoder.encode() |
|
|
|
logger.debug(f"ffmpeg result: {result}") |
|
|
|
logger.info(f"Find the RIFE frames at: {rife_frames_dir.absolute().relative_to(Path.cwd())}") |
|
logger.info(f"Find the output file at: {out_file.absolute().relative_to(Path.cwd())}") |
|
logger.info("Done!") |
|
|