DiffuEraser-demo / gradio_app.py
fffiloni's picture
Update gradio_app.py
8b634d8 verified
import torch
import os
import time
import argparse
from diffueraser.diffueraser import DiffuEraser
from propainter.inference import Propainter, get_device
import gradio as gr
# Download Weights
from huggingface_hub import snapshot_download
# List of subdirectories to create inside "checkpoints"
subfolders = [
"diffuEraser",
"stable-diffusion-v1-5",
"PCM_Weights",
"propainter",
"sd-vae-ft-mse"
]
# Create each subdirectory
for subfolder in subfolders:
os.makedirs(os.path.join("weigths", subfolder), exist_ok=True)
snapshot_download(
repo_id = "lixiaowen/diffuEraser",
local_dir = "./weights/diffuEraser"
)
snapshot_download(
repo_id = "stable-diffusion-v1-5/stable-diffusion-v1-5",
local_dir = "./weights/stable-diffusion-v1-5"
)
snapshot_download(
repo_id = "wangfuyun/PCM_Weights",
local_dir = "./weights/PCM_Weights"
)
snapshot_download(
repo_id = "camenduru/ProPainter",
local_dir = "./weights/propainter"
)
snapshot_download(
repo_id = "stabilityai/sd-vae-ft-mse",
local_dir = "./weights/sd-vae-ft-mse"
)
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
def infer(input_video, input_mask):
video_length = 10 # The maximum length of output video
mask_dilation_iter = 8 # Adjust it to change the degree of mask expansion
max_img_size = 960 # The maximum length of output width and height
save_path = "results" # Path to the output
ref_stride = 10
neighbor_length = 10
subvideo_length = 50
base_model_path = "weights/stable-diffusion-v1-5"
vae_path = "weights/sd-vae-ft-mse"
diffueraser_path = "weights/diffuEraser"
propainter_model_dir = "weights/propainter"
if not os.path.exists(save_path):
os.makedirs(save_path)
priori_path = os.path.join(save_path, "priori.mp4")
output_path = os.path.join(save_path, "diffueraser_result.mp4")
## model initialization
device = get_device()
# PCM params
ckpt = "2-Step"
video_inpainting_sd = DiffuEraser(device, base_model_path, vae_path, diffueraser_path, ckpt=ckpt)
propainter = Propainter(propainter_model_dir, device=device)
start_time = time.time()
## priori
propainter.forward(input_video, input_mask, priori_path, video_length=video_length,
ref_stride=ref_stride, neighbor_length=neighbor_length, subvideo_length = subvideo_length,
mask_dilation = mask_dilation_iter)
## diffueraser
guidance_scale = None # The default value is 0.
video_inpainting_sd.forward(input_video, input_mask, priori_path, output_path,
max_img_size = max_img_size, video_length=video_length, mask_dilation_iter=mask_dilation_iter,
guidance_scale=guidance_scale)
end_time = time.time()
inference_time = end_time - start_time
print(f"DiffuEraser inference time: {inference_time:.4f} s")
torch.cuda.empty_cache()
return output_path
with gr.Blocks() as demo:
with gr.Column():
gr.Markdown("# DiffuEraser: A Diffusion Model for Video Inpainting")
gr.Markdown("DiffuEraser is a diffusion model for video inpainting, which outperforms state-of-the-art model Propainter in both content completeness and temporal consistency while maintaining acceptable efficiency.")
gr.HTML("""
<div style="display:flex;column-gap:4px;">
<a href="https://github.com/lixiaowen-xw/DiffuEraser">
<img src='https://img.shields.io/badge/GitHub-Repo-blue'>
</a>
<a href="https://lixiaowen-xw.github.io/DiffuEraser-page">
<img src='https://img.shields.io/badge/Project-Page-green'>
</a>
<a href="https://lixiaowen-xw.github.io/DiffuEraser-page">
<img src='https://img.shields.io/badge/ArXiv-Paper-red'>
</a>
<a href="https://huggingface.co./spaces/fffiloni/DiffuEraser-demo?duplicate=true">
<img src="https://huggingface.co./datasets/huggingface/badges/resolve/main/duplicate-this-space-sm.svg" alt="Duplicate this Space">
</a>
</div>
""")
with gr.Row():
with gr.Column():
input_video = gr.Video(label="Input Video (MP4 ONLY)")
input_mask = gr.Video(label="Input Mask Video (MP4 ONLY)")
submit_btn = gr.Button("Submit")
with gr.Column():
video_result = gr.Video(label="Result")
gr.Examples(
examples = [
["./examples/example1/video.mp4", "./examples/example1/mask.mp4"],
["./examples/example2/video.mp4", "./examples/example2/mask.mp4"],
["./examples/example3/video.mp4", "./examples/example3/mask.mp4"],
],
inputs = [input_video, input_mask]
)
submit_btn.click(
fn = infer,
inputs = [input_video, input_mask],
outputs = [video_result]
)
demo.queue().launch(show_api=False, show_error=True)