# Copyright (c) OpenMMLab. All rights reserved. from typing import Optional import torch from mmengine.structures import InstanceData from mmdet.registry import TASK_UTILS from .assign_result import AssignResult from .base_assigner import BaseAssigner @TASK_UTILS.register_module() class PointAssigner(BaseAssigner): """Assign a corresponding gt bbox or background to each point. Each proposals will be assigned with `0`, or a positive integer indicating the ground truth index. - 0: negative sample, no assigned gt - positive integer: positive sample, index (1-based) of assigned gt """ def __init__(self, scale: int = 4, pos_num: int = 3) -> None: self.scale = scale self.pos_num = pos_num def assign(self, pred_instances: InstanceData, gt_instances: InstanceData, gt_instances_ignore: Optional[InstanceData] = None, **kwargs) -> AssignResult: """Assign gt to points. This method assign a gt bbox to every points set, each points set will be assigned with the background_label (-1), or a label number. -1 is background, and semi-positive number is the index (0-based) of assigned gt. The assignment is done in following steps, the order matters. 1. assign every points to the background_label (-1) 2. A point is assigned to some gt bbox if (i) the point is within the k closest points to the gt bbox (ii) the distance between this point and the gt is smaller than other gt bboxes Args: pred_instances (:obj:`InstanceData`): Instances of model predictions. It includes ``priors``, and the priors can be anchors or points, or the bboxes predicted by the previous stage, has shape (n, 4). The bboxes predicted by the current model or stage will be named ``bboxes``, ``labels``, and ``scores``, the same as the ``InstanceData`` in other places. gt_instances (:obj:`InstanceData`): Ground truth of instance annotations. It usually includes ``bboxes``, with shape (k, 4), and ``labels``, with shape (k, ). gt_instances_ignore (:obj:`InstanceData`, optional): Instances to be ignored during training. It includes ``bboxes`` attribute data that is ignored during training and testing. Defaults to None. Returns: :obj:`AssignResult`: The assign result. """ gt_bboxes = gt_instances.bboxes gt_labels = gt_instances.labels # points to be assigned, shape(n, 3) while last # dimension stands for (x, y, stride). points = pred_instances.priors num_points = points.shape[0] num_gts = gt_bboxes.shape[0] if num_gts == 0 or num_points == 0: # If no truth assign everything to the background assigned_gt_inds = points.new_full((num_points, ), 0, dtype=torch.long) assigned_labels = points.new_full((num_points, ), -1, dtype=torch.long) return AssignResult( num_gts=num_gts, gt_inds=assigned_gt_inds, max_overlaps=None, labels=assigned_labels) points_xy = points[:, :2] points_stride = points[:, 2] points_lvl = torch.log2( points_stride).int() # [3...,4...,5...,6...,7...] lvl_min, lvl_max = points_lvl.min(), points_lvl.max() # assign gt box gt_bboxes_xy = (gt_bboxes[:, :2] + gt_bboxes[:, 2:]) / 2 gt_bboxes_wh = (gt_bboxes[:, 2:] - gt_bboxes[:, :2]).clamp(min=1e-6) scale = self.scale gt_bboxes_lvl = ((torch.log2(gt_bboxes_wh[:, 0] / scale) + torch.log2(gt_bboxes_wh[:, 1] / scale)) / 2).int() gt_bboxes_lvl = torch.clamp(gt_bboxes_lvl, min=lvl_min, max=lvl_max) # stores the assigned gt index of each point assigned_gt_inds = points.new_zeros((num_points, ), dtype=torch.long) # stores the assigned gt dist (to this point) of each point assigned_gt_dist = points.new_full((num_points, ), float('inf')) points_range = torch.arange(points.shape[0]) for idx in range(num_gts): gt_lvl = gt_bboxes_lvl[idx] # get the index of points in this level lvl_idx = gt_lvl == points_lvl points_index = points_range[lvl_idx] # get the points in this level lvl_points = points_xy[lvl_idx, :] # get the center point of gt gt_point = gt_bboxes_xy[[idx], :] # get width and height of gt gt_wh = gt_bboxes_wh[[idx], :] # compute the distance between gt center and # all points in this level points_gt_dist = ((lvl_points - gt_point) / gt_wh).norm(dim=1) # find the nearest k points to gt center in this level min_dist, min_dist_index = torch.topk( points_gt_dist, self.pos_num, largest=False) # the index of nearest k points to gt center in this level min_dist_points_index = points_index[min_dist_index] # The less_than_recorded_index stores the index # of min_dist that is less then the assigned_gt_dist. Where # assigned_gt_dist stores the dist from previous assigned gt # (if exist) to each point. less_than_recorded_index = min_dist < assigned_gt_dist[ min_dist_points_index] # The min_dist_points_index stores the index of points satisfy: # (1) it is k nearest to current gt center in this level. # (2) it is closer to current gt center than other gt center. min_dist_points_index = min_dist_points_index[ less_than_recorded_index] # assign the result assigned_gt_inds[min_dist_points_index] = idx + 1 assigned_gt_dist[min_dist_points_index] = min_dist[ less_than_recorded_index] assigned_labels = assigned_gt_inds.new_full((num_points, ), -1) pos_inds = torch.nonzero( assigned_gt_inds > 0, as_tuple=False).squeeze() if pos_inds.numel() > 0: assigned_labels[pos_inds] = gt_labels[assigned_gt_inds[pos_inds] - 1] return AssignResult( num_gts=num_gts, gt_inds=assigned_gt_inds, max_overlaps=None, labels=assigned_labels)