Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import glob | |
import os | |
import os.path as osp | |
import warnings | |
from typing import Union | |
from mmengine.config import Config, ConfigDict | |
from mmengine.logging import print_log | |
def find_latest_checkpoint(path, suffix='pth'): | |
"""Find the latest checkpoint from the working directory. | |
Args: | |
path(str): The path to find checkpoints. | |
suffix(str): File extension. | |
Defaults to pth. | |
Returns: | |
latest_path(str | None): File path of the latest checkpoint. | |
References: | |
.. [1] https://github.com/microsoft/SoftTeacher | |
/blob/main/ssod/utils/patch.py | |
""" | |
if not osp.exists(path): | |
warnings.warn('The path of checkpoints does not exist.') | |
return None | |
if osp.exists(osp.join(path, f'latest.{suffix}')): | |
return osp.join(path, f'latest.{suffix}') | |
checkpoints = glob.glob(osp.join(path, f'*.{suffix}')) | |
if len(checkpoints) == 0: | |
warnings.warn('There are no checkpoints in the path.') | |
return None | |
latest = -1 | |
latest_path = None | |
for checkpoint in checkpoints: | |
count = int(osp.basename(checkpoint).split('_')[-1].split('.')[0]) | |
if count > latest: | |
latest = count | |
latest_path = checkpoint | |
return latest_path | |
def update_data_root(cfg, logger=None): | |
"""Update data root according to env MMDET_DATASETS. | |
If set env MMDET_DATASETS, update cfg.data_root according to | |
MMDET_DATASETS. Otherwise, using cfg.data_root as default. | |
Args: | |
cfg (:obj:`Config`): The model config need to modify | |
logger (logging.Logger | str | None): the way to print msg | |
""" | |
assert isinstance(cfg, Config), \ | |
f'cfg got wrong type: {type(cfg)}, expected mmengine.Config' | |
if 'MMDET_DATASETS' in os.environ: | |
dst_root = os.environ['MMDET_DATASETS'] | |
print_log(f'MMDET_DATASETS has been set to be {dst_root}.' | |
f'Using {dst_root} as data root.') | |
else: | |
return | |
assert isinstance(cfg, Config), \ | |
f'cfg got wrong type: {type(cfg)}, expected mmengine.Config' | |
def update(cfg, src_str, dst_str): | |
for k, v in cfg.items(): | |
if isinstance(v, ConfigDict): | |
update(cfg[k], src_str, dst_str) | |
if isinstance(v, str) and src_str in v: | |
cfg[k] = v.replace(src_str, dst_str) | |
update(cfg.data, cfg.data_root, dst_root) | |
cfg.data_root = dst_root | |
def get_test_pipeline_cfg(cfg: Union[str, ConfigDict]) -> ConfigDict: | |
"""Get the test dataset pipeline from entire config. | |
Args: | |
cfg (str or :obj:`ConfigDict`): the entire config. Can be a config | |
file or a ``ConfigDict``. | |
Returns: | |
:obj:`ConfigDict`: the config of test dataset. | |
""" | |
if isinstance(cfg, str): | |
cfg = Config.fromfile(cfg) | |
def _get_test_pipeline_cfg(dataset_cfg): | |
if 'pipeline' in dataset_cfg: | |
return dataset_cfg.pipeline | |
# handle dataset wrapper | |
elif 'dataset' in dataset_cfg: | |
return _get_test_pipeline_cfg(dataset_cfg.dataset) | |
# handle dataset wrappers like ConcatDataset | |
elif 'datasets' in dataset_cfg: | |
return _get_test_pipeline_cfg(dataset_cfg.datasets[0]) | |
raise RuntimeError('Cannot find `pipeline` in `test_dataloader`') | |
return _get_test_pipeline_cfg(cfg.test_dataloader.dataset) | |