fffiloni commited on
Commit
219cab0
Β·
verified Β·
1 Parent(s): 8eb8300

Create gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +137 -0
gradio_app.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import time
4
+ import argparse
5
+ from diffueraser.diffueraser import DiffuEraser
6
+ from propainter.inference import Propainter, get_device
7
+
8
+ # Download Weights
9
+ from huggingface_hub import snapshot_download
10
+
11
+ # List of subdirectories to create inside "checkpoints"
12
+ subfolders = [
13
+ "diffuEraser",
14
+ "stable-diffusion-v1-5",
15
+ "PCM_Weights",
16
+ "propainter",
17
+ "sd-vae-ft-mse"
18
+ ]
19
+ # Create each subdirectory
20
+ for subfolder in subfolders:
21
+ os.makedirs(os.path.join("weigths", subfolder), exist_ok=True)
22
+
23
+ snapshot_download(
24
+ repo_id = "lixiaowen/diffuEraser",
25
+ local_dir = "./weights/diffuEraser"
26
+ )
27
+
28
+ snapshot_download(
29
+ repo_id = "stable-diffusion-v1-5/stable-diffusion-v1-5",
30
+ local_dir = "./weights/stable-diffusion-v1-5"
31
+ )
32
+
33
+ snapshot_download(
34
+ repo_id = "wangfuyun/PCM_Weights",
35
+ local_dir = "./weights/PCM_Weights"
36
+ )
37
+
38
+ snapshot_download(
39
+ repo_id = "camenduru/ProPainter",
40
+ local_dir = "./weights/propainter"
41
+ )
42
+
43
+ snapshot_download(
44
+ repo_id = "stabilityai/sd-vae-ft-mse",
45
+ local_dir = "./weights/sd-vae-ft-mse"
46
+ )
47
+
48
+ # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
49
+
50
+ def infer(input_video, input_mask):
51
+
52
+ video_length = 10 # The maximum length of output video
53
+ mask_dilation_iter = 8 # Adjust it to change the degree of mask expansion
54
+ max_img_size = 960 # The maximum length of output width and height
55
+ save_path = "results" # Path to the output
56
+
57
+ ref_stride = 10
58
+ neighbor_length = 10
59
+ subvideo_length = 50
60
+
61
+ base_model_path = "weights/stable-diffusion-v1-5"
62
+ vae_path = "weights/sd-vae-ft-mse"
63
+ diffueraser_path = "weights/diffuEraser"
64
+ propainter_model_dir = "weights/propainter"
65
+
66
+ if not os.path.exists(save_path):
67
+ os.makedirs(save_path)
68
+ priori_path = os.path.join(save_path, "priori.mp4")
69
+ output_path = os.path.join(save_path, "diffueraser_result.mp4")
70
+
71
+ ## model initialization
72
+ device = get_device()
73
+ # PCM params
74
+ ckpt = "2-Step"
75
+ video_inpainting_sd = DiffuEraser(device, base_model_path, vae_path, diffueraser_path, ckpt=ckpt)
76
+ propainter = Propainter(propainter_model_dir, device=device)
77
+
78
+ start_time = time.time()
79
+
80
+ ## priori
81
+ propainter.forward(input_video, input_mask, priori_path, video_length=video_length,
82
+ ref_stride=ref_stride, neighbor_length=neighbor_length, subvideo_length = subvideo_length,
83
+ mask_dilation = mask_dilation_iter)
84
+
85
+ ## diffueraser
86
+ guidance_scale = None # The default value is 0.
87
+ video_inpainting_sd.forward(input_video, input_mask, priori_path, output_path,
88
+ max_img_size = max_img_size, video_length=video_length, mask_dilation_iter=mask_dilation_iter,
89
+ guidance_scale=guidance_scale)
90
+
91
+ end_time = time.time()
92
+ inference_time = end_time - start_time
93
+ print(f"DiffuEraser inference time: {inference_time:.4f} s")
94
+
95
+ torch.cuda.empty_cache()
96
+
97
+ return output_path, priori_path
98
+
99
+ with gr.Blocks() as demo:
100
+
101
+ with gr.Column():
102
+ gr.Markdown("# DiffuEraser")
103
+
104
+ with gr.Row():
105
+
106
+ with gr.Column():
107
+
108
+ input_video = gr.Video(label="Input Video (MP4 ONLY")
109
+ input_mask = gr.Video(label="Input Mask Video (MP4 ONLY")
110
+ submit_btn = gr.Button("Submit")
111
+
112
+ gr.Examples(
113
+ examples = [
114
+ [".examples/example1/video.mp4", "./examples/example1/mask.mp4"],
115
+ [".examples/example2/video.mp4", "./examples/example2/mask.mp4"],
116
+ [".examples/example3/video.mp4", "./examples/example3/mask.mp4"],
117
+ ],
118
+ inputs = [input_video, input_mask]
119
+ )
120
+
121
+ with gr.Column():
122
+
123
+ video_result = gr.Video(label="Result")
124
+ priori_result = gr.Video(label="ProPainter pass")
125
+
126
+
127
+ submit_btn.click(
128
+ fn = infer,
129
+ inputs = [input_video, input_mask],
130
+ outputs = [video_result, priori_result]
131
+ )
132
+
133
+ demo.queue().launch(show_api=False, show_error=True)
134
+
135
+
136
+
137
+