Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Optional | |
import torch | |
from mmengine.structures import InstanceData | |
from mmdet.registry import TASK_UTILS | |
from .assign_result import AssignResult | |
from .base_assigner import BaseAssigner | |
class PointAssigner(BaseAssigner): | |
"""Assign a corresponding gt bbox or background to each point. | |
Each proposals will be assigned with `0`, or a positive integer | |
indicating the ground truth index. | |
- 0: negative sample, no assigned gt | |
- positive integer: positive sample, index (1-based) of assigned gt | |
""" | |
def __init__(self, scale: int = 4, pos_num: int = 3) -> None: | |
self.scale = scale | |
self.pos_num = pos_num | |
def assign(self, | |
pred_instances: InstanceData, | |
gt_instances: InstanceData, | |
gt_instances_ignore: Optional[InstanceData] = None, | |
**kwargs) -> AssignResult: | |
"""Assign gt to points. | |
This method assign a gt bbox to every points set, each points set | |
will be assigned with the background_label (-1), or a label number. | |
-1 is background, and semi-positive number is the index (0-based) of | |
assigned gt. | |
The assignment is done in following steps, the order matters. | |
1. assign every points to the background_label (-1) | |
2. A point is assigned to some gt bbox if | |
(i) the point is within the k closest points to the gt bbox | |
(ii) the distance between this point and the gt is smaller than | |
other gt bboxes | |
Args: | |
pred_instances (:obj:`InstanceData`): Instances of model | |
predictions. It includes ``priors``, and the priors can | |
be anchors or points, or the bboxes predicted by the | |
previous stage, has shape (n, 4). The bboxes predicted by | |
the current model or stage will be named ``bboxes``, | |
``labels``, and ``scores``, the same as the ``InstanceData`` | |
in other places. | |
gt_instances (:obj:`InstanceData`): Ground truth of instance | |
annotations. It usually includes ``bboxes``, with shape (k, 4), | |
and ``labels``, with shape (k, ). | |
gt_instances_ignore (:obj:`InstanceData`, optional): Instances | |
to be ignored during training. It includes ``bboxes`` | |
attribute data that is ignored during training and testing. | |
Defaults to None. | |
Returns: | |
:obj:`AssignResult`: The assign result. | |
""" | |
gt_bboxes = gt_instances.bboxes | |
gt_labels = gt_instances.labels | |
# points to be assigned, shape(n, 3) while last | |
# dimension stands for (x, y, stride). | |
points = pred_instances.priors | |
num_points = points.shape[0] | |
num_gts = gt_bboxes.shape[0] | |
if num_gts == 0 or num_points == 0: | |
# If no truth assign everything to the background | |
assigned_gt_inds = points.new_full((num_points, ), | |
0, | |
dtype=torch.long) | |
assigned_labels = points.new_full((num_points, ), | |
-1, | |
dtype=torch.long) | |
return AssignResult( | |
num_gts=num_gts, | |
gt_inds=assigned_gt_inds, | |
max_overlaps=None, | |
labels=assigned_labels) | |
points_xy = points[:, :2] | |
points_stride = points[:, 2] | |
points_lvl = torch.log2( | |
points_stride).int() # [3...,4...,5...,6...,7...] | |
lvl_min, lvl_max = points_lvl.min(), points_lvl.max() | |
# assign gt box | |
gt_bboxes_xy = (gt_bboxes[:, :2] + gt_bboxes[:, 2:]) / 2 | |
gt_bboxes_wh = (gt_bboxes[:, 2:] - gt_bboxes[:, :2]).clamp(min=1e-6) | |
scale = self.scale | |
gt_bboxes_lvl = ((torch.log2(gt_bboxes_wh[:, 0] / scale) + | |
torch.log2(gt_bboxes_wh[:, 1] / scale)) / 2).int() | |
gt_bboxes_lvl = torch.clamp(gt_bboxes_lvl, min=lvl_min, max=lvl_max) | |
# stores the assigned gt index of each point | |
assigned_gt_inds = points.new_zeros((num_points, ), dtype=torch.long) | |
# stores the assigned gt dist (to this point) of each point | |
assigned_gt_dist = points.new_full((num_points, ), float('inf')) | |
points_range = torch.arange(points.shape[0]) | |
for idx in range(num_gts): | |
gt_lvl = gt_bboxes_lvl[idx] | |
# get the index of points in this level | |
lvl_idx = gt_lvl == points_lvl | |
points_index = points_range[lvl_idx] | |
# get the points in this level | |
lvl_points = points_xy[lvl_idx, :] | |
# get the center point of gt | |
gt_point = gt_bboxes_xy[[idx], :] | |
# get width and height of gt | |
gt_wh = gt_bboxes_wh[[idx], :] | |
# compute the distance between gt center and | |
# all points in this level | |
points_gt_dist = ((lvl_points - gt_point) / gt_wh).norm(dim=1) | |
# find the nearest k points to gt center in this level | |
min_dist, min_dist_index = torch.topk( | |
points_gt_dist, self.pos_num, largest=False) | |
# the index of nearest k points to gt center in this level | |
min_dist_points_index = points_index[min_dist_index] | |
# The less_than_recorded_index stores the index | |
# of min_dist that is less then the assigned_gt_dist. Where | |
# assigned_gt_dist stores the dist from previous assigned gt | |
# (if exist) to each point. | |
less_than_recorded_index = min_dist < assigned_gt_dist[ | |
min_dist_points_index] | |
# The min_dist_points_index stores the index of points satisfy: | |
# (1) it is k nearest to current gt center in this level. | |
# (2) it is closer to current gt center than other gt center. | |
min_dist_points_index = min_dist_points_index[ | |
less_than_recorded_index] | |
# assign the result | |
assigned_gt_inds[min_dist_points_index] = idx + 1 | |
assigned_gt_dist[min_dist_points_index] = min_dist[ | |
less_than_recorded_index] | |
assigned_labels = assigned_gt_inds.new_full((num_points, ), -1) | |
pos_inds = torch.nonzero( | |
assigned_gt_inds > 0, as_tuple=False).squeeze() | |
if pos_inds.numel() > 0: | |
assigned_labels[pos_inds] = gt_labels[assigned_gt_inds[pos_inds] - | |
1] | |
return AssignResult( | |
num_gts=num_gts, | |
gt_inds=assigned_gt_inds, | |
max_overlaps=None, | |
labels=assigned_labels) | |