Spaces:
Runtime error
Runtime error
File size: 4,979 Bytes
3b96cb1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
# Copyright (c) OpenMMLab. All rights reserved.
import glob
import os
import os.path as osp
import urllib
import warnings
from typing import Union
import torch
from mmengine.config import Config, ConfigDict
from mmengine.logging import print_log
from mmengine.utils import scandir
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif',
'.tiff', '.webp')
def find_latest_checkpoint(path, suffix='pth'):
"""Find the latest checkpoint from the working directory.
Args:
path(str): The path to find checkpoints.
suffix(str): File extension.
Defaults to pth.
Returns:
latest_path(str | None): File path of the latest checkpoint.
References:
.. [1] https://github.com/microsoft/SoftTeacher
/blob/main/ssod/utils/patch.py
"""
if not osp.exists(path):
warnings.warn('The path of checkpoints does not exist.')
return None
if osp.exists(osp.join(path, f'latest.{suffix}')):
return osp.join(path, f'latest.{suffix}')
checkpoints = glob.glob(osp.join(path, f'*.{suffix}'))
if len(checkpoints) == 0:
warnings.warn('There are no checkpoints in the path.')
return None
latest = -1
latest_path = None
for checkpoint in checkpoints:
count = int(osp.basename(checkpoint).split('_')[-1].split('.')[0])
if count > latest:
latest = count
latest_path = checkpoint
return latest_path
def update_data_root(cfg, logger=None):
"""Update data root according to env MMDET_DATASETS.
If set env MMDET_DATASETS, update cfg.data_root according to
MMDET_DATASETS. Otherwise, using cfg.data_root as default.
Args:
cfg (:obj:`Config`): The model config need to modify
logger (logging.Logger | str | None): the way to print msg
"""
assert isinstance(cfg, Config), \
f'cfg got wrong type: {type(cfg)}, expected mmengine.Config'
if 'MMDET_DATASETS' in os.environ:
dst_root = os.environ['MMDET_DATASETS']
print_log(f'MMDET_DATASETS has been set to be {dst_root}.'
f'Using {dst_root} as data root.')
else:
return
assert isinstance(cfg, Config), \
f'cfg got wrong type: {type(cfg)}, expected mmengine.Config'
def update(cfg, src_str, dst_str):
for k, v in cfg.items():
if isinstance(v, ConfigDict):
update(cfg[k], src_str, dst_str)
if isinstance(v, str) and src_str in v:
cfg[k] = v.replace(src_str, dst_str)
update(cfg.data, cfg.data_root, dst_root)
cfg.data_root = dst_root
def get_test_pipeline_cfg(cfg: Union[str, ConfigDict]) -> ConfigDict:
"""Get the test dataset pipeline from entire config.
Args:
cfg (str or :obj:`ConfigDict`): the entire config. Can be a config
file or a ``ConfigDict``.
Returns:
:obj:`ConfigDict`: the config of test dataset.
"""
if isinstance(cfg, str):
cfg = Config.fromfile(cfg)
def _get_test_pipeline_cfg(dataset_cfg):
if 'pipeline' in dataset_cfg:
return dataset_cfg.pipeline
# handle dataset wrapper
elif 'dataset' in dataset_cfg:
return _get_test_pipeline_cfg(dataset_cfg.dataset)
# handle dataset wrappers like ConcatDataset
elif 'datasets' in dataset_cfg:
return _get_test_pipeline_cfg(dataset_cfg.datasets[0])
raise RuntimeError('Cannot find `pipeline` in `test_dataloader`')
return _get_test_pipeline_cfg(cfg.test_dataloader.dataset)
def get_file_list(source_root: str) -> [list, dict]:
"""Get file list.
Args:
source_root (str): image or video source path
Return:
source_file_path_list (list): A list for all source file.
source_type (dict): Source type: file or url or dir.
"""
is_dir = os.path.isdir(source_root)
is_url = source_root.startswith(('http:/', 'https:/'))
is_file = os.path.splitext(source_root)[-1].lower() in IMG_EXTENSIONS
source_file_path_list = []
if is_dir:
# when input source is dir
for file in scandir(source_root, IMG_EXTENSIONS, recursive=True):
source_file_path_list.append(os.path.join(source_root, file))
elif is_url:
# when input source is url
filename = os.path.basename(
urllib.parse.unquote(source_root).split('?')[0])
file_save_path = os.path.join(os.getcwd(), filename)
print(f'Downloading source file to {file_save_path}')
torch.hub.download_url_to_file(source_root, file_save_path)
source_file_path_list = [file_save_path]
elif is_file:
# when input source is single image
source_file_path_list = [source_root]
else:
print('Cannot find image file.')
source_type = dict(is_dir=is_dir, is_url=is_url, is_file=is_file)
return source_file_path_list, source_type
|