jhj0517
commited on
Commit
·
2837c6e
1
Parent(s):
4ddb4d7
Fix retargeting
Browse files- app.py +2 -2
- modules/live_portrait/live_portrait_inferencer.py +5 -11
app.py
CHANGED
@@ -50,8 +50,8 @@ class App:
|
|
50 |
gr.Dropdown(label=_("Model Type"), visible=False, interactive=False,
|
51 |
choices=[item.value for item in ModelType],
|
52 |
value=ModelType.HUMAN.value),
|
53 |
-
gr.Slider(label=_("Retargeting Eyes"), minimum=0, maximum=1, step=0.01, value=0
|
54 |
-
gr.Slider(label=_("Retargeting Mouth"), minimum=0, maximum=1, step=0.01, value=0
|
55 |
gr.Checkbox(label=_("Tracking Source Video"), value=False, visible=False),
|
56 |
gr.Slider(label=_("Crop Factor"), minimum=1.5, maximum=2.5, step=0.1, value=1.7),
|
57 |
]
|
|
|
50 |
gr.Dropdown(label=_("Model Type"), visible=False, interactive=False,
|
51 |
choices=[item.value for item in ModelType],
|
52 |
value=ModelType.HUMAN.value),
|
53 |
+
gr.Slider(label=_("Retargeting Eyes"), minimum=0, maximum=1, step=0.01, value=0),
|
54 |
+
gr.Slider(label=_("Retargeting Mouth"), minimum=0, maximum=1, step=0.01, value=0),
|
55 |
gr.Checkbox(label=_("Tracking Source Video"), value=False, visible=False),
|
56 |
gr.Slider(label=_("Crop Factor"), minimum=1.5, maximum=2.5, step=0.1, value=1.7),
|
57 |
]
|
modules/live_portrait/live_portrait_inferencer.py
CHANGED
@@ -268,20 +268,16 @@ class LivePortraitInferencer:
|
|
268 |
|
269 |
vid_info = get_video_info(vid_input=driving_vid_path)
|
270 |
|
271 |
-
src_length = 1
|
272 |
-
|
273 |
if src_image is not None:
|
274 |
-
src_length = len(src_image)
|
275 |
if id(src_image) != id(self.src_image) or self.crop_factor != crop_factor:
|
276 |
self.crop_factor = crop_factor
|
|
|
277 |
|
278 |
-
|
279 |
-
self.psi_list = self.prepare_source(src_image, crop_factor, True, tracking_src_vid)
|
280 |
-
else:
|
281 |
-
self.psi_list = self.prepare_source(src_image, crop_factor)
|
282 |
|
283 |
progress(0, desc="Extracting frames from the video..")
|
284 |
driving_images, vid_sound = extract_frames(driving_vid_path, os.path.join(self.output_dir, "temp", "video_frames")), extract_sound(driving_vid_path)
|
|
|
285 |
driving_length = 0
|
286 |
if driving_images is not None:
|
287 |
if id(driving_images) != id(self.driving_images):
|
@@ -290,7 +286,6 @@ class LivePortraitInferencer:
|
|
290 |
driving_length = len(self.driving_values)
|
291 |
|
292 |
total_length = len(driving_images)
|
293 |
-
self.psi_list = [self.psi_list[0] for _ in range(total_length)]
|
294 |
|
295 |
c_i_es = ExpressionSet()
|
296 |
c_o_es = ExpressionSet()
|
@@ -299,9 +294,8 @@ class LivePortraitInferencer:
|
|
299 |
psi = None
|
300 |
for i in range(total_length):
|
301 |
|
302 |
-
if i
|
303 |
psi = self.psi_list[i]
|
304 |
-
|
305 |
s_info = psi.x_s_info
|
306 |
s_es = ExpressionSet(erst=(s_info['kp'] + s_info['exp'], torch.Tensor([0, 0, 0]), s_info['scale'], s_info['t']))
|
307 |
|
@@ -309,7 +303,7 @@ class LivePortraitInferencer:
|
|
309 |
|
310 |
if i < driving_length:
|
311 |
d_i_info = self.driving_values[i]
|
312 |
-
d_i_r = torch.Tensor([d_i_info['pitch'], d_i_info['yaw'], d_i_info['roll']])
|
313 |
|
314 |
if d_0_es is None:
|
315 |
d_0_es = ExpressionSet(erst = (d_i_info['exp'], d_i_r, d_i_info['scale'], d_i_info['t']))
|
|
|
268 |
|
269 |
vid_info = get_video_info(vid_input=driving_vid_path)
|
270 |
|
|
|
|
|
271 |
if src_image is not None:
|
|
|
272 |
if id(src_image) != id(self.src_image) or self.crop_factor != crop_factor:
|
273 |
self.crop_factor = crop_factor
|
274 |
+
self.src_image = src_image
|
275 |
|
276 |
+
self.psi_list = [self.prepare_source(src_image, crop_factor)]
|
|
|
|
|
|
|
277 |
|
278 |
progress(0, desc="Extracting frames from the video..")
|
279 |
driving_images, vid_sound = extract_frames(driving_vid_path, os.path.join(self.output_dir, "temp", "video_frames")), extract_sound(driving_vid_path)
|
280 |
+
|
281 |
driving_length = 0
|
282 |
if driving_images is not None:
|
283 |
if id(driving_images) != id(self.driving_images):
|
|
|
286 |
driving_length = len(self.driving_values)
|
287 |
|
288 |
total_length = len(driving_images)
|
|
|
289 |
|
290 |
c_i_es = ExpressionSet()
|
291 |
c_o_es = ExpressionSet()
|
|
|
294 |
psi = None
|
295 |
for i in range(total_length):
|
296 |
|
297 |
+
if i == 0:
|
298 |
psi = self.psi_list[i]
|
|
|
299 |
s_info = psi.x_s_info
|
300 |
s_es = ExpressionSet(erst=(s_info['kp'] + s_info['exp'], torch.Tensor([0, 0, 0]), s_info['scale'], s_info['t']))
|
301 |
|
|
|
303 |
|
304 |
if i < driving_length:
|
305 |
d_i_info = self.driving_values[i]
|
306 |
+
d_i_r = torch.Tensor([d_i_info['pitch'], d_i_info['yaw'], d_i_info['roll']]) # .float().to(device="cuda:0")
|
307 |
|
308 |
if d_0_es is None:
|
309 |
d_0_es = ExpressionSet(erst = (d_i_info['exp'], d_i_r, d_i_info['scale'], d_i_info['t']))
|