Create live_portrait_cpu_pipeline.py
Browse files
src/live_portrait_cpu_pipeline.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
|
3 |
+
"""
|
4 |
+
Pipeline of LivePortrait (CPU-optimized version)
|
5 |
+
"""
|
6 |
+
|
7 |
+
import torch
|
8 |
+
torch.set_num_threads(4) # Limit the number of threads to reduce memory usage
|
9 |
+
|
10 |
+
import cv2
|
11 |
+
import numpy as np
|
12 |
+
import pickle
|
13 |
+
import os
|
14 |
+
import os.path as osp
|
15 |
+
from rich.progress import track
|
16 |
+
import gc
|
17 |
+
|
18 |
+
from .config.argument_config import ArgumentConfig
|
19 |
+
from .config.inference_config import InferenceConfig
|
20 |
+
from .config.crop_config import CropConfig
|
21 |
+
from .utils.cropper import Cropper
|
22 |
+
from .utils.camera import get_rotation_matrix
|
23 |
+
from .utils.video import images2video, concat_frames, get_fps, add_audio_to_video, has_audio_stream
|
24 |
+
from .utils.crop import _transform_img, prepare_paste_back, paste_back
|
25 |
+
from .utils.retargeting_utils import calc_lip_close_ratio
|
26 |
+
from .utils.io import load_image_rgb, load_driving_info, resize_to_limit
|
27 |
+
from .utils.helper_cpu import mkdir, basename, dct2cpu, is_video, is_template,show_memory_usage
|
28 |
+
from .utils.rprint import rlog as log
|
29 |
+
from .live_portrait_wrapper_cpu import LivePortraitWrapperCPU as wrapper
|
30 |
+
# from .live_portrait_wrapper import LivePortraitWrapper as wrapper
|
31 |
+
|
32 |
+
def make_abs_path(fn):
|
33 |
+
return osp.join(osp.dirname(osp.realpath(__file__)), fn)
|
34 |
+
|
35 |
+
class LiveCPUPortraitPipeline(object):
|
36 |
+
|
37 |
+
def __init__(self, inference_cfg: InferenceConfig, crop_cfg: CropConfig):
|
38 |
+
self.live_portrait_wrapper: wrapper = wrapper(cfg=inference_cfg)
|
39 |
+
self.cropper = Cropper(crop_cfg=crop_cfg)
|
40 |
+
self.mem_mon = show_memory_usage()
|
41 |
+
|
42 |
+
def execute(self, args: ArgumentConfig):
|
43 |
+
inference_cfg = self.live_portrait_wrapper.cfg # for convenience
|
44 |
+
|
45 |
+
|
46 |
+
######## process source portrait ########
|
47 |
+
img_rgb = load_image_rgb(args.source_image)
|
48 |
+
log(f"resizing source image to {inference_cfg.ref_max_shape}x{inference_cfg.ref_max_shape}")
|
49 |
+
img_rgb = resize_to_limit(img_rgb, inference_cfg.ref_max_shape, inference_cfg.ref_shape_n)
|
50 |
+
log(f"processing image from {args.source_image}")
|
51 |
+
crop_info = self.cropper.crop_single_image(img_rgb)
|
52 |
+
source_lmk = crop_info['lmk_crop']
|
53 |
+
img_crop, img_crop_256x256 = crop_info['img_crop'], crop_info['img_crop_256x256']
|
54 |
+
if inference_cfg.flag_do_crop:
|
55 |
+
log(f"Cropping source image.")
|
56 |
+
I_s = self.live_portrait_wrapper.prepare_source(img_crop_256x256)
|
57 |
+
else:
|
58 |
+
log(f"Load source image from {args.source_image}")
|
59 |
+
I_s = self.live_portrait_wrapper.prepare_source(img_rgb)
|
60 |
+
x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
|
61 |
+
x_c_s = x_s_info['kp']
|
62 |
+
R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
|
63 |
+
f_s = self.live_portrait_wrapper.extract_feature_3d(I_s)
|
64 |
+
x_s = self.live_portrait_wrapper.transform_keypoint(x_s_info)
|
65 |
+
|
66 |
+
if inference_cfg.flag_lip_zero:
|
67 |
+
c_d_lip_before_animation = [0.]
|
68 |
+
combined_lip_ratio_tensor_before_animation = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_before_animation, source_lmk)
|
69 |
+
if combined_lip_ratio_tensor_before_animation[0][0] < inference_cfg.lip_zero_threshold:
|
70 |
+
inference_cfg.flag_lip_zero = False
|
71 |
+
else:
|
72 |
+
lip_delta_before_animation = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor_before_animation)
|
73 |
+
|
74 |
+
######## process driving info ########
|
75 |
+
output_fps = 10 # default fps
|
76 |
+
if is_video(args.driving_info):
|
77 |
+
log(f"Load from video file (mp4 mov avi etc...): {args.driving_info}")
|
78 |
+
output_fps = int(get_fps(args.driving_info))
|
79 |
+
log(f'The FPS of {args.driving_info} is: {output_fps}')
|
80 |
+
|
81 |
+
driving_rgb_lst = load_driving_info(args.driving_info)
|
82 |
+
driving_rgb_lst_256 = [cv2.resize(_, (128,128)) for _ in driving_rgb_lst]
|
83 |
+
I_d_lst = self.live_portrait_wrapper.prepare_driving_videos(driving_rgb_lst_256)
|
84 |
+
n_frames = I_d_lst.shape[0]
|
85 |
+
if inference_cfg.flag_eye_retargeting or inference_cfg.flag_lip_retargeting:
|
86 |
+
driving_lmk_lst = self.cropper.get_retargeting_lmk_info(driving_rgb_lst)
|
87 |
+
input_eye_ratio_lst, input_lip_ratio_lst = self.live_portrait_wrapper.calc_retargeting_ratio(source_lmk, driving_lmk_lst)
|
88 |
+
elif is_template(args.driving_info):
|
89 |
+
log(f"Load from video templates {args.driving_info}")
|
90 |
+
with open(args.driving_info, 'rb') as f:
|
91 |
+
template_lst, driving_lmk_lst = pickle.load(f)
|
92 |
+
n_frames = template_lst[0]['n_frames']
|
93 |
+
input_eye_ratio_lst, input_lip_ratio_lst = self.live_portrait_wrapper.calc_retargeting_ratio(source_lmk, driving_lmk_lst)
|
94 |
+
else:
|
95 |
+
raise Exception("Unsupported driving types!")
|
96 |
+
|
97 |
+
######## prepare for pasteback ########
|
98 |
+
if inference_cfg.flag_pasteback:
|
99 |
+
mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
|
100 |
+
I_p_paste_lst = []
|
101 |
+
|
102 |
+
# Determine batch size based on available memory and frame size
|
103 |
+
batch_size = 128 # Set this based on your system's memory capacity
|
104 |
+
|
105 |
+
I_p_lst = []
|
106 |
+
R_d_0, x_d_0_info = None, None
|
107 |
+
log(f'Number of frames:{n_frames} processing in {n_frames/batch_size:.0f} batches')
|
108 |
+
for start in range(0, n_frames, batch_size):
|
109 |
+
|
110 |
+
end = min(start + batch_size, n_frames)
|
111 |
+
|
112 |
+
for i in track(range(start, end), description=f'Animating.....', total=end - start):
|
113 |
+
log(f'Processing frame {i+1}/{end}')
|
114 |
+
if is_video(args.driving_info):
|
115 |
+
I_d_i = I_d_lst[i]
|
116 |
+
x_d_i_info = self.live_portrait_wrapper.get_kp_info(I_d_i)
|
117 |
+
R_d_i = get_rotation_matrix(x_d_i_info['pitch'], x_d_i_info['yaw'], x_d_i_info['roll'])
|
118 |
+
else:
|
119 |
+
x_d_i_info = template_lst[i]
|
120 |
+
x_d_i_info = dct2cpu(x_d_i_info)
|
121 |
+
R_d_i = x_d_i_info['R_d']
|
122 |
+
|
123 |
+
if i == 0:
|
124 |
+
R_d_0 = R_d_i
|
125 |
+
x_d_0_info = x_d_i_info
|
126 |
+
|
127 |
+
if inference_cfg.flag_relative:
|
128 |
+
R_new = (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s
|
129 |
+
delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp'])
|
130 |
+
scale_new = x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale'])
|
131 |
+
t_new = x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t'])
|
132 |
+
else:
|
133 |
+
R_new = R_d_i
|
134 |
+
delta_new = x_d_i_info['exp']
|
135 |
+
scale_new = x_s_info['scale']
|
136 |
+
t_new = x_d_i_info['t']
|
137 |
+
|
138 |
+
t_new[..., 2].fill_(0) # zero tz
|
139 |
+
x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new
|
140 |
+
|
141 |
+
if not inference_cfg.flag_stitching and not inference_cfg.flag_eye_retargeting and not inference_cfg.flag_lip_retargeting:
|
142 |
+
if inference_cfg.flag_lip_zero:
|
143 |
+
x_d_i_new += lip_delta_before_animation.reshape(-1, x_s.shape[1], 3)
|
144 |
+
elif inference_cfg.flag_stitching and not inference_cfg.flag_eye_retargeting and not inference_cfg.flag_lip_retargeting:
|
145 |
+
if inference_cfg.flag_lip_zero:
|
146 |
+
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new) + lip_delta_before_animation.reshape(-1, x_s.shape[1], 3)
|
147 |
+
else:
|
148 |
+
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new)
|
149 |
+
else:
|
150 |
+
eyes_delta, lip_delta = None, None
|
151 |
+
if inference_cfg.flag_eye_retargeting:
|
152 |
+
c_d_eyes_i = input_eye_ratio_lst[i]
|
153 |
+
combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio(c_d_eyes_i, source_lmk)
|
154 |
+
eyes_delta = self.live_portrait_wrapper.retarget_eye(x_s, combined_eye_ratio_tensor)
|
155 |
+
if inference_cfg.flag_lip_retargeting:
|
156 |
+
c_d_lip_i = input_lip_ratio_lst[i]
|
157 |
+
combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_i, source_lmk)
|
158 |
+
lip_delta = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor)
|
159 |
+
|
160 |
+
if inference_cfg.flag_relative:
|
161 |
+
x_d_i_new = x_s + \
|
162 |
+
(eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \
|
163 |
+
(lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0)
|
164 |
+
else:
|
165 |
+
x_d_i_new = x_d_i_new + \
|
166 |
+
(eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \
|
167 |
+
(lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0)
|
168 |
+
|
169 |
+
if inference_cfg.flag_stitching:
|
170 |
+
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new)
|
171 |
+
|
172 |
+
# Check memory usage periodically
|
173 |
+
show_memory_usage()
|
174 |
+
|
175 |
+
out = self.live_portrait_wrapper.warp_decode(f_s, x_s, x_d_i_new)
|
176 |
+
I_p_i = self.live_portrait_wrapper.parse_output(out['out'])[0]
|
177 |
+
I_p_lst.append(I_p_i)
|
178 |
+
log(f'Generated {len(I_p_lst)} frames ')
|
179 |
+
if inference_cfg.flag_pasteback:
|
180 |
+
I_p_i_to_ori_blend = paste_back(I_p_i, crop_info['M_c2o'], img_rgb, mask_ori)
|
181 |
+
I_p_paste_lst.append(I_p_i_to_ori_blend)
|
182 |
+
|
183 |
+
# Clear memory after processing the batch
|
184 |
+
torch.cuda.empty_cache()
|
185 |
+
#del I_d_lst, x_d_i_new, x_d_i_info, out, I_p_i # Clear batch-related variables
|
186 |
+
gc.collect() # Force garbage collection
|
187 |
+
|
188 |
+
# Check memory usage periodically
|
189 |
+
show_memory_usage()
|
190 |
+
|
191 |
+
mkdir(args.output_dir)
|
192 |
+
wfp_concat = None
|
193 |
+
flag_has_audio = has_audio_stream(args.driving_info)
|
194 |
+
|
195 |
+
if is_video(args.driving_info):
|
196 |
+
frames_concatenated = concat_frames(I_p_lst, driving_rgb_lst, img_crop_256x256)
|
197 |
+
wfp_concat = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_concat.mp4')
|
198 |
+
images2video(frames_concatenated, wfp=wfp_concat, fps=output_fps)
|
199 |
+
if flag_has_audio:
|
200 |
+
wfp_concat_with_audio = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_concat_with_audio.mp4')
|
201 |
+
add_audio_to_video(wfp_concat, args.driving_info, wfp_concat_with_audio)
|
202 |
+
os.replace(wfp_concat_with_audio, wfp_concat)
|
203 |
+
log(f"Replace {wfp_concat} with {wfp_concat_with_audio}")
|
204 |
+
|
205 |
+
wfp = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}.mp4')
|
206 |
+
if inference_cfg.flag_pasteback:
|
207 |
+
images2video(I_p_paste_lst, wfp=wfp, fps=output_fps)
|
208 |
+
else:
|
209 |
+
images2video(I_p_lst, wfp=wfp, fps=output_fps)
|
210 |
+
|
211 |
+
if flag_has_audio:
|
212 |
+
wfp_with_audio = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_with_audio.mp4')
|
213 |
+
add_audio_to_video(wfp, args.driving_info, wfp_with_audio)
|
214 |
+
os.replace(wfp_with_audio, wfp)
|
215 |
+
log(f"Replace {wfp} with {wfp_with_audio}")
|
216 |
+
|
217 |
+
return wfp, wfp_concat
|