K00B404 commited on
Commit
7f861c0
·
verified ·
1 Parent(s): 62fea9c

Create helper_cpu.py

Browse files
Files changed (1) hide show
  1. src/utils/helper_cpu.py +173 -0
src/utils/helper_cpu.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+
3
+ """
4
+ Utility functions and classes to handle feature extraction and model loading
5
+ """
6
+
7
+ import os
8
+ import os.path as osp
9
+ import torch
10
+ from collections import OrderedDict
11
+ import psutil
12
+ from rich.console import Console
13
+ from rich.progress import Progress
14
+ from ..modules.spade_generator import SPADEDecoder
15
+ from ..modules.warping_network import WarpingNetwork
16
+ from ..modules.motion_extractor import MotionExtractor
17
+ from ..modules.appearance_feature_extractor import AppearanceFeatureExtractor
18
+ from ..modules.stitching_retargeting_network import StitchingRetargetingNetwork
19
+
20
+ from rich.console import Console
21
+ import psutil
22
+
23
+ console = Console()
24
+
25
+ def show_memory_usage():
26
+ """
27
+ Display the current memory usage in the terminal using rich.
28
+ """
29
+ mem_info = psutil.virtual_memory()
30
+ total_mem = mem_info.total / (1024 ** 3) # Convert to GB
31
+ used_mem = mem_info.used / (1024 ** 3) # Convert to GB
32
+ available_mem = mem_info.available / (1024 ** 3) # Convert to GB
33
+
34
+ console.log(f"[bold green]Memory Usage:[/bold green] [bold red]{used_mem:.2f} GB[/bold red] used of [bold blue]{total_mem:.2f} GB[/bold blue]")
35
+ console.log(f"[bold green]Available Memory:[/bold green] [bold yellow]{available_mem:.2f} GB[/bold yellow]")
36
+
37
+
38
+ def suffix(filename):
39
+ """a.jpg -> jpg"""
40
+ pos = filename.rfind(".")
41
+ if pos == -1:
42
+ return ""
43
+ return filename[pos + 1:]
44
+
45
+
46
+ def prefix(filename):
47
+ """a.jpg -> a"""
48
+ pos = filename.rfind(".")
49
+ if pos == -1:
50
+ return filename
51
+ return filename[:pos]
52
+
53
+
54
+ def basename(filename):
55
+ """a/b/c.jpg -> c"""
56
+ return prefix(osp.basename(filename))
57
+
58
+
59
+ def is_video(file_path):
60
+ if file_path.lower().endswith((".mp4", ".mov", ".avi", ".webm")) or osp.isdir(file_path):
61
+ return True
62
+ return False
63
+
64
+
65
+ def is_template(file_path):
66
+ if file_path.endswith(".pkl"):
67
+ return True
68
+ return False
69
+
70
+
71
+ def mkdir(d, log=False):
72
+ # return self-assigned `d`, for one line code
73
+ if not osp.exists(d):
74
+ os.makedirs(d, exist_ok=True)
75
+ if log:
76
+ log(f"Make dir: {d}")
77
+ return d
78
+
79
+
80
+ def squeeze_tensor_to_numpy(tensor):
81
+ out = tensor.data.squeeze(0).cpu().numpy()
82
+ return out
83
+
84
+
85
+ def dct2cpu(dct: dict, device='cpu'):
86
+ for key in dct:
87
+ dct[key] = torch.tensor(dct[key]).to(device)
88
+ return dct
89
+
90
+
91
+ def concat_feat(kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
92
+ """
93
+ kp_source: (bs, k, 3)
94
+ kp_driving: (bs, k, 3)
95
+ Return: (bs, 2k*3)
96
+ """
97
+ bs_src = kp_source.shape[0]
98
+ bs_dri = kp_driving.shape[0]
99
+ assert bs_src == bs_dri, 'batch size must be equal'
100
+
101
+ feat = torch.cat([kp_source.view(bs_src, -1), kp_driving.view(bs_dri, -1)], dim=1)
102
+ return feat
103
+
104
+
105
+ def remove_ddp_duplicate_key(state_dict):
106
+ state_dict_new = OrderedDict()
107
+ for key in state_dict.keys():
108
+ state_dict_new[key.replace('module.', '')] = state_dict[key]
109
+ return state_dict_new
110
+
111
+
112
+ def load_model(ckpt_path, model_config, device, model_type):
113
+ model_params = model_config['model_params'][f'{model_type}_params']
114
+
115
+ if model_type == 'appearance_feature_extractor':
116
+ model = AppearanceFeatureExtractor(**model_params).to('cpu')
117
+ elif model_type == 'motion_extractor':
118
+ model = MotionExtractor(**model_params).to('cpu')
119
+ elif model_type == 'warping_module':
120
+ model = WarpingNetwork(**model_params).to('cpu')
121
+ elif model_type == 'spade_generator':
122
+ model = SPADEDecoder(**model_params).to('cpu')
123
+ elif model_type == 'stitching_retargeting_module':
124
+ # Special handling for stitching and retargeting module
125
+ config = model_config['model_params']['stitching_retargeting_module_params']
126
+ checkpoint = torch.load(ckpt_path, map_location='cpu')
127
+
128
+ stitcher = StitchingRetargetingNetwork(**config.get('stitching'))
129
+ stitcher.load_state_dict(remove_ddp_duplicate_key(checkpoint['retarget_shoulder']))
130
+ stitcher = stitcher.to('cpu')
131
+ stitcher.eval()
132
+
133
+ retargetor_lip = StitchingRetargetingNetwork(**config.get('lip'))
134
+ retargetor_lip.load_state_dict(remove_ddp_duplicate_key(checkpoint['retarget_mouth']))
135
+ retargetor_lip = retargetor_lip.to('cpu')
136
+ retargetor_lip.eval()
137
+
138
+ retargetor_eye = StitchingRetargetingNetwork(**config.get('eye'))
139
+ retargetor_eye.load_state_dict(remove_ddp_duplicate_key(checkpoint['retarget_eye']))
140
+ retargetor_eye = retargetor_eye.to('cpu')
141
+ retargetor_eye.eval()
142
+
143
+ return {
144
+ 'stitching': stitcher,
145
+ 'lip': retargetor_lip,
146
+ 'eye': retargetor_eye
147
+ }
148
+ else:
149
+ raise ValueError(f"Unknown model type: {model_type}")
150
+
151
+ model.load_state_dict(torch.load(ckpt_path, map_location='cpu'))
152
+ model.eval()
153
+ return model
154
+
155
+
156
+ # Get coefficients of Eqn. 7
157
+ def calculate_transformation(config, s_kp_info, t_0_kp_info, t_i_kp_info, R_s, R_t_0, R_t_i):
158
+ if config.relative:
159
+ new_rotation = (R_t_i @ R_t_0.permute(0, 2, 1)) @ R_s
160
+ new_expression = s_kp_info['exp'] + (t_i_kp_info['exp'] - t_0_kp_info['exp'])
161
+ else:
162
+ new_rotation = R_t_i
163
+ new_expression = t_i_kp_info['exp']
164
+ new_translation = s_kp_info['t'] + (t_i_kp_info['t'] - t_0_kp_info['t'])
165
+ new_translation[..., 2].fill_(0) # Keep the z-axis unchanged
166
+ new_scale = s_kp_info['scale'] * (t_i_kp_info['scale'] / t_0_kp_info['scale'])
167
+ return new_rotation, new_expression, new_translation, new_scale
168
+
169
+
170
+ def load_description(fp):
171
+ with open(fp, 'r', encoding='utf-8') as f:
172
+ content = f.read()
173
+ return content