K00B404 commited on
Commit
43648d4
·
verified ·
1 Parent(s): 58710f2

Create live_portrait_cpu_pipeline.py

Browse files
Files changed (1) hide show
  1. src/live_portrait_cpu_pipeline.py +217 -0
src/live_portrait_cpu_pipeline.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+
3
+ """
4
+ Pipeline of LivePortrait (CPU-optimized version)
5
+ """
6
+
7
+ import torch
8
+ torch.set_num_threads(4) # Limit the number of threads to reduce memory usage
9
+
10
+ import cv2
11
+ import numpy as np
12
+ import pickle
13
+ import os
14
+ import os.path as osp
15
+ from rich.progress import track
16
+ import gc
17
+
18
+ from .config.argument_config import ArgumentConfig
19
+ from .config.inference_config import InferenceConfig
20
+ from .config.crop_config import CropConfig
21
+ from .utils.cropper import Cropper
22
+ from .utils.camera import get_rotation_matrix
23
+ from .utils.video import images2video, concat_frames, get_fps, add_audio_to_video, has_audio_stream
24
+ from .utils.crop import _transform_img, prepare_paste_back, paste_back
25
+ from .utils.retargeting_utils import calc_lip_close_ratio
26
+ from .utils.io import load_image_rgb, load_driving_info, resize_to_limit
27
+ from .utils.helper_cpu import mkdir, basename, dct2cpu, is_video, is_template,show_memory_usage
28
+ from .utils.rprint import rlog as log
29
+ from .live_portrait_wrapper_cpu import LivePortraitWrapperCPU as wrapper
30
+ # from .live_portrait_wrapper import LivePortraitWrapper as wrapper
31
+
32
+ def make_abs_path(fn):
33
+ return osp.join(osp.dirname(osp.realpath(__file__)), fn)
34
+
35
+ class LiveCPUPortraitPipeline(object):
36
+
37
+ def __init__(self, inference_cfg: InferenceConfig, crop_cfg: CropConfig):
38
+ self.live_portrait_wrapper: wrapper = wrapper(cfg=inference_cfg)
39
+ self.cropper = Cropper(crop_cfg=crop_cfg)
40
+ self.mem_mon = show_memory_usage()
41
+
42
+ def execute(self, args: ArgumentConfig):
43
+ inference_cfg = self.live_portrait_wrapper.cfg # for convenience
44
+
45
+
46
+ ######## process source portrait ########
47
+ img_rgb = load_image_rgb(args.source_image)
48
+ log(f"resizing source image to {inference_cfg.ref_max_shape}x{inference_cfg.ref_max_shape}")
49
+ img_rgb = resize_to_limit(img_rgb, inference_cfg.ref_max_shape, inference_cfg.ref_shape_n)
50
+ log(f"processing image from {args.source_image}")
51
+ crop_info = self.cropper.crop_single_image(img_rgb)
52
+ source_lmk = crop_info['lmk_crop']
53
+ img_crop, img_crop_256x256 = crop_info['img_crop'], crop_info['img_crop_256x256']
54
+ if inference_cfg.flag_do_crop:
55
+ log(f"Cropping source image.")
56
+ I_s = self.live_portrait_wrapper.prepare_source(img_crop_256x256)
57
+ else:
58
+ log(f"Load source image from {args.source_image}")
59
+ I_s = self.live_portrait_wrapper.prepare_source(img_rgb)
60
+ x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
61
+ x_c_s = x_s_info['kp']
62
+ R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
63
+ f_s = self.live_portrait_wrapper.extract_feature_3d(I_s)
64
+ x_s = self.live_portrait_wrapper.transform_keypoint(x_s_info)
65
+
66
+ if inference_cfg.flag_lip_zero:
67
+ c_d_lip_before_animation = [0.]
68
+ combined_lip_ratio_tensor_before_animation = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_before_animation, source_lmk)
69
+ if combined_lip_ratio_tensor_before_animation[0][0] < inference_cfg.lip_zero_threshold:
70
+ inference_cfg.flag_lip_zero = False
71
+ else:
72
+ lip_delta_before_animation = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor_before_animation)
73
+
74
+ ######## process driving info ########
75
+ output_fps = 10 # default fps
76
+ if is_video(args.driving_info):
77
+ log(f"Load from video file (mp4 mov avi etc...): {args.driving_info}")
78
+ output_fps = int(get_fps(args.driving_info))
79
+ log(f'The FPS of {args.driving_info} is: {output_fps}')
80
+
81
+ driving_rgb_lst = load_driving_info(args.driving_info)
82
+ driving_rgb_lst_256 = [cv2.resize(_, (128,128)) for _ in driving_rgb_lst]
83
+ I_d_lst = self.live_portrait_wrapper.prepare_driving_videos(driving_rgb_lst_256)
84
+ n_frames = I_d_lst.shape[0]
85
+ if inference_cfg.flag_eye_retargeting or inference_cfg.flag_lip_retargeting:
86
+ driving_lmk_lst = self.cropper.get_retargeting_lmk_info(driving_rgb_lst)
87
+ input_eye_ratio_lst, input_lip_ratio_lst = self.live_portrait_wrapper.calc_retargeting_ratio(source_lmk, driving_lmk_lst)
88
+ elif is_template(args.driving_info):
89
+ log(f"Load from video templates {args.driving_info}")
90
+ with open(args.driving_info, 'rb') as f:
91
+ template_lst, driving_lmk_lst = pickle.load(f)
92
+ n_frames = template_lst[0]['n_frames']
93
+ input_eye_ratio_lst, input_lip_ratio_lst = self.live_portrait_wrapper.calc_retargeting_ratio(source_lmk, driving_lmk_lst)
94
+ else:
95
+ raise Exception("Unsupported driving types!")
96
+
97
+ ######## prepare for pasteback ########
98
+ if inference_cfg.flag_pasteback:
99
+ mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
100
+ I_p_paste_lst = []
101
+
102
+ # Determine batch size based on available memory and frame size
103
+ batch_size = 128 # Set this based on your system's memory capacity
104
+
105
+ I_p_lst = []
106
+ R_d_0, x_d_0_info = None, None
107
+ log(f'Number of frames:{n_frames} processing in {n_frames/batch_size:.0f} batches')
108
+ for start in range(0, n_frames, batch_size):
109
+
110
+ end = min(start + batch_size, n_frames)
111
+
112
+ for i in track(range(start, end), description=f'Animating.....', total=end - start):
113
+ log(f'Processing frame {i+1}/{end}')
114
+ if is_video(args.driving_info):
115
+ I_d_i = I_d_lst[i]
116
+ x_d_i_info = self.live_portrait_wrapper.get_kp_info(I_d_i)
117
+ R_d_i = get_rotation_matrix(x_d_i_info['pitch'], x_d_i_info['yaw'], x_d_i_info['roll'])
118
+ else:
119
+ x_d_i_info = template_lst[i]
120
+ x_d_i_info = dct2cpu(x_d_i_info)
121
+ R_d_i = x_d_i_info['R_d']
122
+
123
+ if i == 0:
124
+ R_d_0 = R_d_i
125
+ x_d_0_info = x_d_i_info
126
+
127
+ if inference_cfg.flag_relative:
128
+ R_new = (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s
129
+ delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp'])
130
+ scale_new = x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale'])
131
+ t_new = x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t'])
132
+ else:
133
+ R_new = R_d_i
134
+ delta_new = x_d_i_info['exp']
135
+ scale_new = x_s_info['scale']
136
+ t_new = x_d_i_info['t']
137
+
138
+ t_new[..., 2].fill_(0) # zero tz
139
+ x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new
140
+
141
+ if not inference_cfg.flag_stitching and not inference_cfg.flag_eye_retargeting and not inference_cfg.flag_lip_retargeting:
142
+ if inference_cfg.flag_lip_zero:
143
+ x_d_i_new += lip_delta_before_animation.reshape(-1, x_s.shape[1], 3)
144
+ elif inference_cfg.flag_stitching and not inference_cfg.flag_eye_retargeting and not inference_cfg.flag_lip_retargeting:
145
+ if inference_cfg.flag_lip_zero:
146
+ x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new) + lip_delta_before_animation.reshape(-1, x_s.shape[1], 3)
147
+ else:
148
+ x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new)
149
+ else:
150
+ eyes_delta, lip_delta = None, None
151
+ if inference_cfg.flag_eye_retargeting:
152
+ c_d_eyes_i = input_eye_ratio_lst[i]
153
+ combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio(c_d_eyes_i, source_lmk)
154
+ eyes_delta = self.live_portrait_wrapper.retarget_eye(x_s, combined_eye_ratio_tensor)
155
+ if inference_cfg.flag_lip_retargeting:
156
+ c_d_lip_i = input_lip_ratio_lst[i]
157
+ combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_i, source_lmk)
158
+ lip_delta = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor)
159
+
160
+ if inference_cfg.flag_relative:
161
+ x_d_i_new = x_s + \
162
+ (eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \
163
+ (lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0)
164
+ else:
165
+ x_d_i_new = x_d_i_new + \
166
+ (eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \
167
+ (lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0)
168
+
169
+ if inference_cfg.flag_stitching:
170
+ x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new)
171
+
172
+ # Check memory usage periodically
173
+ show_memory_usage()
174
+
175
+ out = self.live_portrait_wrapper.warp_decode(f_s, x_s, x_d_i_new)
176
+ I_p_i = self.live_portrait_wrapper.parse_output(out['out'])[0]
177
+ I_p_lst.append(I_p_i)
178
+ log(f'Generated {len(I_p_lst)} frames ')
179
+ if inference_cfg.flag_pasteback:
180
+ I_p_i_to_ori_blend = paste_back(I_p_i, crop_info['M_c2o'], img_rgb, mask_ori)
181
+ I_p_paste_lst.append(I_p_i_to_ori_blend)
182
+
183
+ # Clear memory after processing the batch
184
+ torch.cuda.empty_cache()
185
+ #del I_d_lst, x_d_i_new, x_d_i_info, out, I_p_i # Clear batch-related variables
186
+ gc.collect() # Force garbage collection
187
+
188
+ # Check memory usage periodically
189
+ show_memory_usage()
190
+
191
+ mkdir(args.output_dir)
192
+ wfp_concat = None
193
+ flag_has_audio = has_audio_stream(args.driving_info)
194
+
195
+ if is_video(args.driving_info):
196
+ frames_concatenated = concat_frames(I_p_lst, driving_rgb_lst, img_crop_256x256)
197
+ wfp_concat = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_concat.mp4')
198
+ images2video(frames_concatenated, wfp=wfp_concat, fps=output_fps)
199
+ if flag_has_audio:
200
+ wfp_concat_with_audio = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_concat_with_audio.mp4')
201
+ add_audio_to_video(wfp_concat, args.driving_info, wfp_concat_with_audio)
202
+ os.replace(wfp_concat_with_audio, wfp_concat)
203
+ log(f"Replace {wfp_concat} with {wfp_concat_with_audio}")
204
+
205
+ wfp = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}.mp4')
206
+ if inference_cfg.flag_pasteback:
207
+ images2video(I_p_paste_lst, wfp=wfp, fps=output_fps)
208
+ else:
209
+ images2video(I_p_lst, wfp=wfp, fps=output_fps)
210
+
211
+ if flag_has_audio:
212
+ wfp_with_audio = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_with_audio.mp4')
213
+ add_audio_to_video(wfp, args.driving_info, wfp_with_audio)
214
+ os.replace(wfp_with_audio, wfp)
215
+ log(f"Replace {wfp} with {wfp_with_audio}")
216
+
217
+ return wfp, wfp_concat