jhj0517 commited on
Commit
4daf2ff
·
1 Parent(s): a83a3d9

Add inferencer class

Browse files
modules/live_portrait/live_portrait_inferencer.py ADDED
@@ -0,0 +1,763 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import numpy as np
4
+ import torch
5
+ import cv2
6
+ import time
7
+ import copy
8
+ import dill
9
+ from ultralytics import YOLO
10
+ import safetensors.torch
11
+ import gradio as gr
12
+
13
+ from modules.utils.paths import *
14
+ from modules.utils.image_helper import *
15
+ from modules.live_portrait.model_downloader import *
16
+ from modules.live_portrait_wrapper import LivePortraitWrapper
17
+ from modules.utils.camera import get_rotation_matrix
18
+ from modules.utils.helper import load_yaml
19
+ from modules.config.inference_config import InferenceConfig
20
+ from modules.live_portrait.spade_generator import SPADEDecoder
21
+ from modules.live_portrait.warping_network import WarpingNetwork
22
+ from modules.live_portrait.motion_extractor import MotionExtractor
23
+ from modules.live_portrait.appearance_feature_extractor import AppearanceFeatureExtractor
24
+ from modules.live_portrait.stitching_retargeting_network import StitchingRetargetingNetwork
25
+ from collections import OrderedDict
26
+
27
+
28
+ class LivePortraitInferencer:
29
+ def __init__(self,
30
+ model_dir: str = MODELS_DIR,
31
+ output_dir: str = OUTPUTS_DIR):
32
+ self.model_dir = model_dir
33
+ self.output_dir = output_dir
34
+ self.model_config = load_yaml(MODEL_CONFIG)["model_params"]
35
+
36
+ self.appearance_feature_extractor = None
37
+ self.motion_extractor = None
38
+ self.warping_module = None
39
+ self.spade_generator = None
40
+ self.stitching_retargeting_module = None
41
+ self.pipeline = None
42
+ self.detect_model = None
43
+ self.device = self.get_device()
44
+
45
+ self.mask_img = None
46
+ self.temp_img_idx = 0
47
+ self.src_image = None
48
+ self.src_image_list = None
49
+ self.sample_image = None
50
+ self.driving_images = None
51
+ self.driving_values = None
52
+ self.crop_factor = None
53
+ self.psi = None
54
+ self.psi_list = None
55
+ self.d_info = None
56
+
57
+ def load_models(self):
58
+ self.download_if_no_models()
59
+
60
+ appearance_feat_config = self.model_config["appearance_feature_extractor_params"]
61
+ self.appearance_feature_extractor = AppearanceFeatureExtractor(**appearance_feat_config).to(self.device)
62
+ self.appearance_feature_extractor = self.load_safe_tensor(
63
+ self.appearance_feature_extractor,
64
+ os.path.join(self.model_dir, "appearance_feature_extractor.safetensors")
65
+ )
66
+
67
+ motion_ext_config = self.model_config["motion_extractor_params"]
68
+ self.motion_extractor = MotionExtractor(**motion_ext_config).to(self.device)
69
+ self.motion_extractor = self.load_safe_tensor(
70
+ self.motion_extractor,
71
+ os.path.join(self.model_dir, "motion_extractor.safetensors")
72
+ )
73
+
74
+ warping_module_config = self.model_config["warping_module_params"]
75
+ self.warping_module = WarpingNetwork(**warping_module_config).to(self.device)
76
+ self.warping_module = self.load_safe_tensor(
77
+ self.warping_module,
78
+ os.path.join(self.model_dir, "warping_module.safetensors")
79
+ )
80
+
81
+ spaded_decoder_config = self.model_config["spade_generator_params"]
82
+ self.spade_generator = SPADEDecoder(**spaded_decoder_config).to(self.device)
83
+ self.spade_generator = self.load_safe_tensor(
84
+ self.spade_generator,
85
+ os.path.join(self.model_dir, "spade_generator.safetensors")
86
+ )
87
+
88
+ def filter_stitcher(checkpoint, prefix):
89
+ filtered_checkpoint = {key.replace(prefix + "_module.", ""): value for key, value in checkpoint.items() if
90
+ key.startswith(prefix)}
91
+ return filtered_checkpoint
92
+
93
+ stitcher_config = self.model_config["stitching_retargeting_module_params"]
94
+ self.stitching_retargeting_module = StitchingRetargetingNetwork(**stitcher_config.get('stitching'))
95
+ stitcher_model_path = os.path.join(self.model_dir, "stitching_retargeting_module.safetensors")
96
+ ckpt = safetensors.torch.load_file(stitcher_model_path)
97
+ self.stitching_retargeting_module.load_state_dict(filter_stitcher(ckpt, 'retarget_shoulder'))
98
+ self.stitching_retargeting_module.to(self.device).eval()
99
+ self.stitching_retargeting_module = {"stitching": self.stitching_retargeting_module}
100
+
101
+ if self.pipeline is None:
102
+ self.pipeline = LivePortraitWrapper(
103
+ InferenceConfig(),
104
+ self.appearance_feature_extractor,
105
+ self.motion_extractor,
106
+ self.warping_module,
107
+ self.spade_generator,
108
+ self.stitching_retargeting_module
109
+ )
110
+
111
+ self.detect_model = YOLO(MODEL_PATHS["face_yolov8n"])
112
+
113
+ def edit_expression(self,
114
+ rotate_pitch=0,
115
+ rotate_yaw=0,
116
+ rotate_roll=0,
117
+ blink=0,
118
+ eyebrow=0,
119
+ wink=0,
120
+ pupil_x=0,
121
+ pupil_y=0,
122
+ aaa=0,
123
+ eee=0,
124
+ woo=0,
125
+ smile=0,
126
+ src_ratio=1,
127
+ sample_ratio=1,
128
+ sample_parts="All",
129
+ crop_factor=1.5,
130
+ src_image=None,
131
+ sample_image=None,
132
+ motion_link=None,
133
+ add_exp=None):
134
+ if self.pipeline is None:
135
+ self.load_models()
136
+
137
+ try:
138
+ rotate_yaw = -rotate_yaw
139
+
140
+ new_editor_link = None
141
+ if motion_link is not None:
142
+ self.psi = motion_link[0]
143
+ new_editor_link = motion_link.copy()
144
+ elif src_image is not None:
145
+ if id(src_image) != id(self.src_image) or self.crop_factor != crop_factor:
146
+ self.crop_factor = crop_factor
147
+ self.psi = self.prepare_source(src_image, crop_factor)
148
+ self.src_image = src_image
149
+ new_editor_link = []
150
+ new_editor_link.append(self.psi)
151
+ else:
152
+ return None, None
153
+
154
+ psi = self.psi
155
+ s_info = psi.x_s_info
156
+ #delta_new = copy.deepcopy()
157
+ s_exp = s_info['exp'] * src_ratio
158
+ s_exp[0, 5] = s_info['exp'][0, 5]
159
+ s_exp += s_info['kp']
160
+
161
+ es = ExpressionSet()
162
+
163
+ if sample_image is not None:
164
+ if id(self.sample_image) != id(sample_image):
165
+ self.sample_image = sample_image
166
+ d_image_np = (sample_image * 255).byte().numpy()
167
+ d_face = self.crop_face(d_image_np[0], 1.7)
168
+ i_d = self.prepare_src_image(d_face)
169
+ self.d_info = self.pipeline.get_kp_info(i_d)
170
+ self.d_info['exp'][0, 5, 0] = 0
171
+ self.d_info['exp'][0, 5, 1] = 0
172
+
173
+ # "OnlyExpression", "OnlyRotation", "OnlyMouth", "OnlyEyes", "All"
174
+ if sample_parts == "OnlyExpression" or sample_parts == "All":
175
+ es.e += self.d_info['exp'] * sample_ratio
176
+ if sample_parts == "OnlyRotation" or sample_parts == "All":
177
+ rotate_pitch += self.d_info['pitch'] * sample_ratio
178
+ rotate_yaw += self.d_info['yaw'] * sample_ratio
179
+ rotate_roll += self.d_info['roll'] * sample_ratio
180
+ elif sample_parts == "OnlyMouth":
181
+ self.retargeting(es.e, self.d_info['exp'], sample_ratio, (14, 17, 19, 20))
182
+ elif sample_parts == "OnlyEyes":
183
+ self.retargeting(es.e, self.d_info['exp'], sample_ratio, (1, 2, 11, 13, 15, 16))
184
+
185
+ es.r = self.calc_fe(es.e, blink, eyebrow, wink, pupil_x, pupil_y, aaa, eee, woo, smile,
186
+ rotate_pitch, rotate_yaw, rotate_roll)
187
+
188
+ if add_exp is not None:
189
+ es.add(add_exp)
190
+
191
+ new_rotate = get_rotation_matrix(s_info['pitch'] + es.r[0], s_info['yaw'] + es.r[1],
192
+ s_info['roll'] + es.r[2])
193
+ x_d_new = (s_info['scale'] * (1 + es.s)) * ((s_exp + es.e) @ new_rotate) + s_info['t']
194
+
195
+ x_d_new = self.pipeline.stitching(psi.x_s_user, x_d_new)
196
+
197
+ crop_out = self.pipeline.warp_decode(psi.f_s_user, psi.x_s_user, x_d_new)
198
+ crop_out = self.pipeline.parse_output(crop_out['out'])[0]
199
+
200
+ crop_with_fullsize = cv2.warpAffine(crop_out, psi.crop_trans_m, get_rgb_size(psi.src_rgb), cv2.INTER_LINEAR)
201
+ out = np.clip(psi.mask_ori * crop_with_fullsize + (1 - psi.mask_ori) * psi.src_rgb, 0, 255).astype(np.uint8)
202
+
203
+ out_img = pil2tensor(out)
204
+ out_img_path = get_auto_incremental_file_path(TEMP_DIR, "png")
205
+
206
+ img = Image.fromarray(crop_out)
207
+ img.save(out_img_path, compress_level=1)
208
+ new_editor_link.append(es)
209
+
210
+ return out_img # {"ui": {"images": results}, "result": (out_img, new_editor_link, es)}
211
+ except Exception as e:
212
+ raise
213
+
214
+ def create_video(self,
215
+ retargeting_eyes,
216
+ retargeting_mouth,
217
+ turn_on,
218
+ tracking_src_vid,
219
+ animate_without_vid,
220
+ command,
221
+ crop_factor,
222
+ src_images=None,
223
+ driving_images=None,
224
+ motion_link=None,
225
+ progress=gr.Progress()):
226
+ if not turn_on:
227
+ return None, None
228
+ src_length = 1
229
+
230
+ if src_images is None:
231
+ if motion_link is not None:
232
+ self.psi_list = [motion_link[0]]
233
+ else:
234
+ return None, None
235
+
236
+ if src_images is not None:
237
+ src_length = len(src_images)
238
+ if id(src_images) != id(self.src_images) or self.crop_factor != crop_factor:
239
+ self.crop_factor = crop_factor
240
+ self.src_images = src_images
241
+ if 1 < src_length:
242
+ self.psi_list = self.prepare_source(src_images, crop_factor, True, tracking_src_vid)
243
+ else:
244
+ self.psi_list = [self.prepare_source(src_images, crop_factor)]
245
+
246
+ cmd_list, cmd_length = self.parsing_command(command, motion_link)
247
+ if cmd_list is None:
248
+ return None,None
249
+ cmd_idx = 0
250
+
251
+ driving_length = 0
252
+ if driving_images is not None:
253
+ if id(driving_images) != id(self.driving_images):
254
+ self.driving_images = driving_images
255
+ self.driving_values = self.prepare_driving_video(driving_images)
256
+ driving_length = len(self.driving_values)
257
+
258
+ total_length = max(driving_length, src_length)
259
+
260
+ if animate_without_vid:
261
+ total_length = max(total_length, cmd_length)
262
+
263
+ c_i_es = ExpressionSet()
264
+ c_o_es = ExpressionSet()
265
+ d_0_es = None
266
+ out_list = []
267
+
268
+ psi = None
269
+ for i in range(total_length):
270
+
271
+ if i < src_length:
272
+ psi = self.psi_list[i]
273
+ s_info = psi.x_s_info
274
+ s_es = ExpressionSet(erst=(s_info['kp'] + s_info['exp'], torch.Tensor([0, 0, 0]), s_info['scale'], s_info['t']))
275
+
276
+ new_es = ExpressionSet(es=s_es)
277
+
278
+ if i < cmd_length:
279
+ cmd = cmd_list[cmd_idx]
280
+ if 0 < cmd.change:
281
+ cmd.change -= 1
282
+ c_i_es.add(cmd.es)
283
+ c_i_es.sub(c_o_es)
284
+ elif 0 < cmd.keep:
285
+ cmd.keep -= 1
286
+
287
+ new_es.add(c_i_es)
288
+
289
+ if cmd.change == 0 and cmd.keep == 0:
290
+ cmd_idx += 1
291
+ if cmd_idx < len(cmd_list):
292
+ c_o_es = ExpressionSet(es=c_i_es)
293
+ cmd = cmd_list[cmd_idx]
294
+ c_o_es.div(cmd.change)
295
+ elif 0 < cmd_length:
296
+ new_es.add(c_i_es)
297
+
298
+ if i < driving_length:
299
+ d_i_info = self.driving_values[i]
300
+ d_i_r = torch.Tensor([d_i_info['pitch'], d_i_info['yaw'], d_i_info['roll']])#.float().to(device="cuda:0")
301
+
302
+ if d_0_es is None:
303
+ d_0_es = ExpressionSet(erst = (d_i_info['exp'], d_i_r, d_i_info['scale'], d_i_info['t']))
304
+
305
+ self.retargeting(s_es.e, d_0_es.e, retargeting_eyes, (11, 13, 15, 16))
306
+ self.retargeting(s_es.e, d_0_es.e, retargeting_mouth, (14, 17, 19, 20))
307
+
308
+ new_es.e += d_i_info['exp'] - d_0_es.e
309
+ new_es.r += d_i_r - d_0_es.r
310
+ new_es.t += d_i_info['t'] - d_0_es.t
311
+
312
+ r_new = get_rotation_matrix(
313
+ s_info['pitch'] + new_es.r[0], s_info['yaw'] + new_es.r[1], s_info['roll'] + new_es.r[2])
314
+ d_new = new_es.s * (new_es.e @ r_new) + new_es.t
315
+ d_new = self.pipeline.stitching(psi.x_s_user, d_new)
316
+ crop_out = self.pipeline.warp_decode(psi.f_s_user, psi.x_s_user, d_new)
317
+ crop_out = self.pipeline.parse_output(crop_out['out'])[0]
318
+
319
+ crop_with_fullsize = cv2.warpAffine(crop_out, psi.crop_trans_m, get_rgb_size(psi.src_rgb),
320
+ cv2.INTER_LINEAR)
321
+ out = np.clip(psi.mask_ori * crop_with_fullsize + (1 - psi.mask_ori) * psi.src_rgb, 0, 255).astype(
322
+ np.uint8)
323
+ out_list.append(out)
324
+
325
+ progress(i/total_length, "predicting..")
326
+
327
+ if len(out_list) == 0:
328
+ return None
329
+
330
+ out_imgs = torch.cat([pil2tensor(img_rgb) for img_rgb in out_list])
331
+ return out_imgs
332
+
333
+ def download_if_no_models(self):
334
+ for model_name, model_url in MODELS_URL.items():
335
+ if model_url.endswith(".pt"):
336
+ model_name += ".pt"
337
+ else:
338
+ model_name += ".safetensors"
339
+ model_path = os.path.join(self.model_dir, model_name)
340
+ if not os.path.exists(model_path):
341
+ download_model(model_path, model_url)
342
+
343
+ @staticmethod
344
+ def load_safe_tensor(model, file_path):
345
+ model.load_state_dict(safetensors.torch.load_file(file_path))
346
+ model.eval()
347
+ return model
348
+
349
+ @staticmethod
350
+ def get_device():
351
+ if torch.cuda.is_available():
352
+ return "cuda"
353
+ elif torch.backends.mps.is_available():
354
+ return "mps"
355
+ else:
356
+ return "cpu"
357
+
358
+ def get_temp_img_name(self):
359
+ self.temp_img_idx += 1
360
+ return "expression_edit_preview" + str(self.temp_img_idx) + ".png"
361
+
362
+ @staticmethod
363
+ def parsing_command(command, motoin_link):
364
+ command.replace(' ', '')
365
+ lines = command.split('\n')
366
+
367
+ cmd_list = []
368
+
369
+ total_length = 0
370
+
371
+ i = 0
372
+ for line in lines:
373
+ i += 1
374
+ if not line:
375
+ continue
376
+ try:
377
+ cmds = line.split('=')
378
+ idx = int(cmds[0])
379
+ if idx == 0: es = ExpressionSet()
380
+ else: es = ExpressionSet(es = motoin_link[idx])
381
+ cmds = cmds[1].split(':')
382
+ change = int(cmds[0])
383
+ keep = int(cmds[1])
384
+ except Exception as e:
385
+ print(f"(AdvancedLivePortrait) Command Err Line {i}: {line}, :{e}")
386
+ return None, None
387
+
388
+ total_length += change + keep
389
+ es.div(change)
390
+ cmd_list.append(Command(es, change, keep))
391
+
392
+ return cmd_list, total_length
393
+
394
+ def get_face_bboxes(self, image_rgb):
395
+ pred = self.detect_model(image_rgb, conf=0.7, device="")
396
+ return pred[0].boxes.xyxy.cpu().numpy()
397
+
398
+ def detect_face(self, image_rgb, crop_factor, sort = True):
399
+ bboxes = self.get_face_bboxes(image_rgb)
400
+ w, h = get_rgb_size(image_rgb)
401
+
402
+ print(f"w, h:{w, h}")
403
+
404
+ cx = w / 2
405
+ min_diff = w
406
+ best_box = None
407
+ for x1, y1, x2, y2 in bboxes:
408
+ bbox_w = x2 - x1
409
+ if bbox_w < 30: continue
410
+ diff = abs(cx - (x1 + bbox_w / 2))
411
+ if diff < min_diff:
412
+ best_box = [x1, y1, x2, y2]
413
+ print(f"diff, min_diff, best_box:{diff, min_diff, best_box}")
414
+ min_diff = diff
415
+
416
+ if best_box == None:
417
+ print("Failed to detect face!!")
418
+ return [0, 0, w, h]
419
+
420
+ x1, y1, x2, y2 = best_box
421
+
422
+ #for x1, y1, x2, y2 in bboxes:
423
+ bbox_w = x2 - x1
424
+ bbox_h = y2 - y1
425
+
426
+ crop_w = bbox_w * crop_factor
427
+ crop_h = bbox_h * crop_factor
428
+
429
+ crop_w = max(crop_h, crop_w)
430
+ crop_h = crop_w
431
+
432
+ kernel_x = int(x1 + bbox_w / 2)
433
+ kernel_y = int(y1 + bbox_h / 2)
434
+
435
+ new_x1 = int(kernel_x - crop_w / 2)
436
+ new_x2 = int(kernel_x + crop_w / 2)
437
+ new_y1 = int(kernel_y - crop_h / 2)
438
+ new_y2 = int(kernel_y + crop_h / 2)
439
+
440
+ if not sort:
441
+ return [int(new_x1), int(new_y1), int(new_x2), int(new_y2)]
442
+
443
+ if new_x1 < 0:
444
+ new_x2 -= new_x1
445
+ new_x1 = 0
446
+ elif w < new_x2:
447
+ new_x1 -= (new_x2 - w)
448
+ new_x2 = w
449
+ if new_x1 < 0:
450
+ new_x2 -= new_x1
451
+ new_x1 = 0
452
+
453
+ if new_y1 < 0:
454
+ new_y2 -= new_y1
455
+ new_y1 = 0
456
+ elif h < new_y2:
457
+ new_y1 -= (new_y2 - h)
458
+ new_y2 = h
459
+ if new_y1 < 0:
460
+ new_y2 -= new_y1
461
+ new_y1 = 0
462
+
463
+ if w < new_x2 and h < new_y2:
464
+ over_x = new_x2 - w
465
+ over_y = new_y2 - h
466
+ over_min = min(over_x, over_y)
467
+ new_x2 -= over_min
468
+ new_y2 -= over_min
469
+
470
+ return [int(new_x1), int(new_y1), int(new_x2), int(new_y2)]
471
+
472
+ @staticmethod
473
+ def retargeting(delta_out, driving_exp, factor, idxes):
474
+ for idx in idxes:
475
+ # delta_out[0, idx] -= src_exp[0, idx] * factor
476
+ delta_out[0, idx] += driving_exp[0, idx] * factor
477
+
478
+ @staticmethod
479
+ def calc_face_region(square, dsize):
480
+ region = copy.deepcopy(square)
481
+ is_changed = False
482
+ if dsize[0] < region[2]:
483
+ region[2] = dsize[0]
484
+ is_changed = True
485
+ if dsize[1] < region[3]:
486
+ region[3] = dsize[1]
487
+ is_changed = True
488
+
489
+ return region, is_changed
490
+
491
+ @staticmethod
492
+ def expand_img(rgb_img, square):
493
+ crop_trans_m = create_transform_matrix(max(-square[0], 0), max(-square[1], 0), 1, 1)
494
+ new_img = cv2.warpAffine(rgb_img, crop_trans_m, (square[2] - square[0], square[3] - square[1]),
495
+ cv2.INTER_LINEAR)
496
+ return new_img
497
+
498
+ def prepare_src_image(self, img):
499
+ h, w = img.shape[:2]
500
+ input_shape = [256,256]
501
+ if h != input_shape[0] or w != input_shape[1]:
502
+ if 256 < h: interpolation = cv2.INTER_AREA
503
+ else: interpolation = cv2.INTER_LINEAR
504
+ x = cv2.resize(img, (input_shape[0], input_shape[1]), interpolation = interpolation)
505
+ else:
506
+ x = img.copy()
507
+
508
+ if x.ndim == 3:
509
+ x = x[np.newaxis].astype(np.float32) / 255. # HxWx3 -> 1xHxWx3, normalized to 0~1
510
+ elif x.ndim == 4:
511
+ x = x.astype(np.float32) / 255. # BxHxWx3, normalized to 0~1
512
+ else:
513
+ raise ValueError(f'img ndim should be 3 or 4: {x.ndim}')
514
+ x = np.clip(x, 0, 1) # clip to 0~1
515
+ x = torch.from_numpy(x).permute(0, 3, 1, 2) # 1xHxWx3 -> 1x3xHxW
516
+ x = x.to(self.device)
517
+ return x
518
+
519
+ def get_mask_img(self):
520
+ if self.mask_img is None:
521
+ self.mask_img = cv2.imread(MASK_TEMPLATES, cv2.IMREAD_COLOR)
522
+ return self.mask_img
523
+
524
+ def crop_face(self, img_rgb, crop_factor):
525
+ crop_region = self.detect_face(img_rgb, crop_factor)
526
+ face_region, is_changed = self.calc_face_region(crop_region, get_rgb_size(img_rgb))
527
+ face_img = rgb_crop(img_rgb, face_region)
528
+ if is_changed: face_img = self.expand_img(face_img, crop_region)
529
+ return face_img
530
+
531
+ def prepare_source(self, source_image, crop_factor, is_video=False, tracking=False):
532
+ print("Prepare source...")
533
+ #source_image_np = (source_image * 255).byte().numpy()
534
+ # img_rgb = source_image_np[0]
535
+
536
+ psi_list = []
537
+ for img_rgb in source_image:
538
+ if tracking or len(psi_list) == 0:
539
+ crop_region = self.detect_face(img_rgb, crop_factor)
540
+ face_region, is_changed = self.calc_face_region(crop_region, get_rgb_size(img_rgb))
541
+
542
+ s_x = (face_region[2] - face_region[0]) / 512.
543
+ s_y = (face_region[3] - face_region[1]) / 512.
544
+ crop_trans_m = create_transform_matrix(crop_region[0], crop_region[1], s_x, s_y)
545
+ mask_ori = cv2.warpAffine(self.get_mask_img(), crop_trans_m, get_rgb_size(img_rgb), cv2.INTER_LINEAR)
546
+ mask_ori = mask_ori.astype(np.float32) / 255.
547
+
548
+ if is_changed:
549
+ s = (crop_region[2] - crop_region[0]) / 512.
550
+ crop_trans_m = create_transform_matrix(crop_region[0], crop_region[1], s, s)
551
+
552
+ face_img = rgb_crop(img_rgb, face_region)
553
+ if is_changed: face_img = self.expand_img(face_img, crop_region)
554
+ i_s = self.prepare_src_image(face_img)
555
+ x_s_info = self.pipeline.get_kp_info(i_s)
556
+ f_s_user = self.pipeline.extract_feature_3d(i_s)
557
+ x_s_user = self.pipeline.transform_keypoint(x_s_info)
558
+ psi = PreparedSrcImg(img_rgb, crop_trans_m, x_s_info, f_s_user, x_s_user, mask_ori)
559
+ if is_video == False:
560
+ return psi
561
+ psi_list.append(psi)
562
+
563
+ return psi_list
564
+
565
+ def prepare_driving_video(self, face_images):
566
+ print("Prepare driving video...")
567
+ f_img_np = (face_images * 255).byte().numpy()
568
+
569
+ out_list = []
570
+ for f_img in f_img_np:
571
+ i_d = self.prepare_src_image(f_img)
572
+ d_info = self.pipeline.get_kp_info(i_d)
573
+ out_list.append(d_info)
574
+
575
+ return out_list
576
+
577
+ @staticmethod
578
+ def calc_fe(x_d_new, eyes, eyebrow, wink, pupil_x, pupil_y, mouth, eee, woo, smile,
579
+ rotate_pitch, rotate_yaw, rotate_roll):
580
+
581
+ x_d_new[0, 20, 1] += smile * -0.01
582
+ x_d_new[0, 14, 1] += smile * -0.02
583
+ x_d_new[0, 17, 1] += smile * 0.0065
584
+ x_d_new[0, 17, 2] += smile * 0.003
585
+ x_d_new[0, 13, 1] += smile * -0.00275
586
+ x_d_new[0, 16, 1] += smile * -0.00275
587
+ x_d_new[0, 3, 1] += smile * -0.0035
588
+ x_d_new[0, 7, 1] += smile * -0.0035
589
+
590
+ x_d_new[0, 19, 1] += mouth * 0.001
591
+ x_d_new[0, 19, 2] += mouth * 0.0001
592
+ x_d_new[0, 17, 1] += mouth * -0.0001
593
+ rotate_pitch -= mouth * 0.05
594
+
595
+ x_d_new[0, 20, 2] += eee * -0.001
596
+ x_d_new[0, 20, 1] += eee * -0.001
597
+ #x_d_new[0, 19, 1] += eee * 0.0006
598
+ x_d_new[0, 14, 1] += eee * -0.001
599
+
600
+ x_d_new[0, 14, 1] += woo * 0.001
601
+ x_d_new[0, 3, 1] += woo * -0.0005
602
+ x_d_new[0, 7, 1] += woo * -0.0005
603
+ x_d_new[0, 17, 2] += woo * -0.0005
604
+
605
+ x_d_new[0, 11, 1] += wink * 0.001
606
+ x_d_new[0, 13, 1] += wink * -0.0003
607
+ x_d_new[0, 17, 0] += wink * 0.0003
608
+ x_d_new[0, 17, 1] += wink * 0.0003
609
+ x_d_new[0, 3, 1] += wink * -0.0003
610
+ rotate_roll -= wink * 0.1
611
+ rotate_yaw -= wink * 0.1
612
+
613
+ if 0 < pupil_x:
614
+ x_d_new[0, 11, 0] += pupil_x * 0.0007
615
+ x_d_new[0, 15, 0] += pupil_x * 0.001
616
+ else:
617
+ x_d_new[0, 11, 0] += pupil_x * 0.001
618
+ x_d_new[0, 15, 0] += pupil_x * 0.0007
619
+
620
+ x_d_new[0, 11, 1] += pupil_y * -0.001
621
+ x_d_new[0, 15, 1] += pupil_y * -0.001
622
+ eyes -= pupil_y / 2.
623
+
624
+ x_d_new[0, 11, 1] += eyes * -0.001
625
+ x_d_new[0, 13, 1] += eyes * 0.0003
626
+ x_d_new[0, 15, 1] += eyes * -0.001
627
+ x_d_new[0, 16, 1] += eyes * 0.0003
628
+ x_d_new[0, 1, 1] += eyes * -0.00025
629
+ x_d_new[0, 2, 1] += eyes * 0.00025
630
+
631
+ if 0 < eyebrow:
632
+ x_d_new[0, 1, 1] += eyebrow * 0.001
633
+ x_d_new[0, 2, 1] += eyebrow * -0.001
634
+ else:
635
+ x_d_new[0, 1, 0] += eyebrow * -0.001
636
+ x_d_new[0, 2, 0] += eyebrow * 0.001
637
+ x_d_new[0, 1, 1] += eyebrow * 0.0003
638
+ x_d_new[0, 2, 1] += eyebrow * -0.0003
639
+
640
+ return torch.Tensor([rotate_pitch, rotate_yaw, rotate_roll])
641
+
642
+
643
+ class ExpressionSet:
644
+ def __init__(self, erst=None, es=None):
645
+ if es is not None:
646
+ self.e = copy.deepcopy(es.e) # [:, :, :]
647
+ self.r = copy.deepcopy(es.r) # [:]
648
+ self.s = copy.deepcopy(es.s)
649
+ self.t = copy.deepcopy(es.t)
650
+ elif erst is not None:
651
+ self.e = erst[0]
652
+ self.r = erst[1]
653
+ self.s = erst[2]
654
+ self.t = erst[3]
655
+ else:
656
+ self.e = torch.from_numpy(np.zeros((1, 21, 3))).float().to(self.get_device())
657
+ self.r = torch.Tensor([0, 0, 0])
658
+ self.s = 0
659
+ self.t = 0
660
+
661
+ def div(self, value):
662
+ self.e /= value
663
+ self.r /= value
664
+ self.s /= value
665
+ self.t /= value
666
+
667
+ def add(self, other):
668
+ self.e += other.e
669
+ self.r += other.r
670
+ self.s += other.s
671
+ self.t += other.t
672
+
673
+ def sub(self, other):
674
+ self.e -= other.e
675
+ self.r -= other.r
676
+ self.s -= other.s
677
+ self.t -= other.t
678
+
679
+ def mul(self, value):
680
+ self.e *= value
681
+ self.r *= value
682
+ self.s *= value
683
+ self.t *= value
684
+
685
+ @staticmethod
686
+ def get_device():
687
+ if torch.cuda.is_available():
688
+ return "cuda"
689
+ elif torch.backends.mps.is_available():
690
+ return "mps"
691
+ else:
692
+ return "cpu"
693
+
694
+
695
+ def logging_time(original_fn):
696
+ def wrapper_fn(*args, **kwargs):
697
+ start_time = time.time()
698
+ result = original_fn(*args, **kwargs)
699
+ end_time = time.time()
700
+ print("WorkingTime[{}]: {} sec".format(original_fn.__name__, end_time - start_time))
701
+ return result
702
+
703
+ return wrapper_fn
704
+
705
+
706
+ def save_exp_data(file_name: str, save_exp: ExpressionSet = None):
707
+ if save_exp is None or not file_name:
708
+ return file_name
709
+
710
+ with open(os.path.join(EXP_OUTPUT_DIR, file_name + ".exp"), "wb") as f:
711
+ dill.dump(save_exp, f)
712
+
713
+ return file_name
714
+
715
+
716
+ def load_exp_data(self, file_name, ratio):
717
+ file_list = [os.path.splitext(file)[0] for file in os.listdir(EXP_OUTPUT_DIR) if file.endswith('.exp')]
718
+ with open(os.path.join(EXP_OUTPUT_DIR, file_name + ".exp"), 'rb') as f:
719
+ es = dill.load(f)
720
+ es.mul(ratio)
721
+ return es
722
+
723
+
724
+ def handle_exp_data(code1, value1, code2, value2, code3, value3, code4, value4, code5, value5, add_exp=None):
725
+ if add_exp is None:
726
+ es = ExpressionSet()
727
+ else:
728
+ es = ExpressionSet(es=add_exp)
729
+
730
+ codes = [code1, code2, code3, code4, code5]
731
+ values = [value1, value2, value3, value4, value5]
732
+ for i in range(5):
733
+ idx = int(codes[i] / 10)
734
+ r = codes[i] % 10
735
+ es.e[0, idx, r] += values[i] * 0.001
736
+
737
+ return es
738
+
739
+
740
+ def print_exp_data(cut_noise, exp=None):
741
+ if exp is None:
742
+ return exp
743
+
744
+ cuted_list = []
745
+ e = exp.exp * 1000
746
+ for idx in range(21):
747
+ for r in range(3):
748
+ a = abs(e[0, idx, r])
749
+ if (cut_noise < a): cuted_list.append((a, e[0, idx, r], idx * 10 + r))
750
+
751
+ sorted_list = sorted(cuted_list, reverse=True, key=lambda item: item[0])
752
+ print(f"sorted_list: {[[item[2], round(float(item[1]), 1)] for item in sorted_list]}")
753
+ return exp
754
+
755
+
756
+ class Command:
757
+ def __init__(self,
758
+ es: ExpressionSet,
759
+ change,
760
+ keep):
761
+ self.es = es
762
+ self.change = change
763
+ self.keep = keep