|
|
|
|
|
import tyro |
|
from src.config.argument_config import ArgumentConfig |
|
from src.config.inference_config import InferenceConfig |
|
from src.config.crop_config import CropConfig |
|
from src.live_portrait_pipeline import LivePortraitPipeline |
|
from src.live_portrait_cpu_pipeline import LiveCPUPortraitPipeline |
|
|
|
def partial_fields(target_class, kwargs): |
|
return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)}) |
|
|
|
|
|
def main(): |
|
|
|
tyro.extras.set_accent_color("bright_cyan") |
|
args = tyro.cli(ArgumentConfig) |
|
|
|
|
|
inference_cfg = partial_fields(InferenceConfig, args.__dict__) |
|
crop_cfg = partial_fields(CropConfig, args.__dict__) |
|
|
|
|
|
live_portrait_pipeline = LiveCPUPortraitPipeline( |
|
inference_cfg=inference_cfg, |
|
crop_cfg=crop_cfg |
|
) |
|
|
|
|
|
live_portrait_pipeline.execute(args) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|