Spaces:
Paused
Paused
# predict.py | |
import subprocess | |
import time | |
from cog import BasePredictor, Input, Path | |
import os | |
import torch | |
import numpy as np | |
from PIL import Image | |
from omegaconf import OmegaConf | |
from datetime import datetime | |
from torchvision.transforms.functional import pil_to_tensor, resize, center_crop | |
from constants import ASPECT_RATIO | |
MODEL_CACHE = "models" | |
os.environ["HF_DATASETS_OFFLINE"] = "1" | |
os.environ["TRANSFORMERS_OFFLINE"] = "1" | |
os.environ["HF_HOME"] = MODEL_CACHE | |
os.environ["TORCH_HOME"] = MODEL_CACHE | |
os.environ["HF_DATASETS_CACHE"] = MODEL_CACHE | |
os.environ["TRANSFORMERS_CACHE"] = MODEL_CACHE | |
os.environ["HUGGINGFACE_HUB_CACHE"] = MODEL_CACHE | |
BASE_URL = f"https://weights.replicate.delivery/default/MimicMotion/{MODEL_CACHE}/" | |
def download_weights(url: str, dest: str) -> None: | |
# NOTE WHEN YOU EXTRACT SPECIFY THE PARENT FOLDER | |
start = time.time() | |
print("[!] Initiating download from URL: ", url) | |
print("[~] Destination path: ", dest) | |
if ".tar" in dest: | |
dest = os.path.dirname(dest) | |
command = ["pget", "-vf" + ("x" if ".tar" in url else ""), url, dest] | |
try: | |
print(f"[~] Running command: {' '.join(command)}") | |
subprocess.check_call(command, close_fds=False) | |
except subprocess.CalledProcessError as e: | |
print( | |
f"[ERROR] Failed to download weights. Command '{' '.join(e.cmd)}' returned non-zero exit status {e.returncode}." | |
) | |
raise | |
print("[+] Download completed in: ", time.time() - start, "seconds") | |
class Predictor(BasePredictor): | |
def setup(self): | |
"""Load the model into memory to make running multiple predictions efficient""" | |
if not os.path.exists(MODEL_CACHE): | |
os.makedirs(MODEL_CACHE) | |
model_files = [ | |
"DWPose.tar", | |
"MimicMotion.pth", | |
"MimicMotion_1-1.pth", | |
"SVD.tar", | |
] | |
for model_file in model_files: | |
url = BASE_URL + model_file | |
filename = url.split("/")[-1] | |
dest_path = os.path.join(MODEL_CACHE, filename) | |
if not os.path.exists(dest_path.replace(".tar", "")): | |
download_weights(url, dest_path) | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Using device: {self.device}") | |
# Move imports here and make them global | |
# This ensures model files are downloaded before importing mimicmotion modules | |
global MimicMotionPipeline, create_pipeline, save_to_mp4, get_video_pose, get_image_pose | |
from mimicmotion.pipelines.pipeline_mimicmotion import MimicMotionPipeline | |
from mimicmotion.utils.loader import create_pipeline | |
from mimicmotion.utils.utils import save_to_mp4 | |
from mimicmotion.dwpose.preprocess import get_video_pose, get_image_pose | |
# Load config with new checkpoint as default | |
self.config = OmegaConf.create( | |
{ | |
"base_model_path": "models/SVD/stable-video-diffusion-img2vid-xt-1-1", | |
"ckpt_path": "models/MimicMotion_1-1.pth", | |
} | |
) | |
# Create the pipeline with the new checkpoint | |
self.pipeline = create_pipeline(self.config, self.device) | |
self.current_checkpoint = "v1-1" | |
self.current_dtype = torch.get_default_dtype() | |
def predict( | |
self, | |
motion_video: Path = Input( | |
description="Reference video file containing the motion to be mimicked" | |
), | |
appearance_image: Path = Input( | |
description="Reference image file for the appearance of the generated video" | |
), | |
resolution: int = Input( | |
description="Height of the output video in pixels. Width is automatically calculated.", | |
default=576, | |
ge=64, | |
le=1024, | |
), | |
chunk_size: int = Input( | |
description="Number of frames to generate in each processing chunk", | |
default=16, | |
ge=2, | |
), | |
frames_overlap: int = Input( | |
description="Number of overlapping frames between chunks for smoother transitions", | |
default=6, | |
ge=0, | |
), | |
denoising_steps: int = Input( | |
description="Number of denoising steps in the diffusion process. More steps can improve quality but increase processing time.", | |
default=25, | |
ge=1, | |
le=100, | |
), | |
noise_strength: float = Input( | |
description="Strength of noise augmentation. Higher values add more variation but may reduce coherence with the reference.", | |
default=0.0, | |
ge=0.0, | |
le=1.0, | |
), | |
guidance_scale: float = Input( | |
description="Strength of guidance towards the reference. Higher values adhere more closely to the reference but may reduce creativity.", | |
default=2.0, | |
ge=0.1, | |
le=10.0, | |
), | |
sample_stride: int = Input( | |
description="Interval for sampling frames from the reference video. Higher values skip more frames.", | |
default=2, | |
ge=1, | |
), | |
output_frames_per_second: int = Input( | |
description="Frames per second of the output video. Affects playback speed.", | |
default=15, | |
ge=1, | |
le=60, | |
), | |
seed: int = Input( | |
description="Random seed. Leave blank to randomize the seed", | |
default=None, | |
), | |
checkpoint_version: str = Input( | |
description="Choose the checkpoint version to use", | |
choices=["v1", "v1-1"], | |
default="v1-1", | |
), | |
) -> Path: | |
"""Run a single prediction on the model""" | |
ref_video = motion_video | |
ref_image = appearance_image | |
num_frames = chunk_size | |
num_inference_steps = denoising_steps | |
noise_aug_strength = noise_strength | |
fps = output_frames_per_second | |
use_fp16 = True | |
if seed is None: | |
seed = int.from_bytes(os.urandom(2), "big") | |
print(f"Using seed: {seed}") | |
need_pipeline_update = False | |
# Check if we need to switch checkpoints | |
if checkpoint_version != self.current_checkpoint: | |
if checkpoint_version == "v1": | |
self.config.ckpt_path = "models/MimicMotion.pth" | |
else: # v1-1 | |
self.config.ckpt_path = "models/MimicMotion_1-1.pth" | |
need_pipeline_update = True | |
self.current_checkpoint = checkpoint_version | |
# Check if we need to switch dtype | |
target_dtype = torch.float16 if use_fp16 else torch.float32 | |
if target_dtype != self.current_dtype: | |
torch.set_default_dtype(target_dtype) | |
need_pipeline_update = True | |
self.current_dtype = target_dtype | |
# Update pipeline if needed | |
if need_pipeline_update: | |
print( | |
f"Updating pipeline with checkpoint: {self.config.ckpt_path} and dtype: {torch.get_default_dtype()}" | |
) | |
self.pipeline = create_pipeline(self.config, self.device) | |
print(f"Using checkpoint: {self.config.ckpt_path}") | |
print(f"Using dtype: {torch.get_default_dtype()}") | |
print( | |
f"[!] ({type(ref_video)}) ref_video={ref_video}, " | |
f"[!] ({type(ref_image)}) ref_image={ref_image}, " | |
f"[!] ({type(resolution)}) resolution={resolution}, " | |
f"[!] ({type(num_frames)}) num_frames={num_frames}, " | |
f"[!] ({type(frames_overlap)}) frames_overlap={frames_overlap}, " | |
f"[!] ({type(num_inference_steps)}) num_inference_steps={num_inference_steps}, " | |
f"[!] ({type(noise_aug_strength)}) noise_aug_strength={noise_aug_strength}, " | |
f"[!] ({type(guidance_scale)}) guidance_scale={guidance_scale}, " | |
f"[!] ({type(sample_stride)}) sample_stride={sample_stride}, " | |
f"[!] ({type(fps)}) fps={fps}, " | |
f"[!] ({type(seed)}) seed={seed}, " | |
f"[!] ({type(use_fp16)}) use_fp16={use_fp16}" | |
) | |
# Input validation | |
if not ref_video.exists(): | |
raise ValueError(f"Reference video file does not exist: {ref_video}") | |
if not ref_image.exists(): | |
raise ValueError(f"Reference image file does not exist: {ref_image}") | |
if resolution % 8 != 0: | |
raise ValueError(f"Resolution must be a multiple of 8, got {resolution}") | |
if resolution < 64 or resolution > 1024: | |
raise ValueError( | |
f"Resolution must be between 64 and 1024, got {resolution}" | |
) | |
if num_frames <= frames_overlap: | |
raise ValueError( | |
f"Number of frames ({num_frames}) must be greater than frames overlap ({frames_overlap})" | |
) | |
if num_frames < 2: | |
raise ValueError(f"Number of frames must be at least 2, got {num_frames}") | |
if frames_overlap < 0: | |
raise ValueError( | |
f"Frames overlap must be non-negative, got {frames_overlap}" | |
) | |
if num_inference_steps < 1 or num_inference_steps > 100: | |
raise ValueError( | |
f"Number of inference steps must be between 1 and 100, got {num_inference_steps}" | |
) | |
if noise_aug_strength < 0.0 or noise_aug_strength > 1.0: | |
raise ValueError( | |
f"Noise augmentation strength must be between 0.0 and 1.0, got {noise_aug_strength}" | |
) | |
if guidance_scale < 0.1 or guidance_scale > 10.0: | |
raise ValueError( | |
f"Guidance scale must be between 0.1 and 10.0, got {guidance_scale}" | |
) | |
if sample_stride < 1: | |
raise ValueError(f"Sample stride must be at least 1, got {sample_stride}") | |
if fps < 1 or fps > 60: | |
raise ValueError(f"FPS must be between 1 and 60, got {fps}") | |
try: | |
# Preprocess | |
pose_pixels, image_pixels = self.preprocess( | |
str(ref_video), | |
str(ref_image), | |
resolution=resolution, | |
sample_stride=sample_stride, | |
) | |
# Run pipeline | |
video_frames = self.run_pipeline( | |
image_pixels, | |
pose_pixels, | |
num_frames=num_frames, | |
frames_overlap=frames_overlap, | |
num_inference_steps=num_inference_steps, | |
noise_aug_strength=noise_aug_strength, | |
guidance_scale=guidance_scale, | |
seed=seed, | |
) | |
# Save output | |
output_path = f"/tmp/output_{datetime.now().strftime('%Y%m%d%H%M%S')}.mp4" | |
save_to_mp4(video_frames, output_path, fps=fps) | |
return Path(output_path) | |
except Exception as e: | |
print(f"An error occurred during prediction: {str(e)}") | |
raise | |
def preprocess(self, video_path, image_path, resolution=576, sample_stride=2): | |
image_pixels = Image.open(image_path).convert("RGB") | |
image_pixels = pil_to_tensor(image_pixels) # (c, h, w) | |
h, w = image_pixels.shape[-2:] | |
if h > w: | |
w_target, h_target = resolution, int(resolution / ASPECT_RATIO // 64) * 64 | |
else: | |
w_target, h_target = int(resolution / ASPECT_RATIO // 64) * 64, resolution | |
h_w_ratio = float(h) / float(w) | |
if h_w_ratio < h_target / w_target: | |
h_resize, w_resize = h_target, int(h_target / h_w_ratio) | |
else: | |
h_resize, w_resize = int(w_target * h_w_ratio), w_target | |
image_pixels = resize(image_pixels, [h_resize, w_resize], antialias=None) | |
image_pixels = center_crop(image_pixels, [h_target, w_target]) | |
image_pixels = image_pixels.permute((1, 2, 0)).numpy() | |
image_pose = get_image_pose(image_pixels) | |
video_pose = get_video_pose( | |
video_path, image_pixels, sample_stride=sample_stride | |
) | |
pose_pixels = np.concatenate([np.expand_dims(image_pose, 0), video_pose]) | |
image_pixels = np.transpose(np.expand_dims(image_pixels, 0), (0, 3, 1, 2)) | |
return ( | |
torch.from_numpy(pose_pixels.copy()) / 127.5 - 1, | |
torch.from_numpy(image_pixels) / 127.5 - 1, | |
) | |
def run_pipeline( | |
self, | |
image_pixels, | |
pose_pixels, | |
num_frames, | |
frames_overlap, | |
num_inference_steps, | |
noise_aug_strength, | |
guidance_scale, | |
seed, | |
): | |
image_pixels = [ | |
Image.fromarray( | |
(img.cpu().numpy().transpose(1, 2, 0) * 127.5 + 127.5).astype(np.uint8) | |
) | |
for img in image_pixels | |
] | |
pose_pixels = pose_pixels.unsqueeze(0).to(self.device) | |
generator = torch.Generator(device=self.device) | |
generator.manual_seed(seed) | |
frames = self.pipeline( | |
image_pixels, | |
image_pose=pose_pixels, | |
num_frames=pose_pixels.size(1), | |
tile_size=num_frames, | |
tile_overlap=frames_overlap, | |
height=pose_pixels.shape[-2], | |
width=pose_pixels.shape[-1], | |
fps=7, | |
noise_aug_strength=noise_aug_strength, | |
num_inference_steps=num_inference_steps, | |
generator=generator, | |
min_guidance_scale=guidance_scale, | |
max_guidance_scale=guidance_scale, | |
decode_chunk_size=8, | |
output_type="pt", | |
device=self.device, | |
).frames.cpu() | |
video_frames = (frames * 255.0).to(torch.uint8) | |
return video_frames[0, 1:] # Remove the first frame (reference image) | |