Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import copy | |
import warnings | |
from pathlib import Path | |
from typing import Optional, Sequence, Union | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
# from mmcv.ops import RoIPool | |
from mmcv.transforms import Compose | |
from mmengine.config import Config | |
from mmengine.model.utils import revert_sync_batchnorm | |
from mmengine.registry import init_default_scope | |
from mmengine.runner import load_checkpoint | |
from mmdet.registry import DATASETS | |
from ..evaluation import get_classes | |
from ..registry import MODELS | |
from ..structures import DetDataSample, SampleList | |
from ..utils import get_test_pipeline_cfg | |
def init_detector( | |
config: Union[str, Path, Config], | |
checkpoint: Optional[str] = None, | |
palette: str = 'none', | |
device: str = 'cuda:0', | |
cfg_options: Optional[dict] = None, | |
) -> nn.Module: | |
"""Initialize a detector from config file. | |
Args: | |
config (str, :obj:`Path`, or :obj:`mmengine.Config`): Config file path, | |
:obj:`Path`, or the config object. | |
checkpoint (str, optional): Checkpoint path. If left as None, the model | |
will not load any weights. | |
palette (str): Color palette used for visualization. If palette | |
is stored in checkpoint, use checkpoint's palette first, otherwise | |
use externally passed palette. Currently, supports 'coco', 'voc', | |
'citys' and 'random'. Defaults to none. | |
device (str): The device where the anchors will be put on. | |
Defaults to cuda:0. | |
cfg_options (dict, optional): Options to override some settings in | |
the used config. | |
Returns: | |
nn.Module: The constructed detector. | |
""" | |
if isinstance(config, (str, Path)): | |
config = Config.fromfile(config) | |
elif not isinstance(config, Config): | |
raise TypeError('config must be a filename or Config object, ' | |
f'but got {type(config)}') | |
if cfg_options is not None: | |
config.merge_from_dict(cfg_options) | |
elif 'init_cfg' in config.model.backbone: | |
config.model.backbone.init_cfg = None | |
init_default_scope(config.get('default_scope', 'mmdet')) | |
config.model._scope_ = 'mmdet' | |
model = MODELS.build(config.model) | |
model = revert_sync_batchnorm(model) | |
if checkpoint is None: | |
warnings.simplefilter('once') | |
warnings.warn('checkpoint is None, use COCO classes by default.') | |
model.dataset_meta = {'classes': get_classes('coco')} | |
else: | |
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu') | |
# Weights converted from elsewhere may not have meta fields. | |
checkpoint_meta = checkpoint.get('meta', {}) | |
# save the dataset_meta in the model for convenience | |
if 'dataset_meta' in checkpoint_meta: | |
# mmdet 3.x, all keys should be lowercase | |
model.dataset_meta = { | |
k.lower(): v | |
for k, v in checkpoint_meta['dataset_meta'].items() | |
} | |
elif 'CLASSES' in checkpoint_meta: | |
# < mmdet 3.x | |
classes = checkpoint_meta['CLASSES'] | |
model.dataset_meta = {'classes': classes} | |
else: | |
warnings.simplefilter('once') | |
warnings.warn( | |
'dataset_meta or class names are not saved in the ' | |
'checkpoint\'s meta data, use COCO classes by default.') | |
model.dataset_meta = {'classes': get_classes('coco')} | |
# Priority: args.palette -> config -> checkpoint | |
if palette != 'none': | |
model.dataset_meta['palette'] = palette | |
else: | |
test_dataset_cfg = copy.deepcopy(config.test_dataloader.dataset) | |
# lazy init. We only need the metainfo. | |
test_dataset_cfg['lazy_init'] = True | |
test_dataset_cfg._scope_ = 'mmdet' | |
metainfo = DATASETS.build(test_dataset_cfg).metainfo | |
cfg_palette = metainfo.get('palette', None) | |
if cfg_palette is not None: | |
model.dataset_meta['palette'] = cfg_palette | |
else: | |
if 'palette' not in model.dataset_meta: | |
warnings.warn( | |
'palette does not exist, random is used by default. ' | |
'You can also set the palette to customize.') | |
model.dataset_meta['palette'] = 'random' | |
model.cfg = config # save the config in the model for convenience | |
model.to(device) | |
model.eval() | |
return model | |
ImagesType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]] | |
def inference_detector( | |
model: nn.Module, | |
imgs: ImagesType, | |
test_pipeline: Optional[Compose] = None | |
) -> Union[DetDataSample, SampleList]: | |
"""Inference image(s) with the detector. | |
Args: | |
model (nn.Module): The loaded detector. | |
imgs (str, ndarray, Sequence[str/ndarray]): | |
Either image files or loaded images. | |
test_pipeline (:obj:`Compose`): Test pipeline. | |
Returns: | |
:obj:`DetDataSample` or list[:obj:`DetDataSample`]: | |
If imgs is a list or tuple, the same length list type results | |
will be returned, otherwise return the detection results directly. | |
""" | |
if isinstance(imgs, (list, tuple)): | |
is_batch = True | |
else: | |
imgs = [imgs] | |
is_batch = False | |
cfg = model.cfg | |
if test_pipeline is None: | |
cfg = cfg.copy() | |
test_pipeline = get_test_pipeline_cfg(cfg) | |
if isinstance(imgs[0], np.ndarray): | |
# Calling this method across libraries will result | |
# in module unregistered error if not prefixed with mmdet. | |
test_pipeline[0].type = 'mmdet.LoadImageFromNDArray' | |
for idx, _ in enumerate(test_pipeline): | |
test_pipeline[idx]._scope_ = 'mmdet' | |
test_pipeline = Compose(test_pipeline) | |
if model.data_preprocessor.device.type == 'cpu': | |
for m in model.modules(): | |
assert not isinstance( | |
m, RoIPool | |
), 'CPU inference with RoIPool is not supported currently.' | |
result_list = [] | |
for img in imgs: | |
# prepare data | |
if isinstance(img, np.ndarray): | |
# TODO: remove img_id. | |
data_ = dict(img=img, img_id=0) | |
else: | |
# TODO: remove img_id. | |
data_ = dict(img_path=img, img_id=0) | |
# build the data pipeline | |
data_ = test_pipeline(data_) | |
data_['inputs'] = [data_['inputs']] | |
data_['data_samples'] = [data_['data_samples']] | |
# forward the model | |
with torch.no_grad(): | |
results = model.test_step(data_)[0] | |
result_list.append(results) | |
if not is_batch: | |
return result_list[0] | |
else: | |
return result_list | |
# TODO: Awaiting refactoring | |
async def async_inference_detector(model, imgs): | |
"""Async inference image(s) with the detector. | |
Args: | |
model (nn.Module): The loaded detector. | |
img (str | ndarray): Either image files or loaded images. | |
Returns: | |
Awaitable detection results. | |
""" | |
if not isinstance(imgs, (list, tuple)): | |
imgs = [imgs] | |
cfg = model.cfg | |
if isinstance(imgs[0], np.ndarray): | |
cfg = cfg.copy() | |
# set loading pipeline type | |
cfg.data.test.pipeline[0].type = 'LoadImageFromNDArray' | |
# cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline) | |
test_pipeline = Compose(cfg.data.test.pipeline) | |
datas = [] | |
for img in imgs: | |
# prepare data | |
if isinstance(img, np.ndarray): | |
# directly add img | |
data = dict(img=img) | |
else: | |
# add information into dict | |
data = dict(img_info=dict(filename=img), img_prefix=None) | |
# build the data pipeline | |
data = test_pipeline(data) | |
datas.append(data) | |
for m in model.modules(): | |
assert not isinstance( | |
m, | |
RoIPool), 'CPU inference with RoIPool is not supported currently.' | |
# We don't restore `torch.is_grad_enabled()` value during concurrent | |
# inference since execution can overlap | |
torch.set_grad_enabled(False) | |
results = await model.aforward_test(data, rescale=True) | |
return results | |