Spaces:
Runtime error
Runtime error
Matt Goyder
commited on
Commit
•
516b1c3
1
Parent(s):
67cf56f
add main script
Browse files
app.py
CHANGED
@@ -7,10 +7,13 @@ description = "reconstruction and animation demos for DaGAN"
|
|
7 |
|
8 |
def inference(mode, img, video):
|
9 |
os.makedirs('tmp', exist_ok=True)
|
10 |
-
|
11 |
-
""
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
14 |
|
15 |
gr.Interface(
|
16 |
inference,
|
@@ -19,7 +22,7 @@ gr.Interface(
|
|
19 |
gr.inputs.Image(type="filepath", label="Source (only used if mode is reconstruction)"),
|
20 |
gr.inputs.Image(type="mp4", label="Driving Video")
|
21 |
],
|
22 |
-
outputs=
|
23 |
title=title,
|
24 |
description=description,
|
25 |
theme="huggingface",
|
|
|
7 |
|
8 |
def inference(mode, img, video):
|
9 |
os.makedirs('tmp', exist_ok=True)
|
10 |
+
|
11 |
+
cmd = f"ffmpeg -y -ss 00:00:00 -i {video} -to 00:00:05 -c copy cut_vid.mp4"
|
12 |
+
subprocess.run(cmd.split())
|
13 |
+
|
14 |
+
video = "cut_vid.mp4"
|
15 |
+
os.system(f"python run.py --source_image {img} --driving_video {video} --result_video tmp/result.mp4 --mode {mode}")
|
16 |
+
return "tmp/result.mp4"
|
17 |
|
18 |
gr.Interface(
|
19 |
inference,
|
|
|
22 |
gr.inputs.Image(type="filepath", label="Source (only used if mode is reconstruction)"),
|
23 |
gr.inputs.Image(type="mp4", label="Driving Video")
|
24 |
],
|
25 |
+
outputs=gr.outputs.Video(type="mp4", label="Output Video"),
|
26 |
title=title,
|
27 |
description=description,
|
28 |
theme="huggingface",
|
run.py
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
import yaml
|
3 |
+
from argparse import ArgumentParser
|
4 |
+
from tqdm import tqdm
|
5 |
+
import modules.generator as GEN
|
6 |
+
import imageio
|
7 |
+
import numpy as np
|
8 |
+
from skimage.transform import resize
|
9 |
+
from skimage import img_as_ubyte
|
10 |
+
import torch
|
11 |
+
from sync_batchnorm import DataParallelWithCallback
|
12 |
+
import depth
|
13 |
+
from modules.keypoint_detector import KPDetector
|
14 |
+
from scipy.spatial import ConvexHull
|
15 |
+
from collections import OrderedDict
|
16 |
+
import warnings
|
17 |
+
warnings.filterwarnings("ignore")
|
18 |
+
|
19 |
+
def normalize_kp(kp_source, kp_driving, kp_driving_initial, adapt_movement_scale=False,
|
20 |
+
use_relative_movement=False, use_relative_jacobian=False):
|
21 |
+
if adapt_movement_scale:
|
22 |
+
source_area = ConvexHull(kp_source['value'][0].data.cpu().numpy()).volume
|
23 |
+
driving_area = ConvexHull(kp_driving_initial['value'][0].data.cpu().numpy()).volume
|
24 |
+
adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area)
|
25 |
+
else:
|
26 |
+
adapt_movement_scale = 1
|
27 |
+
|
28 |
+
kp_new = {k: v for k, v in kp_driving.items()}
|
29 |
+
|
30 |
+
if use_relative_movement:
|
31 |
+
kp_value_diff = (kp_driving['value'] - kp_driving_initial['value'])
|
32 |
+
kp_value_diff *= adapt_movement_scale
|
33 |
+
kp_new['value'] = kp_value_diff + kp_source['value']
|
34 |
+
|
35 |
+
if use_relative_jacobian:
|
36 |
+
jacobian_diff = torch.matmul(kp_driving['jacobian'], torch.inverse(kp_driving_initial['jacobian']))
|
37 |
+
kp_new['jacobian'] = torch.matmul(jacobian_diff, kp_source['jacobian'])
|
38 |
+
return kp_new
|
39 |
+
|
40 |
+
def find_best_frame(source, driving, cpu=False):
|
41 |
+
import face_alignment
|
42 |
+
|
43 |
+
def normalize_kp(kp):
|
44 |
+
kp = kp - kp.mean(axis=0, keepdims=True)
|
45 |
+
area = ConvexHull(kp[:, :2]).volume
|
46 |
+
area = np.sqrt(area)
|
47 |
+
kp[:, :2] = kp[:, :2] / area
|
48 |
+
return kp
|
49 |
+
|
50 |
+
fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=True,
|
51 |
+
device='cpu' if cpu else 'cuda')
|
52 |
+
kp_source = fa.get_landmarks(255 * source)[0]
|
53 |
+
kp_source = normalize_kp(kp_source)
|
54 |
+
norm = float('inf')
|
55 |
+
frame_num = 0
|
56 |
+
for i, image in tqdm(enumerate(driving)):
|
57 |
+
kp_driving = fa.get_landmarks(255 * image)[0]
|
58 |
+
kp_driving = normalize_kp(kp_driving)
|
59 |
+
new_norm = (np.abs(kp_source - kp_driving) ** 2).sum()
|
60 |
+
if new_norm < norm:
|
61 |
+
norm = new_norm
|
62 |
+
frame_num = i
|
63 |
+
return frame_num
|
64 |
+
|
65 |
+
def animation(source_image, driving_video, generator, kp_detector, depth_encoder, depth_decoder, relative=True, adapt_scale=True, cpu=False):
|
66 |
+
with torch.no_grad():
|
67 |
+
predictions = []
|
68 |
+
source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2)
|
69 |
+
driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3)
|
70 |
+
if not cpu:
|
71 |
+
source = source.cuda()
|
72 |
+
driving = driving.cuda()
|
73 |
+
outputs = depth_decoder(depth_encoder(source))
|
74 |
+
depth_source = outputs[("disp", 0)]
|
75 |
+
|
76 |
+
outputs = depth_decoder(depth_encoder(driving[:, :, 0]))
|
77 |
+
depth_driving = outputs[("disp", 0)]
|
78 |
+
source_kp = torch.cat((source,depth_source),1)
|
79 |
+
driving_kp = torch.cat((driving[:, :, 0],depth_driving),1)
|
80 |
+
|
81 |
+
kp_source = kp_detector(source_kp)
|
82 |
+
kp_driving_initial = kp_detector(driving_kp)
|
83 |
+
|
84 |
+
|
85 |
+
for frame_idx in tqdm(range(driving.shape[2])):
|
86 |
+
driving_frame = driving[:, :, frame_idx]
|
87 |
+
|
88 |
+
if not cpu:
|
89 |
+
driving_frame = driving_frame.cuda()
|
90 |
+
outputs = depth_decoder(depth_encoder(driving_frame))
|
91 |
+
depth_map = outputs[("disp", 0)]
|
92 |
+
|
93 |
+
frame = torch.cat((driving_frame,depth_map),1)
|
94 |
+
kp_driving = kp_detector(frame)
|
95 |
+
|
96 |
+
kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving,
|
97 |
+
kp_driving_initial=kp_driving_initial, use_relative_movement=relative,
|
98 |
+
use_relative_jacobian=relative, adapt_movement_scale=adapt_scale)
|
99 |
+
out = generator(source, kp_source=kp_source, kp_driving=kp_norm,source_depth = depth_source, driving_depth = depth_map)
|
100 |
+
|
101 |
+
predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
|
102 |
+
return predictions
|
103 |
+
|
104 |
+
def reconstruction(source_image, driving_video, generator, kp_detector, depth_encoder, depth_decoder, cpu=False):
|
105 |
+
with torch.no_grad():
|
106 |
+
predictions = []
|
107 |
+
source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2)
|
108 |
+
driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3)
|
109 |
+
if not cpu:
|
110 |
+
source = source.cuda()
|
111 |
+
driving = driving.cuda()
|
112 |
+
outputs = depth_decoder(depth_encoder(source))
|
113 |
+
depth_source = outputs[("disp", 0)]
|
114 |
+
|
115 |
+
outputs = depth_decoder(depth_encoder(driving[:, :, 0]))
|
116 |
+
depth_driving = outputs[("disp", 0)]
|
117 |
+
source_kp = torch.cat((source,depth_source),1)
|
118 |
+
driving_kp = torch.cat((driving[:, :, 0],depth_driving),1)
|
119 |
+
|
120 |
+
kp_source = kp_detector(source_kp)
|
121 |
+
kp_driving_initial = kp_detector(driving_kp)
|
122 |
+
|
123 |
+
|
124 |
+
for frame_idx in tqdm(range(driving.shape[2])):
|
125 |
+
driving_frame = driving[:, :, frame_idx]
|
126 |
+
|
127 |
+
if not cpu:
|
128 |
+
driving_frame = driving_frame.cuda()
|
129 |
+
outputs = depth_decoder(depth_encoder(driving_frame))
|
130 |
+
depth_map = outputs[("disp", 0)]
|
131 |
+
|
132 |
+
frame = torch.cat((driving_frame,depth_map),1)
|
133 |
+
kp_driving = kp_detector(frame)
|
134 |
+
|
135 |
+
out = generator(source, kp_source=kp_source, kp_driving=kp_driving,source_depth = depth_source, driving_depth = depth_map)
|
136 |
+
|
137 |
+
predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
|
138 |
+
return predictions
|
139 |
+
|
140 |
+
parser = ArgumentParser()
|
141 |
+
parser.add_argument("--source_image", help="path to source image")
|
142 |
+
parser.add_argument("--driving_video", help="path to driving video")
|
143 |
+
parser.add_argument("--result_video", help="path to output")
|
144 |
+
parser.add_argument("--mode", type=str, choices=["reconstruction", "animation"], help="mode to run")
|
145 |
+
|
146 |
+
opt = parser.parse_args()
|
147 |
+
|
148 |
+
with open("config/vox-adv-256.yaml") as f:
|
149 |
+
config = yaml.load(f)
|
150 |
+
|
151 |
+
generator = GEN.SPADEDepthAwareGenerator(**config['model_params']['generator_params'],**config['model_params']['common_params'])
|
152 |
+
|
153 |
+
config['model_params']['common_params']['num_channels'] = 4
|
154 |
+
kp_detector = KPDetector(**config['model_params']['kp_detector_params'], **config['model_params']['common_params'])
|
155 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
156 |
+
cpu = False if torch.cuda.is_available() else True
|
157 |
+
|
158 |
+
g_checkpoint = torch.load("checkpoints/generator.pth", map_location = device)
|
159 |
+
kp_checkpoint = torch.load("checkpoints/kp_detector.pth", map_location=device)
|
160 |
+
|
161 |
+
ckp_generator = OrderedDict((k.replace('module.',''),v) for k,v in g_checkpoint.items())
|
162 |
+
generator.load_state_dict(ckp_generator)
|
163 |
+
ckp_kp_detector = OrderedDict((k.replace('module.',''),v) for k,v in kp_checkpoint.items())
|
164 |
+
kp_detector.load_state_dict(ckp_kp_detector)
|
165 |
+
|
166 |
+
depth_encoder = depth.ResnetEncoder(18, False)
|
167 |
+
depth_decoder = depth.DepthDecoder(num_ch_enc=depth_encoder.num_ch_enc, scales=range(4))
|
168 |
+
loaded_dict_enc = torch.load('checkpoints/encoder.pth')
|
169 |
+
loaded_dict_dec = torch.load('checkpoints/depth.pth')
|
170 |
+
|
171 |
+
filtered_dict_enc = {k: v for k, v in loaded_dict_enc.items() if k in depth_encoder.state_dict()}
|
172 |
+
depth_encoder.load_state_dict(filtered_dict_enc)
|
173 |
+
depth_decoder.load_state_dict(loaded_dict_dec)
|
174 |
+
|
175 |
+
depth_encoder.eval()
|
176 |
+
depth_decoder.eval()
|
177 |
+
generator.eval()
|
178 |
+
kp_detector.eval()
|
179 |
+
|
180 |
+
generator.to(device)
|
181 |
+
kp_detector.to(device)
|
182 |
+
depth_encoder.to(device)
|
183 |
+
depth_decoder.to(device)
|
184 |
+
|
185 |
+
with torch.inference_mode():
|
186 |
+
if torch.cuda.is_available():
|
187 |
+
torch.cuda.ipc_collect()
|
188 |
+
torch.cuda.empty_cache()
|
189 |
+
|
190 |
+
reader = imageio.get_reader(opt.driving_video)
|
191 |
+
fps = reader.get_meta_data()['fps']
|
192 |
+
driving_video = []
|
193 |
+
try:
|
194 |
+
for im in reader:
|
195 |
+
driving_video.append(im)
|
196 |
+
except RuntimeError:
|
197 |
+
pass
|
198 |
+
reader.close()
|
199 |
+
|
200 |
+
if opt.mode == 'animation':
|
201 |
+
source_image = imageio.imread(opt.source_image)
|
202 |
+
elif opt.mode == 'reconstruction':
|
203 |
+
source_image = driving_video[0]
|
204 |
+
|
205 |
+
source_image = resize(source_image, (256, 256))[..., :3]
|
206 |
+
driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_video]
|
207 |
+
|
208 |
+
if opt.mode == 'animation':
|
209 |
+
i = find_best_frame(source_image, driving_video, cpu=cpu)
|
210 |
+
driving_forward = driving_video[i:]
|
211 |
+
driving_backward = driving_video[:(i+1)][::-1]
|
212 |
+
|
213 |
+
predictions_forward = animation(source_image, driving_forward, generator, kp_detector, depth_encoder, depth_decoder, cpu=cpu)
|
214 |
+
predictions_backward = animation(source_image, driving_backward, generator, kp_detector, depth_encoder, depth_decoder, cpu=cpu)
|
215 |
+
predictions = predictions_backward[::-1] + predictions_forward[1:]
|
216 |
+
elif opt.mode == 'reconstruction':
|
217 |
+
predictions = reconstruction(source_image, driving_video, generator, kp_detector, depth_encoder, depth_decoder, cpu)
|
218 |
+
|
219 |
+
imageio.mimsave(opt.result_video, [img_as_ubyte(p) for p in predictions], fps=fps)
|