# Copyright (c) OpenMMLab. All rights reserved. from typing import List, Optional, Union import torch from mmengine import ConfigDict from mmengine.structures import InstanceData from scipy.optimize import linear_sum_assignment from torch import Tensor from mmdet.registry import TASK_UTILS from .assign_result import AssignResult from .base_assigner import BaseAssigner @TASK_UTILS.register_module() class HungarianAssigner(BaseAssigner): """Computes one-to-one matching between predictions and ground truth. This class computes an assignment between the targets and the predictions based on the costs. The costs are weighted sum of some components. For DETR the costs are weighted sum of classification cost, regression L1 cost and regression iou cost. The targets don't include the no_object, so generally there are more predictions than targets. After the one-to-one matching, the un-matched are treated as backgrounds. Thus each query prediction will be assigned with `0` or a positive integer indicating the ground truth index: - 0: negative sample, no assigned gt - positive integer: positive sample, index (1-based) of assigned gt Args: match_costs (:obj:`ConfigDict` or dict or \ List[Union[:obj:`ConfigDict`, dict]]): Match cost configs. """ def __init__( self, match_costs: Union[List[Union[dict, ConfigDict]], dict, ConfigDict] ) -> None: if isinstance(match_costs, dict): match_costs = [match_costs] elif isinstance(match_costs, list): assert len(match_costs) > 0, \ 'match_costs must not be a empty list.' self.match_costs = [ TASK_UTILS.build(match_cost) for match_cost in match_costs ] def assign(self, pred_instances: InstanceData, gt_instances: InstanceData, img_meta: Optional[dict] = None, **kwargs) -> AssignResult: """Computes one-to-one matching based on the weighted costs. This method assign each query prediction to a ground truth or background. The `assigned_gt_inds` with -1 means don't care, 0 means negative sample, and positive number is the index (1-based) of assigned gt. The assignment is done in the following steps, the order matters. 1. assign every prediction to -1 2. compute the weighted costs 3. do Hungarian matching on CPU based on the costs 4. assign all to 0 (background) first, then for each matched pair between predictions and gts, treat this prediction as foreground and assign the corresponding gt index (plus 1) to it. Args: pred_instances (:obj:`InstanceData`): Instances of model predictions. It includes ``priors``, and the priors can be anchors or points, or the bboxes predicted by the previous stage, has shape (n, 4). The bboxes predicted by the current model or stage will be named ``bboxes``, ``labels``, and ``scores``, the same as the ``InstanceData`` in other places. It may includes ``masks``, with shape (n, h, w) or (n, l). gt_instances (:obj:`InstanceData`): Ground truth of instance annotations. It usually includes ``bboxes``, with shape (k, 4), ``labels``, with shape (k, ) and ``masks``, with shape (k, h, w) or (k, l). img_meta (dict): Image information. Returns: :obj:`AssignResult`: The assigned result. """ assert isinstance(gt_instances.labels, Tensor) num_gts, num_preds = len(gt_instances), len(pred_instances) gt_labels = gt_instances.labels device = gt_labels.device # 1. assign -1 by default assigned_gt_inds = torch.full((num_preds, ), -1, dtype=torch.long, device=device) assigned_labels = torch.full((num_preds, ), -1, dtype=torch.long, device=device) if num_gts == 0 or num_preds == 0: # No ground truth or boxes, return empty assignment if num_gts == 0: # No ground truth, assign all to background assigned_gt_inds[:] = 0 return AssignResult( num_gts=num_gts, gt_inds=assigned_gt_inds, max_overlaps=None, labels=assigned_labels) # 2. compute weighted cost cost_list = [] for match_cost in self.match_costs: cost = match_cost( pred_instances=pred_instances, gt_instances=gt_instances, img_meta=img_meta) cost_list.append(cost) cost = torch.stack(cost_list).sum(dim=0) # 3. do Hungarian matching on CPU using linear_sum_assignment cost = cost.detach().cpu() if linear_sum_assignment is None: raise ImportError('Please run "pip install scipy" ' 'to install scipy first.') matched_row_inds, matched_col_inds = linear_sum_assignment(cost) matched_row_inds = torch.from_numpy(matched_row_inds).to(device) matched_col_inds = torch.from_numpy(matched_col_inds).to(device) # 4. assign backgrounds and foregrounds # assign all indices to backgrounds first assigned_gt_inds[:] = 0 # assign foregrounds based on matching results assigned_gt_inds[matched_row_inds] = matched_col_inds + 1 assigned_labels[matched_row_inds] = gt_labels[matched_col_inds] return AssignResult( num_gts=num_gts, gt_inds=assigned_gt_inds, max_overlaps=None, labels=assigned_labels)