TTP / mmdet /evaluation /metrics /refseg_metric.py
KyanChen's picture
Upload 1861 files
3b96cb1
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence
import torch
from mmengine.evaluator import BaseMetric
from mmdet.registry import METRICS
@METRICS.register_module()
class RefSegMetric(BaseMetric):
"""Referring Expression Segmentation Metric."""
def __init__(self, metric: Sequence = ('cIoU', 'mIoU'), **kwargs):
super().__init__(**kwargs)
assert set(metric).issubset(['cIoU', 'mIoU']), \
f'Only support cIoU and mIoU, but got {metric}'
assert len(metric) > 0, 'metrics should not be empty'
self.metrics = metric
def compute_iou(self, pred_seg: torch.Tensor,
gt_seg: torch.Tensor) -> tuple:
overlap = pred_seg & gt_seg
union = pred_seg | gt_seg
return overlap, union
def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
"""Process one batch of data and data_samples.
The processed results should be stored in ``self.results``, which will
be used to compute the metrics when all batches have been processed.
Args:
data_batch (dict): A batch of data from the dataloader.
data_samples (Sequence[dict]): A batch of outputs from the model.
"""
for data_sample in data_samples:
pred_label = data_sample['pred_instances']['masks'].bool()
label = data_sample['gt_masks'].to_tensor(
pred_label.dtype, pred_label.device).bool()
# calculate iou
overlap, union = self.compute_iou(pred_label, label)
bs = len(pred_label)
iou = overlap.reshape(bs, -1).sum(-1) * 1.0 / union.reshape(
bs, -1).sum(-1)
iou = torch.nan_to_num_(iou, nan=0.0)
self.results.append((overlap.sum(), union.sum(), iou.sum(), bs))
def compute_metrics(self, results: list) -> dict:
results = tuple(zip(*results))
assert len(results) == 4
cum_i = sum(results[0])
cum_u = sum(results[1])
iou = sum(results[2])
seg_total = sum(results[3])
metrics = {}
if 'cIoU' in self.metrics:
metrics['cIoU'] = cum_i * 100 / cum_u
if 'mIoU' in self.metrics:
metrics['mIoU'] = iou * 100 / seg_total
return metrics