detection_metrics / detection_metrics.py
rapadilla's picture
first commit
a52e8a5
from typing import Dict, List, Union
from pathlib import Path
import datasets
import torch
import evaluate
import json
from tqdm import tqdm
from detection_metrics.pycocotools.coco import COCO
from detection_metrics.coco_evaluate import COCOEvaluator
from detection_metrics.utils import _TYPING_PREDICTION, _TYPING_REFERENCE
_DESCRIPTION = "This class evaluates object detection models using the COCO dataset \
and its evaluation metrics."
_HOMEPAGE = "https://cocodataset.org"
_CITATION = """
@misc{lin2015microsoft, \
title={Microsoft COCO: Common Objects in Context},
author={Tsung-Yi Lin and Michael Maire and Serge Belongie and Lubomir Bourdev and \
Ross Girshick and James Hays and Pietro Perona and Deva Ramanan and C. Lawrence Zitnick \
and Piotr Dollár},
year={2015},
eprint={1405.0312},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
"""
_REFERENCE_URLS = [
"https://ieeexplore.ieee.org/abstract/document/9145130",
"https://www.mdpi.com/2079-9292/10/3/279",
"https://cocodataset.org/#detection-eval",
]
_KWARGS_DESCRIPTION = """\
Computes COCO metrics for object detection: AP(mAP) and its variants.
Args:
coco (COCO): COCO Evaluator object for evaluating predictions.
**kwargs: Additional keyword arguments forwarded to evaluate.Metrics.
"""
class EvaluateObjectDetection(evaluate.Metric):
"""
Class for evaluating object detection models.
"""
def __init__(self, json_gt: Union[Path, Dict], iou_type: str = "bbox", **kwargs):
"""
Initializes the EvaluateObjectDetection class.
Args:
json_gt: JSON with ground-truth annotations in COCO format.
# coco_groundtruth (COCO): COCO Evaluator object for evaluating predictions.
**kwargs: Additional keyword arguments forwarded to evaluate.Metrics.
"""
super().__init__(**kwargs)
# Create COCO object from ground-truth annotations
if isinstance(json_gt, Path):
assert json_gt.exists(), f"Path {json_gt} does not exist."
with open(json_gt) as f:
json_data = json.load(f)
elif isinstance(json_gt, dict):
json_data = json_gt
coco = COCO(json_data)
self.coco_evaluator = COCOEvaluator(coco, [iou_type])
def remove_classes(self, classes_to_remove: List[str]):
to_remove = [c.upper() for c in classes_to_remove]
cats = {}
for id, cat in self.coco_evaluator.coco_eval["bbox"].cocoGt.cats.items():
if cat["name"].upper() not in to_remove:
cats[id] = cat
self.coco_evaluator.coco_eval["bbox"].cocoGt.cats = cats
self.coco_evaluator.coco_gt.cats = cats
self.coco_evaluator.coco_gt.dataset["categories"] = list(cats.values())
self.coco_evaluator.coco_eval["bbox"].params.catIds = [c["id"] for c in cats.values()]
def _info(self):
"""
Returns the MetricInfo object with information about the module.
Returns:
evaluate.MetricInfo: Metric information object.
"""
return evaluate.MetricInfo(
module_type="metric",
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
# This defines the format of each prediction and reference
features=datasets.Features(
{
"predictions": [
datasets.Features(
{
"scores": datasets.Sequence(datasets.Value("float")),
"labels": datasets.Sequence(datasets.Value("int64")),
"boxes": datasets.Sequence(
datasets.Sequence(datasets.Value("float"))
),
}
)
],
"references": [
datasets.Features(
{
"image_id": datasets.Sequence(datasets.Value("int64")),
}
)
],
}
),
# Homepage of the module for documentation
homepage=_HOMEPAGE,
# Additional links to the codebase or references
reference_urls=_REFERENCE_URLS,
)
def _preprocess(
self, predictions: List[Dict[str, torch.Tensor]]
) -> List[_TYPING_PREDICTION]:
"""
Preprocesses the predictions before computing the scores.
Args:
predictions (List[Dict[str, torch.Tensor]]): A list of prediction dicts.
Returns:
List[_TYPING_PREDICTION]: A list of preprocessed prediction dicts.
"""
processed_predictions = []
for pred in predictions:
processed_pred: _TYPING_PREDICTION = {}
for k, val in pred.items():
if isinstance(val, torch.Tensor):
val = val.detach().cpu().tolist()
if k == "labels":
val = list(map(int, val))
processed_pred[k] = val
processed_predictions.append(processed_pred)
return processed_predictions
def _clear_predictions(self, predictions):
# Remove unnecessary keys from predictions
required = ["scores", "labels", "boxes"]
ret = []
for prediction in predictions:
ret.append({k: v for k, v in prediction.items() if k in required})
return ret
def _clear_references(self, references):
required = [""]
ret = []
for ref in references:
ret.append({k: v for k, v in ref.items() if k in required})
return ret
def add(self, *, prediction = None, reference = None, **kwargs):
"""
Preprocesses the predictions and references and calls the parent class function.
Args:
prediction: A list of prediction dicts.
reference: A list of reference dicts.
**kwargs: Additional keyword arguments.
"""
if prediction is not None:
prediction = self._clear_predictions(prediction)
prediction = self._preprocess(prediction)
res = {} # {image_id} : prediction
for output, target in zip(prediction, reference):
res[target["image_id"][0]] = output
self.coco_evaluator.update(res)
super(evaluate.Metric, self).add(prediction=prediction, references=reference, **kwargs)
def _compute(
self,
predictions: List[List[_TYPING_PREDICTION]],
references: List[List[_TYPING_REFERENCE]],
) -> Dict[str, Dict[str, float]]:
"""
Returns the evaluation scores.
Args:
predictions (List[List[_TYPING_PREDICTION]]): A list of predictions.
references (List[List[_TYPING_REFERENCE]]): A list of references.
Returns:
Dict: A dictionary containing evaluation scores.
"""
print("Synchronizing processes")
self.coco_evaluator.synchronize_between_processes()
print("Accumulating values")
self.coco_evaluator.accumulate()
print("Summarizing results")
self.coco_evaluator.summarize()
stats = self.coco_evaluator.get_results()
return stats