jhj0517
commited on
Commit
·
4daf2ff
1
Parent(s):
a83a3d9
Add inferencer class
Browse files
modules/live_portrait/live_portrait_inferencer.py
ADDED
@@ -0,0 +1,763 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import cv2
|
6 |
+
import time
|
7 |
+
import copy
|
8 |
+
import dill
|
9 |
+
from ultralytics import YOLO
|
10 |
+
import safetensors.torch
|
11 |
+
import gradio as gr
|
12 |
+
|
13 |
+
from modules.utils.paths import *
|
14 |
+
from modules.utils.image_helper import *
|
15 |
+
from modules.live_portrait.model_downloader import *
|
16 |
+
from modules.live_portrait_wrapper import LivePortraitWrapper
|
17 |
+
from modules.utils.camera import get_rotation_matrix
|
18 |
+
from modules.utils.helper import load_yaml
|
19 |
+
from modules.config.inference_config import InferenceConfig
|
20 |
+
from modules.live_portrait.spade_generator import SPADEDecoder
|
21 |
+
from modules.live_portrait.warping_network import WarpingNetwork
|
22 |
+
from modules.live_portrait.motion_extractor import MotionExtractor
|
23 |
+
from modules.live_portrait.appearance_feature_extractor import AppearanceFeatureExtractor
|
24 |
+
from modules.live_portrait.stitching_retargeting_network import StitchingRetargetingNetwork
|
25 |
+
from collections import OrderedDict
|
26 |
+
|
27 |
+
|
28 |
+
class LivePortraitInferencer:
|
29 |
+
def __init__(self,
|
30 |
+
model_dir: str = MODELS_DIR,
|
31 |
+
output_dir: str = OUTPUTS_DIR):
|
32 |
+
self.model_dir = model_dir
|
33 |
+
self.output_dir = output_dir
|
34 |
+
self.model_config = load_yaml(MODEL_CONFIG)["model_params"]
|
35 |
+
|
36 |
+
self.appearance_feature_extractor = None
|
37 |
+
self.motion_extractor = None
|
38 |
+
self.warping_module = None
|
39 |
+
self.spade_generator = None
|
40 |
+
self.stitching_retargeting_module = None
|
41 |
+
self.pipeline = None
|
42 |
+
self.detect_model = None
|
43 |
+
self.device = self.get_device()
|
44 |
+
|
45 |
+
self.mask_img = None
|
46 |
+
self.temp_img_idx = 0
|
47 |
+
self.src_image = None
|
48 |
+
self.src_image_list = None
|
49 |
+
self.sample_image = None
|
50 |
+
self.driving_images = None
|
51 |
+
self.driving_values = None
|
52 |
+
self.crop_factor = None
|
53 |
+
self.psi = None
|
54 |
+
self.psi_list = None
|
55 |
+
self.d_info = None
|
56 |
+
|
57 |
+
def load_models(self):
|
58 |
+
self.download_if_no_models()
|
59 |
+
|
60 |
+
appearance_feat_config = self.model_config["appearance_feature_extractor_params"]
|
61 |
+
self.appearance_feature_extractor = AppearanceFeatureExtractor(**appearance_feat_config).to(self.device)
|
62 |
+
self.appearance_feature_extractor = self.load_safe_tensor(
|
63 |
+
self.appearance_feature_extractor,
|
64 |
+
os.path.join(self.model_dir, "appearance_feature_extractor.safetensors")
|
65 |
+
)
|
66 |
+
|
67 |
+
motion_ext_config = self.model_config["motion_extractor_params"]
|
68 |
+
self.motion_extractor = MotionExtractor(**motion_ext_config).to(self.device)
|
69 |
+
self.motion_extractor = self.load_safe_tensor(
|
70 |
+
self.motion_extractor,
|
71 |
+
os.path.join(self.model_dir, "motion_extractor.safetensors")
|
72 |
+
)
|
73 |
+
|
74 |
+
warping_module_config = self.model_config["warping_module_params"]
|
75 |
+
self.warping_module = WarpingNetwork(**warping_module_config).to(self.device)
|
76 |
+
self.warping_module = self.load_safe_tensor(
|
77 |
+
self.warping_module,
|
78 |
+
os.path.join(self.model_dir, "warping_module.safetensors")
|
79 |
+
)
|
80 |
+
|
81 |
+
spaded_decoder_config = self.model_config["spade_generator_params"]
|
82 |
+
self.spade_generator = SPADEDecoder(**spaded_decoder_config).to(self.device)
|
83 |
+
self.spade_generator = self.load_safe_tensor(
|
84 |
+
self.spade_generator,
|
85 |
+
os.path.join(self.model_dir, "spade_generator.safetensors")
|
86 |
+
)
|
87 |
+
|
88 |
+
def filter_stitcher(checkpoint, prefix):
|
89 |
+
filtered_checkpoint = {key.replace(prefix + "_module.", ""): value for key, value in checkpoint.items() if
|
90 |
+
key.startswith(prefix)}
|
91 |
+
return filtered_checkpoint
|
92 |
+
|
93 |
+
stitcher_config = self.model_config["stitching_retargeting_module_params"]
|
94 |
+
self.stitching_retargeting_module = StitchingRetargetingNetwork(**stitcher_config.get('stitching'))
|
95 |
+
stitcher_model_path = os.path.join(self.model_dir, "stitching_retargeting_module.safetensors")
|
96 |
+
ckpt = safetensors.torch.load_file(stitcher_model_path)
|
97 |
+
self.stitching_retargeting_module.load_state_dict(filter_stitcher(ckpt, 'retarget_shoulder'))
|
98 |
+
self.stitching_retargeting_module.to(self.device).eval()
|
99 |
+
self.stitching_retargeting_module = {"stitching": self.stitching_retargeting_module}
|
100 |
+
|
101 |
+
if self.pipeline is None:
|
102 |
+
self.pipeline = LivePortraitWrapper(
|
103 |
+
InferenceConfig(),
|
104 |
+
self.appearance_feature_extractor,
|
105 |
+
self.motion_extractor,
|
106 |
+
self.warping_module,
|
107 |
+
self.spade_generator,
|
108 |
+
self.stitching_retargeting_module
|
109 |
+
)
|
110 |
+
|
111 |
+
self.detect_model = YOLO(MODEL_PATHS["face_yolov8n"])
|
112 |
+
|
113 |
+
def edit_expression(self,
|
114 |
+
rotate_pitch=0,
|
115 |
+
rotate_yaw=0,
|
116 |
+
rotate_roll=0,
|
117 |
+
blink=0,
|
118 |
+
eyebrow=0,
|
119 |
+
wink=0,
|
120 |
+
pupil_x=0,
|
121 |
+
pupil_y=0,
|
122 |
+
aaa=0,
|
123 |
+
eee=0,
|
124 |
+
woo=0,
|
125 |
+
smile=0,
|
126 |
+
src_ratio=1,
|
127 |
+
sample_ratio=1,
|
128 |
+
sample_parts="All",
|
129 |
+
crop_factor=1.5,
|
130 |
+
src_image=None,
|
131 |
+
sample_image=None,
|
132 |
+
motion_link=None,
|
133 |
+
add_exp=None):
|
134 |
+
if self.pipeline is None:
|
135 |
+
self.load_models()
|
136 |
+
|
137 |
+
try:
|
138 |
+
rotate_yaw = -rotate_yaw
|
139 |
+
|
140 |
+
new_editor_link = None
|
141 |
+
if motion_link is not None:
|
142 |
+
self.psi = motion_link[0]
|
143 |
+
new_editor_link = motion_link.copy()
|
144 |
+
elif src_image is not None:
|
145 |
+
if id(src_image) != id(self.src_image) or self.crop_factor != crop_factor:
|
146 |
+
self.crop_factor = crop_factor
|
147 |
+
self.psi = self.prepare_source(src_image, crop_factor)
|
148 |
+
self.src_image = src_image
|
149 |
+
new_editor_link = []
|
150 |
+
new_editor_link.append(self.psi)
|
151 |
+
else:
|
152 |
+
return None, None
|
153 |
+
|
154 |
+
psi = self.psi
|
155 |
+
s_info = psi.x_s_info
|
156 |
+
#delta_new = copy.deepcopy()
|
157 |
+
s_exp = s_info['exp'] * src_ratio
|
158 |
+
s_exp[0, 5] = s_info['exp'][0, 5]
|
159 |
+
s_exp += s_info['kp']
|
160 |
+
|
161 |
+
es = ExpressionSet()
|
162 |
+
|
163 |
+
if sample_image is not None:
|
164 |
+
if id(self.sample_image) != id(sample_image):
|
165 |
+
self.sample_image = sample_image
|
166 |
+
d_image_np = (sample_image * 255).byte().numpy()
|
167 |
+
d_face = self.crop_face(d_image_np[0], 1.7)
|
168 |
+
i_d = self.prepare_src_image(d_face)
|
169 |
+
self.d_info = self.pipeline.get_kp_info(i_d)
|
170 |
+
self.d_info['exp'][0, 5, 0] = 0
|
171 |
+
self.d_info['exp'][0, 5, 1] = 0
|
172 |
+
|
173 |
+
# "OnlyExpression", "OnlyRotation", "OnlyMouth", "OnlyEyes", "All"
|
174 |
+
if sample_parts == "OnlyExpression" or sample_parts == "All":
|
175 |
+
es.e += self.d_info['exp'] * sample_ratio
|
176 |
+
if sample_parts == "OnlyRotation" or sample_parts == "All":
|
177 |
+
rotate_pitch += self.d_info['pitch'] * sample_ratio
|
178 |
+
rotate_yaw += self.d_info['yaw'] * sample_ratio
|
179 |
+
rotate_roll += self.d_info['roll'] * sample_ratio
|
180 |
+
elif sample_parts == "OnlyMouth":
|
181 |
+
self.retargeting(es.e, self.d_info['exp'], sample_ratio, (14, 17, 19, 20))
|
182 |
+
elif sample_parts == "OnlyEyes":
|
183 |
+
self.retargeting(es.e, self.d_info['exp'], sample_ratio, (1, 2, 11, 13, 15, 16))
|
184 |
+
|
185 |
+
es.r = self.calc_fe(es.e, blink, eyebrow, wink, pupil_x, pupil_y, aaa, eee, woo, smile,
|
186 |
+
rotate_pitch, rotate_yaw, rotate_roll)
|
187 |
+
|
188 |
+
if add_exp is not None:
|
189 |
+
es.add(add_exp)
|
190 |
+
|
191 |
+
new_rotate = get_rotation_matrix(s_info['pitch'] + es.r[0], s_info['yaw'] + es.r[1],
|
192 |
+
s_info['roll'] + es.r[2])
|
193 |
+
x_d_new = (s_info['scale'] * (1 + es.s)) * ((s_exp + es.e) @ new_rotate) + s_info['t']
|
194 |
+
|
195 |
+
x_d_new = self.pipeline.stitching(psi.x_s_user, x_d_new)
|
196 |
+
|
197 |
+
crop_out = self.pipeline.warp_decode(psi.f_s_user, psi.x_s_user, x_d_new)
|
198 |
+
crop_out = self.pipeline.parse_output(crop_out['out'])[0]
|
199 |
+
|
200 |
+
crop_with_fullsize = cv2.warpAffine(crop_out, psi.crop_trans_m, get_rgb_size(psi.src_rgb), cv2.INTER_LINEAR)
|
201 |
+
out = np.clip(psi.mask_ori * crop_with_fullsize + (1 - psi.mask_ori) * psi.src_rgb, 0, 255).astype(np.uint8)
|
202 |
+
|
203 |
+
out_img = pil2tensor(out)
|
204 |
+
out_img_path = get_auto_incremental_file_path(TEMP_DIR, "png")
|
205 |
+
|
206 |
+
img = Image.fromarray(crop_out)
|
207 |
+
img.save(out_img_path, compress_level=1)
|
208 |
+
new_editor_link.append(es)
|
209 |
+
|
210 |
+
return out_img # {"ui": {"images": results}, "result": (out_img, new_editor_link, es)}
|
211 |
+
except Exception as e:
|
212 |
+
raise
|
213 |
+
|
214 |
+
def create_video(self,
|
215 |
+
retargeting_eyes,
|
216 |
+
retargeting_mouth,
|
217 |
+
turn_on,
|
218 |
+
tracking_src_vid,
|
219 |
+
animate_without_vid,
|
220 |
+
command,
|
221 |
+
crop_factor,
|
222 |
+
src_images=None,
|
223 |
+
driving_images=None,
|
224 |
+
motion_link=None,
|
225 |
+
progress=gr.Progress()):
|
226 |
+
if not turn_on:
|
227 |
+
return None, None
|
228 |
+
src_length = 1
|
229 |
+
|
230 |
+
if src_images is None:
|
231 |
+
if motion_link is not None:
|
232 |
+
self.psi_list = [motion_link[0]]
|
233 |
+
else:
|
234 |
+
return None, None
|
235 |
+
|
236 |
+
if src_images is not None:
|
237 |
+
src_length = len(src_images)
|
238 |
+
if id(src_images) != id(self.src_images) or self.crop_factor != crop_factor:
|
239 |
+
self.crop_factor = crop_factor
|
240 |
+
self.src_images = src_images
|
241 |
+
if 1 < src_length:
|
242 |
+
self.psi_list = self.prepare_source(src_images, crop_factor, True, tracking_src_vid)
|
243 |
+
else:
|
244 |
+
self.psi_list = [self.prepare_source(src_images, crop_factor)]
|
245 |
+
|
246 |
+
cmd_list, cmd_length = self.parsing_command(command, motion_link)
|
247 |
+
if cmd_list is None:
|
248 |
+
return None,None
|
249 |
+
cmd_idx = 0
|
250 |
+
|
251 |
+
driving_length = 0
|
252 |
+
if driving_images is not None:
|
253 |
+
if id(driving_images) != id(self.driving_images):
|
254 |
+
self.driving_images = driving_images
|
255 |
+
self.driving_values = self.prepare_driving_video(driving_images)
|
256 |
+
driving_length = len(self.driving_values)
|
257 |
+
|
258 |
+
total_length = max(driving_length, src_length)
|
259 |
+
|
260 |
+
if animate_without_vid:
|
261 |
+
total_length = max(total_length, cmd_length)
|
262 |
+
|
263 |
+
c_i_es = ExpressionSet()
|
264 |
+
c_o_es = ExpressionSet()
|
265 |
+
d_0_es = None
|
266 |
+
out_list = []
|
267 |
+
|
268 |
+
psi = None
|
269 |
+
for i in range(total_length):
|
270 |
+
|
271 |
+
if i < src_length:
|
272 |
+
psi = self.psi_list[i]
|
273 |
+
s_info = psi.x_s_info
|
274 |
+
s_es = ExpressionSet(erst=(s_info['kp'] + s_info['exp'], torch.Tensor([0, 0, 0]), s_info['scale'], s_info['t']))
|
275 |
+
|
276 |
+
new_es = ExpressionSet(es=s_es)
|
277 |
+
|
278 |
+
if i < cmd_length:
|
279 |
+
cmd = cmd_list[cmd_idx]
|
280 |
+
if 0 < cmd.change:
|
281 |
+
cmd.change -= 1
|
282 |
+
c_i_es.add(cmd.es)
|
283 |
+
c_i_es.sub(c_o_es)
|
284 |
+
elif 0 < cmd.keep:
|
285 |
+
cmd.keep -= 1
|
286 |
+
|
287 |
+
new_es.add(c_i_es)
|
288 |
+
|
289 |
+
if cmd.change == 0 and cmd.keep == 0:
|
290 |
+
cmd_idx += 1
|
291 |
+
if cmd_idx < len(cmd_list):
|
292 |
+
c_o_es = ExpressionSet(es=c_i_es)
|
293 |
+
cmd = cmd_list[cmd_idx]
|
294 |
+
c_o_es.div(cmd.change)
|
295 |
+
elif 0 < cmd_length:
|
296 |
+
new_es.add(c_i_es)
|
297 |
+
|
298 |
+
if i < driving_length:
|
299 |
+
d_i_info = self.driving_values[i]
|
300 |
+
d_i_r = torch.Tensor([d_i_info['pitch'], d_i_info['yaw'], d_i_info['roll']])#.float().to(device="cuda:0")
|
301 |
+
|
302 |
+
if d_0_es is None:
|
303 |
+
d_0_es = ExpressionSet(erst = (d_i_info['exp'], d_i_r, d_i_info['scale'], d_i_info['t']))
|
304 |
+
|
305 |
+
self.retargeting(s_es.e, d_0_es.e, retargeting_eyes, (11, 13, 15, 16))
|
306 |
+
self.retargeting(s_es.e, d_0_es.e, retargeting_mouth, (14, 17, 19, 20))
|
307 |
+
|
308 |
+
new_es.e += d_i_info['exp'] - d_0_es.e
|
309 |
+
new_es.r += d_i_r - d_0_es.r
|
310 |
+
new_es.t += d_i_info['t'] - d_0_es.t
|
311 |
+
|
312 |
+
r_new = get_rotation_matrix(
|
313 |
+
s_info['pitch'] + new_es.r[0], s_info['yaw'] + new_es.r[1], s_info['roll'] + new_es.r[2])
|
314 |
+
d_new = new_es.s * (new_es.e @ r_new) + new_es.t
|
315 |
+
d_new = self.pipeline.stitching(psi.x_s_user, d_new)
|
316 |
+
crop_out = self.pipeline.warp_decode(psi.f_s_user, psi.x_s_user, d_new)
|
317 |
+
crop_out = self.pipeline.parse_output(crop_out['out'])[0]
|
318 |
+
|
319 |
+
crop_with_fullsize = cv2.warpAffine(crop_out, psi.crop_trans_m, get_rgb_size(psi.src_rgb),
|
320 |
+
cv2.INTER_LINEAR)
|
321 |
+
out = np.clip(psi.mask_ori * crop_with_fullsize + (1 - psi.mask_ori) * psi.src_rgb, 0, 255).astype(
|
322 |
+
np.uint8)
|
323 |
+
out_list.append(out)
|
324 |
+
|
325 |
+
progress(i/total_length, "predicting..")
|
326 |
+
|
327 |
+
if len(out_list) == 0:
|
328 |
+
return None
|
329 |
+
|
330 |
+
out_imgs = torch.cat([pil2tensor(img_rgb) for img_rgb in out_list])
|
331 |
+
return out_imgs
|
332 |
+
|
333 |
+
def download_if_no_models(self):
|
334 |
+
for model_name, model_url in MODELS_URL.items():
|
335 |
+
if model_url.endswith(".pt"):
|
336 |
+
model_name += ".pt"
|
337 |
+
else:
|
338 |
+
model_name += ".safetensors"
|
339 |
+
model_path = os.path.join(self.model_dir, model_name)
|
340 |
+
if not os.path.exists(model_path):
|
341 |
+
download_model(model_path, model_url)
|
342 |
+
|
343 |
+
@staticmethod
|
344 |
+
def load_safe_tensor(model, file_path):
|
345 |
+
model.load_state_dict(safetensors.torch.load_file(file_path))
|
346 |
+
model.eval()
|
347 |
+
return model
|
348 |
+
|
349 |
+
@staticmethod
|
350 |
+
def get_device():
|
351 |
+
if torch.cuda.is_available():
|
352 |
+
return "cuda"
|
353 |
+
elif torch.backends.mps.is_available():
|
354 |
+
return "mps"
|
355 |
+
else:
|
356 |
+
return "cpu"
|
357 |
+
|
358 |
+
def get_temp_img_name(self):
|
359 |
+
self.temp_img_idx += 1
|
360 |
+
return "expression_edit_preview" + str(self.temp_img_idx) + ".png"
|
361 |
+
|
362 |
+
@staticmethod
|
363 |
+
def parsing_command(command, motoin_link):
|
364 |
+
command.replace(' ', '')
|
365 |
+
lines = command.split('\n')
|
366 |
+
|
367 |
+
cmd_list = []
|
368 |
+
|
369 |
+
total_length = 0
|
370 |
+
|
371 |
+
i = 0
|
372 |
+
for line in lines:
|
373 |
+
i += 1
|
374 |
+
if not line:
|
375 |
+
continue
|
376 |
+
try:
|
377 |
+
cmds = line.split('=')
|
378 |
+
idx = int(cmds[0])
|
379 |
+
if idx == 0: es = ExpressionSet()
|
380 |
+
else: es = ExpressionSet(es = motoin_link[idx])
|
381 |
+
cmds = cmds[1].split(':')
|
382 |
+
change = int(cmds[0])
|
383 |
+
keep = int(cmds[1])
|
384 |
+
except Exception as e:
|
385 |
+
print(f"(AdvancedLivePortrait) Command Err Line {i}: {line}, :{e}")
|
386 |
+
return None, None
|
387 |
+
|
388 |
+
total_length += change + keep
|
389 |
+
es.div(change)
|
390 |
+
cmd_list.append(Command(es, change, keep))
|
391 |
+
|
392 |
+
return cmd_list, total_length
|
393 |
+
|
394 |
+
def get_face_bboxes(self, image_rgb):
|
395 |
+
pred = self.detect_model(image_rgb, conf=0.7, device="")
|
396 |
+
return pred[0].boxes.xyxy.cpu().numpy()
|
397 |
+
|
398 |
+
def detect_face(self, image_rgb, crop_factor, sort = True):
|
399 |
+
bboxes = self.get_face_bboxes(image_rgb)
|
400 |
+
w, h = get_rgb_size(image_rgb)
|
401 |
+
|
402 |
+
print(f"w, h:{w, h}")
|
403 |
+
|
404 |
+
cx = w / 2
|
405 |
+
min_diff = w
|
406 |
+
best_box = None
|
407 |
+
for x1, y1, x2, y2 in bboxes:
|
408 |
+
bbox_w = x2 - x1
|
409 |
+
if bbox_w < 30: continue
|
410 |
+
diff = abs(cx - (x1 + bbox_w / 2))
|
411 |
+
if diff < min_diff:
|
412 |
+
best_box = [x1, y1, x2, y2]
|
413 |
+
print(f"diff, min_diff, best_box:{diff, min_diff, best_box}")
|
414 |
+
min_diff = diff
|
415 |
+
|
416 |
+
if best_box == None:
|
417 |
+
print("Failed to detect face!!")
|
418 |
+
return [0, 0, w, h]
|
419 |
+
|
420 |
+
x1, y1, x2, y2 = best_box
|
421 |
+
|
422 |
+
#for x1, y1, x2, y2 in bboxes:
|
423 |
+
bbox_w = x2 - x1
|
424 |
+
bbox_h = y2 - y1
|
425 |
+
|
426 |
+
crop_w = bbox_w * crop_factor
|
427 |
+
crop_h = bbox_h * crop_factor
|
428 |
+
|
429 |
+
crop_w = max(crop_h, crop_w)
|
430 |
+
crop_h = crop_w
|
431 |
+
|
432 |
+
kernel_x = int(x1 + bbox_w / 2)
|
433 |
+
kernel_y = int(y1 + bbox_h / 2)
|
434 |
+
|
435 |
+
new_x1 = int(kernel_x - crop_w / 2)
|
436 |
+
new_x2 = int(kernel_x + crop_w / 2)
|
437 |
+
new_y1 = int(kernel_y - crop_h / 2)
|
438 |
+
new_y2 = int(kernel_y + crop_h / 2)
|
439 |
+
|
440 |
+
if not sort:
|
441 |
+
return [int(new_x1), int(new_y1), int(new_x2), int(new_y2)]
|
442 |
+
|
443 |
+
if new_x1 < 0:
|
444 |
+
new_x2 -= new_x1
|
445 |
+
new_x1 = 0
|
446 |
+
elif w < new_x2:
|
447 |
+
new_x1 -= (new_x2 - w)
|
448 |
+
new_x2 = w
|
449 |
+
if new_x1 < 0:
|
450 |
+
new_x2 -= new_x1
|
451 |
+
new_x1 = 0
|
452 |
+
|
453 |
+
if new_y1 < 0:
|
454 |
+
new_y2 -= new_y1
|
455 |
+
new_y1 = 0
|
456 |
+
elif h < new_y2:
|
457 |
+
new_y1 -= (new_y2 - h)
|
458 |
+
new_y2 = h
|
459 |
+
if new_y1 < 0:
|
460 |
+
new_y2 -= new_y1
|
461 |
+
new_y1 = 0
|
462 |
+
|
463 |
+
if w < new_x2 and h < new_y2:
|
464 |
+
over_x = new_x2 - w
|
465 |
+
over_y = new_y2 - h
|
466 |
+
over_min = min(over_x, over_y)
|
467 |
+
new_x2 -= over_min
|
468 |
+
new_y2 -= over_min
|
469 |
+
|
470 |
+
return [int(new_x1), int(new_y1), int(new_x2), int(new_y2)]
|
471 |
+
|
472 |
+
@staticmethod
|
473 |
+
def retargeting(delta_out, driving_exp, factor, idxes):
|
474 |
+
for idx in idxes:
|
475 |
+
# delta_out[0, idx] -= src_exp[0, idx] * factor
|
476 |
+
delta_out[0, idx] += driving_exp[0, idx] * factor
|
477 |
+
|
478 |
+
@staticmethod
|
479 |
+
def calc_face_region(square, dsize):
|
480 |
+
region = copy.deepcopy(square)
|
481 |
+
is_changed = False
|
482 |
+
if dsize[0] < region[2]:
|
483 |
+
region[2] = dsize[0]
|
484 |
+
is_changed = True
|
485 |
+
if dsize[1] < region[3]:
|
486 |
+
region[3] = dsize[1]
|
487 |
+
is_changed = True
|
488 |
+
|
489 |
+
return region, is_changed
|
490 |
+
|
491 |
+
@staticmethod
|
492 |
+
def expand_img(rgb_img, square):
|
493 |
+
crop_trans_m = create_transform_matrix(max(-square[0], 0), max(-square[1], 0), 1, 1)
|
494 |
+
new_img = cv2.warpAffine(rgb_img, crop_trans_m, (square[2] - square[0], square[3] - square[1]),
|
495 |
+
cv2.INTER_LINEAR)
|
496 |
+
return new_img
|
497 |
+
|
498 |
+
def prepare_src_image(self, img):
|
499 |
+
h, w = img.shape[:2]
|
500 |
+
input_shape = [256,256]
|
501 |
+
if h != input_shape[0] or w != input_shape[1]:
|
502 |
+
if 256 < h: interpolation = cv2.INTER_AREA
|
503 |
+
else: interpolation = cv2.INTER_LINEAR
|
504 |
+
x = cv2.resize(img, (input_shape[0], input_shape[1]), interpolation = interpolation)
|
505 |
+
else:
|
506 |
+
x = img.copy()
|
507 |
+
|
508 |
+
if x.ndim == 3:
|
509 |
+
x = x[np.newaxis].astype(np.float32) / 255. # HxWx3 -> 1xHxWx3, normalized to 0~1
|
510 |
+
elif x.ndim == 4:
|
511 |
+
x = x.astype(np.float32) / 255. # BxHxWx3, normalized to 0~1
|
512 |
+
else:
|
513 |
+
raise ValueError(f'img ndim should be 3 or 4: {x.ndim}')
|
514 |
+
x = np.clip(x, 0, 1) # clip to 0~1
|
515 |
+
x = torch.from_numpy(x).permute(0, 3, 1, 2) # 1xHxWx3 -> 1x3xHxW
|
516 |
+
x = x.to(self.device)
|
517 |
+
return x
|
518 |
+
|
519 |
+
def get_mask_img(self):
|
520 |
+
if self.mask_img is None:
|
521 |
+
self.mask_img = cv2.imread(MASK_TEMPLATES, cv2.IMREAD_COLOR)
|
522 |
+
return self.mask_img
|
523 |
+
|
524 |
+
def crop_face(self, img_rgb, crop_factor):
|
525 |
+
crop_region = self.detect_face(img_rgb, crop_factor)
|
526 |
+
face_region, is_changed = self.calc_face_region(crop_region, get_rgb_size(img_rgb))
|
527 |
+
face_img = rgb_crop(img_rgb, face_region)
|
528 |
+
if is_changed: face_img = self.expand_img(face_img, crop_region)
|
529 |
+
return face_img
|
530 |
+
|
531 |
+
def prepare_source(self, source_image, crop_factor, is_video=False, tracking=False):
|
532 |
+
print("Prepare source...")
|
533 |
+
#source_image_np = (source_image * 255).byte().numpy()
|
534 |
+
# img_rgb = source_image_np[0]
|
535 |
+
|
536 |
+
psi_list = []
|
537 |
+
for img_rgb in source_image:
|
538 |
+
if tracking or len(psi_list) == 0:
|
539 |
+
crop_region = self.detect_face(img_rgb, crop_factor)
|
540 |
+
face_region, is_changed = self.calc_face_region(crop_region, get_rgb_size(img_rgb))
|
541 |
+
|
542 |
+
s_x = (face_region[2] - face_region[0]) / 512.
|
543 |
+
s_y = (face_region[3] - face_region[1]) / 512.
|
544 |
+
crop_trans_m = create_transform_matrix(crop_region[0], crop_region[1], s_x, s_y)
|
545 |
+
mask_ori = cv2.warpAffine(self.get_mask_img(), crop_trans_m, get_rgb_size(img_rgb), cv2.INTER_LINEAR)
|
546 |
+
mask_ori = mask_ori.astype(np.float32) / 255.
|
547 |
+
|
548 |
+
if is_changed:
|
549 |
+
s = (crop_region[2] - crop_region[0]) / 512.
|
550 |
+
crop_trans_m = create_transform_matrix(crop_region[0], crop_region[1], s, s)
|
551 |
+
|
552 |
+
face_img = rgb_crop(img_rgb, face_region)
|
553 |
+
if is_changed: face_img = self.expand_img(face_img, crop_region)
|
554 |
+
i_s = self.prepare_src_image(face_img)
|
555 |
+
x_s_info = self.pipeline.get_kp_info(i_s)
|
556 |
+
f_s_user = self.pipeline.extract_feature_3d(i_s)
|
557 |
+
x_s_user = self.pipeline.transform_keypoint(x_s_info)
|
558 |
+
psi = PreparedSrcImg(img_rgb, crop_trans_m, x_s_info, f_s_user, x_s_user, mask_ori)
|
559 |
+
if is_video == False:
|
560 |
+
return psi
|
561 |
+
psi_list.append(psi)
|
562 |
+
|
563 |
+
return psi_list
|
564 |
+
|
565 |
+
def prepare_driving_video(self, face_images):
|
566 |
+
print("Prepare driving video...")
|
567 |
+
f_img_np = (face_images * 255).byte().numpy()
|
568 |
+
|
569 |
+
out_list = []
|
570 |
+
for f_img in f_img_np:
|
571 |
+
i_d = self.prepare_src_image(f_img)
|
572 |
+
d_info = self.pipeline.get_kp_info(i_d)
|
573 |
+
out_list.append(d_info)
|
574 |
+
|
575 |
+
return out_list
|
576 |
+
|
577 |
+
@staticmethod
|
578 |
+
def calc_fe(x_d_new, eyes, eyebrow, wink, pupil_x, pupil_y, mouth, eee, woo, smile,
|
579 |
+
rotate_pitch, rotate_yaw, rotate_roll):
|
580 |
+
|
581 |
+
x_d_new[0, 20, 1] += smile * -0.01
|
582 |
+
x_d_new[0, 14, 1] += smile * -0.02
|
583 |
+
x_d_new[0, 17, 1] += smile * 0.0065
|
584 |
+
x_d_new[0, 17, 2] += smile * 0.003
|
585 |
+
x_d_new[0, 13, 1] += smile * -0.00275
|
586 |
+
x_d_new[0, 16, 1] += smile * -0.00275
|
587 |
+
x_d_new[0, 3, 1] += smile * -0.0035
|
588 |
+
x_d_new[0, 7, 1] += smile * -0.0035
|
589 |
+
|
590 |
+
x_d_new[0, 19, 1] += mouth * 0.001
|
591 |
+
x_d_new[0, 19, 2] += mouth * 0.0001
|
592 |
+
x_d_new[0, 17, 1] += mouth * -0.0001
|
593 |
+
rotate_pitch -= mouth * 0.05
|
594 |
+
|
595 |
+
x_d_new[0, 20, 2] += eee * -0.001
|
596 |
+
x_d_new[0, 20, 1] += eee * -0.001
|
597 |
+
#x_d_new[0, 19, 1] += eee * 0.0006
|
598 |
+
x_d_new[0, 14, 1] += eee * -0.001
|
599 |
+
|
600 |
+
x_d_new[0, 14, 1] += woo * 0.001
|
601 |
+
x_d_new[0, 3, 1] += woo * -0.0005
|
602 |
+
x_d_new[0, 7, 1] += woo * -0.0005
|
603 |
+
x_d_new[0, 17, 2] += woo * -0.0005
|
604 |
+
|
605 |
+
x_d_new[0, 11, 1] += wink * 0.001
|
606 |
+
x_d_new[0, 13, 1] += wink * -0.0003
|
607 |
+
x_d_new[0, 17, 0] += wink * 0.0003
|
608 |
+
x_d_new[0, 17, 1] += wink * 0.0003
|
609 |
+
x_d_new[0, 3, 1] += wink * -0.0003
|
610 |
+
rotate_roll -= wink * 0.1
|
611 |
+
rotate_yaw -= wink * 0.1
|
612 |
+
|
613 |
+
if 0 < pupil_x:
|
614 |
+
x_d_new[0, 11, 0] += pupil_x * 0.0007
|
615 |
+
x_d_new[0, 15, 0] += pupil_x * 0.001
|
616 |
+
else:
|
617 |
+
x_d_new[0, 11, 0] += pupil_x * 0.001
|
618 |
+
x_d_new[0, 15, 0] += pupil_x * 0.0007
|
619 |
+
|
620 |
+
x_d_new[0, 11, 1] += pupil_y * -0.001
|
621 |
+
x_d_new[0, 15, 1] += pupil_y * -0.001
|
622 |
+
eyes -= pupil_y / 2.
|
623 |
+
|
624 |
+
x_d_new[0, 11, 1] += eyes * -0.001
|
625 |
+
x_d_new[0, 13, 1] += eyes * 0.0003
|
626 |
+
x_d_new[0, 15, 1] += eyes * -0.001
|
627 |
+
x_d_new[0, 16, 1] += eyes * 0.0003
|
628 |
+
x_d_new[0, 1, 1] += eyes * -0.00025
|
629 |
+
x_d_new[0, 2, 1] += eyes * 0.00025
|
630 |
+
|
631 |
+
if 0 < eyebrow:
|
632 |
+
x_d_new[0, 1, 1] += eyebrow * 0.001
|
633 |
+
x_d_new[0, 2, 1] += eyebrow * -0.001
|
634 |
+
else:
|
635 |
+
x_d_new[0, 1, 0] += eyebrow * -0.001
|
636 |
+
x_d_new[0, 2, 0] += eyebrow * 0.001
|
637 |
+
x_d_new[0, 1, 1] += eyebrow * 0.0003
|
638 |
+
x_d_new[0, 2, 1] += eyebrow * -0.0003
|
639 |
+
|
640 |
+
return torch.Tensor([rotate_pitch, rotate_yaw, rotate_roll])
|
641 |
+
|
642 |
+
|
643 |
+
class ExpressionSet:
|
644 |
+
def __init__(self, erst=None, es=None):
|
645 |
+
if es is not None:
|
646 |
+
self.e = copy.deepcopy(es.e) # [:, :, :]
|
647 |
+
self.r = copy.deepcopy(es.r) # [:]
|
648 |
+
self.s = copy.deepcopy(es.s)
|
649 |
+
self.t = copy.deepcopy(es.t)
|
650 |
+
elif erst is not None:
|
651 |
+
self.e = erst[0]
|
652 |
+
self.r = erst[1]
|
653 |
+
self.s = erst[2]
|
654 |
+
self.t = erst[3]
|
655 |
+
else:
|
656 |
+
self.e = torch.from_numpy(np.zeros((1, 21, 3))).float().to(self.get_device())
|
657 |
+
self.r = torch.Tensor([0, 0, 0])
|
658 |
+
self.s = 0
|
659 |
+
self.t = 0
|
660 |
+
|
661 |
+
def div(self, value):
|
662 |
+
self.e /= value
|
663 |
+
self.r /= value
|
664 |
+
self.s /= value
|
665 |
+
self.t /= value
|
666 |
+
|
667 |
+
def add(self, other):
|
668 |
+
self.e += other.e
|
669 |
+
self.r += other.r
|
670 |
+
self.s += other.s
|
671 |
+
self.t += other.t
|
672 |
+
|
673 |
+
def sub(self, other):
|
674 |
+
self.e -= other.e
|
675 |
+
self.r -= other.r
|
676 |
+
self.s -= other.s
|
677 |
+
self.t -= other.t
|
678 |
+
|
679 |
+
def mul(self, value):
|
680 |
+
self.e *= value
|
681 |
+
self.r *= value
|
682 |
+
self.s *= value
|
683 |
+
self.t *= value
|
684 |
+
|
685 |
+
@staticmethod
|
686 |
+
def get_device():
|
687 |
+
if torch.cuda.is_available():
|
688 |
+
return "cuda"
|
689 |
+
elif torch.backends.mps.is_available():
|
690 |
+
return "mps"
|
691 |
+
else:
|
692 |
+
return "cpu"
|
693 |
+
|
694 |
+
|
695 |
+
def logging_time(original_fn):
|
696 |
+
def wrapper_fn(*args, **kwargs):
|
697 |
+
start_time = time.time()
|
698 |
+
result = original_fn(*args, **kwargs)
|
699 |
+
end_time = time.time()
|
700 |
+
print("WorkingTime[{}]: {} sec".format(original_fn.__name__, end_time - start_time))
|
701 |
+
return result
|
702 |
+
|
703 |
+
return wrapper_fn
|
704 |
+
|
705 |
+
|
706 |
+
def save_exp_data(file_name: str, save_exp: ExpressionSet = None):
|
707 |
+
if save_exp is None or not file_name:
|
708 |
+
return file_name
|
709 |
+
|
710 |
+
with open(os.path.join(EXP_OUTPUT_DIR, file_name + ".exp"), "wb") as f:
|
711 |
+
dill.dump(save_exp, f)
|
712 |
+
|
713 |
+
return file_name
|
714 |
+
|
715 |
+
|
716 |
+
def load_exp_data(self, file_name, ratio):
|
717 |
+
file_list = [os.path.splitext(file)[0] for file in os.listdir(EXP_OUTPUT_DIR) if file.endswith('.exp')]
|
718 |
+
with open(os.path.join(EXP_OUTPUT_DIR, file_name + ".exp"), 'rb') as f:
|
719 |
+
es = dill.load(f)
|
720 |
+
es.mul(ratio)
|
721 |
+
return es
|
722 |
+
|
723 |
+
|
724 |
+
def handle_exp_data(code1, value1, code2, value2, code3, value3, code4, value4, code5, value5, add_exp=None):
|
725 |
+
if add_exp is None:
|
726 |
+
es = ExpressionSet()
|
727 |
+
else:
|
728 |
+
es = ExpressionSet(es=add_exp)
|
729 |
+
|
730 |
+
codes = [code1, code2, code3, code4, code5]
|
731 |
+
values = [value1, value2, value3, value4, value5]
|
732 |
+
for i in range(5):
|
733 |
+
idx = int(codes[i] / 10)
|
734 |
+
r = codes[i] % 10
|
735 |
+
es.e[0, idx, r] += values[i] * 0.001
|
736 |
+
|
737 |
+
return es
|
738 |
+
|
739 |
+
|
740 |
+
def print_exp_data(cut_noise, exp=None):
|
741 |
+
if exp is None:
|
742 |
+
return exp
|
743 |
+
|
744 |
+
cuted_list = []
|
745 |
+
e = exp.exp * 1000
|
746 |
+
for idx in range(21):
|
747 |
+
for r in range(3):
|
748 |
+
a = abs(e[0, idx, r])
|
749 |
+
if (cut_noise < a): cuted_list.append((a, e[0, idx, r], idx * 10 + r))
|
750 |
+
|
751 |
+
sorted_list = sorted(cuted_list, reverse=True, key=lambda item: item[0])
|
752 |
+
print(f"sorted_list: {[[item[2], round(float(item[1]), 1)] for item in sorted_list]}")
|
753 |
+
return exp
|
754 |
+
|
755 |
+
|
756 |
+
class Command:
|
757 |
+
def __init__(self,
|
758 |
+
es: ExpressionSet,
|
759 |
+
change,
|
760 |
+
keep):
|
761 |
+
self.es = es
|
762 |
+
self.change = change
|
763 |
+
self.keep = keep
|