Spaces:
Running
Running
from typing import Dict, List, Union | |
from pathlib import Path | |
import datasets | |
import torch | |
import evaluate | |
import json | |
from tqdm import tqdm | |
from detection_metrics.pycocotools.coco import COCO | |
from detection_metrics.coco_evaluate import COCOEvaluator | |
from detection_metrics.utils import _TYPING_PREDICTION, _TYPING_REFERENCE | |
_DESCRIPTION = "This class evaluates object detection models using the COCO dataset \ | |
and its evaluation metrics." | |
_HOMEPAGE = "https://cocodataset.org" | |
_CITATION = """ | |
@misc{lin2015microsoft, \ | |
title={Microsoft COCO: Common Objects in Context}, | |
author={Tsung-Yi Lin and Michael Maire and Serge Belongie and Lubomir Bourdev and \ | |
Ross Girshick and James Hays and Pietro Perona and Deva Ramanan and C. Lawrence Zitnick \ | |
and Piotr Dollár}, | |
year={2015}, | |
eprint={1405.0312}, | |
archivePrefix={arXiv}, | |
primaryClass={cs.CV} | |
} | |
""" | |
_REFERENCE_URLS = [ | |
"https://ieeexplore.ieee.org/abstract/document/9145130", | |
"https://www.mdpi.com/2079-9292/10/3/279", | |
"https://cocodataset.org/#detection-eval", | |
] | |
_KWARGS_DESCRIPTION = """\ | |
Computes COCO metrics for object detection: AP(mAP) and its variants. | |
Args: | |
coco (COCO): COCO Evaluator object for evaluating predictions. | |
**kwargs: Additional keyword arguments forwarded to evaluate.Metrics. | |
""" | |
class EvaluateObjectDetection(evaluate.Metric): | |
""" | |
Class for evaluating object detection models. | |
""" | |
def __init__(self, json_gt: Union[Path, Dict], iou_type: str = "bbox", **kwargs): | |
""" | |
Initializes the EvaluateObjectDetection class. | |
Args: | |
json_gt: JSON with ground-truth annotations in COCO format. | |
# coco_groundtruth (COCO): COCO Evaluator object for evaluating predictions. | |
**kwargs: Additional keyword arguments forwarded to evaluate.Metrics. | |
""" | |
super().__init__(**kwargs) | |
# Create COCO object from ground-truth annotations | |
if isinstance(json_gt, Path): | |
assert json_gt.exists(), f"Path {json_gt} does not exist." | |
with open(json_gt) as f: | |
json_data = json.load(f) | |
elif isinstance(json_gt, dict): | |
json_data = json_gt | |
coco = COCO(json_data) | |
self.coco_evaluator = COCOEvaluator(coco, [iou_type]) | |
def remove_classes(self, classes_to_remove: List[str]): | |
to_remove = [c.upper() for c in classes_to_remove] | |
cats = {} | |
for id, cat in self.coco_evaluator.coco_eval["bbox"].cocoGt.cats.items(): | |
if cat["name"].upper() not in to_remove: | |
cats[id] = cat | |
self.coco_evaluator.coco_eval["bbox"].cocoGt.cats = cats | |
self.coco_evaluator.coco_gt.cats = cats | |
self.coco_evaluator.coco_gt.dataset["categories"] = list(cats.values()) | |
self.coco_evaluator.coco_eval["bbox"].params.catIds = [c["id"] for c in cats.values()] | |
def _info(self): | |
""" | |
Returns the MetricInfo object with information about the module. | |
Returns: | |
evaluate.MetricInfo: Metric information object. | |
""" | |
return evaluate.MetricInfo( | |
module_type="metric", | |
description=_DESCRIPTION, | |
citation=_CITATION, | |
inputs_description=_KWARGS_DESCRIPTION, | |
# This defines the format of each prediction and reference | |
features=datasets.Features( | |
{ | |
"predictions": [ | |
datasets.Features( | |
{ | |
"scores": datasets.Sequence(datasets.Value("float")), | |
"labels": datasets.Sequence(datasets.Value("int64")), | |
"boxes": datasets.Sequence( | |
datasets.Sequence(datasets.Value("float")) | |
), | |
} | |
) | |
], | |
"references": [ | |
datasets.Features( | |
{ | |
"image_id": datasets.Sequence(datasets.Value("int64")), | |
} | |
) | |
], | |
} | |
), | |
# Homepage of the module for documentation | |
homepage=_HOMEPAGE, | |
# Additional links to the codebase or references | |
reference_urls=_REFERENCE_URLS, | |
) | |
def _preprocess( | |
self, predictions: List[Dict[str, torch.Tensor]] | |
) -> List[_TYPING_PREDICTION]: | |
""" | |
Preprocesses the predictions before computing the scores. | |
Args: | |
predictions (List[Dict[str, torch.Tensor]]): A list of prediction dicts. | |
Returns: | |
List[_TYPING_PREDICTION]: A list of preprocessed prediction dicts. | |
""" | |
processed_predictions = [] | |
for pred in predictions: | |
processed_pred: _TYPING_PREDICTION = {} | |
for k, val in pred.items(): | |
if isinstance(val, torch.Tensor): | |
val = val.detach().cpu().tolist() | |
if k == "labels": | |
val = list(map(int, val)) | |
processed_pred[k] = val | |
processed_predictions.append(processed_pred) | |
return processed_predictions | |
def _clear_predictions(self, predictions): | |
# Remove unnecessary keys from predictions | |
required = ["scores", "labels", "boxes"] | |
ret = [] | |
for prediction in predictions: | |
ret.append({k: v for k, v in prediction.items() if k in required}) | |
return ret | |
def _clear_references(self, references): | |
required = [""] | |
ret = [] | |
for ref in references: | |
ret.append({k: v for k, v in ref.items() if k in required}) | |
return ret | |
def add(self, *, prediction = None, reference = None, **kwargs): | |
""" | |
Preprocesses the predictions and references and calls the parent class function. | |
Args: | |
prediction: A list of prediction dicts. | |
reference: A list of reference dicts. | |
**kwargs: Additional keyword arguments. | |
""" | |
if prediction is not None: | |
prediction = self._clear_predictions(prediction) | |
prediction = self._preprocess(prediction) | |
res = {} # {image_id} : prediction | |
for output, target in zip(prediction, reference): | |
res[target["image_id"][0]] = output | |
self.coco_evaluator.update(res) | |
super(evaluate.Metric, self).add(prediction=prediction, references=reference, **kwargs) | |
def _compute( | |
self, | |
predictions: List[List[_TYPING_PREDICTION]], | |
references: List[List[_TYPING_REFERENCE]], | |
) -> Dict[str, Dict[str, float]]: | |
""" | |
Returns the evaluation scores. | |
Args: | |
predictions (List[List[_TYPING_PREDICTION]]): A list of predictions. | |
references (List[List[_TYPING_REFERENCE]]): A list of references. | |
Returns: | |
Dict: A dictionary containing evaluation scores. | |
""" | |
print("Synchronizing processes") | |
self.coco_evaluator.synchronize_between_processes() | |
print("Accumulating values") | |
self.coco_evaluator.accumulate() | |
print("Summarizing results") | |
self.coco_evaluator.summarize() | |
stats = self.coco_evaluator.get_results() | |
return stats | |