|
|
|
|
|
""" |
|
Pipeline of LivePortrait (CPU-optimized version) |
|
""" |
|
|
|
import torch |
|
torch.set_num_threads(4) |
|
|
|
import cv2 |
|
import numpy as np |
|
import pickle |
|
import os |
|
import os.path as osp |
|
from rich.progress import track |
|
import gc |
|
|
|
from .config.argument_config import ArgumentConfig |
|
from .config.inference_config import InferenceConfig |
|
from .config.crop_config import CropConfig |
|
from .utils.cropper import Cropper |
|
from .utils.camera import get_rotation_matrix |
|
from .utils.video import images2video, concat_frames, get_fps, add_audio_to_video, has_audio_stream |
|
from .utils.crop import _transform_img, prepare_paste_back, paste_back |
|
from .utils.retargeting_utils import calc_lip_close_ratio |
|
from .utils.io import load_image_rgb, load_driving_info, resize_to_limit |
|
from .utils.helper_cpu import mkdir, basename, dct2cpu, is_video, is_template,show_memory_usage |
|
from .utils.rprint import rlog as log |
|
from .live_portrait_wrapper_cpu import LivePortraitWrapperCPU as wrapper |
|
|
|
|
|
def make_abs_path(fn): |
|
return osp.join(osp.dirname(osp.realpath(__file__)), fn) |
|
|
|
class LiveCPUPortraitPipeline(object): |
|
|
|
def __init__(self, inference_cfg: InferenceConfig, crop_cfg: CropConfig): |
|
self.live_portrait_wrapper: wrapper = wrapper(cfg=inference_cfg) |
|
self.cropper = Cropper(crop_cfg=crop_cfg) |
|
self.mem_mon = show_memory_usage() |
|
|
|
def execute(self, args: ArgumentConfig): |
|
inference_cfg = self.live_portrait_wrapper.cfg |
|
|
|
|
|
|
|
img_rgb = load_image_rgb(args.source_image) |
|
log(f"resizing source image to {inference_cfg.ref_max_shape}x{inference_cfg.ref_max_shape}") |
|
img_rgb = resize_to_limit(img_rgb, inference_cfg.ref_max_shape, inference_cfg.ref_shape_n) |
|
log(f"processing image from {args.source_image}") |
|
crop_info = self.cropper.crop_single_image(img_rgb) |
|
source_lmk = crop_info['lmk_crop'] |
|
img_crop, img_crop_256x256 = crop_info['img_crop'], crop_info['img_crop_256x256'] |
|
if inference_cfg.flag_do_crop: |
|
log(f"Cropping source image.") |
|
I_s = self.live_portrait_wrapper.prepare_source(img_crop_256x256) |
|
else: |
|
log(f"Load source image from {args.source_image}") |
|
I_s = self.live_portrait_wrapper.prepare_source(img_rgb) |
|
x_s_info = self.live_portrait_wrapper.get_kp_info(I_s) |
|
x_c_s = x_s_info['kp'] |
|
R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll']) |
|
f_s = self.live_portrait_wrapper.extract_feature_3d(I_s) |
|
x_s = self.live_portrait_wrapper.transform_keypoint(x_s_info) |
|
|
|
if inference_cfg.flag_lip_zero: |
|
c_d_lip_before_animation = [0.] |
|
combined_lip_ratio_tensor_before_animation = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_before_animation, source_lmk) |
|
if combined_lip_ratio_tensor_before_animation[0][0] < inference_cfg.lip_zero_threshold: |
|
inference_cfg.flag_lip_zero = False |
|
else: |
|
lip_delta_before_animation = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor_before_animation) |
|
|
|
|
|
output_fps = 10 |
|
if is_video(args.driving_info): |
|
log(f"Load from video file (mp4 mov avi etc...): {args.driving_info}") |
|
output_fps = int(get_fps(args.driving_info)) |
|
log(f'The FPS of {args.driving_info} is: {output_fps}') |
|
|
|
driving_rgb_lst = load_driving_info(args.driving_info) |
|
driving_rgb_lst_256 = [cv2.resize(_, (128,128)) for _ in driving_rgb_lst] |
|
I_d_lst = self.live_portrait_wrapper.prepare_driving_videos(driving_rgb_lst_256) |
|
n_frames = I_d_lst.shape[0] |
|
if inference_cfg.flag_eye_retargeting or inference_cfg.flag_lip_retargeting: |
|
driving_lmk_lst = self.cropper.get_retargeting_lmk_info(driving_rgb_lst) |
|
input_eye_ratio_lst, input_lip_ratio_lst = self.live_portrait_wrapper.calc_retargeting_ratio(source_lmk, driving_lmk_lst) |
|
elif is_template(args.driving_info): |
|
log(f"Load from video templates {args.driving_info}") |
|
with open(args.driving_info, 'rb') as f: |
|
template_lst, driving_lmk_lst = pickle.load(f) |
|
n_frames = template_lst[0]['n_frames'] |
|
input_eye_ratio_lst, input_lip_ratio_lst = self.live_portrait_wrapper.calc_retargeting_ratio(source_lmk, driving_lmk_lst) |
|
else: |
|
raise Exception("Unsupported driving types!") |
|
|
|
|
|
if inference_cfg.flag_pasteback: |
|
mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0])) |
|
I_p_paste_lst = [] |
|
|
|
|
|
batch_size = 128 |
|
|
|
I_p_lst = [] |
|
R_d_0, x_d_0_info = None, None |
|
log(f'Number of frames:{n_frames} processing in {n_frames/batch_size:.0f} batches') |
|
for start in range(0, n_frames, batch_size): |
|
|
|
end = min(start + batch_size, n_frames) |
|
|
|
for i in track(range(start, end), description=f'Animating.....', total=end - start): |
|
log(f'Processing frame {i+1}/{end}') |
|
if is_video(args.driving_info): |
|
I_d_i = I_d_lst[i] |
|
x_d_i_info = self.live_portrait_wrapper.get_kp_info(I_d_i) |
|
R_d_i = get_rotation_matrix(x_d_i_info['pitch'], x_d_i_info['yaw'], x_d_i_info['roll']) |
|
else: |
|
x_d_i_info = template_lst[i] |
|
x_d_i_info = dct2cpu(x_d_i_info) |
|
R_d_i = x_d_i_info['R_d'] |
|
|
|
if i == 0: |
|
R_d_0 = R_d_i |
|
x_d_0_info = x_d_i_info |
|
|
|
if inference_cfg.flag_relative: |
|
R_new = (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s |
|
delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']) |
|
scale_new = x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale']) |
|
t_new = x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t']) |
|
else: |
|
R_new = R_d_i |
|
delta_new = x_d_i_info['exp'] |
|
scale_new = x_s_info['scale'] |
|
t_new = x_d_i_info['t'] |
|
|
|
t_new[..., 2].fill_(0) |
|
x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new |
|
|
|
if not inference_cfg.flag_stitching and not inference_cfg.flag_eye_retargeting and not inference_cfg.flag_lip_retargeting: |
|
if inference_cfg.flag_lip_zero: |
|
x_d_i_new += lip_delta_before_animation.reshape(-1, x_s.shape[1], 3) |
|
elif inference_cfg.flag_stitching and not inference_cfg.flag_eye_retargeting and not inference_cfg.flag_lip_retargeting: |
|
if inference_cfg.flag_lip_zero: |
|
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new) + lip_delta_before_animation.reshape(-1, x_s.shape[1], 3) |
|
else: |
|
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new) |
|
else: |
|
eyes_delta, lip_delta = None, None |
|
if inference_cfg.flag_eye_retargeting: |
|
c_d_eyes_i = input_eye_ratio_lst[i] |
|
combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio(c_d_eyes_i, source_lmk) |
|
eyes_delta = self.live_portrait_wrapper.retarget_eye(x_s, combined_eye_ratio_tensor) |
|
if inference_cfg.flag_lip_retargeting: |
|
c_d_lip_i = input_lip_ratio_lst[i] |
|
combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_i, source_lmk) |
|
lip_delta = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor) |
|
|
|
if inference_cfg.flag_relative: |
|
x_d_i_new = x_s + \ |
|
(eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \ |
|
(lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0) |
|
else: |
|
x_d_i_new = x_d_i_new + \ |
|
(eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \ |
|
(lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0) |
|
|
|
if inference_cfg.flag_stitching: |
|
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new) |
|
|
|
|
|
show_memory_usage() |
|
|
|
out = self.live_portrait_wrapper.warp_decode(f_s, x_s, x_d_i_new) |
|
I_p_i = self.live_portrait_wrapper.parse_output(out['out'])[0] |
|
I_p_lst.append(I_p_i) |
|
log(f'Generated {len(I_p_lst)} frames ') |
|
if inference_cfg.flag_pasteback: |
|
I_p_i_to_ori_blend = paste_back(I_p_i, crop_info['M_c2o'], img_rgb, mask_ori) |
|
I_p_paste_lst.append(I_p_i_to_ori_blend) |
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
|
gc.collect() |
|
|
|
|
|
show_memory_usage() |
|
|
|
mkdir(args.output_dir) |
|
wfp_concat = None |
|
flag_has_audio = has_audio_stream(args.driving_info) |
|
|
|
if is_video(args.driving_info): |
|
frames_concatenated = concat_frames(I_p_lst, driving_rgb_lst, img_crop_256x256) |
|
wfp_concat = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_concat.mp4') |
|
images2video(frames_concatenated, wfp=wfp_concat, fps=output_fps) |
|
if flag_has_audio: |
|
wfp_concat_with_audio = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_concat_with_audio.mp4') |
|
add_audio_to_video(wfp_concat, args.driving_info, wfp_concat_with_audio) |
|
os.replace(wfp_concat_with_audio, wfp_concat) |
|
log(f"Replace {wfp_concat} with {wfp_concat_with_audio}") |
|
|
|
wfp = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}.mp4') |
|
if inference_cfg.flag_pasteback: |
|
images2video(I_p_paste_lst, wfp=wfp, fps=output_fps) |
|
else: |
|
images2video(I_p_lst, wfp=wfp, fps=output_fps) |
|
|
|
if flag_has_audio: |
|
wfp_with_audio = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_with_audio.mp4') |
|
add_audio_to_video(wfp, args.driving_info, wfp_with_audio) |
|
os.replace(wfp_with_audio, wfp) |
|
log(f"Replace {wfp} with {wfp_with_audio}") |
|
|
|
return wfp, wfp_concat |