KyanChen's picture
init
f549064
raw
history blame
43.6 kB
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import math
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule, is_norm
from mmcv.ops import batched_nms
from mmengine.model import (BaseModule, bias_init_with_prob, constant_init,
normal_init)
from mmengine.structures import InstanceData
from torch import Tensor
from mmdet.models.layers.transformer import inverse_sigmoid
from mmdet.models.utils import (filter_scores_and_topk, multi_apply,
select_single_mlvl, sigmoid_geometric_mean)
from mmdet.registry import MODELS
from mmdet.structures.bbox import (cat_boxes, distance2bbox, get_box_tensor,
get_box_wh, scale_boxes)
from mmdet.utils import ConfigType, InstanceList, OptInstanceList, reduce_mean
from .rtmdet_head import RTMDetHead
@MODELS.register_module()
class RTMDetInsHead(RTMDetHead):
"""Detection Head of RTMDet-Ins.
Args:
num_prototypes (int): Number of mask prototype features extracted
from the mask head. Defaults to 8.
dyconv_channels (int): Channel of the dynamic conv layers.
Defaults to 8.
num_dyconvs (int): Number of the dynamic convolution layers.
Defaults to 3.
mask_loss_stride (int): Down sample stride of the masks for loss
computation. Defaults to 4.
loss_mask (:obj:`ConfigDict` or dict): Config dict for mask loss.
"""
def __init__(self,
*args,
num_prototypes: int = 8,
dyconv_channels: int = 8,
num_dyconvs: int = 3,
mask_loss_stride: int = 4,
loss_mask=dict(
type='DiceLoss',
loss_weight=2.0,
eps=5e-6,
reduction='mean'),
**kwargs) -> None:
self.num_prototypes = num_prototypes
self.num_dyconvs = num_dyconvs
self.dyconv_channels = dyconv_channels
self.mask_loss_stride = mask_loss_stride
super().__init__(*args, **kwargs)
self.loss_mask = MODELS.build(loss_mask)
def _init_layers(self) -> None:
"""Initialize layers of the head."""
super()._init_layers()
# a branch to predict kernels of dynamic convs
self.kernel_convs = nn.ModuleList()
# calculate num dynamic parameters
weight_nums, bias_nums = [], []
for i in range(self.num_dyconvs):
if i == 0:
weight_nums.append(
# mask prototype and coordinate features
(self.num_prototypes + 2) * self.dyconv_channels)
bias_nums.append(self.dyconv_channels * 1)
elif i == self.num_dyconvs - 1:
weight_nums.append(self.dyconv_channels * 1)
bias_nums.append(1)
else:
weight_nums.append(self.dyconv_channels * self.dyconv_channels)
bias_nums.append(self.dyconv_channels * 1)
self.weight_nums = weight_nums
self.bias_nums = bias_nums
self.num_gen_params = sum(weight_nums) + sum(bias_nums)
for i in range(self.stacked_convs):
chn = self.in_channels if i == 0 else self.feat_channels
self.kernel_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
pred_pad_size = self.pred_kernel_size // 2
self.rtm_kernel = nn.Conv2d(
self.feat_channels,
self.num_gen_params,
self.pred_kernel_size,
padding=pred_pad_size)
self.mask_head = MaskFeatModule(
in_channels=self.in_channels,
feat_channels=self.feat_channels,
stacked_convs=4,
num_levels=len(self.prior_generator.strides),
num_prototypes=self.num_prototypes,
act_cfg=self.act_cfg,
norm_cfg=self.norm_cfg)
def forward(self, feats: Tuple[Tensor, ...]) -> tuple:
"""Forward features from the upstream network.
Args:
feats (tuple[Tensor]): Features from the upstream network, each is
a 4D-tensor.
Returns:
tuple: Usually a tuple of classification scores and bbox prediction
- cls_scores (list[Tensor]): Classification scores for all scale
levels, each is a 4D-tensor, the channels number is
num_base_priors * num_classes.
- bbox_preds (list[Tensor]): Box energies / deltas for all scale
levels, each is a 4D-tensor, the channels number is
num_base_priors * 4.
- kernel_preds (list[Tensor]): Dynamic conv kernels for all scale
levels, each is a 4D-tensor, the channels number is
num_gen_params.
- mask_feat (Tensor): Output feature of the mask head. Each is a
4D-tensor, the channels number is num_prototypes.
"""
mask_feat = self.mask_head(feats)
cls_scores = []
bbox_preds = []
kernel_preds = []
for idx, (x, scale, stride) in enumerate(
zip(feats, self.scales, self.prior_generator.strides)):
cls_feat = x
reg_feat = x
kernel_feat = x
for cls_layer in self.cls_convs:
cls_feat = cls_layer(cls_feat)
cls_score = self.rtm_cls(cls_feat)
for kernel_layer in self.kernel_convs:
kernel_feat = kernel_layer(kernel_feat)
kernel_pred = self.rtm_kernel(kernel_feat)
for reg_layer in self.reg_convs:
reg_feat = reg_layer(reg_feat)
if self.with_objectness:
objectness = self.rtm_obj(reg_feat)
cls_score = inverse_sigmoid(
sigmoid_geometric_mean(cls_score, objectness))
reg_dist = scale(self.rtm_reg(reg_feat)) * stride[0]
cls_scores.append(cls_score)
bbox_preds.append(reg_dist)
kernel_preds.append(kernel_pred)
return tuple(cls_scores), tuple(bbox_preds), tuple(
kernel_preds), mask_feat
def predict_by_feat(self,
cls_scores: List[Tensor],
bbox_preds: List[Tensor],
kernel_preds: List[Tensor],
mask_feat: Tensor,
score_factors: Optional[List[Tensor]] = None,
batch_img_metas: Optional[List[dict]] = None,
cfg: Optional[ConfigType] = None,
rescale: bool = False,
with_nms: bool = True) -> InstanceList:
"""Transform a batch of output features extracted from the head into
bbox results.
Note: When score_factors is not None, the cls_scores are
usually multiplied by it then obtain the real score used in NMS,
such as CenterNess in FCOS, IoU branch in ATSS.
Args:
cls_scores (list[Tensor]): Classification scores for all
scale levels, each is a 4D-tensor, has shape
(batch_size, num_priors * num_classes, H, W).
bbox_preds (list[Tensor]): Box energies / deltas for all
scale levels, each is a 4D-tensor, has shape
(batch_size, num_priors * 4, H, W).
kernel_preds (list[Tensor]): Kernel predictions of dynamic
convs for all scale levels, each is a 4D-tensor, has shape
(batch_size, num_params, H, W).
mask_feat (Tensor): Mask prototype features extracted from the
mask head, has shape (batch_size, num_prototypes, H, W).
score_factors (list[Tensor], optional): Score factor for
all scale level, each is a 4D-tensor, has shape
(batch_size, num_priors * 1, H, W). Defaults to None.
batch_img_metas (list[dict], Optional): Batch image meta info.
Defaults to None.
cfg (ConfigDict, optional): Test / postprocessing
configuration, if None, test_cfg would be used.
Defaults to None.
rescale (bool): If True, return boxes in original image space.
Defaults to False.
with_nms (bool): If True, do nms before return boxes.
Defaults to True.
Returns:
list[:obj:`InstanceData`]: Object detection results of each image
after the post process. Each item usually contains following keys.
- scores (Tensor): Classification scores, has a shape
(num_instance, )
- labels (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes (Tensor): Has a shape (num_instances, 4),
the last dimension 4 arrange as (x1, y1, x2, y2).
- masks (Tensor): Has a shape (num_instances, h, w).
"""
assert len(cls_scores) == len(bbox_preds)
if score_factors is None:
# e.g. Retina, FreeAnchor, Foveabox, etc.
with_score_factors = False
else:
# e.g. FCOS, PAA, ATSS, AutoAssign, etc.
with_score_factors = True
assert len(cls_scores) == len(score_factors)
num_levels = len(cls_scores)
featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)]
mlvl_priors = self.prior_generator.grid_priors(
featmap_sizes,
dtype=cls_scores[0].dtype,
device=cls_scores[0].device,
with_stride=True)
result_list = []
for img_id in range(len(batch_img_metas)):
img_meta = batch_img_metas[img_id]
cls_score_list = select_single_mlvl(
cls_scores, img_id, detach=True)
bbox_pred_list = select_single_mlvl(
bbox_preds, img_id, detach=True)
kernel_pred_list = select_single_mlvl(
kernel_preds, img_id, detach=True)
if with_score_factors:
score_factor_list = select_single_mlvl(
score_factors, img_id, detach=True)
else:
score_factor_list = [None for _ in range(num_levels)]
results = self._predict_by_feat_single(
cls_score_list=cls_score_list,
bbox_pred_list=bbox_pred_list,
kernel_pred_list=kernel_pred_list,
mask_feat=mask_feat[img_id],
score_factor_list=score_factor_list,
mlvl_priors=mlvl_priors,
img_meta=img_meta,
cfg=cfg,
rescale=rescale,
with_nms=with_nms)
result_list.append(results)
return result_list
def _predict_by_feat_single(self,
cls_score_list: List[Tensor],
bbox_pred_list: List[Tensor],
kernel_pred_list: List[Tensor],
mask_feat: Tensor,
score_factor_list: List[Tensor],
mlvl_priors: List[Tensor],
img_meta: dict,
cfg: ConfigType,
rescale: bool = False,
with_nms: bool = True) -> InstanceData:
"""Transform a single image's features extracted from the head into
bbox and mask results.
Args:
cls_score_list (list[Tensor]): Box scores from all scale
levels of a single image, each item has shape
(num_priors * num_classes, H, W).
bbox_pred_list (list[Tensor]): Box energies / deltas from
all scale levels of a single image, each item has shape
(num_priors * 4, H, W).
kernel_preds (list[Tensor]): Kernel predictions of dynamic
convs for all scale levels of a single image, each is a
4D-tensor, has shape (num_params, H, W).
mask_feat (Tensor): Mask prototype features of a single image
extracted from the mask head, has shape (num_prototypes, H, W).
score_factor_list (list[Tensor]): Score factor from all scale
levels of a single image, each item has shape
(num_priors * 1, H, W).
mlvl_priors (list[Tensor]): Each element in the list is
the priors of a single level in feature pyramid. In all
anchor-based methods, it has shape (num_priors, 4). In
all anchor-free methods, it has shape (num_priors, 2)
when `with_stride=True`, otherwise it still has shape
(num_priors, 4).
img_meta (dict): Image meta info.
cfg (mmengine.Config): Test / postprocessing configuration,
if None, test_cfg would be used.
rescale (bool): If True, return boxes in original image space.
Defaults to False.
with_nms (bool): If True, do nms before return boxes.
Defaults to True.
Returns:
:obj:`InstanceData`: Detection results of each image
after the post process.
Each item usually contains following keys.
- scores (Tensor): Classification scores, has a shape
(num_instance, )
- labels (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes (Tensor): Has a shape (num_instances, 4),
the last dimension 4 arrange as (x1, y1, x2, y2).
- masks (Tensor): Has a shape (num_instances, h, w).
"""
if score_factor_list[0] is None:
# e.g. Retina, FreeAnchor, etc.
with_score_factors = False
else:
# e.g. FCOS, PAA, ATSS, etc.
with_score_factors = True
cfg = self.test_cfg if cfg is None else cfg
cfg = copy.deepcopy(cfg)
img_shape = img_meta['img_shape']
nms_pre = cfg.get('nms_pre', -1)
mlvl_bbox_preds = []
mlvl_kernels = []
mlvl_valid_priors = []
mlvl_scores = []
mlvl_labels = []
if with_score_factors:
mlvl_score_factors = []
else:
mlvl_score_factors = None
for level_idx, (cls_score, bbox_pred, kernel_pred,
score_factor, priors) in \
enumerate(zip(cls_score_list, bbox_pred_list, kernel_pred_list,
score_factor_list, mlvl_priors)):
assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
dim = self.bbox_coder.encode_size
bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, dim)
if with_score_factors:
score_factor = score_factor.permute(1, 2,
0).reshape(-1).sigmoid()
cls_score = cls_score.permute(1, 2,
0).reshape(-1, self.cls_out_channels)
kernel_pred = kernel_pred.permute(1, 2, 0).reshape(
-1, self.num_gen_params)
if self.use_sigmoid_cls:
scores = cls_score.sigmoid()
else:
# remind that we set FG labels to [0, num_class-1]
# since mmdet v2.0
# BG cat_id: num_class
scores = cls_score.softmax(-1)[:, :-1]
# After https://github.com/open-mmlab/mmdetection/pull/6268/,
# this operation keeps fewer bboxes under the same `nms_pre`.
# There is no difference in performance for most models. If you
# find a slight drop in performance, you can set a larger
# `nms_pre` than before.
score_thr = cfg.get('score_thr', 0)
results = filter_scores_and_topk(
scores, score_thr, nms_pre,
dict(
bbox_pred=bbox_pred,
priors=priors,
kernel_pred=kernel_pred))
scores, labels, keep_idxs, filtered_results = results
bbox_pred = filtered_results['bbox_pred']
priors = filtered_results['priors']
kernel_pred = filtered_results['kernel_pred']
if with_score_factors:
score_factor = score_factor[keep_idxs]
mlvl_bbox_preds.append(bbox_pred)
mlvl_valid_priors.append(priors)
mlvl_scores.append(scores)
mlvl_labels.append(labels)
mlvl_kernels.append(kernel_pred)
if with_score_factors:
mlvl_score_factors.append(score_factor)
bbox_pred = torch.cat(mlvl_bbox_preds)
priors = cat_boxes(mlvl_valid_priors)
bboxes = self.bbox_coder.decode(
priors[..., :2], bbox_pred, max_shape=img_shape)
results = InstanceData()
results.bboxes = bboxes
results.priors = priors
results.scores = torch.cat(mlvl_scores)
results.labels = torch.cat(mlvl_labels)
results.kernels = torch.cat(mlvl_kernels)
if with_score_factors:
results.score_factors = torch.cat(mlvl_score_factors)
return self._bbox_mask_post_process(
results=results,
mask_feat=mask_feat,
cfg=cfg,
rescale=rescale,
with_nms=with_nms,
img_meta=img_meta)
def _bbox_mask_post_process(
self,
results: InstanceData,
mask_feat,
cfg: ConfigType,
rescale: bool = False,
with_nms: bool = True,
img_meta: Optional[dict] = None) -> InstanceData:
"""bbox and mask post-processing method.
The boxes would be rescaled to the original image scale and do
the nms operation. Usually `with_nms` is False is used for aug test.
Args:
results (:obj:`InstaceData`): Detection instance results,
each item has shape (num_bboxes, ).
cfg (ConfigDict): Test / postprocessing configuration,
if None, test_cfg would be used.
rescale (bool): If True, return boxes in original image space.
Default to False.
with_nms (bool): If True, do nms before return boxes.
Default to True.
img_meta (dict, optional): Image meta info. Defaults to None.
Returns:
:obj:`InstanceData`: Detection results of each image
after the post process.
Each item usually contains following keys.
- scores (Tensor): Classification scores, has a shape
(num_instance, )
- labels (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes (Tensor): Has a shape (num_instances, 4),
the last dimension 4 arrange as (x1, y1, x2, y2).
- masks (Tensor): Has a shape (num_instances, h, w).
"""
stride = self.prior_generator.strides[0][0]
if rescale:
assert img_meta.get('scale_factor') is not None
scale_factor = [1 / s for s in img_meta['scale_factor']]
results.bboxes = scale_boxes(results.bboxes, scale_factor)
if hasattr(results, 'score_factors'):
# TODO: Add sqrt operation in order to be consistent with
# the paper.
score_factors = results.pop('score_factors')
results.scores = results.scores * score_factors
# filter small size bboxes
if cfg.get('min_bbox_size', -1) >= 0:
w, h = get_box_wh(results.bboxes)
valid_mask = (w > cfg.min_bbox_size) & (h > cfg.min_bbox_size)
if not valid_mask.all():
results = results[valid_mask]
# TODO: deal with `with_nms` and `nms_cfg=None` in test_cfg
assert with_nms, 'with_nms must be True for RTMDet-Ins'
if results.bboxes.numel() > 0:
bboxes = get_box_tensor(results.bboxes)
det_bboxes, keep_idxs = batched_nms(bboxes, results.scores,
results.labels, cfg.nms)
results = results[keep_idxs]
# some nms would reweight the score, such as softnms
results.scores = det_bboxes[:, -1]
results = results[:cfg.max_per_img]
# process masks
mask_logits = self._mask_predict_by_feat_single(
mask_feat, results.kernels, results.priors)
mask_logits = F.interpolate(
mask_logits.unsqueeze(0), scale_factor=stride, mode='bilinear')
if rescale:
ori_h, ori_w = img_meta['ori_shape'][:2]
mask_logits = F.interpolate(
mask_logits,
size=[
math.ceil(mask_logits.shape[-2] * scale_factor[0]),
math.ceil(mask_logits.shape[-1] * scale_factor[1])
],
mode='bilinear',
align_corners=False)[..., :ori_h, :ori_w]
masks = mask_logits.sigmoid().squeeze(0)
masks = masks > cfg.mask_thr_binary
results.masks = masks
else:
h, w = img_meta['ori_shape'][:2] if rescale else img_meta[
'img_shape'][:2]
results.masks = torch.zeros(
size=(results.bboxes.shape[0], h, w),
dtype=torch.bool,
device=results.bboxes.device)
return results
def parse_dynamic_params(self, flatten_kernels: Tensor) -> tuple:
"""split kernel head prediction to conv weight and bias."""
n_inst = flatten_kernels.size(0)
n_layers = len(self.weight_nums)
params_splits = list(
torch.split_with_sizes(
flatten_kernels, self.weight_nums + self.bias_nums, dim=1))
weight_splits = params_splits[:n_layers]
bias_splits = params_splits[n_layers:]
for i in range(n_layers):
if i < n_layers - 1:
weight_splits[i] = weight_splits[i].reshape(
n_inst * self.dyconv_channels, -1, 1, 1)
bias_splits[i] = bias_splits[i].reshape(n_inst *
self.dyconv_channels)
else:
weight_splits[i] = weight_splits[i].reshape(n_inst, -1, 1, 1)
bias_splits[i] = bias_splits[i].reshape(n_inst)
return weight_splits, bias_splits
def _mask_predict_by_feat_single(self, mask_feat: Tensor, kernels: Tensor,
priors: Tensor) -> Tensor:
"""Generate mask logits from mask features with dynamic convs.
Args:
mask_feat (Tensor): Mask prototype features.
Has shape (num_prototypes, H, W).
kernels (Tensor): Kernel parameters for each instance.
Has shape (num_instance, num_params)
priors (Tensor): Center priors for each instance.
Has shape (num_instance, 4).
Returns:
Tensor: Instance segmentation masks for each instance.
Has shape (num_instance, H, W).
"""
num_inst = priors.shape[0]
h, w = mask_feat.size()[-2:]
if num_inst < 1:
return torch.empty(
size=(num_inst, h, w),
dtype=mask_feat.dtype,
device=mask_feat.device)
if len(mask_feat.shape) < 4:
mask_feat.unsqueeze(0)
coord = self.prior_generator.single_level_grid_priors(
(h, w), level_idx=0).reshape(1, -1, 2)
num_inst = priors.shape[0]
points = priors[:, :2].reshape(-1, 1, 2)
strides = priors[:, 2:].reshape(-1, 1, 2)
relative_coord = (points - coord).permute(0, 2, 1) / (
strides[..., 0].reshape(-1, 1, 1) * 8)
relative_coord = relative_coord.reshape(num_inst, 2, h, w)
mask_feat = torch.cat(
[relative_coord,
mask_feat.repeat(num_inst, 1, 1, 1)], dim=1)
weights, biases = self.parse_dynamic_params(kernels)
n_layers = len(weights)
x = mask_feat.reshape(1, -1, h, w)
for i, (weight, bias) in enumerate(zip(weights, biases)):
x = F.conv2d(
x, weight, bias=bias, stride=1, padding=0, groups=num_inst)
if i < n_layers - 1:
x = F.relu(x)
x = x.reshape(num_inst, h, w)
return x
def loss_mask_by_feat(self, mask_feats: Tensor, flatten_kernels: Tensor,
sampling_results_list: list,
batch_gt_instances: InstanceList) -> Tensor:
"""Compute instance segmentation loss.
Args:
mask_feats (list[Tensor]): Mask prototype features extracted from
the mask head. Has shape (N, num_prototypes, H, W)
flatten_kernels (list[Tensor]): Kernels of the dynamic conv layers.
Has shape (N, num_instances, num_params)
sampling_results_list (list[:obj:`SamplingResults`]) Batch of
assignment results.
batch_gt_instances (list[:obj:`InstanceData`]): Batch of
gt_instance. It usually includes ``bboxes`` and ``labels``
attributes.
Returns:
Tensor: The mask loss tensor.
"""
batch_pos_mask_logits = []
pos_gt_masks = []
for idx, (mask_feat, kernels, sampling_results,
gt_instances) in enumerate(
zip(mask_feats, flatten_kernels, sampling_results_list,
batch_gt_instances)):
pos_priors = sampling_results.pos_priors
pos_inds = sampling_results.pos_inds
pos_kernels = kernels[pos_inds] # n_pos, num_gen_params
pos_mask_logits = self._mask_predict_by_feat_single(
mask_feat, pos_kernels, pos_priors)
if gt_instances.masks.numel() == 0:
gt_masks = torch.empty_like(gt_instances.masks)
else:
gt_masks = gt_instances.masks[
sampling_results.pos_assigned_gt_inds, :]
batch_pos_mask_logits.append(pos_mask_logits)
pos_gt_masks.append(gt_masks)
pos_gt_masks = torch.cat(pos_gt_masks, 0)
batch_pos_mask_logits = torch.cat(batch_pos_mask_logits, 0)
# avg_factor
num_pos = batch_pos_mask_logits.shape[0]
num_pos = reduce_mean(mask_feats.new_tensor([num_pos
])).clamp_(min=1).item()
if batch_pos_mask_logits.shape[0] == 0:
return mask_feats.sum() * 0
scale = self.prior_generator.strides[0][0] // self.mask_loss_stride
# upsample pred masks
batch_pos_mask_logits = F.interpolate(
batch_pos_mask_logits.unsqueeze(0),
scale_factor=scale,
mode='bilinear',
align_corners=False).squeeze(0)
# downsample gt masks
pos_gt_masks = pos_gt_masks[:, self.mask_loss_stride //
2::self.mask_loss_stride,
self.mask_loss_stride //
2::self.mask_loss_stride]
loss_mask = self.loss_mask(
batch_pos_mask_logits,
pos_gt_masks,
weight=None,
avg_factor=num_pos)
return loss_mask
def loss_by_feat(self,
cls_scores: List[Tensor],
bbox_preds: List[Tensor],
kernel_preds: List[Tensor],
mask_feat: Tensor,
batch_gt_instances: InstanceList,
batch_img_metas: List[dict],
batch_gt_instances_ignore: OptInstanceList = None):
"""Compute losses of the head.
Args:
cls_scores (list[Tensor]): Box scores for each scale level
Has shape (N, num_anchors * num_classes, H, W)
bbox_preds (list[Tensor]): Decoded box for each scale
level with shape (N, num_anchors * 4, H, W) in
[tl_x, tl_y, br_x, br_y] format.
batch_gt_instances (list[:obj:`InstanceData`]): Batch of
gt_instance. It usually includes ``bboxes`` and ``labels``
attributes.
batch_img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
Batch of gt_instances_ignore. It includes ``bboxes`` attribute
data that is ignored during training and testing.
Defaults to None.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
num_imgs = len(batch_img_metas)
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
assert len(featmap_sizes) == self.prior_generator.num_levels
device = cls_scores[0].device
anchor_list, valid_flag_list = self.get_anchors(
featmap_sizes, batch_img_metas, device=device)
flatten_cls_scores = torch.cat([
cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1,
self.cls_out_channels)
for cls_score in cls_scores
], 1)
flatten_kernels = torch.cat([
kernel_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1,
self.num_gen_params)
for kernel_pred in kernel_preds
], 1)
decoded_bboxes = []
for anchor, bbox_pred in zip(anchor_list[0], bbox_preds):
anchor = anchor.reshape(-1, 4)
bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
bbox_pred = distance2bbox(anchor, bbox_pred)
decoded_bboxes.append(bbox_pred)
flatten_bboxes = torch.cat(decoded_bboxes, 1)
for gt_instances in batch_gt_instances:
gt_instances.masks = gt_instances.masks.to_tensor(
dtype=torch.bool, device=device)
cls_reg_targets = self.get_targets(
flatten_cls_scores,
flatten_bboxes,
anchor_list,
valid_flag_list,
batch_gt_instances,
batch_img_metas,
batch_gt_instances_ignore=batch_gt_instances_ignore)
(anchor_list, labels_list, label_weights_list, bbox_targets_list,
assign_metrics_list, sampling_results_list) = cls_reg_targets
losses_cls, losses_bbox,\
cls_avg_factors, bbox_avg_factors = multi_apply(
self.loss_by_feat_single,
cls_scores,
decoded_bboxes,
labels_list,
label_weights_list,
bbox_targets_list,
assign_metrics_list,
self.prior_generator.strides)
cls_avg_factor = reduce_mean(sum(cls_avg_factors)).clamp_(min=1).item()
losses_cls = list(map(lambda x: x / cls_avg_factor, losses_cls))
bbox_avg_factor = reduce_mean(
sum(bbox_avg_factors)).clamp_(min=1).item()
losses_bbox = list(map(lambda x: x / bbox_avg_factor, losses_bbox))
loss_mask = self.loss_mask_by_feat(mask_feat, flatten_kernels,
sampling_results_list,
batch_gt_instances)
loss = dict(
loss_cls=losses_cls, loss_bbox=losses_bbox, loss_mask=loss_mask)
return loss
class MaskFeatModule(BaseModule):
"""Mask feature head used in RTMDet-Ins.
Args:
in_channels (int): Number of channels in the input feature map.
feat_channels (int): Number of hidden channels of the mask feature
map branch.
num_levels (int): The starting feature map level from RPN that
will be used to predict the mask feature map.
num_prototypes (int): Number of output channel of the mask feature
map branch. This is the channel count of the mask
feature map that to be dynamically convolved with the predicted
kernel.
stacked_convs (int): Number of convs in mask feature branch.
act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer.
Default: dict(type='ReLU', inplace=True)
norm_cfg (dict): Config dict for normalization layer. Default: None.
"""
def __init__(
self,
in_channels: int,
feat_channels: int = 256,
stacked_convs: int = 4,
num_levels: int = 3,
num_prototypes: int = 8,
act_cfg: ConfigType = dict(type='ReLU', inplace=True),
norm_cfg: ConfigType = dict(type='BN')
) -> None:
super().__init__(init_cfg=None)
self.num_levels = num_levels
self.fusion_conv = nn.Conv2d(num_levels * in_channels, in_channels, 1)
convs = []
for i in range(stacked_convs):
in_c = in_channels if i == 0 else feat_channels
convs.append(
ConvModule(
in_c,
feat_channels,
3,
padding=1,
act_cfg=act_cfg,
norm_cfg=norm_cfg))
self.stacked_convs = nn.Sequential(*convs)
self.projection = nn.Conv2d(
feat_channels, num_prototypes, kernel_size=1)
def forward(self, features: Tuple[Tensor, ...]) -> Tensor:
# multi-level feature fusion
fusion_feats = [features[0]]
size = features[0].shape[-2:]
for i in range(1, self.num_levels):
f = F.interpolate(features[i], size=size, mode='bilinear')
fusion_feats.append(f)
fusion_feats = torch.cat(fusion_feats, dim=1)
fusion_feats = self.fusion_conv(fusion_feats)
# pred mask feats
mask_features = self.stacked_convs(fusion_feats)
mask_features = self.projection(mask_features)
return mask_features
@MODELS.register_module()
class RTMDetInsSepBNHead(RTMDetInsHead):
"""Detection Head of RTMDet-Ins with sep-bn layers.
Args:
num_classes (int): Number of categories excluding the background
category.
in_channels (int): Number of channels in the input feature map.
share_conv (bool): Whether to share conv layers between stages.
Defaults to True.
norm_cfg (:obj:`ConfigDict` or dict)): Config dict for normalization
layer. Defaults to dict(type='BN').
act_cfg (:obj:`ConfigDict` or dict)): Config dict for activation layer.
Defaults to dict(type='SiLU', inplace=True).
pred_kernel_size (int): Kernel size of prediction layer. Defaults to 1.
"""
def __init__(self,
num_classes: int,
in_channels: int,
share_conv: bool = True,
with_objectness: bool = False,
norm_cfg: ConfigType = dict(type='BN', requires_grad=True),
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
pred_kernel_size: int = 1,
**kwargs) -> None:
self.share_conv = share_conv
super().__init__(
num_classes,
in_channels,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
pred_kernel_size=pred_kernel_size,
with_objectness=with_objectness,
**kwargs)
def _init_layers(self) -> None:
"""Initialize layers of the head."""
self.cls_convs = nn.ModuleList()
self.reg_convs = nn.ModuleList()
self.kernel_convs = nn.ModuleList()
self.rtm_cls = nn.ModuleList()
self.rtm_reg = nn.ModuleList()
self.rtm_kernel = nn.ModuleList()
self.rtm_obj = nn.ModuleList()
# calculate num dynamic parameters
weight_nums, bias_nums = [], []
for i in range(self.num_dyconvs):
if i == 0:
weight_nums.append(
(self.num_prototypes + 2) * self.dyconv_channels)
bias_nums.append(self.dyconv_channels)
elif i == self.num_dyconvs - 1:
weight_nums.append(self.dyconv_channels)
bias_nums.append(1)
else:
weight_nums.append(self.dyconv_channels * self.dyconv_channels)
bias_nums.append(self.dyconv_channels)
self.weight_nums = weight_nums
self.bias_nums = bias_nums
self.num_gen_params = sum(weight_nums) + sum(bias_nums)
pred_pad_size = self.pred_kernel_size // 2
for n in range(len(self.prior_generator.strides)):
cls_convs = nn.ModuleList()
reg_convs = nn.ModuleList()
kernel_convs = nn.ModuleList()
for i in range(self.stacked_convs):
chn = self.in_channels if i == 0 else self.feat_channels
cls_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
reg_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
kernel_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
self.cls_convs.append(cls_convs)
self.reg_convs.append(cls_convs)
self.kernel_convs.append(kernel_convs)
self.rtm_cls.append(
nn.Conv2d(
self.feat_channels,
self.num_base_priors * self.cls_out_channels,
self.pred_kernel_size,
padding=pred_pad_size))
self.rtm_reg.append(
nn.Conv2d(
self.feat_channels,
self.num_base_priors * 4,
self.pred_kernel_size,
padding=pred_pad_size))
self.rtm_kernel.append(
nn.Conv2d(
self.feat_channels,
self.num_gen_params,
self.pred_kernel_size,
padding=pred_pad_size))
if self.with_objectness:
self.rtm_obj.append(
nn.Conv2d(
self.feat_channels,
1,
self.pred_kernel_size,
padding=pred_pad_size))
if self.share_conv:
for n in range(len(self.prior_generator.strides)):
for i in range(self.stacked_convs):
self.cls_convs[n][i].conv = self.cls_convs[0][i].conv
self.reg_convs[n][i].conv = self.reg_convs[0][i].conv
self.mask_head = MaskFeatModule(
in_channels=self.in_channels,
feat_channels=self.feat_channels,
stacked_convs=4,
num_levels=len(self.prior_generator.strides),
num_prototypes=self.num_prototypes,
act_cfg=self.act_cfg,
norm_cfg=self.norm_cfg)
def init_weights(self) -> None:
"""Initialize weights of the head."""
for m in self.modules():
if isinstance(m, nn.Conv2d):
normal_init(m, mean=0, std=0.01)
if is_norm(m):
constant_init(m, 1)
bias_cls = bias_init_with_prob(0.01)
for rtm_cls, rtm_reg, rtm_kernel in zip(self.rtm_cls, self.rtm_reg,
self.rtm_kernel):
normal_init(rtm_cls, std=0.01, bias=bias_cls)
normal_init(rtm_reg, std=0.01, bias=1)
if self.with_objectness:
for rtm_obj in self.rtm_obj:
normal_init(rtm_obj, std=0.01, bias=bias_cls)
def forward(self, feats: Tuple[Tensor, ...]) -> tuple:
"""Forward features from the upstream network.
Args:
feats (tuple[Tensor]): Features from the upstream network, each is
a 4D-tensor.
Returns:
tuple: Usually a tuple of classification scores and bbox prediction
- cls_scores (list[Tensor]): Classification scores for all scale
levels, each is a 4D-tensor, the channels number is
num_base_priors * num_classes.
- bbox_preds (list[Tensor]): Box energies / deltas for all scale
levels, each is a 4D-tensor, the channels number is
num_base_priors * 4.
- kernel_preds (list[Tensor]): Dynamic conv kernels for all scale
levels, each is a 4D-tensor, the channels number is
num_gen_params.
- mask_feat (Tensor): Output feature of the mask head. Each is a
4D-tensor, the channels number is num_prototypes.
"""
mask_feat = self.mask_head(feats)
cls_scores = []
bbox_preds = []
kernel_preds = []
for idx, (x, stride) in enumerate(
zip(feats, self.prior_generator.strides)):
cls_feat = x
reg_feat = x
kernel_feat = x
for cls_layer in self.cls_convs[idx]:
cls_feat = cls_layer(cls_feat)
cls_score = self.rtm_cls[idx](cls_feat)
for kernel_layer in self.kernel_convs[idx]:
kernel_feat = kernel_layer(kernel_feat)
kernel_pred = self.rtm_kernel[idx](kernel_feat)
for reg_layer in self.reg_convs[idx]:
reg_feat = reg_layer(reg_feat)
if self.with_objectness:
objectness = self.rtm_obj[idx](reg_feat)
cls_score = inverse_sigmoid(
sigmoid_geometric_mean(cls_score, objectness))
reg_dist = F.relu(self.rtm_reg[idx](reg_feat)) * stride[0]
cls_scores.append(cls_score)
bbox_preds.append(reg_dist)
kernel_preds.append(kernel_pred)
return tuple(cls_scores), tuple(bbox_preds), tuple(
kernel_preds), mask_feat