|
|
|
|
|
""" |
|
Benchmark the inference speed of each module in LivePortrait. |
|
|
|
TODO: heavy GPT style, need to refactor |
|
""" |
|
|
|
import yaml |
|
import torch |
|
import time |
|
import numpy as np |
|
from src.utils.helper import load_model, concat_feat |
|
from src.config.inference_config import InferenceConfig |
|
|
|
|
|
def initialize_inputs(batch_size=1): |
|
""" |
|
Generate random input tensors and move them to GPU |
|
""" |
|
feature_3d = torch.randn(batch_size, 32, 16, 64, 64).cuda().half() |
|
kp_source = torch.randn(batch_size, 21, 3).cuda().half() |
|
kp_driving = torch.randn(batch_size, 21, 3).cuda().half() |
|
source_image = torch.randn(batch_size, 3, 256, 256).cuda().half() |
|
generator_input = torch.randn(batch_size, 256, 64, 64).cuda().half() |
|
eye_close_ratio = torch.randn(batch_size, 3).cuda().half() |
|
lip_close_ratio = torch.randn(batch_size, 2).cuda().half() |
|
feat_stitching = concat_feat(kp_source, kp_driving).half() |
|
feat_eye = concat_feat(kp_source, eye_close_ratio).half() |
|
feat_lip = concat_feat(kp_source, lip_close_ratio).half() |
|
|
|
inputs = { |
|
'feature_3d': feature_3d, |
|
'kp_source': kp_source, |
|
'kp_driving': kp_driving, |
|
'source_image': source_image, |
|
'generator_input': generator_input, |
|
'feat_stitching': feat_stitching, |
|
'feat_eye': feat_eye, |
|
'feat_lip': feat_lip |
|
} |
|
|
|
return inputs |
|
|
|
|
|
def load_and_compile_models(cfg, model_config): |
|
""" |
|
Load and compile models for inference |
|
""" |
|
appearance_feature_extractor = load_model(cfg.checkpoint_F, model_config, cfg.device_id, 'appearance_feature_extractor') |
|
motion_extractor = load_model(cfg.checkpoint_M, model_config, cfg.device_id, 'motion_extractor') |
|
warping_module = load_model(cfg.checkpoint_W, model_config, cfg.device_id, 'warping_module') |
|
spade_generator = load_model(cfg.checkpoint_G, model_config, cfg.device_id, 'spade_generator') |
|
stitching_retargeting_module = load_model(cfg.checkpoint_S, model_config, cfg.device_id, 'stitching_retargeting_module') |
|
|
|
models_with_params = [ |
|
('Appearance Feature Extractor', appearance_feature_extractor), |
|
('Motion Extractor', motion_extractor), |
|
('Warping Network', warping_module), |
|
('SPADE Decoder', spade_generator) |
|
] |
|
|
|
compiled_models = {} |
|
for name, model in models_with_params: |
|
model = model.half() |
|
model = torch.compile(model, mode='max-autotune') |
|
model.eval() |
|
compiled_models[name] = model |
|
|
|
retargeting_models = ['stitching', 'eye', 'lip'] |
|
for retarget in retargeting_models: |
|
module = stitching_retargeting_module[retarget].half() |
|
module = torch.compile(module, mode='max-autotune') |
|
module.eval() |
|
stitching_retargeting_module[retarget] = module |
|
|
|
return compiled_models, stitching_retargeting_module |
|
|
|
|
|
def warm_up_models(compiled_models, stitching_retargeting_module, inputs): |
|
""" |
|
Warm up models to prepare them for benchmarking |
|
""" |
|
print("Warm up start!") |
|
with torch.no_grad(): |
|
for _ in range(10): |
|
compiled_models['Appearance Feature Extractor'](inputs['source_image']) |
|
compiled_models['Motion Extractor'](inputs['source_image']) |
|
compiled_models['Warping Network'](inputs['feature_3d'], inputs['kp_driving'], inputs['kp_source']) |
|
compiled_models['SPADE Decoder'](inputs['generator_input']) |
|
stitching_retargeting_module['stitching'](inputs['feat_stitching']) |
|
stitching_retargeting_module['eye'](inputs['feat_eye']) |
|
stitching_retargeting_module['lip'](inputs['feat_lip']) |
|
print("Warm up end!") |
|
|
|
|
|
def measure_inference_times(compiled_models, stitching_retargeting_module, inputs): |
|
""" |
|
Measure inference times for each model |
|
""" |
|
times = {name: [] for name in compiled_models.keys()} |
|
times['Retargeting Models'] = [] |
|
|
|
overall_times = [] |
|
|
|
with torch.no_grad(): |
|
for _ in range(100): |
|
torch.cuda.synchronize() |
|
overall_start = time.time() |
|
|
|
start = time.time() |
|
compiled_models['Appearance Feature Extractor'](inputs['source_image']) |
|
torch.cuda.synchronize() |
|
times['Appearance Feature Extractor'].append(time.time() - start) |
|
|
|
start = time.time() |
|
compiled_models['Motion Extractor'](inputs['source_image']) |
|
torch.cuda.synchronize() |
|
times['Motion Extractor'].append(time.time() - start) |
|
|
|
start = time.time() |
|
compiled_models['Warping Network'](inputs['feature_3d'], inputs['kp_driving'], inputs['kp_source']) |
|
torch.cuda.synchronize() |
|
times['Warping Network'].append(time.time() - start) |
|
|
|
start = time.time() |
|
compiled_models['SPADE Decoder'](inputs['generator_input']) |
|
torch.cuda.synchronize() |
|
times['SPADE Decoder'].append(time.time() - start) |
|
|
|
start = time.time() |
|
stitching_retargeting_module['stitching'](inputs['feat_stitching']) |
|
stitching_retargeting_module['eye'](inputs['feat_eye']) |
|
stitching_retargeting_module['lip'](inputs['feat_lip']) |
|
torch.cuda.synchronize() |
|
times['Retargeting Models'].append(time.time() - start) |
|
|
|
overall_times.append(time.time() - overall_start) |
|
|
|
return times, overall_times |
|
|
|
|
|
def print_benchmark_results(compiled_models, stitching_retargeting_module, retargeting_models, times, overall_times): |
|
""" |
|
Print benchmark results with average and standard deviation of inference times |
|
""" |
|
average_times = {name: np.mean(times[name]) * 1000 for name in times.keys()} |
|
std_times = {name: np.std(times[name]) * 1000 for name in times.keys()} |
|
|
|
for name, model in compiled_models.items(): |
|
num_params = sum(p.numel() for p in model.parameters()) |
|
num_params_in_millions = num_params / 1e6 |
|
print(f"Number of parameters for {name}: {num_params_in_millions:.2f} M") |
|
|
|
for index, retarget in enumerate(retargeting_models): |
|
num_params = sum(p.numel() for p in stitching_retargeting_module[retarget].parameters()) |
|
num_params_in_millions = num_params / 1e6 |
|
print(f"Number of parameters for part_{index} in Stitching and Retargeting Modules: {num_params_in_millions:.2f} M") |
|
|
|
for name, avg_time in average_times.items(): |
|
std_time = std_times[name] |
|
print(f"Average inference time for {name} over 100 runs: {avg_time:.2f} ms (std: {std_time:.2f} ms)") |
|
|
|
|
|
def main(): |
|
""" |
|
Main function to benchmark speed and model parameters |
|
""" |
|
|
|
inputs = initialize_inputs() |
|
|
|
|
|
cfg = InferenceConfig(device_id=0) |
|
model_config_path = cfg.models_config |
|
with open(model_config_path, 'r') as file: |
|
model_config = yaml.safe_load(file) |
|
|
|
|
|
compiled_models, stitching_retargeting_module = load_and_compile_models(cfg, model_config) |
|
|
|
|
|
warm_up_models(compiled_models, stitching_retargeting_module, inputs) |
|
|
|
|
|
times, overall_times = measure_inference_times(compiled_models, stitching_retargeting_module, inputs) |
|
|
|
|
|
print_benchmark_results(compiled_models, stitching_retargeting_module, ['stitching', 'eye', 'lip'], times, overall_times) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|