Update inference.py
Browse files- inference.py +3 -2
inference.py
CHANGED
@@ -5,7 +5,7 @@ from src.config.argument_config import ArgumentConfig
|
|
5 |
from src.config.inference_config import InferenceConfig
|
6 |
from src.config.crop_config import CropConfig
|
7 |
from src.live_portrait_pipeline import LivePortraitPipeline
|
8 |
-
|
9 |
|
10 |
def partial_fields(target_class, kwargs):
|
11 |
return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})
|
@@ -20,7 +20,8 @@ def main():
|
|
20 |
inference_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig
|
21 |
crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig
|
22 |
|
23 |
-
live_portrait_pipeline = LivePortraitPipeline(
|
|
|
24 |
inference_cfg=inference_cfg,
|
25 |
crop_cfg=crop_cfg
|
26 |
)
|
|
|
5 |
from src.config.inference_config import InferenceConfig
|
6 |
from src.config.crop_config import CropConfig
|
7 |
from src.live_portrait_pipeline import LivePortraitPipeline
|
8 |
+
from src.live_portrait_cpu_pipeline import LiveCPUPortraitPipeline
|
9 |
|
10 |
def partial_fields(target_class, kwargs):
|
11 |
return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})
|
|
|
20 |
inference_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig
|
21 |
crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig
|
22 |
|
23 |
+
# live_portrait_pipeline = LivePortraitPipeline(
|
24 |
+
live_portrait_pipeline = LiveCPUPortraitPipeline(
|
25 |
inference_cfg=inference_cfg,
|
26 |
crop_cfg=crop_cfg
|
27 |
)
|