Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import os | |
import os.path as osp | |
from typing import Optional, Sequence | |
from mmengine.dist import is_main_process | |
from mmengine.evaluator import BaseMetric | |
from mmengine.fileio import dump | |
from mmengine.logging import MMLogger | |
from mmengine.structures import InstanceData | |
from mmdet.registry import METRICS | |
class DumpProposals(BaseMetric): | |
"""Dump proposals pseudo metric. | |
Args: | |
output_dir (str): The root directory for ``proposals_file``. | |
Defaults to ''. | |
proposals_file (str): Proposals file path. Defaults to 'proposals.pkl'. | |
num_max_proposals (int, optional): Maximum number of proposals to dump. | |
If not specified, all proposals will be dumped. | |
file_client_args (dict, optional): Arguments to instantiate the | |
corresponding backend in mmdet <= 3.0.0rc6. Defaults to None. | |
backend_args (dict, optional): Arguments to instantiate the | |
corresponding backend. Defaults to None. | |
collect_device (str): Device name used for collecting results from | |
different ranks during distributed training. Must be 'cpu' or | |
'gpu'. Defaults to 'cpu'. | |
prefix (str, optional): The prefix that will be added in the metric | |
names to disambiguate homonymous metrics of different evaluators. | |
If prefix is not provided in the argument, self.default_prefix | |
will be used instead. Defaults to None. | |
""" | |
default_prefix: Optional[str] = 'dump_proposals' | |
def __init__(self, | |
output_dir: str = '', | |
proposals_file: str = 'proposals.pkl', | |
num_max_proposals: Optional[int] = None, | |
file_client_args: dict = None, | |
backend_args: dict = None, | |
collect_device: str = 'cpu', | |
prefix: Optional[str] = None) -> None: | |
super().__init__(collect_device=collect_device, prefix=prefix) | |
self.num_max_proposals = num_max_proposals | |
# TODO: update after mmengine finish refactor fileio. | |
self.backend_args = backend_args | |
if file_client_args is not None: | |
raise RuntimeError( | |
'The `file_client_args` is deprecated, ' | |
'please use `backend_args` instead, please refer to' | |
'https://github.com/open-mmlab/mmdetection/blob/main/configs/_base_/datasets/coco_detection.py' # noqa: E501 | |
) | |
self.output_dir = output_dir | |
assert proposals_file.endswith(('.pkl', '.pickle')), \ | |
'The output file must be a pkl file.' | |
self.proposals_file = os.path.join(self.output_dir, proposals_file) | |
if is_main_process(): | |
os.makedirs(self.output_dir, exist_ok=True) | |
def process(self, data_batch: Sequence[dict], | |
data_samples: Sequence[dict]) -> None: | |
"""Process one batch of data samples and predictions. The processed | |
results should be stored in ``self.results``, which will be used to | |
compute the metrics when all batches have been processed. | |
Args: | |
data_batch (dict): A batch of data from the dataloader. | |
data_samples (Sequence[dict]): A batch of data samples that | |
contain annotations and predictions. | |
""" | |
for data_sample in data_samples: | |
pred = data_sample['pred_instances'] | |
# `bboxes` is sorted by `scores` | |
ranked_scores, rank_inds = pred['scores'].sort(descending=True) | |
ranked_bboxes = pred['bboxes'][rank_inds, :] | |
ranked_bboxes = ranked_bboxes.cpu().numpy() | |
ranked_scores = ranked_scores.cpu().numpy() | |
pred_instance = InstanceData() | |
pred_instance.bboxes = ranked_bboxes | |
pred_instance.scores = ranked_scores | |
if self.num_max_proposals is not None: | |
pred_instance = pred_instance[:self.num_max_proposals] | |
img_path = data_sample['img_path'] | |
# `file_name` is the key to obtain the proposals from the | |
# `proposals_list`. | |
file_name = osp.join( | |
osp.split(osp.split(img_path)[0])[-1], | |
osp.split(img_path)[-1]) | |
result = {file_name: pred_instance} | |
self.results.append(result) | |
def compute_metrics(self, results: list) -> dict: | |
"""Dump the processed results. | |
Args: | |
results (list): The processed results of each batch. | |
Returns: | |
dict: An empty dict. | |
""" | |
logger: MMLogger = MMLogger.get_current_instance() | |
dump_results = {} | |
for result in results: | |
dump_results.update(result) | |
dump( | |
dump_results, | |
file=self.proposals_file, | |
backend_args=self.backend_args) | |
logger.info(f'Results are saved at {self.proposals_file}') | |
return {} | |