# Copyright (c) OpenMMLab. All rights reserved. from typing import Dict, List, Optional, Sequence, Tuple import torch import torch.nn as nn from mmcv.cnn import Scale from mmengine.structures import InstanceData from torch import Tensor from mmdet.registry import MODELS from mmdet.structures.bbox import bbox2distance from mmdet.utils import (ConfigType, InstanceList, OptConfigType, OptInstanceList, reduce_mean) from ..utils import multi_apply from .anchor_free_head import AnchorFreeHead INF = 1000000000 RangeType = Sequence[Tuple[int, int]] def _transpose(tensor_list: List[Tensor], num_point_list: list) -> List[Tensor]: """This function is used to transpose image first tensors to level first ones.""" for img_idx in range(len(tensor_list)): tensor_list[img_idx] = torch.split( tensor_list[img_idx], num_point_list, dim=0) tensors_level_first = [] for targets_per_level in zip(*tensor_list): tensors_level_first.append(torch.cat(targets_per_level, dim=0)) return tensors_level_first @MODELS.register_module() class CenterNetUpdateHead(AnchorFreeHead): """CenterNetUpdateHead is an improved version of CenterNet in CenterNet2. Paper link ``_. Args: num_classes (int): Number of categories excluding the background category. in_channels (int): Number of channel in the input feature map. regress_ranges (Sequence[Tuple[int, int]]): Regress range of multiple level points. hm_min_radius (int): Heatmap target minimum radius of cls branch. Defaults to 4. hm_min_overlap (float): Heatmap target minimum overlap of cls branch. Defaults to 0.8. more_pos_thresh (float): The filtering threshold when the cls branch adds more positive samples. Defaults to 0.2. more_pos_topk (int): The maximum number of additional positive samples added to each gt. Defaults to 9. soft_weight_on_reg (bool): Whether to use the soft target of the cls branch as the soft weight of the bbox branch. Defaults to False. loss_cls (:obj:`ConfigDict` or dict): Config of cls loss. Defaults to dict(type='GaussianFocalLoss', loss_weight=1.0) loss_bbox (:obj:`ConfigDict` or dict): Config of bbox loss. Defaults to dict(type='GIoULoss', loss_weight=2.0). norm_cfg (:obj:`ConfigDict` or dict, optional): dictionary to construct and config norm layer. Defaults to ``norm_cfg=dict(type='GN', num_groups=32, requires_grad=True)``. train_cfg (:obj:`ConfigDict` or dict, optional): Training config. Unused in CenterNet. Reserved for compatibility with SingleStageDetector. test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of CenterNet. """ def __init__(self, num_classes: int, in_channels: int, regress_ranges: RangeType = ((0, 80), (64, 160), (128, 320), (256, 640), (512, INF)), hm_min_radius: int = 4, hm_min_overlap: float = 0.8, more_pos_thresh: float = 0.2, more_pos_topk: int = 9, soft_weight_on_reg: bool = False, loss_cls: ConfigType = dict( type='GaussianFocalLoss', pos_weight=0.25, neg_weight=0.75, loss_weight=1.0), loss_bbox: ConfigType = dict( type='GIoULoss', loss_weight=2.0), norm_cfg: OptConfigType = dict( type='GN', num_groups=32, requires_grad=True), train_cfg: OptConfigType = None, test_cfg: OptConfigType = None, **kwargs) -> None: super().__init__( num_classes=num_classes, in_channels=in_channels, loss_cls=loss_cls, loss_bbox=loss_bbox, norm_cfg=norm_cfg, train_cfg=train_cfg, test_cfg=test_cfg, **kwargs) self.soft_weight_on_reg = soft_weight_on_reg self.hm_min_radius = hm_min_radius self.more_pos_thresh = more_pos_thresh self.more_pos_topk = more_pos_topk self.delta = (1 - hm_min_overlap) / (1 + hm_min_overlap) self.sigmoid_clamp = 0.0001 # GaussianFocalLoss must be sigmoid mode self.use_sigmoid_cls = True self.cls_out_channels = num_classes self.regress_ranges = regress_ranges self.scales = nn.ModuleList([Scale(1.0) for _ in self.strides]) def _init_predictor(self) -> None: """Initialize predictor layers of the head.""" self.conv_cls = nn.Conv2d( self.feat_channels, self.num_classes, 3, padding=1) self.conv_reg = nn.Conv2d(self.feat_channels, 4, 3, padding=1) def forward(self, x: Tuple[Tensor]) -> Tuple[List[Tensor], List[Tensor]]: """Forward features from the upstream network. Args: x (tuple[Tensor]): Features from the upstream network, each is a 4D-tensor. Returns: tuple: A tuple of each level outputs. - cls_scores (list[Tensor]): Box scores for each scale level, \ each is a 4D-tensor, the channel number is num_classes. - bbox_preds (list[Tensor]): Box energies / deltas for each \ scale level, each is a 4D-tensor, the channel number is 4. """ return multi_apply(self.forward_single, x, self.scales, self.strides) def forward_single(self, x: Tensor, scale: Scale, stride: int) -> Tuple[Tensor, Tensor]: """Forward features of a single scale level. Args: x (Tensor): FPN feature maps of the specified stride. scale (:obj:`mmcv.cnn.Scale`): Learnable scale module to resize the bbox prediction. stride (int): The corresponding stride for feature maps. Returns: tuple: scores for each class, bbox predictions of input feature maps. """ cls_score, bbox_pred, _, _ = super().forward_single(x) # scale the bbox_pred of different level # float to avoid overflow when enabling FP16 bbox_pred = scale(bbox_pred).float() # bbox_pred needed for gradient computation has been modified # by F.relu(bbox_pred) when run with PyTorch 1.10. So replace # F.relu(bbox_pred) with bbox_pred.clamp(min=0) bbox_pred = bbox_pred.clamp(min=0) if not self.training: bbox_pred *= stride return cls_score, bbox_pred def loss_by_feat( self, cls_scores: List[Tensor], bbox_preds: List[Tensor], batch_gt_instances: InstanceList, batch_img_metas: List[dict], batch_gt_instances_ignore: OptInstanceList = None ) -> Dict[str, Tensor]: """Calculate the loss based on the features extracted by the detection head. Args: cls_scores (list[Tensor]): Box scores for each scale level, each is a 4D-tensor, the channel number is num_classes. bbox_preds (list[Tensor]): Box energies / deltas for each scale level, each is a 4D-tensor, the channel number is 4. batch_gt_instances (list[:obj:`InstanceData`]): Batch of gt_instance. It usually includes ``bboxes`` and ``labels`` attributes. batch_img_metas (list[dict]): Meta information of each image, e.g., image size, scaling factor, etc. batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): Batch of gt_instances_ignore. It includes ``bboxes`` attribute data that is ignored during training and testing. Defaults to None. Returns: dict[str, Tensor]: A dictionary of loss components. """ num_imgs = cls_scores[0].size(0) assert len(cls_scores) == len(bbox_preds) featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] all_level_points = self.prior_generator.grid_priors( featmap_sizes, dtype=bbox_preds[0].dtype, device=bbox_preds[0].device) # 1 flatten outputs flatten_cls_scores = [ cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels) for cls_score in cls_scores ] flatten_bbox_preds = [ bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4) for bbox_pred in bbox_preds ] flatten_cls_scores = torch.cat(flatten_cls_scores) flatten_bbox_preds = torch.cat(flatten_bbox_preds) # repeat points to align with bbox_preds flatten_points = torch.cat( [points.repeat(num_imgs, 1) for points in all_level_points]) assert (torch.isfinite(flatten_bbox_preds).all().item()) # 2 calc reg and cls branch targets cls_targets, bbox_targets = self.get_targets(all_level_points, batch_gt_instances) # 3 add more pos index for cls branch featmap_sizes = flatten_points.new_tensor(featmap_sizes) pos_inds, cls_labels = self.add_cls_pos_inds(flatten_points, flatten_bbox_preds, featmap_sizes, batch_gt_instances) # 4 calc cls loss if pos_inds is None: # num_gts=0 num_pos_cls = bbox_preds[0].new_tensor(0, dtype=torch.float) else: num_pos_cls = bbox_preds[0].new_tensor( len(pos_inds), dtype=torch.float) num_pos_cls = max(reduce_mean(num_pos_cls), 1.0) flatten_cls_scores = flatten_cls_scores.sigmoid().clamp( min=self.sigmoid_clamp, max=1 - self.sigmoid_clamp) cls_loss = self.loss_cls( flatten_cls_scores, cls_targets, pos_inds=pos_inds, pos_labels=cls_labels, avg_factor=num_pos_cls) # 5 calc reg loss pos_bbox_inds = torch.nonzero( bbox_targets.max(dim=1)[0] >= 0).squeeze(1) pos_bbox_preds = flatten_bbox_preds[pos_bbox_inds] pos_bbox_targets = bbox_targets[pos_bbox_inds] bbox_weight_map = cls_targets.max(dim=1)[0] bbox_weight_map = bbox_weight_map[pos_bbox_inds] bbox_weight_map = bbox_weight_map if self.soft_weight_on_reg \ else torch.ones_like(bbox_weight_map) num_pos_bbox = max(reduce_mean(bbox_weight_map.sum()), 1.0) if len(pos_bbox_inds) > 0: pos_points = flatten_points[pos_bbox_inds] pos_decoded_bbox_preds = self.bbox_coder.decode( pos_points, pos_bbox_preds) pos_decoded_target_preds = self.bbox_coder.decode( pos_points, pos_bbox_targets) bbox_loss = self.loss_bbox( pos_decoded_bbox_preds, pos_decoded_target_preds, weight=bbox_weight_map, avg_factor=num_pos_bbox) else: bbox_loss = flatten_bbox_preds.sum() * 0 return dict(loss_cls=cls_loss, loss_bbox=bbox_loss) def get_targets( self, points: List[Tensor], batch_gt_instances: InstanceList, ) -> Tuple[Tensor, Tensor]: """Compute classification and bbox targets for points in multiple images. Args: points (list[Tensor]): Points of each fpn level, each has shape (num_points, 2). batch_gt_instances (list[:obj:`InstanceData`]): Batch of gt_instance. It usually includes ``bboxes`` and ``labels`` attributes. Returns: tuple: Targets of each level. - concat_lvl_labels (Tensor): Labels of all level and batch. - concat_lvl_bbox_targets (Tensor): BBox targets of all \ level and batch. """ assert len(points) == len(self.regress_ranges) num_levels = len(points) # the number of points per img, per lvl num_points = [center.size(0) for center in points] # expand regress ranges to align with points expanded_regress_ranges = [ points[i].new_tensor(self.regress_ranges[i])[None].expand_as( points[i]) for i in range(num_levels) ] # concat all levels points and regress ranges concat_regress_ranges = torch.cat(expanded_regress_ranges, dim=0) concat_points = torch.cat(points, dim=0) concat_strides = torch.cat([ concat_points.new_ones(num_points[i]) * self.strides[i] for i in range(num_levels) ]) # get labels and bbox_targets of each image cls_targets_list, bbox_targets_list = multi_apply( self._get_targets_single, batch_gt_instances, points=concat_points, regress_ranges=concat_regress_ranges, strides=concat_strides) bbox_targets_list = _transpose(bbox_targets_list, num_points) cls_targets_list = _transpose(cls_targets_list, num_points) concat_lvl_bbox_targets = torch.cat(bbox_targets_list, 0) concat_lvl_cls_targets = torch.cat(cls_targets_list, dim=0) return concat_lvl_cls_targets, concat_lvl_bbox_targets def _get_targets_single(self, gt_instances: InstanceData, points: Tensor, regress_ranges: Tensor, strides: Tensor) -> Tuple[Tensor, Tensor]: """Compute classification and bbox targets for a single image.""" num_points = points.size(0) num_gts = len(gt_instances) gt_bboxes = gt_instances.bboxes gt_labels = gt_instances.labels if num_gts == 0: return gt_labels.new_full((num_points, self.num_classes), self.num_classes), \ gt_bboxes.new_full((num_points, 4), -1) # Calculate the regression tblr target corresponding to all points points = points[:, None].expand(num_points, num_gts, 2) gt_bboxes = gt_bboxes[None].expand(num_points, num_gts, 4) strides = strides[:, None, None].expand(num_points, num_gts, 2) bbox_target = bbox2distance(points, gt_bboxes) # M x N x 4 # condition1: inside a gt bbox inside_gt_bbox_mask = bbox_target.min(dim=2)[0] > 0 # M x N # condition2: Calculate the nearest points from # the upper, lower, left and right ranges from # the center of the gt bbox centers = ((gt_bboxes[..., [0, 1]] + gt_bboxes[..., [2, 3]]) / 2) centers_discret = ((centers / strides).int() * strides).float() + \ strides / 2 centers_discret_dist = points - centers_discret dist_x = centers_discret_dist[..., 0].abs() dist_y = centers_discret_dist[..., 1].abs() inside_gt_center3x3_mask = (dist_x <= strides[..., 0]) & \ (dist_y <= strides[..., 0]) # condition3: limit the regression range for each location bbox_target_wh = bbox_target[..., :2] + bbox_target[..., 2:] crit = (bbox_target_wh**2).sum(dim=2)**0.5 / 2 inside_fpn_level_mask = (crit >= regress_ranges[:, [0]]) & \ (crit <= regress_ranges[:, [1]]) bbox_target_mask = inside_gt_bbox_mask & \ inside_gt_center3x3_mask & \ inside_fpn_level_mask # Calculate the distance weight map gt_center_peak_mask = ((centers_discret_dist**2).sum(dim=2) == 0) weighted_dist = ((points - centers)**2).sum(dim=2) # M x N weighted_dist[gt_center_peak_mask] = 0 areas = (gt_bboxes[..., 2] - gt_bboxes[..., 0]) * ( gt_bboxes[..., 3] - gt_bboxes[..., 1]) radius = self.delta**2 * 2 * areas radius = torch.clamp(radius, min=self.hm_min_radius**2) weighted_dist = weighted_dist / radius # Calculate bbox_target bbox_weighted_dist = weighted_dist.clone() bbox_weighted_dist[bbox_target_mask == 0] = INF * 1.0 min_dist, min_inds = bbox_weighted_dist.min(dim=1) bbox_target = bbox_target[range(len(bbox_target)), min_inds] # M x N x 4 --> M x 4 bbox_target[min_dist == INF] = -INF # Convert to feature map scale bbox_target /= strides[:, 0, :].repeat(1, 2) # Calculate cls_target cls_target = self._create_heatmaps_from_dist(weighted_dist, gt_labels) return cls_target, bbox_target @torch.no_grad() def add_cls_pos_inds( self, flatten_points: Tensor, flatten_bbox_preds: Tensor, featmap_sizes: Tensor, batch_gt_instances: InstanceList ) -> Tuple[Optional[Tensor], Optional[Tensor]]: """Provide additional adaptive positive samples to the classification branch. Args: flatten_points (Tensor): The point after flatten, including batch image and all levels. The shape is (N, 2). flatten_bbox_preds (Tensor): The bbox predicts after flatten, including batch image and all levels. The shape is (N, 4). featmap_sizes (Tensor): Feature map size of all layers. The shape is (5, 2). batch_gt_instances (list[:obj:`InstanceData`]): Batch of gt_instance. It usually includes ``bboxes`` and ``labels`` attributes. Returns: tuple: - pos_inds (Tensor): Adaptively selected positive sample index. - cls_labels (Tensor): Corresponding positive class label. """ outputs = self._get_center3x3_region_index_targets( batch_gt_instances, featmap_sizes) cls_labels, fpn_level_masks, center3x3_inds, \ center3x3_bbox_targets, center3x3_masks = outputs num_gts, total_level, K = cls_labels.shape[0], len( self.strides), center3x3_masks.shape[-1] if num_gts == 0: return None, None # The out-of-bounds index is forcibly set to 0 # to prevent loss calculation errors center3x3_inds[center3x3_masks == 0] = 0 reg_pred_center3x3 = flatten_bbox_preds[center3x3_inds] center3x3_points = flatten_points[center3x3_inds].view(-1, 2) center3x3_bbox_targets_expand = center3x3_bbox_targets.view( -1, 4).clamp(min=0) pos_decoded_bbox_preds = self.bbox_coder.decode( center3x3_points, reg_pred_center3x3.view(-1, 4)) pos_decoded_target_preds = self.bbox_coder.decode( center3x3_points, center3x3_bbox_targets_expand) center3x3_bbox_loss = self.loss_bbox( pos_decoded_bbox_preds, pos_decoded_target_preds, None, reduction_override='none').view(num_gts, total_level, K) / self.loss_bbox.loss_weight # Invalid index Loss set to infinity center3x3_bbox_loss[center3x3_masks == 0] = INF # 4 is the center point of the sampled 9 points, the center point # of gt bbox after discretization. # The center point of gt bbox after discretization # must be a positive sample, so we force its loss to be set to 0. center3x3_bbox_loss.view(-1, K)[fpn_level_masks.view(-1), 4] = 0 center3x3_bbox_loss = center3x3_bbox_loss.view(num_gts, -1) loss_thr = torch.kthvalue( center3x3_bbox_loss, self.more_pos_topk, dim=1)[0] loss_thr[loss_thr > self.more_pos_thresh] = self.more_pos_thresh new_pos = center3x3_bbox_loss < loss_thr.view(num_gts, 1) pos_inds = center3x3_inds.view(num_gts, -1)[new_pos] cls_labels = cls_labels.view(num_gts, 1).expand(num_gts, total_level * K)[new_pos] return pos_inds, cls_labels def _create_heatmaps_from_dist(self, weighted_dist: Tensor, cls_labels: Tensor) -> Tensor: """Generate heatmaps of classification branch based on weighted distance map.""" heatmaps = weighted_dist.new_zeros( (weighted_dist.shape[0], self.num_classes)) for c in range(self.num_classes): inds = (cls_labels == c) # N if inds.int().sum() == 0: continue heatmaps[:, c] = torch.exp(-weighted_dist[:, inds].min(dim=1)[0]) zeros = heatmaps[:, c] < 1e-4 heatmaps[zeros, c] = 0 return heatmaps def _get_center3x3_region_index_targets(self, bacth_gt_instances: InstanceList, shapes_per_level: Tensor) -> tuple: """Get the center (and the 3x3 region near center) locations and target of each objects.""" cls_labels = [] inside_fpn_level_masks = [] center3x3_inds = [] center3x3_masks = [] center3x3_bbox_targets = [] total_levels = len(self.strides) batch = len(bacth_gt_instances) shapes_per_level = shapes_per_level.long() area_per_level = (shapes_per_level[:, 0] * shapes_per_level[:, 1]) # Select a total of 9 positions of 3x3 in the center of the gt bbox # as candidate positive samples K = 9 dx = shapes_per_level.new_tensor([-1, 0, 1, -1, 0, 1, -1, 0, 1]).view(1, 1, K) dy = shapes_per_level.new_tensor([-1, -1, -1, 0, 0, 0, 1, 1, 1]).view(1, 1, K) regress_ranges = shapes_per_level.new_tensor(self.regress_ranges).view( len(self.regress_ranges), 2) # L x 2 strides = shapes_per_level.new_tensor(self.strides) start_coord_pre_level = [] _start = 0 for level in range(total_levels): start_coord_pre_level.append(_start) _start = _start + batch * area_per_level[level] start_coord_pre_level = shapes_per_level.new_tensor( start_coord_pre_level).view(1, total_levels, 1) area_per_level = area_per_level.view(1, total_levels, 1) for im_i in range(batch): gt_instance = bacth_gt_instances[im_i] gt_bboxes = gt_instance.bboxes gt_labels = gt_instance.labels num_gts = gt_bboxes.shape[0] if num_gts == 0: continue cls_labels.append(gt_labels) gt_bboxes = gt_bboxes[:, None].expand(num_gts, total_levels, 4) expanded_strides = strides[None, :, None].expand(num_gts, total_levels, 2) expanded_regress_ranges = regress_ranges[None].expand( num_gts, total_levels, 2) expanded_shapes_per_level = shapes_per_level[None].expand( num_gts, total_levels, 2) # calc reg_target centers = ((gt_bboxes[..., [0, 1]] + gt_bboxes[..., [2, 3]]) / 2) centers_inds = (centers / expanded_strides).long() centers_discret = centers_inds * expanded_strides \ + expanded_strides // 2 bbox_target = bbox2distance(centers_discret, gt_bboxes) # M x N x 4 # calc inside_fpn_level_mask bbox_target_wh = bbox_target[..., :2] + bbox_target[..., 2:] crit = (bbox_target_wh**2).sum(dim=2)**0.5 / 2 inside_fpn_level_mask = \ (crit >= expanded_regress_ranges[..., 0]) & \ (crit <= expanded_regress_ranges[..., 1]) inside_gt_bbox_mask = bbox_target.min(dim=2)[0] >= 0 inside_fpn_level_mask = inside_gt_bbox_mask & inside_fpn_level_mask inside_fpn_level_masks.append(inside_fpn_level_mask) # calc center3x3_ind and mask expand_ws = expanded_shapes_per_level[..., 1:2].expand( num_gts, total_levels, K) expand_hs = expanded_shapes_per_level[..., 0:1].expand( num_gts, total_levels, K) centers_inds_x = centers_inds[..., 0:1] centers_inds_y = centers_inds[..., 1:2] center3x3_idx = start_coord_pre_level + \ im_i * area_per_level + \ (centers_inds_y + dy) * expand_ws + \ (centers_inds_x + dx) center3x3_mask = \ ((centers_inds_y + dy) < expand_hs) & \ ((centers_inds_y + dy) >= 0) & \ ((centers_inds_x + dx) < expand_ws) & \ ((centers_inds_x + dx) >= 0) # recalc center3x3 region reg target bbox_target = bbox_target / expanded_strides.repeat(1, 1, 2) center3x3_bbox_target = bbox_target[..., None, :].expand( num_gts, total_levels, K, 4).clone() center3x3_bbox_target[..., 0] += dx center3x3_bbox_target[..., 1] += dy center3x3_bbox_target[..., 2] -= dx center3x3_bbox_target[..., 3] -= dy # update center3x3_mask center3x3_mask = center3x3_mask & ( center3x3_bbox_target.min(dim=3)[0] >= 0) # n x L x K center3x3_inds.append(center3x3_idx) center3x3_masks.append(center3x3_mask) center3x3_bbox_targets.append(center3x3_bbox_target) if len(inside_fpn_level_masks) > 0: cls_labels = torch.cat(cls_labels, dim=0) inside_fpn_level_masks = torch.cat(inside_fpn_level_masks, dim=0) center3x3_inds = torch.cat(center3x3_inds, dim=0).long() center3x3_bbox_targets = torch.cat(center3x3_bbox_targets, dim=0) center3x3_masks = torch.cat(center3x3_masks, dim=0) else: cls_labels = shapes_per_level.new_zeros(0).long() inside_fpn_level_masks = shapes_per_level.new_zeros( (0, total_levels)).bool() center3x3_inds = shapes_per_level.new_zeros( (0, total_levels, K)).long() center3x3_bbox_targets = shapes_per_level.new_zeros( (0, total_levels, K, 4)).float() center3x3_masks = shapes_per_level.new_zeros( (0, total_levels, K)).bool() return cls_labels, inside_fpn_level_masks, center3x3_inds, \ center3x3_bbox_targets, center3x3_masks