MotionLCM / mld /config.py
wxDai's picture
init
6b1e9f7
raw
history blame contribute delete
No virus
1.68 kB
import os
import importlib
from typing import Type, TypeVar
from argparse import ArgumentParser
from omegaconf import OmegaConf, DictConfig
def get_module_config(cfg_model: DictConfig, path: str = "modules") -> DictConfig:
files = os.listdir(f'./configs/{path}/')
for file in files:
if file.endswith('.yaml'):
with open(f'./configs/{path}/' + file, 'r') as f:
cfg_model.merge_with(OmegaConf.load(f))
return cfg_model
def get_obj_from_str(string: str, reload: bool = False) -> Type:
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
def instantiate_from_config(config: DictConfig) -> TypeVar:
return get_obj_from_str(config["target"])(**config.get("params", dict()))
def parse_args() -> DictConfig:
parser = ArgumentParser()
parser.add_argument("--cfg", type=str, required=True, help="config file")
# Demo Args
parser.add_argument('--example', type=str, required=False, help="input text and lengths with txt format")
parser.add_argument('--no-plot', action="store_true", required=False, help="whether plot the skeleton-based motion")
parser.add_argument('--replication', type=int, default=1, help="the number of replication of sampling")
args = parser.parse_args()
cfg = OmegaConf.load(args.cfg)
cfg_model = get_module_config(cfg.model, cfg.model.target)
cfg = OmegaConf.merge(cfg, cfg_model)
cfg.example = args.example
cfg.no_plot = args.no_plot
cfg.replication = args.replication
return cfg