TTP / mmdet /evaluation /metrics /coco_video_metric.py
KyanChen's picture
Upload 1861 files
3b96cb1
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import Sequence
from mmengine.dist import broadcast_object_list, is_main_process
from mmdet.registry import METRICS
from .base_video_metric import collect_tracking_results
from .coco_metric import CocoMetric
@METRICS.register_module()
class CocoVideoMetric(CocoMetric):
"""COCO evaluation metric.
Evaluate AR, AP, and mAP for detection tasks including proposal/box
detection and instance segmentation. Please refer to
https://cocodataset.org/#detection-eval for more details.
"""
def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
"""Process one batch of data samples and predictions.
The processed results should be stored in ``self.results``, which will
be used to compute the metrics when all batches have been processed.
Args:
data_batch (dict): A batch of data from the dataloader.
data_samples (Sequence[dict]): A batch of data samples that
contain annotations and predictions.
"""
for track_data_sample in data_samples:
video_data_samples = track_data_sample['video_data_samples']
ori_video_len = video_data_samples[0].ori_video_length
video_len = len(video_data_samples)
if ori_video_len == video_len:
# video process
for frame_id in range(video_len):
img_data_sample = video_data_samples[frame_id].to_dict()
super().process(None, [img_data_sample])
else:
# image process
img_data_sample = video_data_samples[0].to_dict()
super().process(None, [img_data_sample])
def evaluate(self, size: int = 1) -> dict:
"""Evaluate the model performance of the whole dataset after processing
all batches.
Args:
size (int): Length of the entire validation dataset.
Returns:
dict: Evaluation metrics dict on the val dataset. The keys are the
names of the metrics, and the values are corresponding results.
"""
if len(self.results) == 0:
warnings.warn(
f'{self.__class__.__name__} got empty `self.results`. Please '
'ensure that the processed results are properly added into '
'`self.results` in `process` method.')
results = collect_tracking_results(self.results, self.collect_device)
if is_main_process():
_metrics = self.compute_metrics(results) # type: ignore
# Add prefix to metric names
if self.prefix:
_metrics = {
'/'.join((self.prefix, k)): v
for k, v in _metrics.items()
}
metrics = [_metrics]
else:
metrics = [None] # type: ignore
broadcast_object_list(metrics)
# reset the results list
self.results.clear()
return metrics[0]