|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import json |
|
import os |
|
|
|
from .easydict import EasyDict as edict |
|
from .arg_utils import infer_type |
|
|
|
import pathlib |
|
import platform |
|
|
|
ROOT = pathlib.Path(__file__).parent.parent.resolve() |
|
|
|
HOME_DIR = os.path.expanduser("~") |
|
|
|
COMMON_CONFIG = { |
|
"save_dir": os.path.expanduser("~/shortcuts/monodepth3_checkpoints"), |
|
"project": "ZoeDepth", |
|
"tags": '', |
|
"notes": "", |
|
"gpu": None, |
|
"root": ".", |
|
"uid": None, |
|
"print_losses": False |
|
} |
|
|
|
DATASETS_CONFIG = { |
|
"kitti": { |
|
"dataset": "kitti", |
|
"min_depth": 0.001, |
|
"max_depth": 80, |
|
"data_path": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/raw"), |
|
"gt_path": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/gts"), |
|
"filenames_file": "./train_test_inputs/kitti_eigen_train_files_with_gt.txt", |
|
"input_height": 352, |
|
"input_width": 1216, |
|
"data_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/raw"), |
|
"gt_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/gts"), |
|
"filenames_file_eval": "./train_test_inputs/kitti_eigen_test_files_with_gt.txt", |
|
|
|
"min_depth_eval": 1e-3, |
|
"max_depth_eval": 80, |
|
|
|
"do_random_rotate": True, |
|
"degree": 1.0, |
|
"do_kb_crop": True, |
|
"garg_crop": True, |
|
"eigen_crop": False, |
|
"use_right": False |
|
}, |
|
"kitti_test": { |
|
"dataset": "kitti", |
|
"min_depth": 0.001, |
|
"max_depth": 80, |
|
"data_path": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/raw"), |
|
"gt_path": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/gts"), |
|
"filenames_file": "./train_test_inputs/kitti_eigen_train_files_with_gt.txt", |
|
"input_height": 352, |
|
"input_width": 1216, |
|
"data_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/raw"), |
|
"gt_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/gts"), |
|
"filenames_file_eval": "./train_test_inputs/kitti_eigen_test_files_with_gt.txt", |
|
|
|
"min_depth_eval": 1e-3, |
|
"max_depth_eval": 80, |
|
|
|
"do_random_rotate": False, |
|
"degree": 1.0, |
|
"do_kb_crop": True, |
|
"garg_crop": True, |
|
"eigen_crop": False, |
|
"use_right": False |
|
}, |
|
"nyu": { |
|
"dataset": "nyu", |
|
"avoid_boundary": False, |
|
"min_depth": 1e-3, |
|
"max_depth": 10, |
|
"data_path": os.path.join(HOME_DIR, "shortcuts/datasets/nyu_depth_v2/sync/"), |
|
"gt_path": os.path.join(HOME_DIR, "shortcuts/datasets/nyu_depth_v2/sync/"), |
|
"filenames_file": "./train_test_inputs/nyudepthv2_train_files_with_gt.txt", |
|
"input_height": 480, |
|
"input_width": 640, |
|
"data_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/nyu_depth_v2/official_splits/test/"), |
|
"gt_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/nyu_depth_v2/official_splits/test/"), |
|
"filenames_file_eval": "./train_test_inputs/nyudepthv2_test_files_with_gt.txt", |
|
"min_depth_eval": 1e-3, |
|
"max_depth_eval": 10, |
|
"min_depth_diff": -10, |
|
"max_depth_diff": 10, |
|
|
|
"do_random_rotate": True, |
|
"degree": 1.0, |
|
"do_kb_crop": False, |
|
"garg_crop": False, |
|
"eigen_crop": True |
|
}, |
|
"ibims": { |
|
"dataset": "ibims", |
|
"ibims_root": os.path.join(HOME_DIR, "shortcuts/datasets/ibims/ibims1_core_raw/"), |
|
"eigen_crop": True, |
|
"garg_crop": False, |
|
"do_kb_crop": False, |
|
"min_depth_eval": 0, |
|
"max_depth_eval": 10, |
|
"min_depth": 1e-3, |
|
"max_depth": 10 |
|
}, |
|
"sunrgbd": { |
|
"dataset": "sunrgbd", |
|
"sunrgbd_root": os.path.join(HOME_DIR, "shortcuts/datasets/SUNRGBD/test/"), |
|
"eigen_crop": True, |
|
"garg_crop": False, |
|
"do_kb_crop": False, |
|
"min_depth_eval": 0, |
|
"max_depth_eval": 8, |
|
"min_depth": 1e-3, |
|
"max_depth": 10 |
|
}, |
|
"diml_indoor": { |
|
"dataset": "diml_indoor", |
|
"diml_indoor_root": os.path.join(HOME_DIR, "shortcuts/datasets/diml_indoor_test/"), |
|
"eigen_crop": True, |
|
"garg_crop": False, |
|
"do_kb_crop": False, |
|
"min_depth_eval": 0, |
|
"max_depth_eval": 10, |
|
"min_depth": 1e-3, |
|
"max_depth": 10 |
|
}, |
|
"diml_outdoor": { |
|
"dataset": "diml_outdoor", |
|
"diml_outdoor_root": os.path.join(HOME_DIR, "shortcuts/datasets/diml_outdoor_test/"), |
|
"eigen_crop": False, |
|
"garg_crop": True, |
|
"do_kb_crop": False, |
|
"min_depth_eval": 2, |
|
"max_depth_eval": 80, |
|
"min_depth": 1e-3, |
|
"max_depth": 80 |
|
}, |
|
"diode_indoor": { |
|
"dataset": "diode_indoor", |
|
"diode_indoor_root": os.path.join(HOME_DIR, "shortcuts/datasets/diode_indoor/"), |
|
"eigen_crop": True, |
|
"garg_crop": False, |
|
"do_kb_crop": False, |
|
"min_depth_eval": 1e-3, |
|
"max_depth_eval": 10, |
|
"min_depth": 1e-3, |
|
"max_depth": 10 |
|
}, |
|
"diode_outdoor": { |
|
"dataset": "diode_outdoor", |
|
"diode_outdoor_root": os.path.join(HOME_DIR, "shortcuts/datasets/diode_outdoor/"), |
|
"eigen_crop": False, |
|
"garg_crop": True, |
|
"do_kb_crop": False, |
|
"min_depth_eval": 1e-3, |
|
"max_depth_eval": 80, |
|
"min_depth": 1e-3, |
|
"max_depth": 80 |
|
}, |
|
"hypersim_test": { |
|
"dataset": "hypersim_test", |
|
"hypersim_test_root": os.path.join(HOME_DIR, "shortcuts/datasets/hypersim_test/"), |
|
"eigen_crop": True, |
|
"garg_crop": False, |
|
"do_kb_crop": False, |
|
"min_depth_eval": 1e-3, |
|
"max_depth_eval": 80, |
|
"min_depth": 1e-3, |
|
"max_depth": 10 |
|
}, |
|
"vkitti": { |
|
"dataset": "vkitti", |
|
"vkitti_root": os.path.join(HOME_DIR, "shortcuts/datasets/vkitti_test/"), |
|
"eigen_crop": False, |
|
"garg_crop": True, |
|
"do_kb_crop": True, |
|
"min_depth_eval": 1e-3, |
|
"max_depth_eval": 80, |
|
"min_depth": 1e-3, |
|
"max_depth": 80 |
|
}, |
|
"vkitti2": { |
|
"dataset": "vkitti2", |
|
"vkitti2_root": os.path.join(HOME_DIR, "shortcuts/datasets/vkitti2/"), |
|
"eigen_crop": False, |
|
"garg_crop": True, |
|
"do_kb_crop": True, |
|
"min_depth_eval": 1e-3, |
|
"max_depth_eval": 80, |
|
"min_depth": 1e-3, |
|
"max_depth": 80, |
|
}, |
|
"ddad": { |
|
"dataset": "ddad", |
|
"ddad_root": os.path.join(HOME_DIR, "shortcuts/datasets/ddad/ddad_val/"), |
|
"eigen_crop": False, |
|
"garg_crop": True, |
|
"do_kb_crop": True, |
|
"min_depth_eval": 1e-3, |
|
"max_depth_eval": 80, |
|
"min_depth": 1e-3, |
|
"max_depth": 80, |
|
}, |
|
} |
|
|
|
ALL_INDOOR = ["nyu", "ibims", "sunrgbd", "diode_indoor", "hypersim_test"] |
|
ALL_OUTDOOR = ["kitti", "diml_outdoor", "diode_outdoor", "vkitti2", "ddad"] |
|
ALL_EVAL_DATASETS = ALL_INDOOR + ALL_OUTDOOR |
|
|
|
COMMON_TRAINING_CONFIG = { |
|
"dataset": "nyu", |
|
"distributed": True, |
|
"workers": 16, |
|
"clip_grad": 0.1, |
|
"use_shared_dict": False, |
|
"shared_dict": None, |
|
"use_amp": False, |
|
|
|
"aug": True, |
|
"random_crop": False, |
|
"random_translate": False, |
|
"translate_prob": 0.2, |
|
"max_translation": 100, |
|
|
|
"validate_every": 0.25, |
|
"log_images_every": 0.1, |
|
"prefetch": False, |
|
} |
|
|
|
|
|
def flatten(config, except_keys=('bin_conf')): |
|
def recurse(inp): |
|
if isinstance(inp, dict): |
|
for key, value in inp.items(): |
|
if key in except_keys: |
|
yield (key, value) |
|
if isinstance(value, dict): |
|
yield from recurse(value) |
|
else: |
|
yield (key, value) |
|
|
|
return dict(list(recurse(config))) |
|
|
|
|
|
def split_combined_args(kwargs): |
|
"""Splits the arguments that are combined with '__' into multiple arguments. |
|
Combined arguments should have equal number of keys and values. |
|
Keys are separated by '__' and Values are separated with ';'. |
|
For example, '__n_bins__lr=256;0.001' |
|
|
|
Args: |
|
kwargs (dict): key-value pairs of arguments where key-value is optionally combined according to the above format. |
|
|
|
Returns: |
|
dict: Parsed dict with the combined arguments split into individual key-value pairs. |
|
""" |
|
new_kwargs = dict(kwargs) |
|
for key, value in kwargs.items(): |
|
if key.startswith("__"): |
|
keys = key.split("__")[1:] |
|
values = value.split(";") |
|
assert len(keys) == len( |
|
values), f"Combined arguments should have equal number of keys and values. Keys are separated by '__' and Values are separated with ';'. For example, '__n_bins__lr=256;0.001. Given (keys,values) is ({keys}, {values})" |
|
for k, v in zip(keys, values): |
|
new_kwargs[k] = v |
|
return new_kwargs |
|
|
|
|
|
def parse_list(config, key, dtype=int): |
|
"""Parse a list of values for the key if the value is a string. The values are separated by a comma. |
|
Modifies the config in place. |
|
""" |
|
if key in config: |
|
if isinstance(config[key], str): |
|
config[key] = list(map(dtype, config[key].split(','))) |
|
assert isinstance(config[key], list) and all([isinstance(e, dtype) for e in config[key]] |
|
), f"{key} should be a list of values dtype {dtype}. Given {config[key]} of type {type(config[key])} with values of type {[type(e) for e in config[key]]}." |
|
|
|
|
|
def get_model_config(model_name, model_version=None): |
|
"""Find and parse the .json config file for the model. |
|
|
|
Args: |
|
model_name (str): name of the model. The config file should be named config_{model_name}[_{model_version}].json under the models/{model_name} directory. |
|
model_version (str, optional): Specific config version. If specified config_{model_name}_{model_version}.json is searched for and used. Otherwise config_{model_name}.json is used. Defaults to None. |
|
|
|
Returns: |
|
easydict: the config dictionary for the model. |
|
""" |
|
config_fname = f"config_{model_name}_{model_version}.json" if model_version is not None else f"config_{model_name}.json" |
|
config_file = os.path.join(ROOT, "models", model_name, config_fname) |
|
if not os.path.exists(config_file): |
|
return None |
|
|
|
with open(config_file, "r") as f: |
|
config = edict(json.load(f)) |
|
|
|
|
|
|
|
if "inherit" in config.train and config.train.inherit is not None: |
|
inherit_config = get_model_config(config.train["inherit"]).train |
|
for key, value in inherit_config.items(): |
|
if key not in config.train: |
|
config.train[key] = value |
|
return edict(config) |
|
|
|
|
|
def update_model_config(config, mode, model_name, model_version=None, strict=False): |
|
model_config = get_model_config(model_name, model_version) |
|
if model_config is not None: |
|
config = {**config, ** |
|
flatten({**model_config.model, **model_config[mode]})} |
|
elif strict: |
|
raise ValueError(f"Config file for model {model_name} not found.") |
|
return config |
|
|
|
|
|
def check_choices(name, value, choices): |
|
|
|
if value not in choices: |
|
raise ValueError(f"{name} {value} not in supported choices {choices}") |
|
|
|
|
|
KEYS_TYPE_BOOL = ["use_amp", "distributed", "use_shared_dict", "same_lr", "aug", "three_phase", |
|
"prefetch", "cycle_momentum"] |
|
|
|
|
|
def get_config(model_name, mode='train', dataset=None, **overwrite_kwargs): |
|
"""Main entry point to get the config for the model. |
|
|
|
Args: |
|
model_name (str): name of the desired model. |
|
mode (str, optional): "train" or "infer". Defaults to 'train'. |
|
dataset (str, optional): If specified, the corresponding dataset configuration is loaded as well. Defaults to None. |
|
|
|
Keyword Args: key-value pairs of arguments to overwrite the default config. |
|
|
|
The order of precedence for overwriting the config is (Higher precedence first): |
|
# 1. overwrite_kwargs |
|
# 2. "config_version": Config file version if specified in overwrite_kwargs. The corresponding config loaded is config_{model_name}_{config_version}.json |
|
# 3. "version_name": Default Model version specific config specified in overwrite_kwargs. The corresponding config loaded is config_{model_name}_{version_name}.json |
|
# 4. common_config: Default config for all models specified in COMMON_CONFIG |
|
|
|
Returns: |
|
easydict: The config dictionary for the model. |
|
""" |
|
|
|
|
|
check_choices("Model", model_name, ["zoedepth", "zoedepth_nk"]) |
|
check_choices("Mode", mode, ["train", "infer", "eval"]) |
|
if mode == "train": |
|
check_choices("Dataset", dataset, ["nyu", "kitti", "mix", None]) |
|
|
|
config = flatten({**COMMON_CONFIG, **COMMON_TRAINING_CONFIG}) |
|
config = update_model_config(config, mode, model_name) |
|
|
|
|
|
version_name = overwrite_kwargs.get("version_name", config["version_name"]) |
|
config = update_model_config(config, mode, model_name, version_name) |
|
|
|
|
|
config_version = overwrite_kwargs.get("config_version", None) |
|
if config_version is not None: |
|
print("Overwriting config with config_version", config_version) |
|
config = update_model_config(config, mode, model_name, config_version) |
|
|
|
|
|
|
|
overwrite_kwargs = split_combined_args(overwrite_kwargs) |
|
config = {**config, **overwrite_kwargs} |
|
|
|
|
|
for key in KEYS_TYPE_BOOL: |
|
if key in config: |
|
config[key] = bool(config[key]) |
|
|
|
|
|
parse_list(config, "n_attractors") |
|
|
|
|
|
if 'bin_conf' in config and 'n_bins' in overwrite_kwargs: |
|
bin_conf = config['bin_conf'] |
|
n_bins = overwrite_kwargs['n_bins'] |
|
new_bin_conf = [] |
|
for conf in bin_conf: |
|
conf['n_bins'] = n_bins |
|
new_bin_conf.append(conf) |
|
config['bin_conf'] = new_bin_conf |
|
|
|
if mode == "train": |
|
orig_dataset = dataset |
|
if dataset == "mix": |
|
dataset = 'nyu' |
|
if dataset is not None: |
|
config['project'] = f"MonoDepth3-{orig_dataset}" |
|
|
|
if dataset is not None: |
|
config['dataset'] = dataset |
|
config = {**DATASETS_CONFIG[dataset], **config} |
|
|
|
|
|
config['model'] = model_name |
|
typed_config = {k: infer_type(v) for k, v in config.items()} |
|
|
|
config['hostname'] = platform.node() |
|
return edict(typed_config) |
|
|
|
|
|
def change_dataset(config, new_dataset): |
|
config.update(DATASETS_CONFIG[new_dataset]) |
|
return config |
|
|