B2BMGMT_LivePortrait_cpu / src /live_portrait_cpu_pipeline.py
K00B404's picture
Create live_portrait_cpu_pipeline.py
43648d4 verified
# coding: utf-8
"""
Pipeline of LivePortrait (CPU-optimized version)
"""
import torch
torch.set_num_threads(4) # Limit the number of threads to reduce memory usage
import cv2
import numpy as np
import pickle
import os
import os.path as osp
from rich.progress import track
import gc
from .config.argument_config import ArgumentConfig
from .config.inference_config import InferenceConfig
from .config.crop_config import CropConfig
from .utils.cropper import Cropper
from .utils.camera import get_rotation_matrix
from .utils.video import images2video, concat_frames, get_fps, add_audio_to_video, has_audio_stream
from .utils.crop import _transform_img, prepare_paste_back, paste_back
from .utils.retargeting_utils import calc_lip_close_ratio
from .utils.io import load_image_rgb, load_driving_info, resize_to_limit
from .utils.helper_cpu import mkdir, basename, dct2cpu, is_video, is_template,show_memory_usage
from .utils.rprint import rlog as log
from .live_portrait_wrapper_cpu import LivePortraitWrapperCPU as wrapper
# from .live_portrait_wrapper import LivePortraitWrapper as wrapper
def make_abs_path(fn):
return osp.join(osp.dirname(osp.realpath(__file__)), fn)
class LiveCPUPortraitPipeline(object):
def __init__(self, inference_cfg: InferenceConfig, crop_cfg: CropConfig):
self.live_portrait_wrapper: wrapper = wrapper(cfg=inference_cfg)
self.cropper = Cropper(crop_cfg=crop_cfg)
self.mem_mon = show_memory_usage()
def execute(self, args: ArgumentConfig):
inference_cfg = self.live_portrait_wrapper.cfg # for convenience
######## process source portrait ########
img_rgb = load_image_rgb(args.source_image)
log(f"resizing source image to {inference_cfg.ref_max_shape}x{inference_cfg.ref_max_shape}")
img_rgb = resize_to_limit(img_rgb, inference_cfg.ref_max_shape, inference_cfg.ref_shape_n)
log(f"processing image from {args.source_image}")
crop_info = self.cropper.crop_single_image(img_rgb)
source_lmk = crop_info['lmk_crop']
img_crop, img_crop_256x256 = crop_info['img_crop'], crop_info['img_crop_256x256']
if inference_cfg.flag_do_crop:
log(f"Cropping source image.")
I_s = self.live_portrait_wrapper.prepare_source(img_crop_256x256)
else:
log(f"Load source image from {args.source_image}")
I_s = self.live_portrait_wrapper.prepare_source(img_rgb)
x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
x_c_s = x_s_info['kp']
R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
f_s = self.live_portrait_wrapper.extract_feature_3d(I_s)
x_s = self.live_portrait_wrapper.transform_keypoint(x_s_info)
if inference_cfg.flag_lip_zero:
c_d_lip_before_animation = [0.]
combined_lip_ratio_tensor_before_animation = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_before_animation, source_lmk)
if combined_lip_ratio_tensor_before_animation[0][0] < inference_cfg.lip_zero_threshold:
inference_cfg.flag_lip_zero = False
else:
lip_delta_before_animation = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor_before_animation)
######## process driving info ########
output_fps = 10 # default fps
if is_video(args.driving_info):
log(f"Load from video file (mp4 mov avi etc...): {args.driving_info}")
output_fps = int(get_fps(args.driving_info))
log(f'The FPS of {args.driving_info} is: {output_fps}')
driving_rgb_lst = load_driving_info(args.driving_info)
driving_rgb_lst_256 = [cv2.resize(_, (128,128)) for _ in driving_rgb_lst]
I_d_lst = self.live_portrait_wrapper.prepare_driving_videos(driving_rgb_lst_256)
n_frames = I_d_lst.shape[0]
if inference_cfg.flag_eye_retargeting or inference_cfg.flag_lip_retargeting:
driving_lmk_lst = self.cropper.get_retargeting_lmk_info(driving_rgb_lst)
input_eye_ratio_lst, input_lip_ratio_lst = self.live_portrait_wrapper.calc_retargeting_ratio(source_lmk, driving_lmk_lst)
elif is_template(args.driving_info):
log(f"Load from video templates {args.driving_info}")
with open(args.driving_info, 'rb') as f:
template_lst, driving_lmk_lst = pickle.load(f)
n_frames = template_lst[0]['n_frames']
input_eye_ratio_lst, input_lip_ratio_lst = self.live_portrait_wrapper.calc_retargeting_ratio(source_lmk, driving_lmk_lst)
else:
raise Exception("Unsupported driving types!")
######## prepare for pasteback ########
if inference_cfg.flag_pasteback:
mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
I_p_paste_lst = []
# Determine batch size based on available memory and frame size
batch_size = 128 # Set this based on your system's memory capacity
I_p_lst = []
R_d_0, x_d_0_info = None, None
log(f'Number of frames:{n_frames} processing in {n_frames/batch_size:.0f} batches')
for start in range(0, n_frames, batch_size):
end = min(start + batch_size, n_frames)
for i in track(range(start, end), description=f'Animating.....', total=end - start):
log(f'Processing frame {i+1}/{end}')
if is_video(args.driving_info):
I_d_i = I_d_lst[i]
x_d_i_info = self.live_portrait_wrapper.get_kp_info(I_d_i)
R_d_i = get_rotation_matrix(x_d_i_info['pitch'], x_d_i_info['yaw'], x_d_i_info['roll'])
else:
x_d_i_info = template_lst[i]
x_d_i_info = dct2cpu(x_d_i_info)
R_d_i = x_d_i_info['R_d']
if i == 0:
R_d_0 = R_d_i
x_d_0_info = x_d_i_info
if inference_cfg.flag_relative:
R_new = (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s
delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp'])
scale_new = x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale'])
t_new = x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t'])
else:
R_new = R_d_i
delta_new = x_d_i_info['exp']
scale_new = x_s_info['scale']
t_new = x_d_i_info['t']
t_new[..., 2].fill_(0) # zero tz
x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new
if not inference_cfg.flag_stitching and not inference_cfg.flag_eye_retargeting and not inference_cfg.flag_lip_retargeting:
if inference_cfg.flag_lip_zero:
x_d_i_new += lip_delta_before_animation.reshape(-1, x_s.shape[1], 3)
elif inference_cfg.flag_stitching and not inference_cfg.flag_eye_retargeting and not inference_cfg.flag_lip_retargeting:
if inference_cfg.flag_lip_zero:
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new) + lip_delta_before_animation.reshape(-1, x_s.shape[1], 3)
else:
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new)
else:
eyes_delta, lip_delta = None, None
if inference_cfg.flag_eye_retargeting:
c_d_eyes_i = input_eye_ratio_lst[i]
combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio(c_d_eyes_i, source_lmk)
eyes_delta = self.live_portrait_wrapper.retarget_eye(x_s, combined_eye_ratio_tensor)
if inference_cfg.flag_lip_retargeting:
c_d_lip_i = input_lip_ratio_lst[i]
combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_i, source_lmk)
lip_delta = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor)
if inference_cfg.flag_relative:
x_d_i_new = x_s + \
(eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \
(lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0)
else:
x_d_i_new = x_d_i_new + \
(eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \
(lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0)
if inference_cfg.flag_stitching:
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new)
# Check memory usage periodically
show_memory_usage()
out = self.live_portrait_wrapper.warp_decode(f_s, x_s, x_d_i_new)
I_p_i = self.live_portrait_wrapper.parse_output(out['out'])[0]
I_p_lst.append(I_p_i)
log(f'Generated {len(I_p_lst)} frames ')
if inference_cfg.flag_pasteback:
I_p_i_to_ori_blend = paste_back(I_p_i, crop_info['M_c2o'], img_rgb, mask_ori)
I_p_paste_lst.append(I_p_i_to_ori_blend)
# Clear memory after processing the batch
torch.cuda.empty_cache()
#del I_d_lst, x_d_i_new, x_d_i_info, out, I_p_i # Clear batch-related variables
gc.collect() # Force garbage collection
# Check memory usage periodically
show_memory_usage()
mkdir(args.output_dir)
wfp_concat = None
flag_has_audio = has_audio_stream(args.driving_info)
if is_video(args.driving_info):
frames_concatenated = concat_frames(I_p_lst, driving_rgb_lst, img_crop_256x256)
wfp_concat = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_concat.mp4')
images2video(frames_concatenated, wfp=wfp_concat, fps=output_fps)
if flag_has_audio:
wfp_concat_with_audio = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_concat_with_audio.mp4')
add_audio_to_video(wfp_concat, args.driving_info, wfp_concat_with_audio)
os.replace(wfp_concat_with_audio, wfp_concat)
log(f"Replace {wfp_concat} with {wfp_concat_with_audio}")
wfp = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}.mp4')
if inference_cfg.flag_pasteback:
images2video(I_p_paste_lst, wfp=wfp, fps=output_fps)
else:
images2video(I_p_lst, wfp=wfp, fps=output_fps)
if flag_has_audio:
wfp_with_audio = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_with_audio.mp4')
add_audio_to_video(wfp, args.driving_info, wfp_with_audio)
os.replace(wfp_with_audio, wfp)
log(f"Replace {wfp} with {wfp_with_audio}")
return wfp, wfp_concat