Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Optional | |
import torch | |
from mmengine.structures import InstanceData | |
from mmdet.registry import TASK_UTILS | |
from .assign_result import AssignResult | |
from .max_iou_assigner import MaxIoUAssigner | |
class MultiInstanceAssigner(MaxIoUAssigner): | |
"""Assign a corresponding gt bbox or background to each proposal bbox. If | |
we need to use a proposal box to generate multiple predict boxes, | |
`MultiInstanceAssigner` can assign multiple gt to each proposal box. | |
Args: | |
num_instance (int): How many bboxes are predicted by each proposal box. | |
""" | |
def __init__(self, num_instance: int = 2, **kwargs): | |
super().__init__(**kwargs) | |
self.num_instance = num_instance | |
def assign(self, | |
pred_instances: InstanceData, | |
gt_instances: InstanceData, | |
gt_instances_ignore: Optional[InstanceData] = None, | |
**kwargs) -> AssignResult: | |
"""Assign gt to bboxes. | |
This method assign gt bboxes to every bbox (proposal/anchor), each bbox | |
is assigned a set of gts, and the number of gts in this set is defined | |
by `self.num_instance`. | |
Args: | |
pred_instances (:obj:`InstanceData`): Instances of model | |
predictions. It includes ``priors``, and the priors can | |
be anchors or points, or the bboxes predicted by the | |
previous stage, has shape (n, 4). The bboxes predicted by | |
the current model or stage will be named ``bboxes``, | |
``labels``, and ``scores``, the same as the ``InstanceData`` | |
in other places. | |
gt_instances (:obj:`InstanceData`): Ground truth of instance | |
annotations. It usually includes ``bboxes``, with shape (k, 4), | |
and ``labels``, with shape (k, ). | |
gt_instances_ignore (:obj:`InstanceData`, optional): Instances | |
to be ignored during training. It includes ``bboxes`` | |
attribute data that is ignored during training and testing. | |
Defaults to None. | |
Returns: | |
:obj:`AssignResult`: The assign result. | |
""" | |
gt_bboxes = gt_instances.bboxes | |
priors = pred_instances.priors | |
# Set the FG label to 1 and add ignored annotations | |
gt_labels = gt_instances.labels + 1 | |
if gt_instances_ignore is not None: | |
gt_bboxes_ignore = gt_instances_ignore.bboxes | |
if hasattr(gt_instances_ignore, 'labels'): | |
gt_labels_ignore = gt_instances_ignore.labels | |
else: | |
gt_labels_ignore = torch.ones_like(gt_bboxes_ignore)[:, 0] * -1 | |
else: | |
gt_bboxes_ignore = None | |
gt_labels_ignore = None | |
assign_on_cpu = True if (self.gpu_assign_thr > 0) and ( | |
gt_bboxes.shape[0] > self.gpu_assign_thr) else False | |
# compute overlap and assign gt on CPU when number of GT is large | |
if assign_on_cpu: | |
device = priors.device | |
priors = priors.cpu() | |
gt_bboxes = gt_bboxes.cpu() | |
gt_labels = gt_labels.cpu() | |
if gt_bboxes_ignore is not None: | |
gt_bboxes_ignore = gt_bboxes_ignore.cpu() | |
gt_labels_ignore = gt_labels_ignore.cpu() | |
if gt_bboxes_ignore is not None: | |
all_bboxes = torch.cat([gt_bboxes, gt_bboxes_ignore], dim=0) | |
all_labels = torch.cat([gt_labels, gt_labels_ignore], dim=0) | |
else: | |
all_bboxes = gt_bboxes | |
all_labels = gt_labels | |
all_priors = torch.cat([priors, all_bboxes], dim=0) | |
overlaps_normal = self.iou_calculator( | |
all_priors, all_bboxes, mode='iou') | |
overlaps_ignore = self.iou_calculator( | |
all_priors, all_bboxes, mode='iof') | |
gt_ignore_mask = all_labels.eq(-1).repeat(all_priors.shape[0], 1) | |
overlaps_normal = overlaps_normal * ~gt_ignore_mask | |
overlaps_ignore = overlaps_ignore * gt_ignore_mask | |
overlaps_normal, overlaps_normal_indices = overlaps_normal.sort( | |
descending=True, dim=1) | |
overlaps_ignore, overlaps_ignore_indices = overlaps_ignore.sort( | |
descending=True, dim=1) | |
# select the roi with the higher score | |
max_overlaps_normal = overlaps_normal[:, :self.num_instance].flatten() | |
gt_assignment_normal = overlaps_normal_indices[:, :self. | |
num_instance].flatten() | |
max_overlaps_ignore = overlaps_ignore[:, :self.num_instance].flatten() | |
gt_assignment_ignore = overlaps_ignore_indices[:, :self. | |
num_instance].flatten() | |
# ignore or not | |
ignore_assign_mask = (max_overlaps_normal < self.pos_iou_thr) * ( | |
max_overlaps_ignore > max_overlaps_normal) | |
overlaps = (max_overlaps_normal * ~ignore_assign_mask) + ( | |
max_overlaps_ignore * ignore_assign_mask) | |
gt_assignment = (gt_assignment_normal * ~ignore_assign_mask) + ( | |
gt_assignment_ignore * ignore_assign_mask) | |
assigned_labels = all_labels[gt_assignment] | |
fg_mask = (overlaps >= self.pos_iou_thr) * (assigned_labels != -1) | |
bg_mask = (overlaps < self.neg_iou_thr) * (overlaps >= 0) | |
assigned_labels[fg_mask] = 1 | |
assigned_labels[bg_mask] = 0 | |
overlaps = overlaps.reshape(-1, self.num_instance) | |
gt_assignment = gt_assignment.reshape(-1, self.num_instance) | |
assigned_labels = assigned_labels.reshape(-1, self.num_instance) | |
assign_result = AssignResult( | |
num_gts=all_bboxes.size(0), | |
gt_inds=gt_assignment, | |
max_overlaps=overlaps, | |
labels=assigned_labels) | |
if assign_on_cpu: | |
assign_result.gt_inds = assign_result.gt_inds.to(device) | |
assign_result.max_overlaps = assign_result.max_overlaps.to(device) | |
if assign_result.labels is not None: | |
assign_result.labels = assign_result.labels.to(device) | |
return assign_result | |