Matt Goyder commited on
Commit
516b1c3
1 Parent(s): 67cf56f

add main script

Browse files
Files changed (2) hide show
  1. app.py +8 -5
  2. run.py +219 -0
app.py CHANGED
@@ -7,10 +7,13 @@ description = "reconstruction and animation demos for DaGAN"
7
 
8
  def inference(mode, img, video):
9
  os.makedirs('tmp', exist_ok=True)
10
- subprocess.run("echo 'hi'")
11
- """
12
- os.system("python run.py --source_image {} --driving_video
13
- """
 
 
 
14
 
15
  gr.Interface(
16
  inference,
@@ -19,7 +22,7 @@ gr.Interface(
19
  gr.inputs.Image(type="filepath", label="Source (only used if mode is reconstruction)"),
20
  gr.inputs.Image(type="mp4", label="Driving Video")
21
  ],
22
- outputs=None,
23
  title=title,
24
  description=description,
25
  theme="huggingface",
 
7
 
8
  def inference(mode, img, video):
9
  os.makedirs('tmp', exist_ok=True)
10
+
11
+ cmd = f"ffmpeg -y -ss 00:00:00 -i {video} -to 00:00:05 -c copy cut_vid.mp4"
12
+ subprocess.run(cmd.split())
13
+
14
+ video = "cut_vid.mp4"
15
+ os.system(f"python run.py --source_image {img} --driving_video {video} --result_video tmp/result.mp4 --mode {mode}")
16
+ return "tmp/result.mp4"
17
 
18
  gr.Interface(
19
  inference,
 
22
  gr.inputs.Image(type="filepath", label="Source (only used if mode is reconstruction)"),
23
  gr.inputs.Image(type="mp4", label="Driving Video")
24
  ],
25
+ outputs=gr.outputs.Video(type="mp4", label="Output Video"),
26
  title=title,
27
  description=description,
28
  theme="huggingface",
run.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ import yaml
3
+ from argparse import ArgumentParser
4
+ from tqdm import tqdm
5
+ import modules.generator as GEN
6
+ import imageio
7
+ import numpy as np
8
+ from skimage.transform import resize
9
+ from skimage import img_as_ubyte
10
+ import torch
11
+ from sync_batchnorm import DataParallelWithCallback
12
+ import depth
13
+ from modules.keypoint_detector import KPDetector
14
+ from scipy.spatial import ConvexHull
15
+ from collections import OrderedDict
16
+ import warnings
17
+ warnings.filterwarnings("ignore")
18
+
19
+ def normalize_kp(kp_source, kp_driving, kp_driving_initial, adapt_movement_scale=False,
20
+ use_relative_movement=False, use_relative_jacobian=False):
21
+ if adapt_movement_scale:
22
+ source_area = ConvexHull(kp_source['value'][0].data.cpu().numpy()).volume
23
+ driving_area = ConvexHull(kp_driving_initial['value'][0].data.cpu().numpy()).volume
24
+ adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area)
25
+ else:
26
+ adapt_movement_scale = 1
27
+
28
+ kp_new = {k: v for k, v in kp_driving.items()}
29
+
30
+ if use_relative_movement:
31
+ kp_value_diff = (kp_driving['value'] - kp_driving_initial['value'])
32
+ kp_value_diff *= adapt_movement_scale
33
+ kp_new['value'] = kp_value_diff + kp_source['value']
34
+
35
+ if use_relative_jacobian:
36
+ jacobian_diff = torch.matmul(kp_driving['jacobian'], torch.inverse(kp_driving_initial['jacobian']))
37
+ kp_new['jacobian'] = torch.matmul(jacobian_diff, kp_source['jacobian'])
38
+ return kp_new
39
+
40
+ def find_best_frame(source, driving, cpu=False):
41
+ import face_alignment
42
+
43
+ def normalize_kp(kp):
44
+ kp = kp - kp.mean(axis=0, keepdims=True)
45
+ area = ConvexHull(kp[:, :2]).volume
46
+ area = np.sqrt(area)
47
+ kp[:, :2] = kp[:, :2] / area
48
+ return kp
49
+
50
+ fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=True,
51
+ device='cpu' if cpu else 'cuda')
52
+ kp_source = fa.get_landmarks(255 * source)[0]
53
+ kp_source = normalize_kp(kp_source)
54
+ norm = float('inf')
55
+ frame_num = 0
56
+ for i, image in tqdm(enumerate(driving)):
57
+ kp_driving = fa.get_landmarks(255 * image)[0]
58
+ kp_driving = normalize_kp(kp_driving)
59
+ new_norm = (np.abs(kp_source - kp_driving) ** 2).sum()
60
+ if new_norm < norm:
61
+ norm = new_norm
62
+ frame_num = i
63
+ return frame_num
64
+
65
+ def animation(source_image, driving_video, generator, kp_detector, depth_encoder, depth_decoder, relative=True, adapt_scale=True, cpu=False):
66
+ with torch.no_grad():
67
+ predictions = []
68
+ source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2)
69
+ driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3)
70
+ if not cpu:
71
+ source = source.cuda()
72
+ driving = driving.cuda()
73
+ outputs = depth_decoder(depth_encoder(source))
74
+ depth_source = outputs[("disp", 0)]
75
+
76
+ outputs = depth_decoder(depth_encoder(driving[:, :, 0]))
77
+ depth_driving = outputs[("disp", 0)]
78
+ source_kp = torch.cat((source,depth_source),1)
79
+ driving_kp = torch.cat((driving[:, :, 0],depth_driving),1)
80
+
81
+ kp_source = kp_detector(source_kp)
82
+ kp_driving_initial = kp_detector(driving_kp)
83
+
84
+
85
+ for frame_idx in tqdm(range(driving.shape[2])):
86
+ driving_frame = driving[:, :, frame_idx]
87
+
88
+ if not cpu:
89
+ driving_frame = driving_frame.cuda()
90
+ outputs = depth_decoder(depth_encoder(driving_frame))
91
+ depth_map = outputs[("disp", 0)]
92
+
93
+ frame = torch.cat((driving_frame,depth_map),1)
94
+ kp_driving = kp_detector(frame)
95
+
96
+ kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving,
97
+ kp_driving_initial=kp_driving_initial, use_relative_movement=relative,
98
+ use_relative_jacobian=relative, adapt_movement_scale=adapt_scale)
99
+ out = generator(source, kp_source=kp_source, kp_driving=kp_norm,source_depth = depth_source, driving_depth = depth_map)
100
+
101
+ predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
102
+ return predictions
103
+
104
+ def reconstruction(source_image, driving_video, generator, kp_detector, depth_encoder, depth_decoder, cpu=False):
105
+ with torch.no_grad():
106
+ predictions = []
107
+ source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2)
108
+ driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3)
109
+ if not cpu:
110
+ source = source.cuda()
111
+ driving = driving.cuda()
112
+ outputs = depth_decoder(depth_encoder(source))
113
+ depth_source = outputs[("disp", 0)]
114
+
115
+ outputs = depth_decoder(depth_encoder(driving[:, :, 0]))
116
+ depth_driving = outputs[("disp", 0)]
117
+ source_kp = torch.cat((source,depth_source),1)
118
+ driving_kp = torch.cat((driving[:, :, 0],depth_driving),1)
119
+
120
+ kp_source = kp_detector(source_kp)
121
+ kp_driving_initial = kp_detector(driving_kp)
122
+
123
+
124
+ for frame_idx in tqdm(range(driving.shape[2])):
125
+ driving_frame = driving[:, :, frame_idx]
126
+
127
+ if not cpu:
128
+ driving_frame = driving_frame.cuda()
129
+ outputs = depth_decoder(depth_encoder(driving_frame))
130
+ depth_map = outputs[("disp", 0)]
131
+
132
+ frame = torch.cat((driving_frame,depth_map),1)
133
+ kp_driving = kp_detector(frame)
134
+
135
+ out = generator(source, kp_source=kp_source, kp_driving=kp_driving,source_depth = depth_source, driving_depth = depth_map)
136
+
137
+ predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
138
+ return predictions
139
+
140
+ parser = ArgumentParser()
141
+ parser.add_argument("--source_image", help="path to source image")
142
+ parser.add_argument("--driving_video", help="path to driving video")
143
+ parser.add_argument("--result_video", help="path to output")
144
+ parser.add_argument("--mode", type=str, choices=["reconstruction", "animation"], help="mode to run")
145
+
146
+ opt = parser.parse_args()
147
+
148
+ with open("config/vox-adv-256.yaml") as f:
149
+ config = yaml.load(f)
150
+
151
+ generator = GEN.SPADEDepthAwareGenerator(**config['model_params']['generator_params'],**config['model_params']['common_params'])
152
+
153
+ config['model_params']['common_params']['num_channels'] = 4
154
+ kp_detector = KPDetector(**config['model_params']['kp_detector_params'], **config['model_params']['common_params'])
155
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
156
+ cpu = False if torch.cuda.is_available() else True
157
+
158
+ g_checkpoint = torch.load("checkpoints/generator.pth", map_location = device)
159
+ kp_checkpoint = torch.load("checkpoints/kp_detector.pth", map_location=device)
160
+
161
+ ckp_generator = OrderedDict((k.replace('module.',''),v) for k,v in g_checkpoint.items())
162
+ generator.load_state_dict(ckp_generator)
163
+ ckp_kp_detector = OrderedDict((k.replace('module.',''),v) for k,v in kp_checkpoint.items())
164
+ kp_detector.load_state_dict(ckp_kp_detector)
165
+
166
+ depth_encoder = depth.ResnetEncoder(18, False)
167
+ depth_decoder = depth.DepthDecoder(num_ch_enc=depth_encoder.num_ch_enc, scales=range(4))
168
+ loaded_dict_enc = torch.load('checkpoints/encoder.pth')
169
+ loaded_dict_dec = torch.load('checkpoints/depth.pth')
170
+
171
+ filtered_dict_enc = {k: v for k, v in loaded_dict_enc.items() if k in depth_encoder.state_dict()}
172
+ depth_encoder.load_state_dict(filtered_dict_enc)
173
+ depth_decoder.load_state_dict(loaded_dict_dec)
174
+
175
+ depth_encoder.eval()
176
+ depth_decoder.eval()
177
+ generator.eval()
178
+ kp_detector.eval()
179
+
180
+ generator.to(device)
181
+ kp_detector.to(device)
182
+ depth_encoder.to(device)
183
+ depth_decoder.to(device)
184
+
185
+ with torch.inference_mode():
186
+ if torch.cuda.is_available():
187
+ torch.cuda.ipc_collect()
188
+ torch.cuda.empty_cache()
189
+
190
+ reader = imageio.get_reader(opt.driving_video)
191
+ fps = reader.get_meta_data()['fps']
192
+ driving_video = []
193
+ try:
194
+ for im in reader:
195
+ driving_video.append(im)
196
+ except RuntimeError:
197
+ pass
198
+ reader.close()
199
+
200
+ if opt.mode == 'animation':
201
+ source_image = imageio.imread(opt.source_image)
202
+ elif opt.mode == 'reconstruction':
203
+ source_image = driving_video[0]
204
+
205
+ source_image = resize(source_image, (256, 256))[..., :3]
206
+ driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_video]
207
+
208
+ if opt.mode == 'animation':
209
+ i = find_best_frame(source_image, driving_video, cpu=cpu)
210
+ driving_forward = driving_video[i:]
211
+ driving_backward = driving_video[:(i+1)][::-1]
212
+
213
+ predictions_forward = animation(source_image, driving_forward, generator, kp_detector, depth_encoder, depth_decoder, cpu=cpu)
214
+ predictions_backward = animation(source_image, driving_backward, generator, kp_detector, depth_encoder, depth_decoder, cpu=cpu)
215
+ predictions = predictions_backward[::-1] + predictions_forward[1:]
216
+ elif opt.mode == 'reconstruction':
217
+ predictions = reconstruction(source_image, driving_video, generator, kp_detector, depth_encoder, depth_decoder, cpu)
218
+
219
+ imageio.mimsave(opt.result_video, [img_as_ubyte(p) for p in predictions], fps=fps)