Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import re | |
from mmengine.config import Config | |
def replace_cfg_vals(ori_cfg): | |
"""Replace the string "${key}" with the corresponding value. | |
Replace the "${key}" with the value of ori_cfg.key in the config. And | |
support replacing the chained ${key}. Such as, replace "${key0.key1}" | |
with the value of cfg.key0.key1. Code is modified from `vars.py | |
< https://github.com/microsoft/SoftTeacher/blob/main/ssod/utils/vars.py>`_ # noqa: E501 | |
Args: | |
ori_cfg (mmengine.config.Config): | |
The origin config with "${key}" generated from a file. | |
Returns: | |
updated_cfg [mmengine.config.Config]: | |
The config with "${key}" replaced by the corresponding value. | |
""" | |
def get_value(cfg, key): | |
for k in key.split('.'): | |
cfg = cfg[k] | |
return cfg | |
def replace_value(cfg): | |
if isinstance(cfg, dict): | |
return {key: replace_value(value) for key, value in cfg.items()} | |
elif isinstance(cfg, list): | |
return [replace_value(item) for item in cfg] | |
elif isinstance(cfg, tuple): | |
return tuple([replace_value(item) for item in cfg]) | |
elif isinstance(cfg, str): | |
# the format of string cfg may be: | |
# 1) "${key}", which will be replaced with cfg.key directly | |
# 2) "xxx${key}xxx" or "xxx${key1}xxx${key2}xxx", | |
# which will be replaced with the string of the cfg.key | |
keys = pattern_key.findall(cfg) | |
values = [get_value(ori_cfg, key[2:-1]) for key in keys] | |
if len(keys) == 1 and keys[0] == cfg: | |
# the format of string cfg is "${key}" | |
cfg = values[0] | |
else: | |
for key, value in zip(keys, values): | |
# the format of string cfg is | |
# "xxx${key}xxx" or "xxx${key1}xxx${key2}xxx" | |
assert not isinstance(value, (dict, list, tuple)), \ | |
f'for the format of string cfg is ' \ | |
f"'xxxxx${key}xxxxx' or 'xxx${key}xxx${key}xxx', " \ | |
f"the type of the value of '${key}' " \ | |
f'can not be dict, list, or tuple' \ | |
f'but you input {type(value)} in {cfg}' | |
cfg = cfg.replace(key, str(value)) | |
return cfg | |
else: | |
return cfg | |
# the pattern of string "${key}" | |
pattern_key = re.compile(r'\$\{[a-zA-Z\d_.]*\}') | |
# the type of ori_cfg._cfg_dict is mmengine.config.ConfigDict | |
updated_cfg = Config( | |
replace_value(ori_cfg._cfg_dict), filename=ori_cfg.filename) | |
# replace the model with model_wrapper | |
if updated_cfg.get('model_wrapper', None) is not None: | |
updated_cfg.model = updated_cfg.model_wrapper | |
updated_cfg.pop('model_wrapper') | |
return updated_cfg | |