Spaces:
Running
on
L40S
Running
on
L40S
import torch | |
import os | |
import time | |
import argparse | |
from diffueraser.diffueraser import DiffuEraser | |
from propainter.inference import Propainter, get_device | |
import gradio as gr | |
# Download Weights | |
from huggingface_hub import snapshot_download | |
# List of subdirectories to create inside "checkpoints" | |
subfolders = [ | |
"diffuEraser", | |
"stable-diffusion-v1-5", | |
"PCM_Weights", | |
"propainter", | |
"sd-vae-ft-mse" | |
] | |
# Create each subdirectory | |
for subfolder in subfolders: | |
os.makedirs(os.path.join("weigths", subfolder), exist_ok=True) | |
snapshot_download( | |
repo_id = "lixiaowen/diffuEraser", | |
local_dir = "./weights/diffuEraser" | |
) | |
snapshot_download( | |
repo_id = "stable-diffusion-v1-5/stable-diffusion-v1-5", | |
local_dir = "./weights/stable-diffusion-v1-5" | |
) | |
snapshot_download( | |
repo_id = "wangfuyun/PCM_Weights", | |
local_dir = "./weights/PCM_Weights" | |
) | |
snapshot_download( | |
repo_id = "camenduru/ProPainter", | |
local_dir = "./weights/propainter" | |
) | |
snapshot_download( | |
repo_id = "stabilityai/sd-vae-ft-mse", | |
local_dir = "./weights/sd-vae-ft-mse" | |
) | |
# βββββββββββββββββββββ | |
def infer(input_video, input_mask): | |
video_length = 10 # The maximum length of output video | |
mask_dilation_iter = 8 # Adjust it to change the degree of mask expansion | |
max_img_size = 960 # The maximum length of output width and height | |
save_path = "results" # Path to the output | |
ref_stride = 10 | |
neighbor_length = 10 | |
subvideo_length = 50 | |
base_model_path = "weights/stable-diffusion-v1-5" | |
vae_path = "weights/sd-vae-ft-mse" | |
diffueraser_path = "weights/diffuEraser" | |
propainter_model_dir = "weights/propainter" | |
if not os.path.exists(save_path): | |
os.makedirs(save_path) | |
priori_path = os.path.join(save_path, "priori.mp4") | |
output_path = os.path.join(save_path, "diffueraser_result.mp4") | |
## model initialization | |
device = get_device() | |
# PCM params | |
ckpt = "2-Step" | |
video_inpainting_sd = DiffuEraser(device, base_model_path, vae_path, diffueraser_path, ckpt=ckpt) | |
propainter = Propainter(propainter_model_dir, device=device) | |
start_time = time.time() | |
## priori | |
propainter.forward(input_video, input_mask, priori_path, video_length=video_length, | |
ref_stride=ref_stride, neighbor_length=neighbor_length, subvideo_length = subvideo_length, | |
mask_dilation = mask_dilation_iter) | |
## diffueraser | |
guidance_scale = None # The default value is 0. | |
video_inpainting_sd.forward(input_video, input_mask, priori_path, output_path, | |
max_img_size = max_img_size, video_length=video_length, mask_dilation_iter=mask_dilation_iter, | |
guidance_scale=guidance_scale) | |
end_time = time.time() | |
inference_time = end_time - start_time | |
print(f"DiffuEraser inference time: {inference_time:.4f} s") | |
torch.cuda.empty_cache() | |
return output_path | |
with gr.Blocks() as demo: | |
with gr.Column(): | |
gr.Markdown("# DiffuEraser: A Diffusion Model for Video Inpainting") | |
gr.Markdown("DiffuEraser is a diffusion model for video inpainting, which outperforms state-of-the-art model Propainter in both content completeness and temporal consistency while maintaining acceptable efficiency.") | |
gr.HTML(""" | |
<div style="display:flex;column-gap:4px;"> | |
<a href="https://github.com/lixiaowen-xw/DiffuEraser"> | |
<img src='https://img.shields.io/badge/GitHub-Repo-blue'> | |
</a> | |
<a href="https://lixiaowen-xw.github.io/DiffuEraser-page"> | |
<img src='https://img.shields.io/badge/Project-Page-green'> | |
</a> | |
<a href="https://lixiaowen-xw.github.io/DiffuEraser-page"> | |
<img src='https://img.shields.io/badge/ArXiv-Paper-red'> | |
</a> | |
<a href="https://huggingface.co./spaces/fffiloni/DiffuEraser-demo?duplicate=true"> | |
<img src="https://huggingface.co./datasets/huggingface/badges/resolve/main/duplicate-this-space-sm.svg" alt="Duplicate this Space"> | |
</a> | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
input_video = gr.Video(label="Input Video (MP4 ONLY)") | |
input_mask = gr.Video(label="Input Mask Video (MP4 ONLY)") | |
submit_btn = gr.Button("Submit") | |
with gr.Column(): | |
video_result = gr.Video(label="Result") | |
gr.Examples( | |
examples = [ | |
["./examples/example1/video.mp4", "./examples/example1/mask.mp4"], | |
["./examples/example2/video.mp4", "./examples/example2/mask.mp4"], | |
["./examples/example3/video.mp4", "./examples/example3/mask.mp4"], | |
], | |
inputs = [input_video, input_mask] | |
) | |
submit_btn.click( | |
fn = infer, | |
inputs = [input_video, input_mask], | |
outputs = [video_result] | |
) | |
demo.queue().launch(show_api=False, show_error=True) | |