KyanChen's picture
init
f549064
raw
history blame
6.19 kB
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
import torch
from mmengine.structures import InstanceData
from mmdet.registry import TASK_UTILS
from .assign_result import AssignResult
from .max_iou_assigner import MaxIoUAssigner
@TASK_UTILS.register_module()
class MultiInstanceAssigner(MaxIoUAssigner):
"""Assign a corresponding gt bbox or background to each proposal bbox. If
we need to use a proposal box to generate multiple predict boxes,
`MultiInstanceAssigner` can assign multiple gt to each proposal box.
Args:
num_instance (int): How many bboxes are predicted by each proposal box.
"""
def __init__(self, num_instance: int = 2, **kwargs):
super().__init__(**kwargs)
self.num_instance = num_instance
def assign(self,
pred_instances: InstanceData,
gt_instances: InstanceData,
gt_instances_ignore: Optional[InstanceData] = None,
**kwargs) -> AssignResult:
"""Assign gt to bboxes.
This method assign gt bboxes to every bbox (proposal/anchor), each bbox
is assigned a set of gts, and the number of gts in this set is defined
by `self.num_instance`.
Args:
pred_instances (:obj:`InstanceData`): Instances of model
predictions. It includes ``priors``, and the priors can
be anchors or points, or the bboxes predicted by the
previous stage, has shape (n, 4). The bboxes predicted by
the current model or stage will be named ``bboxes``,
``labels``, and ``scores``, the same as the ``InstanceData``
in other places.
gt_instances (:obj:`InstanceData`): Ground truth of instance
annotations. It usually includes ``bboxes``, with shape (k, 4),
and ``labels``, with shape (k, ).
gt_instances_ignore (:obj:`InstanceData`, optional): Instances
to be ignored during training. It includes ``bboxes``
attribute data that is ignored during training and testing.
Defaults to None.
Returns:
:obj:`AssignResult`: The assign result.
"""
gt_bboxes = gt_instances.bboxes
priors = pred_instances.priors
# Set the FG label to 1 and add ignored annotations
gt_labels = gt_instances.labels + 1
if gt_instances_ignore is not None:
gt_bboxes_ignore = gt_instances_ignore.bboxes
if hasattr(gt_instances_ignore, 'labels'):
gt_labels_ignore = gt_instances_ignore.labels
else:
gt_labels_ignore = torch.ones_like(gt_bboxes_ignore)[:, 0] * -1
else:
gt_bboxes_ignore = None
gt_labels_ignore = None
assign_on_cpu = True if (self.gpu_assign_thr > 0) and (
gt_bboxes.shape[0] > self.gpu_assign_thr) else False
# compute overlap and assign gt on CPU when number of GT is large
if assign_on_cpu:
device = priors.device
priors = priors.cpu()
gt_bboxes = gt_bboxes.cpu()
gt_labels = gt_labels.cpu()
if gt_bboxes_ignore is not None:
gt_bboxes_ignore = gt_bboxes_ignore.cpu()
gt_labels_ignore = gt_labels_ignore.cpu()
if gt_bboxes_ignore is not None:
all_bboxes = torch.cat([gt_bboxes, gt_bboxes_ignore], dim=0)
all_labels = torch.cat([gt_labels, gt_labels_ignore], dim=0)
else:
all_bboxes = gt_bboxes
all_labels = gt_labels
all_priors = torch.cat([priors, all_bboxes], dim=0)
overlaps_normal = self.iou_calculator(
all_priors, all_bboxes, mode='iou')
overlaps_ignore = self.iou_calculator(
all_priors, all_bboxes, mode='iof')
gt_ignore_mask = all_labels.eq(-1).repeat(all_priors.shape[0], 1)
overlaps_normal = overlaps_normal * ~gt_ignore_mask
overlaps_ignore = overlaps_ignore * gt_ignore_mask
overlaps_normal, overlaps_normal_indices = overlaps_normal.sort(
descending=True, dim=1)
overlaps_ignore, overlaps_ignore_indices = overlaps_ignore.sort(
descending=True, dim=1)
# select the roi with the higher score
max_overlaps_normal = overlaps_normal[:, :self.num_instance].flatten()
gt_assignment_normal = overlaps_normal_indices[:, :self.
num_instance].flatten()
max_overlaps_ignore = overlaps_ignore[:, :self.num_instance].flatten()
gt_assignment_ignore = overlaps_ignore_indices[:, :self.
num_instance].flatten()
# ignore or not
ignore_assign_mask = (max_overlaps_normal < self.pos_iou_thr) * (
max_overlaps_ignore > max_overlaps_normal)
overlaps = (max_overlaps_normal * ~ignore_assign_mask) + (
max_overlaps_ignore * ignore_assign_mask)
gt_assignment = (gt_assignment_normal * ~ignore_assign_mask) + (
gt_assignment_ignore * ignore_assign_mask)
assigned_labels = all_labels[gt_assignment]
fg_mask = (overlaps >= self.pos_iou_thr) * (assigned_labels != -1)
bg_mask = (overlaps < self.neg_iou_thr) * (overlaps >= 0)
assigned_labels[fg_mask] = 1
assigned_labels[bg_mask] = 0
overlaps = overlaps.reshape(-1, self.num_instance)
gt_assignment = gt_assignment.reshape(-1, self.num_instance)
assigned_labels = assigned_labels.reshape(-1, self.num_instance)
assign_result = AssignResult(
num_gts=all_bboxes.size(0),
gt_inds=gt_assignment,
max_overlaps=overlaps,
labels=assigned_labels)
if assign_on_cpu:
assign_result.gt_inds = assign_result.gt_inds.to(device)
assign_result.max_overlaps = assign_result.max_overlaps.to(device)
if assign_result.labels is not None:
assign_result.labels = assign_result.labels.to(device)
return assign_result