Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import warnings | |
from typing import Sequence | |
from mmengine.dist import broadcast_object_list, is_main_process | |
from mmdet.registry import METRICS | |
from .base_video_metric import collect_tracking_results | |
from .coco_metric import CocoMetric | |
class CocoVideoMetric(CocoMetric): | |
"""COCO evaluation metric. | |
Evaluate AR, AP, and mAP for detection tasks including proposal/box | |
detection and instance segmentation. Please refer to | |
https://cocodataset.org/#detection-eval for more details. | |
""" | |
def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: | |
"""Process one batch of data samples and predictions. | |
The processed results should be stored in ``self.results``, which will | |
be used to compute the metrics when all batches have been processed. | |
Args: | |
data_batch (dict): A batch of data from the dataloader. | |
data_samples (Sequence[dict]): A batch of data samples that | |
contain annotations and predictions. | |
""" | |
for track_data_sample in data_samples: | |
video_data_samples = track_data_sample['video_data_samples'] | |
ori_video_len = video_data_samples[0].ori_video_length | |
video_len = len(video_data_samples) | |
if ori_video_len == video_len: | |
# video process | |
for frame_id in range(video_len): | |
img_data_sample = video_data_samples[frame_id].to_dict() | |
super().process(None, [img_data_sample]) | |
else: | |
# image process | |
img_data_sample = video_data_samples[0].to_dict() | |
super().process(None, [img_data_sample]) | |
def evaluate(self, size: int = 1) -> dict: | |
"""Evaluate the model performance of the whole dataset after processing | |
all batches. | |
Args: | |
size (int): Length of the entire validation dataset. | |
Returns: | |
dict: Evaluation metrics dict on the val dataset. The keys are the | |
names of the metrics, and the values are corresponding results. | |
""" | |
if len(self.results) == 0: | |
warnings.warn( | |
f'{self.__class__.__name__} got empty `self.results`. Please ' | |
'ensure that the processed results are properly added into ' | |
'`self.results` in `process` method.') | |
results = collect_tracking_results(self.results, self.collect_device) | |
if is_main_process(): | |
_metrics = self.compute_metrics(results) # type: ignore | |
# Add prefix to metric names | |
if self.prefix: | |
_metrics = { | |
'/'.join((self.prefix, k)): v | |
for k, v in _metrics.items() | |
} | |
metrics = [_metrics] | |
else: | |
metrics = [None] # type: ignore | |
broadcast_object_list(metrics) | |
# reset the results list | |
self.results.clear() | |
return metrics[0] | |