Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Sequence | |
import torch | |
from mmengine.evaluator import BaseMetric | |
from mmdet.registry import METRICS | |
class RefSegMetric(BaseMetric): | |
"""Referring Expression Segmentation Metric.""" | |
def __init__(self, metric: Sequence = ('cIoU', 'mIoU'), **kwargs): | |
super().__init__(**kwargs) | |
assert set(metric).issubset(['cIoU', 'mIoU']), \ | |
f'Only support cIoU and mIoU, but got {metric}' | |
assert len(metric) > 0, 'metrics should not be empty' | |
self.metrics = metric | |
def compute_iou(self, pred_seg: torch.Tensor, | |
gt_seg: torch.Tensor) -> tuple: | |
overlap = pred_seg & gt_seg | |
union = pred_seg | gt_seg | |
return overlap, union | |
def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: | |
"""Process one batch of data and data_samples. | |
The processed results should be stored in ``self.results``, which will | |
be used to compute the metrics when all batches have been processed. | |
Args: | |
data_batch (dict): A batch of data from the dataloader. | |
data_samples (Sequence[dict]): A batch of outputs from the model. | |
""" | |
for data_sample in data_samples: | |
pred_label = data_sample['pred_instances']['masks'].bool() | |
label = data_sample['gt_masks'].to_tensor( | |
pred_label.dtype, pred_label.device).bool() | |
# calculate iou | |
overlap, union = self.compute_iou(pred_label, label) | |
bs = len(pred_label) | |
iou = overlap.reshape(bs, -1).sum(-1) * 1.0 / union.reshape( | |
bs, -1).sum(-1) | |
iou = torch.nan_to_num_(iou, nan=0.0) | |
self.results.append((overlap.sum(), union.sum(), iou.sum(), bs)) | |
def compute_metrics(self, results: list) -> dict: | |
results = tuple(zip(*results)) | |
assert len(results) == 4 | |
cum_i = sum(results[0]) | |
cum_u = sum(results[1]) | |
iou = sum(results[2]) | |
seg_total = sum(results[3]) | |
metrics = {} | |
if 'cIoU' in self.metrics: | |
metrics['cIoU'] = cum_i * 100 / cum_u | |
if 'mIoU' in self.metrics: | |
metrics['mIoU'] = iou * 100 / seg_total | |
return metrics | |