# Copyright (c) OpenMMLab. All rights reserved. import warnings from typing import Sequence from mmengine.dist import broadcast_object_list, is_main_process from mmdet.registry import METRICS from .base_video_metric import collect_tracking_results from .coco_metric import CocoMetric @METRICS.register_module() class CocoVideoMetric(CocoMetric): """COCO evaluation metric. Evaluate AR, AP, and mAP for detection tasks including proposal/box detection and instance segmentation. Please refer to https://cocodataset.org/#detection-eval for more details. """ def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: """Process one batch of data samples and predictions. The processed results should be stored in ``self.results``, which will be used to compute the metrics when all batches have been processed. Args: data_batch (dict): A batch of data from the dataloader. data_samples (Sequence[dict]): A batch of data samples that contain annotations and predictions. """ for track_data_sample in data_samples: video_data_samples = track_data_sample['video_data_samples'] ori_video_len = video_data_samples[0].ori_video_length video_len = len(video_data_samples) if ori_video_len == video_len: # video process for frame_id in range(video_len): img_data_sample = video_data_samples[frame_id].to_dict() super().process(None, [img_data_sample]) else: # image process img_data_sample = video_data_samples[0].to_dict() super().process(None, [img_data_sample]) def evaluate(self, size: int = 1) -> dict: """Evaluate the model performance of the whole dataset after processing all batches. Args: size (int): Length of the entire validation dataset. Returns: dict: Evaluation metrics dict on the val dataset. The keys are the names of the metrics, and the values are corresponding results. """ if len(self.results) == 0: warnings.warn( f'{self.__class__.__name__} got empty `self.results`. Please ' 'ensure that the processed results are properly added into ' '`self.results` in `process` method.') results = collect_tracking_results(self.results, self.collect_device) if is_main_process(): _metrics = self.compute_metrics(results) # type: ignore # Add prefix to metric names if self.prefix: _metrics = { '/'.join((self.prefix, k)): v for k, v in _metrics.items() } metrics = [_metrics] else: metrics = [None] # type: ignore broadcast_object_list(metrics) # reset the results list self.results.clear() return metrics[0]