K00B404 commited on
Commit
8e78a3e
·
verified ·
1 Parent(s): 43648d4

Create live_portrait_wrapper_cpu.py

Browse files
Files changed (1) hide show
  1. src/live_portrait_wrapper_cpu.py +288 -0
src/live_portrait_wrapper_cpu.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+
3
+ """
4
+ Wrapper for LivePortrait core functions (CPU-optimized version)
5
+ """
6
+
7
+ import os.path as osp
8
+ import numpy as np
9
+ import cv2
10
+ import torch
11
+ import yaml
12
+ import psutil
13
+
14
+ from .utils.timer import Timer
15
+ from .utils.helper_cpu import load_model, concat_feat
16
+ from .utils.camera import headpose_pred_to_degree, get_rotation_matrix
17
+ from .utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio
18
+ from .config.inference_config import InferenceConfig
19
+ from .utils.rprint import rlog as log
20
+
21
+ class LivePortraitWrapperCPU(object):
22
+
23
+ def __init__(self, cfg: InferenceConfig):
24
+ model_config = yaml.load(open(cfg.models_config, 'r'), Loader=yaml.SafeLoader)
25
+
26
+ # Check available memory
27
+ available_memory = psutil.virtual_memory().available / (1024 * 1024 * 1024) # in GB
28
+ if available_memory < 2: # If less than 2GB available
29
+ log(f"Warning: Only {available_memory:.2f}GB of RAM available. This may cause performance issues or crashes.")
30
+
31
+ # init F
32
+ self.appearance_feature_extractor = load_model(cfg.checkpoint_F, model_config, 'cpu', 'appearance_feature_extractor')
33
+ log(f'Load appearance_feature_extractor done.')
34
+ # init M
35
+ self.motion_extractor = load_model(cfg.checkpoint_M, model_config, 'cpu', 'motion_extractor')
36
+ log(f'Load motion_extractor done.')
37
+ # init W
38
+ self.warping_module = load_model(cfg.checkpoint_W, model_config, 'cpu', 'warping_module')
39
+ log(f'Load warping_module done.')
40
+ # init G
41
+ self.spade_generator = load_model(cfg.checkpoint_G, model_config, 'cpu', 'spade_generator')
42
+ log(f'Load spade_generator done.')
43
+ # init S and R
44
+ if cfg.checkpoint_S is not None and osp.exists(cfg.checkpoint_S):
45
+ self.stitching_retargeting_module = load_model(cfg.checkpoint_S, model_config, 'cpu', 'stitching_retargeting_module')
46
+ log(f'Load stitching_retargeting_module done.')
47
+ else:
48
+ self.stitching_retargeting_module = None
49
+ self.device = 'cpu'
50
+ self.cfg = cfg
51
+ self.timer = Timer()
52
+
53
+ def update_config(self, user_args):
54
+ for k, v in user_args.items():
55
+ if hasattr(self.cfg, k):
56
+ setattr(self.cfg, k, v)
57
+
58
+ def prepare_source(self, img: np.ndarray) -> torch.Tensor:
59
+ """ construct the input as standard
60
+ img: HxWx3, uint8, 256x256
61
+ """
62
+ h, w = img.shape[:2]
63
+ if h != self.cfg.input_shape[0] or w != self.cfg.input_shape[1]:
64
+ x = cv2.resize(img, (self.cfg.input_shape[0], self.cfg.input_shape[1]))
65
+ else:
66
+ x = img.copy()
67
+
68
+ if x.ndim == 3:
69
+ x = x[np.newaxis].astype(np.float32) / 255. # HxWx3 -> 1xHxWx3, normalized to 0~1
70
+ elif x.ndim == 4:
71
+ x = x.astype(np.float32) / 255. # BxHxWx3, normalized to 0~1
72
+ else:
73
+ raise ValueError(f'img ndim should be 3 or 4: {x.ndim}')
74
+ x = np.clip(x, 0, 1) # clip to 0~1
75
+ x = torch.from_numpy(x).permute(0, 3, 1, 2) # 1xHxWx3 -> 1x3xHxW
76
+ return x
77
+
78
+ def prepare_driving_videos(self, imgs) -> torch.Tensor:
79
+ """ construct the input as standard
80
+ imgs: NxBxHxWx3, uint8
81
+ """
82
+ if isinstance(imgs, list):
83
+ _imgs = np.array(imgs)[..., np.newaxis] # TxHxWx3x1
84
+ elif isinstance(imgs, np.ndarray):
85
+ _imgs = imgs
86
+ else:
87
+ raise ValueError(f'imgs type error: {type(imgs)}')
88
+
89
+ y = _imgs.astype(np.float32) / 255.
90
+ y = np.clip(y, 0, 1) # clip to 0~1
91
+ y = torch.from_numpy(y).permute(0, 4, 3, 1, 2) # TxHxWx3x1 -> Tx1x3xHxW
92
+ return y
93
+
94
+ def extract_feature_3d(self, x: torch.Tensor) -> torch.Tensor:
95
+ """ get the appearance feature of the image by F
96
+ x: Bx3xHxW, normalized to 0~1
97
+ """
98
+ with torch.no_grad():
99
+ feature_3d = self.appearance_feature_extractor(x)
100
+ return feature_3d
101
+
102
+ def get_kp_info(self, x: torch.Tensor, **kwargs) -> dict:
103
+ """ get the implicit keypoint information
104
+ x: Bx3xHxW, normalized to 0~1
105
+ flag_refine_info: whether to transform the pose to degrees and the dimension of the reshape
106
+ return: A dict contains keys: 'pitch', 'yaw', 'roll', 't', 'exp', 'scale', 'kp'
107
+ """
108
+ with torch.no_grad():
109
+ kp_info = self.motion_extractor(x)
110
+
111
+ flag_refine_info: bool = kwargs.get('flag_refine_info', True)
112
+ if flag_refine_info:
113
+ bs = kp_info['kp'].shape[0]
114
+ kp_info['pitch'] = headpose_pred_to_degree(kp_info['pitch'])[:, None] # Bx1
115
+ kp_info['yaw'] = headpose_pred_to_degree(kp_info['yaw'])[:, None] # Bx1
116
+ kp_info['roll'] = headpose_pred_to_degree(kp_info['roll'])[:, None] # Bx1
117
+ kp_info['kp'] = kp_info['kp'].reshape(bs, -1, 3) # BxNx3
118
+ kp_info['exp'] = kp_info['exp'].reshape(bs, -1, 3) # BxNx3
119
+
120
+ return kp_info
121
+
122
+ def get_pose_dct(self, kp_info: dict) -> dict:
123
+ pose_dct = dict(
124
+ pitch=headpose_pred_to_degree(kp_info['pitch']).item(),
125
+ yaw=headpose_pred_to_degree(kp_info['yaw']).item(),
126
+ roll=headpose_pred_to_degree(kp_info['roll']).item(),
127
+ )
128
+ return pose_dct
129
+
130
+ def get_fs_and_kp_info(self, source_prepared, driving_first_frame):
131
+ # get the canonical keypoints of source image by M
132
+ source_kp_info = self.get_kp_info(source_prepared, flag_refine_info=True)
133
+ source_rotation = get_rotation_matrix(source_kp_info['pitch'], source_kp_info['yaw'], source_kp_info['roll'])
134
+
135
+ # get the canonical keypoints of first driving frame by M
136
+ driving_first_frame_kp_info = self.get_kp_info(driving_first_frame, flag_refine_info=True)
137
+ driving_first_frame_rotation = get_rotation_matrix(
138
+ driving_first_frame_kp_info['pitch'],
139
+ driving_first_frame_kp_info['yaw'],
140
+ driving_first_frame_kp_info['roll']
141
+ )
142
+
143
+ # get feature volume by F
144
+ source_feature_3d = self.extract_feature_3d(source_prepared)
145
+
146
+ return source_kp_info, source_rotation, source_feature_3d, driving_first_frame_kp_info, driving_first_frame_rotation
147
+
148
+ def transform_keypoint(self, kp_info: dict):
149
+ """
150
+ transform the implicit keypoints with the pose, shift, and expression deformation
151
+ kp: BxNx3
152
+ """
153
+ kp = kp_info['kp'] # (bs, k, 3)
154
+ pitch, yaw, roll = kp_info['pitch'], kp_info['yaw'], kp_info['roll']
155
+
156
+ t, exp = kp_info['t'], kp_info['exp']
157
+ scale = kp_info['scale']
158
+
159
+ pitch = headpose_pred_to_degree(pitch)
160
+ yaw = headpose_pred_to_degree(yaw)
161
+ roll = headpose_pred_to_degree(roll)
162
+
163
+ bs = kp.shape[0]
164
+ if kp.ndim == 2:
165
+ num_kp = kp.shape[1] // 3 # Bx(num_kpx3)
166
+ else:
167
+ num_kp = kp.shape[1] # Bxnum_kpx3
168
+
169
+ rot_mat = get_rotation_matrix(pitch, yaw, roll) # (bs, 3, 3)
170
+
171
+ # Eqn.2: s * (R * x_c,s + exp) + t
172
+ kp_transformed = kp.view(bs, num_kp, 3) @ rot_mat + exp.view(bs, num_kp, 3)
173
+ kp_transformed *= scale[..., None] # (bs, k, 3) * (bs, 1, 1) = (bs, k, 3)
174
+ kp_transformed[:, :, 0:2] += t[:, None, 0:2] # remove z, only apply tx ty
175
+
176
+ return kp_transformed
177
+
178
+ def retarget_eye(self, kp_source: torch.Tensor, eye_close_ratio: torch.Tensor) -> torch.Tensor:
179
+ """
180
+ kp_source: BxNx3
181
+ eye_close_ratio: Bx3
182
+ Return: Bx(3*num_kp+2)
183
+ """
184
+ feat_eye = concat_feat(kp_source, eye_close_ratio)
185
+
186
+ with torch.no_grad():
187
+ delta = self.stitching_retargeting_module['eye'](feat_eye)
188
+
189
+ return delta
190
+
191
+ def retarget_lip(self, kp_source: torch.Tensor, lip_close_ratio: torch.Tensor) -> torch.Tensor:
192
+ """
193
+ kp_source: BxNx3
194
+ lip_close_ratio: Bx2
195
+ """
196
+ feat_lip = concat_feat(kp_source, lip_close_ratio)
197
+
198
+ with torch.no_grad():
199
+ delta = self.stitching_retargeting_module['lip'](feat_lip)
200
+
201
+ return delta
202
+
203
+ def stitch(self, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
204
+ """
205
+ kp_source: BxNx3
206
+ kp_driving: BxNx3
207
+ Return: Bx(3*num_kp+2)
208
+ """
209
+ feat_stiching = concat_feat(kp_source, kp_driving)
210
+
211
+ with torch.no_grad():
212
+ delta = self.stitching_retargeting_module['stitching'](feat_stiching)
213
+
214
+ return delta
215
+
216
+ def stitching(self, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
217
+ """ conduct the stitching
218
+ kp_source: Bxnum_kpx3
219
+ kp_driving: Bxnum_kpx3
220
+ """
221
+ if self.stitching_retargeting_module is not None:
222
+ bs, num_kp = kp_source.shape[:2]
223
+
224
+ kp_driving_new = kp_driving.clone()
225
+ delta = self.stitch(kp_source, kp_driving_new)
226
+
227
+ delta_exp = delta[..., :3*num_kp].reshape(bs, num_kp, 3) # 1x20x3
228
+ delta_tx_ty = delta[..., 3*num_kp:3*num_kp+2].reshape(bs, 1, 2) # 1x1x2
229
+
230
+ kp_driving_new += delta_exp
231
+ kp_driving_new[..., :2] += delta_tx_ty
232
+
233
+ return kp_driving_new
234
+
235
+ return kp_driving
236
+
237
+ def warp_decode(self, feature_3d: torch.Tensor, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
238
+ """ get the image after the warping of the implicit keypoints
239
+ feature_3d: Bx32x16x64x64, feature volume
240
+ kp_source: BxNx3
241
+ kp_driving: BxNx3
242
+ """
243
+ # The line 18 in Algorithm 1: D(W(f_s; x_s, x′_d,i))
244
+ with torch.no_grad():
245
+ # get decoder input
246
+ ret_dct = self.warping_module(feature_3d, kp_source=kp_source, kp_driving=kp_driving)
247
+ # decode
248
+ ret_dct['out'] = self.spade_generator(feature=ret_dct['out'])
249
+
250
+ return ret_dct
251
+
252
+ def parse_output(self, out: torch.Tensor) -> np.ndarray:
253
+ """ construct the output as standard
254
+ return: 1xHxWx3, uint8
255
+ """
256
+ out = np.transpose(out.data.numpy(), [0, 2, 3, 1]) # 1x3xHxW -> 1xHxWx3
257
+ out = np.clip(out, 0, 1) # clip to 0~1
258
+ out = np.clip(out * 255, 0, 255).astype(np.uint8) # 0~1 -> 0~255
259
+
260
+ return out
261
+
262
+ def calc_retargeting_ratio(self, source_lmk, driving_lmk_lst):
263
+ input_eye_ratio_lst = []
264
+ input_lip_ratio_lst = []
265
+ for lmk in driving_lmk_lst:
266
+ # for eyes retargeting
267
+ input_eye_ratio_lst.append(calc_eye_close_ratio(lmk[None]))
268
+ # for lip retargeting
269
+ input_lip_ratio_lst.append(calc_lip_close_ratio(lmk[None]))
270
+ return input_eye_ratio_lst, input_lip_ratio_lst
271
+
272
+ def calc_combined_eye_ratio(self, input_eye_ratio, source_lmk):
273
+ eye_close_ratio = calc_eye_close_ratio(source_lmk[None])
274
+ eye_close_ratio_tensor = torch.from_numpy(eye_close_ratio).float().to(self.device)
275
+ input_eye_ratio_tensor = torch.tensor([input_eye_ratio[0][0]]).reshape(1, 1).to(self.device)
276
+ # [c_s,eyes, c_d,eyes,i]
277
+ combined_eye_ratio_tensor = torch.cat([eye_close_ratio_tensor, input_eye_ratio_tensor], dim=1)
278
+ return combined_eye_ratio_tensor
279
+
280
+ def calc_combined_lip_ratio(self, input_lip_ratio, source_lmk):
281
+ lip_close_ratio = calc_lip_close_ratio(source_lmk[None])
282
+ lip_close_ratio_tensor = torch.from_numpy(lip_close_ratio).float().to(self.device)
283
+ # [c_s,lip, c_d,lip,i]
284
+ input_lip_ratio_tensor = torch.tensor([input_lip_ratio[0]]).to(self.device)
285
+ if input_lip_ratio_tensor.shape != torch.Size([1, 1]):
286
+ input_lip_ratio_tensor = input_lip_ratio_tensor.reshape(1, 1)
287
+ combined_lip_ratio_tensor = torch.cat([lip_close_ratio_tensor, input_lip_ratio_tensor], dim=1)
288
+ return combined_lip_ratio_tensor