Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Dict, List, Optional, Sequence, Tuple | |
import torch | |
import torch.nn as nn | |
from mmcv.cnn import Scale | |
from mmengine.structures import InstanceData | |
from torch import Tensor | |
from mmdet.registry import MODELS | |
from mmdet.structures.bbox import bbox2distance | |
from mmdet.utils import (ConfigType, InstanceList, OptConfigType, | |
OptInstanceList, reduce_mean) | |
from ..utils import multi_apply | |
from .anchor_free_head import AnchorFreeHead | |
INF = 1000000000 | |
RangeType = Sequence[Tuple[int, int]] | |
def _transpose(tensor_list: List[Tensor], | |
num_point_list: list) -> List[Tensor]: | |
"""This function is used to transpose image first tensors to level first | |
ones.""" | |
for img_idx in range(len(tensor_list)): | |
tensor_list[img_idx] = torch.split( | |
tensor_list[img_idx], num_point_list, dim=0) | |
tensors_level_first = [] | |
for targets_per_level in zip(*tensor_list): | |
tensors_level_first.append(torch.cat(targets_per_level, dim=0)) | |
return tensors_level_first | |
class CenterNetUpdateHead(AnchorFreeHead): | |
"""CenterNetUpdateHead is an improved version of CenterNet in CenterNet2. | |
Paper link `<https://arxiv.org/abs/2103.07461>`_. | |
Args: | |
num_classes (int): Number of categories excluding the background | |
category. | |
in_channels (int): Number of channel in the input feature map. | |
regress_ranges (Sequence[Tuple[int, int]]): Regress range of multiple | |
level points. | |
hm_min_radius (int): Heatmap target minimum radius of cls branch. | |
Defaults to 4. | |
hm_min_overlap (float): Heatmap target minimum overlap of cls branch. | |
Defaults to 0.8. | |
more_pos_thresh (float): The filtering threshold when the cls branch | |
adds more positive samples. Defaults to 0.2. | |
more_pos_topk (int): The maximum number of additional positive samples | |
added to each gt. Defaults to 9. | |
soft_weight_on_reg (bool): Whether to use the soft target of the | |
cls branch as the soft weight of the bbox branch. | |
Defaults to False. | |
loss_cls (:obj:`ConfigDict` or dict): Config of cls loss. Defaults to | |
dict(type='GaussianFocalLoss', loss_weight=1.0) | |
loss_bbox (:obj:`ConfigDict` or dict): Config of bbox loss. Defaults to | |
dict(type='GIoULoss', loss_weight=2.0). | |
norm_cfg (:obj:`ConfigDict` or dict, optional): dictionary to construct | |
and config norm layer. Defaults to | |
``norm_cfg=dict(type='GN', num_groups=32, requires_grad=True)``. | |
train_cfg (:obj:`ConfigDict` or dict, optional): Training config. | |
Unused in CenterNet. Reserved for compatibility with | |
SingleStageDetector. | |
test_cfg (:obj:`ConfigDict` or dict, optional): Testing config | |
of CenterNet. | |
""" | |
def __init__(self, | |
num_classes: int, | |
in_channels: int, | |
regress_ranges: RangeType = ((0, 80), (64, 160), (128, 320), | |
(256, 640), (512, INF)), | |
hm_min_radius: int = 4, | |
hm_min_overlap: float = 0.8, | |
more_pos_thresh: float = 0.2, | |
more_pos_topk: int = 9, | |
soft_weight_on_reg: bool = False, | |
loss_cls: ConfigType = dict( | |
type='GaussianFocalLoss', | |
pos_weight=0.25, | |
neg_weight=0.75, | |
loss_weight=1.0), | |
loss_bbox: ConfigType = dict( | |
type='GIoULoss', loss_weight=2.0), | |
norm_cfg: OptConfigType = dict( | |
type='GN', num_groups=32, requires_grad=True), | |
train_cfg: OptConfigType = None, | |
test_cfg: OptConfigType = None, | |
**kwargs) -> None: | |
super().__init__( | |
num_classes=num_classes, | |
in_channels=in_channels, | |
loss_cls=loss_cls, | |
loss_bbox=loss_bbox, | |
norm_cfg=norm_cfg, | |
train_cfg=train_cfg, | |
test_cfg=test_cfg, | |
**kwargs) | |
self.soft_weight_on_reg = soft_weight_on_reg | |
self.hm_min_radius = hm_min_radius | |
self.more_pos_thresh = more_pos_thresh | |
self.more_pos_topk = more_pos_topk | |
self.delta = (1 - hm_min_overlap) / (1 + hm_min_overlap) | |
self.sigmoid_clamp = 0.0001 | |
# GaussianFocalLoss must be sigmoid mode | |
self.use_sigmoid_cls = True | |
self.cls_out_channels = num_classes | |
self.regress_ranges = regress_ranges | |
self.scales = nn.ModuleList([Scale(1.0) for _ in self.strides]) | |
def _init_predictor(self) -> None: | |
"""Initialize predictor layers of the head.""" | |
self.conv_cls = nn.Conv2d( | |
self.feat_channels, self.num_classes, 3, padding=1) | |
self.conv_reg = nn.Conv2d(self.feat_channels, 4, 3, padding=1) | |
def forward(self, x: Tuple[Tensor]) -> Tuple[List[Tensor], List[Tensor]]: | |
"""Forward features from the upstream network. | |
Args: | |
x (tuple[Tensor]): Features from the upstream network, each is | |
a 4D-tensor. | |
Returns: | |
tuple: A tuple of each level outputs. | |
- cls_scores (list[Tensor]): Box scores for each scale level, \ | |
each is a 4D-tensor, the channel number is num_classes. | |
- bbox_preds (list[Tensor]): Box energies / deltas for each \ | |
scale level, each is a 4D-tensor, the channel number is 4. | |
""" | |
return multi_apply(self.forward_single, x, self.scales, self.strides) | |
def forward_single(self, x: Tensor, scale: Scale, | |
stride: int) -> Tuple[Tensor, Tensor]: | |
"""Forward features of a single scale level. | |
Args: | |
x (Tensor): FPN feature maps of the specified stride. | |
scale (:obj:`mmcv.cnn.Scale`): Learnable scale module to resize | |
the bbox prediction. | |
stride (int): The corresponding stride for feature maps. | |
Returns: | |
tuple: scores for each class, bbox predictions of | |
input feature maps. | |
""" | |
cls_score, bbox_pred, _, _ = super().forward_single(x) | |
# scale the bbox_pred of different level | |
# float to avoid overflow when enabling FP16 | |
bbox_pred = scale(bbox_pred).float() | |
# bbox_pred needed for gradient computation has been modified | |
# by F.relu(bbox_pred) when run with PyTorch 1.10. So replace | |
# F.relu(bbox_pred) with bbox_pred.clamp(min=0) | |
bbox_pred = bbox_pred.clamp(min=0) | |
if not self.training: | |
bbox_pred *= stride | |
return cls_score, bbox_pred | |
def loss_by_feat( | |
self, | |
cls_scores: List[Tensor], | |
bbox_preds: List[Tensor], | |
batch_gt_instances: InstanceList, | |
batch_img_metas: List[dict], | |
batch_gt_instances_ignore: OptInstanceList = None | |
) -> Dict[str, Tensor]: | |
"""Calculate the loss based on the features extracted by the detection | |
head. | |
Args: | |
cls_scores (list[Tensor]): Box scores for each scale level, | |
each is a 4D-tensor, the channel number is num_classes. | |
bbox_preds (list[Tensor]): Box energies / deltas for each scale | |
level, each is a 4D-tensor, the channel number is 4. | |
batch_gt_instances (list[:obj:`InstanceData`]): Batch of | |
gt_instance. It usually includes ``bboxes`` and ``labels`` | |
attributes. | |
batch_img_metas (list[dict]): Meta information of each image, e.g., | |
image size, scaling factor, etc. | |
batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): | |
Batch of gt_instances_ignore. It includes ``bboxes`` attribute | |
data that is ignored during training and testing. | |
Defaults to None. | |
Returns: | |
dict[str, Tensor]: A dictionary of loss components. | |
""" | |
num_imgs = cls_scores[0].size(0) | |
assert len(cls_scores) == len(bbox_preds) | |
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] | |
all_level_points = self.prior_generator.grid_priors( | |
featmap_sizes, | |
dtype=bbox_preds[0].dtype, | |
device=bbox_preds[0].device) | |
# 1 flatten outputs | |
flatten_cls_scores = [ | |
cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels) | |
for cls_score in cls_scores | |
] | |
flatten_bbox_preds = [ | |
bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4) | |
for bbox_pred in bbox_preds | |
] | |
flatten_cls_scores = torch.cat(flatten_cls_scores) | |
flatten_bbox_preds = torch.cat(flatten_bbox_preds) | |
# repeat points to align with bbox_preds | |
flatten_points = torch.cat( | |
[points.repeat(num_imgs, 1) for points in all_level_points]) | |
assert (torch.isfinite(flatten_bbox_preds).all().item()) | |
# 2 calc reg and cls branch targets | |
cls_targets, bbox_targets = self.get_targets(all_level_points, | |
batch_gt_instances) | |
# 3 add more pos index for cls branch | |
featmap_sizes = flatten_points.new_tensor(featmap_sizes) | |
pos_inds, cls_labels = self.add_cls_pos_inds(flatten_points, | |
flatten_bbox_preds, | |
featmap_sizes, | |
batch_gt_instances) | |
# 4 calc cls loss | |
if pos_inds is None: | |
# num_gts=0 | |
num_pos_cls = bbox_preds[0].new_tensor(0, dtype=torch.float) | |
else: | |
num_pos_cls = bbox_preds[0].new_tensor( | |
len(pos_inds), dtype=torch.float) | |
num_pos_cls = max(reduce_mean(num_pos_cls), 1.0) | |
flatten_cls_scores = flatten_cls_scores.sigmoid().clamp( | |
min=self.sigmoid_clamp, max=1 - self.sigmoid_clamp) | |
cls_loss = self.loss_cls( | |
flatten_cls_scores, | |
cls_targets, | |
pos_inds=pos_inds, | |
pos_labels=cls_labels, | |
avg_factor=num_pos_cls) | |
# 5 calc reg loss | |
pos_bbox_inds = torch.nonzero( | |
bbox_targets.max(dim=1)[0] >= 0).squeeze(1) | |
pos_bbox_preds = flatten_bbox_preds[pos_bbox_inds] | |
pos_bbox_targets = bbox_targets[pos_bbox_inds] | |
bbox_weight_map = cls_targets.max(dim=1)[0] | |
bbox_weight_map = bbox_weight_map[pos_bbox_inds] | |
bbox_weight_map = bbox_weight_map if self.soft_weight_on_reg \ | |
else torch.ones_like(bbox_weight_map) | |
num_pos_bbox = max(reduce_mean(bbox_weight_map.sum()), 1.0) | |
if len(pos_bbox_inds) > 0: | |
pos_points = flatten_points[pos_bbox_inds] | |
pos_decoded_bbox_preds = self.bbox_coder.decode( | |
pos_points, pos_bbox_preds) | |
pos_decoded_target_preds = self.bbox_coder.decode( | |
pos_points, pos_bbox_targets) | |
bbox_loss = self.loss_bbox( | |
pos_decoded_bbox_preds, | |
pos_decoded_target_preds, | |
weight=bbox_weight_map, | |
avg_factor=num_pos_bbox) | |
else: | |
bbox_loss = flatten_bbox_preds.sum() * 0 | |
return dict(loss_cls=cls_loss, loss_bbox=bbox_loss) | |
def get_targets( | |
self, | |
points: List[Tensor], | |
batch_gt_instances: InstanceList, | |
) -> Tuple[Tensor, Tensor]: | |
"""Compute classification and bbox targets for points in multiple | |
images. | |
Args: | |
points (list[Tensor]): Points of each fpn level, each has shape | |
(num_points, 2). | |
batch_gt_instances (list[:obj:`InstanceData`]): Batch of | |
gt_instance. It usually includes ``bboxes`` and ``labels`` | |
attributes. | |
Returns: | |
tuple: Targets of each level. | |
- concat_lvl_labels (Tensor): Labels of all level and batch. | |
- concat_lvl_bbox_targets (Tensor): BBox targets of all \ | |
level and batch. | |
""" | |
assert len(points) == len(self.regress_ranges) | |
num_levels = len(points) | |
# the number of points per img, per lvl | |
num_points = [center.size(0) for center in points] | |
# expand regress ranges to align with points | |
expanded_regress_ranges = [ | |
points[i].new_tensor(self.regress_ranges[i])[None].expand_as( | |
points[i]) for i in range(num_levels) | |
] | |
# concat all levels points and regress ranges | |
concat_regress_ranges = torch.cat(expanded_regress_ranges, dim=0) | |
concat_points = torch.cat(points, dim=0) | |
concat_strides = torch.cat([ | |
concat_points.new_ones(num_points[i]) * self.strides[i] | |
for i in range(num_levels) | |
]) | |
# get labels and bbox_targets of each image | |
cls_targets_list, bbox_targets_list = multi_apply( | |
self._get_targets_single, | |
batch_gt_instances, | |
points=concat_points, | |
regress_ranges=concat_regress_ranges, | |
strides=concat_strides) | |
bbox_targets_list = _transpose(bbox_targets_list, num_points) | |
cls_targets_list = _transpose(cls_targets_list, num_points) | |
concat_lvl_bbox_targets = torch.cat(bbox_targets_list, 0) | |
concat_lvl_cls_targets = torch.cat(cls_targets_list, dim=0) | |
return concat_lvl_cls_targets, concat_lvl_bbox_targets | |
def _get_targets_single(self, gt_instances: InstanceData, points: Tensor, | |
regress_ranges: Tensor, | |
strides: Tensor) -> Tuple[Tensor, Tensor]: | |
"""Compute classification and bbox targets for a single image.""" | |
num_points = points.size(0) | |
num_gts = len(gt_instances) | |
gt_bboxes = gt_instances.bboxes | |
gt_labels = gt_instances.labels | |
if num_gts == 0: | |
return gt_labels.new_full((num_points, | |
self.num_classes), | |
self.num_classes), \ | |
gt_bboxes.new_full((num_points, 4), -1) | |
# Calculate the regression tblr target corresponding to all points | |
points = points[:, None].expand(num_points, num_gts, 2) | |
gt_bboxes = gt_bboxes[None].expand(num_points, num_gts, 4) | |
strides = strides[:, None, None].expand(num_points, num_gts, 2) | |
bbox_target = bbox2distance(points, gt_bboxes) # M x N x 4 | |
# condition1: inside a gt bbox | |
inside_gt_bbox_mask = bbox_target.min(dim=2)[0] > 0 # M x N | |
# condition2: Calculate the nearest points from | |
# the upper, lower, left and right ranges from | |
# the center of the gt bbox | |
centers = ((gt_bboxes[..., [0, 1]] + gt_bboxes[..., [2, 3]]) / 2) | |
centers_discret = ((centers / strides).int() * strides).float() + \ | |
strides / 2 | |
centers_discret_dist = points - centers_discret | |
dist_x = centers_discret_dist[..., 0].abs() | |
dist_y = centers_discret_dist[..., 1].abs() | |
inside_gt_center3x3_mask = (dist_x <= strides[..., 0]) & \ | |
(dist_y <= strides[..., 0]) | |
# condition3: limit the regression range for each location | |
bbox_target_wh = bbox_target[..., :2] + bbox_target[..., 2:] | |
crit = (bbox_target_wh**2).sum(dim=2)**0.5 / 2 | |
inside_fpn_level_mask = (crit >= regress_ranges[:, [0]]) & \ | |
(crit <= regress_ranges[:, [1]]) | |
bbox_target_mask = inside_gt_bbox_mask & \ | |
inside_gt_center3x3_mask & \ | |
inside_fpn_level_mask | |
# Calculate the distance weight map | |
gt_center_peak_mask = ((centers_discret_dist**2).sum(dim=2) == 0) | |
weighted_dist = ((points - centers)**2).sum(dim=2) # M x N | |
weighted_dist[gt_center_peak_mask] = 0 | |
areas = (gt_bboxes[..., 2] - gt_bboxes[..., 0]) * ( | |
gt_bboxes[..., 3] - gt_bboxes[..., 1]) | |
radius = self.delta**2 * 2 * areas | |
radius = torch.clamp(radius, min=self.hm_min_radius**2) | |
weighted_dist = weighted_dist / radius | |
# Calculate bbox_target | |
bbox_weighted_dist = weighted_dist.clone() | |
bbox_weighted_dist[bbox_target_mask == 0] = INF * 1.0 | |
min_dist, min_inds = bbox_weighted_dist.min(dim=1) | |
bbox_target = bbox_target[range(len(bbox_target)), | |
min_inds] # M x N x 4 --> M x 4 | |
bbox_target[min_dist == INF] = -INF | |
# Convert to feature map scale | |
bbox_target /= strides[:, 0, :].repeat(1, 2) | |
# Calculate cls_target | |
cls_target = self._create_heatmaps_from_dist(weighted_dist, gt_labels) | |
return cls_target, bbox_target | |
def add_cls_pos_inds( | |
self, flatten_points: Tensor, flatten_bbox_preds: Tensor, | |
featmap_sizes: Tensor, batch_gt_instances: InstanceList | |
) -> Tuple[Optional[Tensor], Optional[Tensor]]: | |
"""Provide additional adaptive positive samples to the classification | |
branch. | |
Args: | |
flatten_points (Tensor): The point after flatten, including | |
batch image and all levels. The shape is (N, 2). | |
flatten_bbox_preds (Tensor): The bbox predicts after flatten, | |
including batch image and all levels. The shape is (N, 4). | |
featmap_sizes (Tensor): Feature map size of all layers. | |
The shape is (5, 2). | |
batch_gt_instances (list[:obj:`InstanceData`]): Batch of | |
gt_instance. It usually includes ``bboxes`` and ``labels`` | |
attributes. | |
Returns: | |
tuple: | |
- pos_inds (Tensor): Adaptively selected positive sample index. | |
- cls_labels (Tensor): Corresponding positive class label. | |
""" | |
outputs = self._get_center3x3_region_index_targets( | |
batch_gt_instances, featmap_sizes) | |
cls_labels, fpn_level_masks, center3x3_inds, \ | |
center3x3_bbox_targets, center3x3_masks = outputs | |
num_gts, total_level, K = cls_labels.shape[0], len( | |
self.strides), center3x3_masks.shape[-1] | |
if num_gts == 0: | |
return None, None | |
# The out-of-bounds index is forcibly set to 0 | |
# to prevent loss calculation errors | |
center3x3_inds[center3x3_masks == 0] = 0 | |
reg_pred_center3x3 = flatten_bbox_preds[center3x3_inds] | |
center3x3_points = flatten_points[center3x3_inds].view(-1, 2) | |
center3x3_bbox_targets_expand = center3x3_bbox_targets.view( | |
-1, 4).clamp(min=0) | |
pos_decoded_bbox_preds = self.bbox_coder.decode( | |
center3x3_points, reg_pred_center3x3.view(-1, 4)) | |
pos_decoded_target_preds = self.bbox_coder.decode( | |
center3x3_points, center3x3_bbox_targets_expand) | |
center3x3_bbox_loss = self.loss_bbox( | |
pos_decoded_bbox_preds, | |
pos_decoded_target_preds, | |
None, | |
reduction_override='none').view(num_gts, total_level, | |
K) / self.loss_bbox.loss_weight | |
# Invalid index Loss set to infinity | |
center3x3_bbox_loss[center3x3_masks == 0] = INF | |
# 4 is the center point of the sampled 9 points, the center point | |
# of gt bbox after discretization. | |
# The center point of gt bbox after discretization | |
# must be a positive sample, so we force its loss to be set to 0. | |
center3x3_bbox_loss.view(-1, K)[fpn_level_masks.view(-1), 4] = 0 | |
center3x3_bbox_loss = center3x3_bbox_loss.view(num_gts, -1) | |
loss_thr = torch.kthvalue( | |
center3x3_bbox_loss, self.more_pos_topk, dim=1)[0] | |
loss_thr[loss_thr > self.more_pos_thresh] = self.more_pos_thresh | |
new_pos = center3x3_bbox_loss < loss_thr.view(num_gts, 1) | |
pos_inds = center3x3_inds.view(num_gts, -1)[new_pos] | |
cls_labels = cls_labels.view(num_gts, | |
1).expand(num_gts, | |
total_level * K)[new_pos] | |
return pos_inds, cls_labels | |
def _create_heatmaps_from_dist(self, weighted_dist: Tensor, | |
cls_labels: Tensor) -> Tensor: | |
"""Generate heatmaps of classification branch based on weighted | |
distance map.""" | |
heatmaps = weighted_dist.new_zeros( | |
(weighted_dist.shape[0], self.num_classes)) | |
for c in range(self.num_classes): | |
inds = (cls_labels == c) # N | |
if inds.int().sum() == 0: | |
continue | |
heatmaps[:, c] = torch.exp(-weighted_dist[:, inds].min(dim=1)[0]) | |
zeros = heatmaps[:, c] < 1e-4 | |
heatmaps[zeros, c] = 0 | |
return heatmaps | |
def _get_center3x3_region_index_targets(self, | |
bacth_gt_instances: InstanceList, | |
shapes_per_level: Tensor) -> tuple: | |
"""Get the center (and the 3x3 region near center) locations and target | |
of each objects.""" | |
cls_labels = [] | |
inside_fpn_level_masks = [] | |
center3x3_inds = [] | |
center3x3_masks = [] | |
center3x3_bbox_targets = [] | |
total_levels = len(self.strides) | |
batch = len(bacth_gt_instances) | |
shapes_per_level = shapes_per_level.long() | |
area_per_level = (shapes_per_level[:, 0] * shapes_per_level[:, 1]) | |
# Select a total of 9 positions of 3x3 in the center of the gt bbox | |
# as candidate positive samples | |
K = 9 | |
dx = shapes_per_level.new_tensor([-1, 0, 1, -1, 0, 1, -1, 0, | |
1]).view(1, 1, K) | |
dy = shapes_per_level.new_tensor([-1, -1, -1, 0, 0, 0, 1, 1, | |
1]).view(1, 1, K) | |
regress_ranges = shapes_per_level.new_tensor(self.regress_ranges).view( | |
len(self.regress_ranges), 2) # L x 2 | |
strides = shapes_per_level.new_tensor(self.strides) | |
start_coord_pre_level = [] | |
_start = 0 | |
for level in range(total_levels): | |
start_coord_pre_level.append(_start) | |
_start = _start + batch * area_per_level[level] | |
start_coord_pre_level = shapes_per_level.new_tensor( | |
start_coord_pre_level).view(1, total_levels, 1) | |
area_per_level = area_per_level.view(1, total_levels, 1) | |
for im_i in range(batch): | |
gt_instance = bacth_gt_instances[im_i] | |
gt_bboxes = gt_instance.bboxes | |
gt_labels = gt_instance.labels | |
num_gts = gt_bboxes.shape[0] | |
if num_gts == 0: | |
continue | |
cls_labels.append(gt_labels) | |
gt_bboxes = gt_bboxes[:, None].expand(num_gts, total_levels, 4) | |
expanded_strides = strides[None, :, | |
None].expand(num_gts, total_levels, 2) | |
expanded_regress_ranges = regress_ranges[None].expand( | |
num_gts, total_levels, 2) | |
expanded_shapes_per_level = shapes_per_level[None].expand( | |
num_gts, total_levels, 2) | |
# calc reg_target | |
centers = ((gt_bboxes[..., [0, 1]] + gt_bboxes[..., [2, 3]]) / 2) | |
centers_inds = (centers / expanded_strides).long() | |
centers_discret = centers_inds * expanded_strides \ | |
+ expanded_strides // 2 | |
bbox_target = bbox2distance(centers_discret, | |
gt_bboxes) # M x N x 4 | |
# calc inside_fpn_level_mask | |
bbox_target_wh = bbox_target[..., :2] + bbox_target[..., 2:] | |
crit = (bbox_target_wh**2).sum(dim=2)**0.5 / 2 | |
inside_fpn_level_mask = \ | |
(crit >= expanded_regress_ranges[..., 0]) & \ | |
(crit <= expanded_regress_ranges[..., 1]) | |
inside_gt_bbox_mask = bbox_target.min(dim=2)[0] >= 0 | |
inside_fpn_level_mask = inside_gt_bbox_mask & inside_fpn_level_mask | |
inside_fpn_level_masks.append(inside_fpn_level_mask) | |
# calc center3x3_ind and mask | |
expand_ws = expanded_shapes_per_level[..., 1:2].expand( | |
num_gts, total_levels, K) | |
expand_hs = expanded_shapes_per_level[..., 0:1].expand( | |
num_gts, total_levels, K) | |
centers_inds_x = centers_inds[..., 0:1] | |
centers_inds_y = centers_inds[..., 1:2] | |
center3x3_idx = start_coord_pre_level + \ | |
im_i * area_per_level + \ | |
(centers_inds_y + dy) * expand_ws + \ | |
(centers_inds_x + dx) | |
center3x3_mask = \ | |
((centers_inds_y + dy) < expand_hs) & \ | |
((centers_inds_y + dy) >= 0) & \ | |
((centers_inds_x + dx) < expand_ws) & \ | |
((centers_inds_x + dx) >= 0) | |
# recalc center3x3 region reg target | |
bbox_target = bbox_target / expanded_strides.repeat(1, 1, 2) | |
center3x3_bbox_target = bbox_target[..., None, :].expand( | |
num_gts, total_levels, K, 4).clone() | |
center3x3_bbox_target[..., 0] += dx | |
center3x3_bbox_target[..., 1] += dy | |
center3x3_bbox_target[..., 2] -= dx | |
center3x3_bbox_target[..., 3] -= dy | |
# update center3x3_mask | |
center3x3_mask = center3x3_mask & ( | |
center3x3_bbox_target.min(dim=3)[0] >= 0) # n x L x K | |
center3x3_inds.append(center3x3_idx) | |
center3x3_masks.append(center3x3_mask) | |
center3x3_bbox_targets.append(center3x3_bbox_target) | |
if len(inside_fpn_level_masks) > 0: | |
cls_labels = torch.cat(cls_labels, dim=0) | |
inside_fpn_level_masks = torch.cat(inside_fpn_level_masks, dim=0) | |
center3x3_inds = torch.cat(center3x3_inds, dim=0).long() | |
center3x3_bbox_targets = torch.cat(center3x3_bbox_targets, dim=0) | |
center3x3_masks = torch.cat(center3x3_masks, dim=0) | |
else: | |
cls_labels = shapes_per_level.new_zeros(0).long() | |
inside_fpn_level_masks = shapes_per_level.new_zeros( | |
(0, total_levels)).bool() | |
center3x3_inds = shapes_per_level.new_zeros( | |
(0, total_levels, K)).long() | |
center3x3_bbox_targets = shapes_per_level.new_zeros( | |
(0, total_levels, K, 4)).float() | |
center3x3_masks = shapes_per_level.new_zeros( | |
(0, total_levels, K)).bool() | |
return cls_labels, inside_fpn_level_masks, center3x3_inds, \ | |
center3x3_bbox_targets, center3x3_masks | |