Create helper_cpu.py
Browse files- src/utils/helper_cpu.py +173 -0
src/utils/helper_cpu.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
|
3 |
+
"""
|
4 |
+
Utility functions and classes to handle feature extraction and model loading
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os
|
8 |
+
import os.path as osp
|
9 |
+
import torch
|
10 |
+
from collections import OrderedDict
|
11 |
+
import psutil
|
12 |
+
from rich.console import Console
|
13 |
+
from rich.progress import Progress
|
14 |
+
from ..modules.spade_generator import SPADEDecoder
|
15 |
+
from ..modules.warping_network import WarpingNetwork
|
16 |
+
from ..modules.motion_extractor import MotionExtractor
|
17 |
+
from ..modules.appearance_feature_extractor import AppearanceFeatureExtractor
|
18 |
+
from ..modules.stitching_retargeting_network import StitchingRetargetingNetwork
|
19 |
+
|
20 |
+
from rich.console import Console
|
21 |
+
import psutil
|
22 |
+
|
23 |
+
console = Console()
|
24 |
+
|
25 |
+
def show_memory_usage():
|
26 |
+
"""
|
27 |
+
Display the current memory usage in the terminal using rich.
|
28 |
+
"""
|
29 |
+
mem_info = psutil.virtual_memory()
|
30 |
+
total_mem = mem_info.total / (1024 ** 3) # Convert to GB
|
31 |
+
used_mem = mem_info.used / (1024 ** 3) # Convert to GB
|
32 |
+
available_mem = mem_info.available / (1024 ** 3) # Convert to GB
|
33 |
+
|
34 |
+
console.log(f"[bold green]Memory Usage:[/bold green] [bold red]{used_mem:.2f} GB[/bold red] used of [bold blue]{total_mem:.2f} GB[/bold blue]")
|
35 |
+
console.log(f"[bold green]Available Memory:[/bold green] [bold yellow]{available_mem:.2f} GB[/bold yellow]")
|
36 |
+
|
37 |
+
|
38 |
+
def suffix(filename):
|
39 |
+
"""a.jpg -> jpg"""
|
40 |
+
pos = filename.rfind(".")
|
41 |
+
if pos == -1:
|
42 |
+
return ""
|
43 |
+
return filename[pos + 1:]
|
44 |
+
|
45 |
+
|
46 |
+
def prefix(filename):
|
47 |
+
"""a.jpg -> a"""
|
48 |
+
pos = filename.rfind(".")
|
49 |
+
if pos == -1:
|
50 |
+
return filename
|
51 |
+
return filename[:pos]
|
52 |
+
|
53 |
+
|
54 |
+
def basename(filename):
|
55 |
+
"""a/b/c.jpg -> c"""
|
56 |
+
return prefix(osp.basename(filename))
|
57 |
+
|
58 |
+
|
59 |
+
def is_video(file_path):
|
60 |
+
if file_path.lower().endswith((".mp4", ".mov", ".avi", ".webm")) or osp.isdir(file_path):
|
61 |
+
return True
|
62 |
+
return False
|
63 |
+
|
64 |
+
|
65 |
+
def is_template(file_path):
|
66 |
+
if file_path.endswith(".pkl"):
|
67 |
+
return True
|
68 |
+
return False
|
69 |
+
|
70 |
+
|
71 |
+
def mkdir(d, log=False):
|
72 |
+
# return self-assigned `d`, for one line code
|
73 |
+
if not osp.exists(d):
|
74 |
+
os.makedirs(d, exist_ok=True)
|
75 |
+
if log:
|
76 |
+
log(f"Make dir: {d}")
|
77 |
+
return d
|
78 |
+
|
79 |
+
|
80 |
+
def squeeze_tensor_to_numpy(tensor):
|
81 |
+
out = tensor.data.squeeze(0).cpu().numpy()
|
82 |
+
return out
|
83 |
+
|
84 |
+
|
85 |
+
def dct2cpu(dct: dict, device='cpu'):
|
86 |
+
for key in dct:
|
87 |
+
dct[key] = torch.tensor(dct[key]).to(device)
|
88 |
+
return dct
|
89 |
+
|
90 |
+
|
91 |
+
def concat_feat(kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
|
92 |
+
"""
|
93 |
+
kp_source: (bs, k, 3)
|
94 |
+
kp_driving: (bs, k, 3)
|
95 |
+
Return: (bs, 2k*3)
|
96 |
+
"""
|
97 |
+
bs_src = kp_source.shape[0]
|
98 |
+
bs_dri = kp_driving.shape[0]
|
99 |
+
assert bs_src == bs_dri, 'batch size must be equal'
|
100 |
+
|
101 |
+
feat = torch.cat([kp_source.view(bs_src, -1), kp_driving.view(bs_dri, -1)], dim=1)
|
102 |
+
return feat
|
103 |
+
|
104 |
+
|
105 |
+
def remove_ddp_duplicate_key(state_dict):
|
106 |
+
state_dict_new = OrderedDict()
|
107 |
+
for key in state_dict.keys():
|
108 |
+
state_dict_new[key.replace('module.', '')] = state_dict[key]
|
109 |
+
return state_dict_new
|
110 |
+
|
111 |
+
|
112 |
+
def load_model(ckpt_path, model_config, device, model_type):
|
113 |
+
model_params = model_config['model_params'][f'{model_type}_params']
|
114 |
+
|
115 |
+
if model_type == 'appearance_feature_extractor':
|
116 |
+
model = AppearanceFeatureExtractor(**model_params).to('cpu')
|
117 |
+
elif model_type == 'motion_extractor':
|
118 |
+
model = MotionExtractor(**model_params).to('cpu')
|
119 |
+
elif model_type == 'warping_module':
|
120 |
+
model = WarpingNetwork(**model_params).to('cpu')
|
121 |
+
elif model_type == 'spade_generator':
|
122 |
+
model = SPADEDecoder(**model_params).to('cpu')
|
123 |
+
elif model_type == 'stitching_retargeting_module':
|
124 |
+
# Special handling for stitching and retargeting module
|
125 |
+
config = model_config['model_params']['stitching_retargeting_module_params']
|
126 |
+
checkpoint = torch.load(ckpt_path, map_location='cpu')
|
127 |
+
|
128 |
+
stitcher = StitchingRetargetingNetwork(**config.get('stitching'))
|
129 |
+
stitcher.load_state_dict(remove_ddp_duplicate_key(checkpoint['retarget_shoulder']))
|
130 |
+
stitcher = stitcher.to('cpu')
|
131 |
+
stitcher.eval()
|
132 |
+
|
133 |
+
retargetor_lip = StitchingRetargetingNetwork(**config.get('lip'))
|
134 |
+
retargetor_lip.load_state_dict(remove_ddp_duplicate_key(checkpoint['retarget_mouth']))
|
135 |
+
retargetor_lip = retargetor_lip.to('cpu')
|
136 |
+
retargetor_lip.eval()
|
137 |
+
|
138 |
+
retargetor_eye = StitchingRetargetingNetwork(**config.get('eye'))
|
139 |
+
retargetor_eye.load_state_dict(remove_ddp_duplicate_key(checkpoint['retarget_eye']))
|
140 |
+
retargetor_eye = retargetor_eye.to('cpu')
|
141 |
+
retargetor_eye.eval()
|
142 |
+
|
143 |
+
return {
|
144 |
+
'stitching': stitcher,
|
145 |
+
'lip': retargetor_lip,
|
146 |
+
'eye': retargetor_eye
|
147 |
+
}
|
148 |
+
else:
|
149 |
+
raise ValueError(f"Unknown model type: {model_type}")
|
150 |
+
|
151 |
+
model.load_state_dict(torch.load(ckpt_path, map_location='cpu'))
|
152 |
+
model.eval()
|
153 |
+
return model
|
154 |
+
|
155 |
+
|
156 |
+
# Get coefficients of Eqn. 7
|
157 |
+
def calculate_transformation(config, s_kp_info, t_0_kp_info, t_i_kp_info, R_s, R_t_0, R_t_i):
|
158 |
+
if config.relative:
|
159 |
+
new_rotation = (R_t_i @ R_t_0.permute(0, 2, 1)) @ R_s
|
160 |
+
new_expression = s_kp_info['exp'] + (t_i_kp_info['exp'] - t_0_kp_info['exp'])
|
161 |
+
else:
|
162 |
+
new_rotation = R_t_i
|
163 |
+
new_expression = t_i_kp_info['exp']
|
164 |
+
new_translation = s_kp_info['t'] + (t_i_kp_info['t'] - t_0_kp_info['t'])
|
165 |
+
new_translation[..., 2].fill_(0) # Keep the z-axis unchanged
|
166 |
+
new_scale = s_kp_info['scale'] * (t_i_kp_info['scale'] / t_0_kp_info['scale'])
|
167 |
+
return new_rotation, new_expression, new_translation, new_scale
|
168 |
+
|
169 |
+
|
170 |
+
def load_description(fp):
|
171 |
+
with open(fp, 'r', encoding='utf-8') as f:
|
172 |
+
content = f.read()
|
173 |
+
return content
|