Spaces:
Sleeping
Sleeping
import os, sys | |
import os.path as osp | |
import ast | |
import tempfile | |
import shutil | |
from importlib import import_module | |
from argparse import Action | |
from addict import Dict | |
from yapf.yapflib.yapf_api import FormatCode | |
BASE_KEY = '_base_' | |
DELETE_KEY = '_delete_' | |
RESERVED_KEYS = [ | |
'filename', 'text', 'pretty_text', 'get', 'dump', 'merge_from_dict' | |
] | |
def check_file_exist(filename, msg_tmpl='file "{}" does not exist'): | |
if not osp.isfile(filename): | |
raise FileNotFoundError(msg_tmpl.format(filename)) | |
class ConfigDict(Dict): | |
def __missing__(self, name): | |
raise KeyError(name) | |
def __getattr__(self, name): | |
try: | |
value = super(ConfigDict, self).__getattr__(name) | |
except KeyError: | |
ex = AttributeError(f"'{self.__class__.__name__}' object has no " | |
f"attribute '{name}'") | |
except Exception as e: | |
ex = e | |
else: | |
return value | |
raise ex | |
class Config(object): | |
def _validate_py_syntax(filename): | |
with open(filename) as f: | |
content = f.read() | |
try: | |
ast.parse(content) | |
except SyntaxError: | |
raise SyntaxError('There are syntax errors in config ' | |
f'file {filename}') | |
def _file2dict(filename): | |
filename = osp.abspath(osp.expanduser(filename)) | |
check_file_exist(filename) | |
if filename.lower().endswith('.py'): | |
with tempfile.TemporaryDirectory() as temp_config_dir: | |
temp_config_file = tempfile.NamedTemporaryFile( | |
dir=temp_config_dir, suffix='.py') | |
temp_config_name = osp.basename(temp_config_file.name) | |
shutil.copyfile(filename, | |
osp.join(temp_config_dir, temp_config_name)) | |
temp_module_name = osp.splitext(temp_config_name)[0] | |
sys.path.insert(0, temp_config_dir) | |
Config._validate_py_syntax(filename) | |
mod = import_module(temp_module_name) | |
sys.path.pop(0) | |
cfg_dict = { | |
name: value | |
for name, value in mod.__dict__.items() | |
if not name.startswith('__') | |
} | |
del sys.modules[temp_module_name] | |
temp_config_file.close() | |
else: | |
raise IOError('Only py/yml/yaml/json type are supported now!') | |
cfg_text = filename + '\n' | |
with open(filename, 'r') as f: | |
cfg_text += f.read() | |
if BASE_KEY in cfg_dict: | |
cfg_dir = osp.dirname(filename) | |
base_filename = cfg_dict.pop(BASE_KEY) | |
base_filename = base_filename if isinstance( | |
base_filename, list) else [base_filename] | |
cfg_dict_list = list() | |
cfg_text_list = list() | |
for f in base_filename: | |
_cfg_dict, _cfg_text = Config._file2dict(osp.join(cfg_dir, f)) | |
cfg_dict_list.append(_cfg_dict) | |
cfg_text_list.append(_cfg_text) | |
base_cfg_dict = dict() | |
for c in cfg_dict_list: | |
if len(base_cfg_dict.keys() & c.keys()) > 0: | |
raise KeyError('Duplicate key is not allowed among bases') | |
# TODO Allow the duplicate key while warnning user | |
base_cfg_dict.update(c) | |
base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict) | |
cfg_dict = base_cfg_dict | |
cfg_text_list.append(cfg_text) | |
cfg_text = '\n'.join(cfg_text_list) | |
return cfg_dict, cfg_text | |
def _merge_a_into_b(a, b): | |
"""merge dict `a` into dict `b` (non-inplace). values in `a` will | |
overwrite `b`. copy first to avoid inplace modification. | |
Args: | |
a ([type]): [description] | |
b ([type]): [description] | |
Returns: | |
[dict]: [description] | |
""" | |
if not isinstance(a, dict): | |
return a | |
b = b.copy() | |
for k, v in a.items(): | |
if isinstance(v, dict) and k in b and not v.pop(DELETE_KEY, False): | |
if not isinstance(b[k], dict) and not isinstance(b[k], list): | |
raise TypeError( | |
f'{k}={v} in child config cannot inherit from base ' | |
f'because {k} is a dict in the child config but is of ' | |
f'type {type(b[k])} in base config. You may set ' | |
f'`{DELETE_KEY}=True` to ignore the base config') | |
b[k] = Config._merge_a_into_b(v, b[k]) | |
elif isinstance(b, list): | |
try: | |
_ = int(k) | |
except: | |
raise TypeError( | |
f'b is a list, ' | |
f'index {k} should be an int when input but {type(k)}') | |
b[int(k)] = Config._merge_a_into_b(v, b[int(k)]) | |
else: | |
b[k] = v | |
return b | |
def fromfile(filename): | |
cfg_dict, cfg_text = Config._file2dict(filename) | |
return Config(cfg_dict, cfg_text=cfg_text, filename=filename) | |
def __init__(self, cfg_dict=None, cfg_text=None, filename=None): | |
if cfg_dict is None: | |
cfg_dict = dict() | |
elif not isinstance(cfg_dict, dict): | |
raise TypeError('cfg_dict must be a dict, but ' | |
f'got {type(cfg_dict)}') | |
for key in cfg_dict: | |
if key in RESERVED_KEYS: | |
raise KeyError(f'{key} is reserved for config file') | |
super(Config, self).__setattr__('_cfg_dict', ConfigDict(cfg_dict)) | |
super(Config, self).__setattr__('_filename', filename) | |
if cfg_text: | |
text = cfg_text | |
elif filename: | |
with open(filename, 'r') as f: | |
text = f.read() | |
else: | |
text = '' | |
super(Config, self).__setattr__('_text', text) | |
def filename(self): | |
return self._filename | |
def text(self): | |
return self._text | |
def pretty_text(self): | |
indent = 4 | |
def _indent(s_, num_spaces): | |
s = s_.split('\n') | |
if len(s) == 1: | |
return s_ | |
first = s.pop(0) | |
s = [(num_spaces * ' ') + line for line in s] | |
s = '\n'.join(s) | |
s = first + '\n' + s | |
return s | |
def _format_basic_types(k, v, use_mapping=False): | |
if isinstance(v, str): | |
v_str = f"'{v}'" | |
else: | |
v_str = str(v) | |
if use_mapping: | |
k_str = f"'{k}'" if isinstance(k, str) else str(k) | |
attr_str = f'{k_str}: {v_str}' | |
else: | |
attr_str = f'{str(k)}={v_str}' | |
attr_str = _indent(attr_str, indent) | |
return attr_str | |
def _format_list(k, v, use_mapping=False): | |
# check if all items in the list are dict | |
if all(isinstance(_, dict) for _ in v): | |
v_str = '[\n' | |
v_str += '\n'.join( | |
f'dict({_indent(_format_dict(v_), indent)}),' | |
for v_ in v).rstrip(',') | |
if use_mapping: | |
k_str = f"'{k}'" if isinstance(k, str) else str(k) | |
attr_str = f'{k_str}: {v_str}' | |
else: | |
attr_str = f'{str(k)}={v_str}' | |
attr_str = _indent(attr_str, indent) + ']' | |
else: | |
attr_str = _format_basic_types(k, v, use_mapping) | |
return attr_str | |
def _contain_invalid_identifier(dict_str): | |
contain_invalid_identifier = False | |
for key_name in dict_str: | |
contain_invalid_identifier |= \ | |
(not str(key_name).isidentifier()) | |
return contain_invalid_identifier | |
def _format_dict(input_dict, outest_level=False): | |
r = '' | |
s = [] | |
use_mapping = _contain_invalid_identifier(input_dict) | |
if use_mapping: | |
r += '{' | |
for idx, (k, v) in enumerate(input_dict.items()): | |
is_last = idx >= len(input_dict) - 1 | |
end = '' if outest_level or is_last else ',' | |
if isinstance(v, dict): | |
v_str = '\n' + _format_dict(v) | |
if use_mapping: | |
k_str = f"'{k}'" if isinstance(k, str) else str(k) | |
attr_str = f'{k_str}: dict({v_str}' | |
else: | |
attr_str = f'{str(k)}=dict({v_str}' | |
attr_str = _indent(attr_str, indent) + ')' + end | |
elif isinstance(v, list): | |
attr_str = _format_list(k, v, use_mapping) + end | |
else: | |
attr_str = _format_basic_types(k, v, use_mapping) + end | |
s.append(attr_str) | |
r += '\n'.join(s) | |
if use_mapping: | |
r += '}' | |
return r | |
cfg_dict = self._cfg_dict.to_dict() | |
text = _format_dict(cfg_dict, outest_level=True) | |
# copied from setup.cfg | |
yapf_style = dict(based_on_style='pep8', | |
blank_line_before_nested_class_or_def=True, | |
split_before_expression_after_opening_paren=True) | |
text, _ = FormatCode(text, style_config=yapf_style, verify=True) | |
return text | |
def __repr__(self): | |
return f'Config (path: {self.filename}): {self._cfg_dict.__repr__()}' | |
def __len__(self): | |
return len(self._cfg_dict) | |
def __getattr__(self, name): | |
return getattr(self._cfg_dict, name) | |
def __getitem__(self, name): | |
return self._cfg_dict.__getitem__(name) | |
def __setattr__(self, name, value): | |
if isinstance(value, dict): | |
value = ConfigDict(value) | |
self._cfg_dict.__setattr__(name, value) | |
def __setitem__(self, name, value): | |
if isinstance(value, dict): | |
value = ConfigDict(value) | |
self._cfg_dict.__setitem__(name, value) | |
def __iter__(self): | |
return iter(self._cfg_dict) | |
def dump(self, file=None): | |
# import pdb; pdb.set_trace() | |
if file is None: | |
return self.pretty_text | |
else: | |
with open(file, 'w') as f: | |
f.write(self.pretty_text) | |
def merge_from_dict(self, options): | |
"""Merge list into cfg_dict. | |
Merge the dict parsed by MultipleKVAction into this cfg. | |
Examples: | |
>>> options = {'model.backbone.depth': 50, | |
... 'model.backbone.with_cp':True} | |
>>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet')))) | |
>>> cfg.merge_from_dict(options) | |
>>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict') | |
>>> assert cfg_dict == dict( | |
... model=dict(backbone=dict(depth=50, with_cp=True))) | |
Args: | |
options (dict): dict of configs to merge from. | |
""" | |
option_cfg_dict = {} | |
for full_key, v in options.items(): | |
d = option_cfg_dict | |
key_list = full_key.split('.') | |
for subkey in key_list[:-1]: | |
d.setdefault(subkey, ConfigDict()) | |
d = d[subkey] | |
subkey = key_list[-1] | |
d[subkey] = v | |
cfg_dict = super(Config, self).__getattribute__('_cfg_dict') | |
super(Config, self).__setattr__( | |
'_cfg_dict', Config._merge_a_into_b(option_cfg_dict, cfg_dict)) | |
# for multiprocess | |
def __setstate__(self, state): | |
self.__init__(state) | |
def copy(self): | |
return Config(self._cfg_dict.copy()) | |
def deepcopy(self): | |
return Config(self._cfg_dict.deepcopy()) | |
class DictAction(Action): | |
""" | |
argparse action to split an argument into KEY=VALUE form | |
on the first = and append to a dictionary. List options should | |
be passed as comma separated values, i.e KEY=V1,V2,V3 | |
""" | |
def _parse_int_float_bool(val): | |
try: | |
return int(val) | |
except ValueError: | |
pass | |
try: | |
return float(val) | |
except ValueError: | |
pass | |
if val.lower() in ['true', 'false']: | |
return True if val.lower() == 'true' else False | |
if val.lower() in ['none', 'null']: | |
return None | |
return val | |
def __call__(self, parser, namespace, values, option_string=None): | |
options = {} | |
for kv in values: | |
key, val = kv.split('=', maxsplit=1) | |
val = [self._parse_int_float_bool(v) for v in val.split(',')] | |
if len(val) == 1: | |
val = val[0] | |
options[key] = val | |
setattr(namespace, self.dest, options) | |
cfg = Config() | |