jhj0517 commited on
Commit
2837c6e
·
1 Parent(s): 4ddb4d7

Fix retargeting

Browse files
app.py CHANGED
@@ -50,8 +50,8 @@ class App:
50
  gr.Dropdown(label=_("Model Type"), visible=False, interactive=False,
51
  choices=[item.value for item in ModelType],
52
  value=ModelType.HUMAN.value),
53
- gr.Slider(label=_("Retargeting Eyes"), minimum=0, maximum=1, step=0.01, value=0, visible=False),
54
- gr.Slider(label=_("Retargeting Mouth"), minimum=0, maximum=1, step=0.01, value=0, visible=False),
55
  gr.Checkbox(label=_("Tracking Source Video"), value=False, visible=False),
56
  gr.Slider(label=_("Crop Factor"), minimum=1.5, maximum=2.5, step=0.1, value=1.7),
57
  ]
 
50
  gr.Dropdown(label=_("Model Type"), visible=False, interactive=False,
51
  choices=[item.value for item in ModelType],
52
  value=ModelType.HUMAN.value),
53
+ gr.Slider(label=_("Retargeting Eyes"), minimum=0, maximum=1, step=0.01, value=0),
54
+ gr.Slider(label=_("Retargeting Mouth"), minimum=0, maximum=1, step=0.01, value=0),
55
  gr.Checkbox(label=_("Tracking Source Video"), value=False, visible=False),
56
  gr.Slider(label=_("Crop Factor"), minimum=1.5, maximum=2.5, step=0.1, value=1.7),
57
  ]
modules/live_portrait/live_portrait_inferencer.py CHANGED
@@ -268,20 +268,16 @@ class LivePortraitInferencer:
268
 
269
  vid_info = get_video_info(vid_input=driving_vid_path)
270
 
271
- src_length = 1
272
-
273
  if src_image is not None:
274
- src_length = len(src_image)
275
  if id(src_image) != id(self.src_image) or self.crop_factor != crop_factor:
276
  self.crop_factor = crop_factor
 
277
 
278
- if 1 < src_length:
279
- self.psi_list = self.prepare_source(src_image, crop_factor, True, tracking_src_vid)
280
- else:
281
- self.psi_list = self.prepare_source(src_image, crop_factor)
282
 
283
  progress(0, desc="Extracting frames from the video..")
284
  driving_images, vid_sound = extract_frames(driving_vid_path, os.path.join(self.output_dir, "temp", "video_frames")), extract_sound(driving_vid_path)
 
285
  driving_length = 0
286
  if driving_images is not None:
287
  if id(driving_images) != id(self.driving_images):
@@ -290,7 +286,6 @@ class LivePortraitInferencer:
290
  driving_length = len(self.driving_values)
291
 
292
  total_length = len(driving_images)
293
- self.psi_list = [self.psi_list[0] for _ in range(total_length)]
294
 
295
  c_i_es = ExpressionSet()
296
  c_o_es = ExpressionSet()
@@ -299,9 +294,8 @@ class LivePortraitInferencer:
299
  psi = None
300
  for i in range(total_length):
301
 
302
- if i < src_length:
303
  psi = self.psi_list[i]
304
-
305
  s_info = psi.x_s_info
306
  s_es = ExpressionSet(erst=(s_info['kp'] + s_info['exp'], torch.Tensor([0, 0, 0]), s_info['scale'], s_info['t']))
307
 
@@ -309,7 +303,7 @@ class LivePortraitInferencer:
309
 
310
  if i < driving_length:
311
  d_i_info = self.driving_values[i]
312
- d_i_r = torch.Tensor([d_i_info['pitch'], d_i_info['yaw'], d_i_info['roll']]) #.float().to(device="cuda:0")
313
 
314
  if d_0_es is None:
315
  d_0_es = ExpressionSet(erst = (d_i_info['exp'], d_i_r, d_i_info['scale'], d_i_info['t']))
 
268
 
269
  vid_info = get_video_info(vid_input=driving_vid_path)
270
 
 
 
271
  if src_image is not None:
 
272
  if id(src_image) != id(self.src_image) or self.crop_factor != crop_factor:
273
  self.crop_factor = crop_factor
274
+ self.src_image = src_image
275
 
276
+ self.psi_list = [self.prepare_source(src_image, crop_factor)]
 
 
 
277
 
278
  progress(0, desc="Extracting frames from the video..")
279
  driving_images, vid_sound = extract_frames(driving_vid_path, os.path.join(self.output_dir, "temp", "video_frames")), extract_sound(driving_vid_path)
280
+
281
  driving_length = 0
282
  if driving_images is not None:
283
  if id(driving_images) != id(self.driving_images):
 
286
  driving_length = len(self.driving_values)
287
 
288
  total_length = len(driving_images)
 
289
 
290
  c_i_es = ExpressionSet()
291
  c_o_es = ExpressionSet()
 
294
  psi = None
295
  for i in range(total_length):
296
 
297
+ if i == 0:
298
  psi = self.psi_list[i]
 
299
  s_info = psi.x_s_info
300
  s_es = ExpressionSet(erst=(s_info['kp'] + s_info['exp'], torch.Tensor([0, 0, 0]), s_info['scale'], s_info['t']))
301
 
 
303
 
304
  if i < driving_length:
305
  d_i_info = self.driving_values[i]
306
+ d_i_r = torch.Tensor([d_i_info['pitch'], d_i_info['yaw'], d_i_info['roll']]) # .float().to(device="cuda:0")
307
 
308
  if d_0_es is None:
309
  d_0_es = ExpressionSet(erst = (d_i_info['exp'], d_i_r, d_i_info['scale'], d_i_info['t']))