KyanChen's picture
init
f549064
raw
history blame
53.7 kB
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule, Scale
from mmengine.config import ConfigDict
from mmengine.model import BaseModule, kaiming_init
from mmengine.structures import InstanceData
from torch import Tensor
from mmdet.registry import MODELS
from mmdet.structures.bbox import cat_boxes
from mmdet.utils import (ConfigType, InstanceList, MultiConfig, OptConfigType,
OptInstanceList, reduce_mean)
from ..task_modules.prior_generators import MlvlPointGenerator
from ..utils import (aligned_bilinear, filter_scores_and_topk, multi_apply,
relative_coordinate_maps, select_single_mlvl)
from ..utils.misc import empty_instances
from .base_mask_head import BaseMaskHead
from .fcos_head import FCOSHead
INF = 1e8
@MODELS.register_module()
class CondInstBboxHead(FCOSHead):
"""CondInst box head used in https://arxiv.org/abs/1904.02689.
Note that CondInst Bbox Head is a extension of FCOS head.
Two differences are described as follows:
1. CondInst box head predicts a set of params for each instance.
2. CondInst box head return the pos_gt_inds and pos_inds.
Args:
num_params (int): Number of params for instance segmentation.
"""
def __init__(self, *args, num_params: int = 169, **kwargs) -> None:
self.num_params = num_params
super().__init__(*args, **kwargs)
def _init_layers(self) -> None:
"""Initialize layers of the head."""
super()._init_layers()
self.controller = nn.Conv2d(
self.feat_channels, self.num_params, 3, padding=1)
def forward_single(self, x: Tensor, scale: Scale,
stride: int) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Forward features of a single scale level.
Args:
x (Tensor): FPN feature maps of the specified stride.
scale (:obj:`mmcv.cnn.Scale`): Learnable scale module to resize
the bbox prediction.
stride (int): The corresponding stride for feature maps, only
used to normalize the bbox prediction when self.norm_on_bbox
is True.
Returns:
tuple: scores for each class, bbox predictions, centerness
predictions and param predictions of input feature maps.
"""
cls_score, bbox_pred, cls_feat, reg_feat = \
super(FCOSHead, self).forward_single(x)
if self.centerness_on_reg:
centerness = self.conv_centerness(reg_feat)
else:
centerness = self.conv_centerness(cls_feat)
# scale the bbox_pred of different level
# float to avoid overflow when enabling FP16
bbox_pred = scale(bbox_pred).float()
if self.norm_on_bbox:
# bbox_pred needed for gradient computation has been modified
# by F.relu(bbox_pred) when run with PyTorch 1.10. So replace
# F.relu(bbox_pred) with bbox_pred.clamp(min=0)
bbox_pred = bbox_pred.clamp(min=0)
if not self.training:
bbox_pred *= stride
else:
bbox_pred = bbox_pred.exp()
param_pred = self.controller(reg_feat)
return cls_score, bbox_pred, centerness, param_pred
def loss_by_feat(
self,
cls_scores: List[Tensor],
bbox_preds: List[Tensor],
centernesses: List[Tensor],
param_preds: List[Tensor],
batch_gt_instances: InstanceList,
batch_img_metas: List[dict],
batch_gt_instances_ignore: OptInstanceList = None
) -> Dict[str, Tensor]:
"""Calculate the loss based on the features extracted by the detection
head.
Args:
cls_scores (list[Tensor]): Box scores for each scale level,
each is a 4D-tensor, the channel number is
num_points * num_classes.
bbox_preds (list[Tensor]): Box energies / deltas for each scale
level, each is a 4D-tensor, the channel number is
num_points * 4.
centernesses (list[Tensor]): centerness for each scale level, each
is a 4D-tensor, the channel number is num_points * 1.
param_preds (List[Tensor]): param_pred for each scale level, each
is a 4D-tensor, the channel number is num_params.
batch_gt_instances (list[:obj:`InstanceData`]): Batch of
gt_instance. It usually includes ``bboxes`` and ``labels``
attributes.
batch_img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
Batch of gt_instances_ignore. It includes ``bboxes`` attribute
data that is ignored during training and testing.
Defaults to None.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
assert len(cls_scores) == len(bbox_preds) == len(centernesses)
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
# Need stride for rel coord compute
all_level_points_strides = self.prior_generator.grid_priors(
featmap_sizes,
dtype=bbox_preds[0].dtype,
device=bbox_preds[0].device,
with_stride=True)
all_level_points = [i[:, :2] for i in all_level_points_strides]
all_level_strides = [i[:, 2] for i in all_level_points_strides]
labels, bbox_targets, pos_inds_list, pos_gt_inds_list = \
self.get_targets(all_level_points, batch_gt_instances)
num_imgs = cls_scores[0].size(0)
# flatten cls_scores, bbox_preds and centerness
flatten_cls_scores = [
cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels)
for cls_score in cls_scores
]
flatten_bbox_preds = [
bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
for bbox_pred in bbox_preds
]
flatten_centerness = [
centerness.permute(0, 2, 3, 1).reshape(-1)
for centerness in centernesses
]
flatten_cls_scores = torch.cat(flatten_cls_scores)
flatten_bbox_preds = torch.cat(flatten_bbox_preds)
flatten_centerness = torch.cat(flatten_centerness)
flatten_labels = torch.cat(labels)
flatten_bbox_targets = torch.cat(bbox_targets)
# repeat points to align with bbox_preds
flatten_points = torch.cat(
[points.repeat(num_imgs, 1) for points in all_level_points])
# FG cat_id: [0, num_classes -1], BG cat_id: num_classes
bg_class_ind = self.num_classes
pos_inds = ((flatten_labels >= 0)
& (flatten_labels < bg_class_ind)).nonzero().reshape(-1)
num_pos = torch.tensor(
len(pos_inds), dtype=torch.float, device=bbox_preds[0].device)
num_pos = max(reduce_mean(num_pos), 1.0)
loss_cls = self.loss_cls(
flatten_cls_scores, flatten_labels, avg_factor=num_pos)
pos_bbox_preds = flatten_bbox_preds[pos_inds]
pos_centerness = flatten_centerness[pos_inds]
pos_bbox_targets = flatten_bbox_targets[pos_inds]
pos_centerness_targets = self.centerness_target(pos_bbox_targets)
# centerness weighted iou loss
centerness_denorm = max(
reduce_mean(pos_centerness_targets.sum().detach()), 1e-6)
if len(pos_inds) > 0:
pos_points = flatten_points[pos_inds]
pos_decoded_bbox_preds = self.bbox_coder.decode(
pos_points, pos_bbox_preds)
pos_decoded_target_preds = self.bbox_coder.decode(
pos_points, pos_bbox_targets)
loss_bbox = self.loss_bbox(
pos_decoded_bbox_preds,
pos_decoded_target_preds,
weight=pos_centerness_targets,
avg_factor=centerness_denorm)
loss_centerness = self.loss_centerness(
pos_centerness, pos_centerness_targets, avg_factor=num_pos)
else:
loss_bbox = pos_bbox_preds.sum()
loss_centerness = pos_centerness.sum()
self._raw_positive_infos.update(cls_scores=cls_scores)
self._raw_positive_infos.update(centernesses=centernesses)
self._raw_positive_infos.update(param_preds=param_preds)
self._raw_positive_infos.update(all_level_points=all_level_points)
self._raw_positive_infos.update(all_level_strides=all_level_strides)
self._raw_positive_infos.update(pos_gt_inds_list=pos_gt_inds_list)
self._raw_positive_infos.update(pos_inds_list=pos_inds_list)
return dict(
loss_cls=loss_cls,
loss_bbox=loss_bbox,
loss_centerness=loss_centerness)
def get_targets(
self, points: List[Tensor], batch_gt_instances: InstanceList
) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]:
"""Compute regression, classification and centerness targets for points
in multiple images.
Args:
points (list[Tensor]): Points of each fpn level, each has shape
(num_points, 2).
batch_gt_instances (list[:obj:`InstanceData`]): Batch of
gt_instance. It usually includes ``bboxes`` and ``labels``
attributes.
Returns:
tuple: Targets of each level.
- concat_lvl_labels (list[Tensor]): Labels of each level.
- concat_lvl_bbox_targets (list[Tensor]): BBox targets of each \
level.
- pos_inds_list (list[Tensor]): pos_inds of each image.
- pos_gt_inds_list (List[Tensor]): pos_gt_inds of each image.
"""
assert len(points) == len(self.regress_ranges)
num_levels = len(points)
# expand regress ranges to align with points
expanded_regress_ranges = [
points[i].new_tensor(self.regress_ranges[i])[None].expand_as(
points[i]) for i in range(num_levels)
]
# concat all levels points and regress ranges
concat_regress_ranges = torch.cat(expanded_regress_ranges, dim=0)
concat_points = torch.cat(points, dim=0)
# the number of points per img, per lvl
num_points = [center.size(0) for center in points]
# get labels and bbox_targets of each image
labels_list, bbox_targets_list, pos_inds_list, pos_gt_inds_list = \
multi_apply(
self._get_targets_single,
batch_gt_instances,
points=concat_points,
regress_ranges=concat_regress_ranges,
num_points_per_lvl=num_points)
# split to per img, per level
labels_list = [labels.split(num_points, 0) for labels in labels_list]
bbox_targets_list = [
bbox_targets.split(num_points, 0)
for bbox_targets in bbox_targets_list
]
# concat per level image
concat_lvl_labels = []
concat_lvl_bbox_targets = []
for i in range(num_levels):
concat_lvl_labels.append(
torch.cat([labels[i] for labels in labels_list]))
bbox_targets = torch.cat(
[bbox_targets[i] for bbox_targets in bbox_targets_list])
if self.norm_on_bbox:
bbox_targets = bbox_targets / self.strides[i]
concat_lvl_bbox_targets.append(bbox_targets)
return (concat_lvl_labels, concat_lvl_bbox_targets, pos_inds_list,
pos_gt_inds_list)
def _get_targets_single(
self, gt_instances: InstanceData, points: Tensor,
regress_ranges: Tensor, num_points_per_lvl: List[int]
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Compute regression and classification targets for a single image."""
num_points = points.size(0)
num_gts = len(gt_instances)
gt_bboxes = gt_instances.bboxes
gt_labels = gt_instances.labels
gt_masks = gt_instances.get('masks', None)
if num_gts == 0:
return gt_labels.new_full((num_points,), self.num_classes), \
gt_bboxes.new_zeros((num_points, 4)), \
gt_bboxes.new_zeros((0,), dtype=torch.int64), \
gt_bboxes.new_zeros((0,), dtype=torch.int64)
areas = (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * (
gt_bboxes[:, 3] - gt_bboxes[:, 1])
# TODO: figure out why these two are different
# areas = areas[None].expand(num_points, num_gts)
areas = areas[None].repeat(num_points, 1)
regress_ranges = regress_ranges[:, None, :].expand(
num_points, num_gts, 2)
gt_bboxes = gt_bboxes[None].expand(num_points, num_gts, 4)
xs, ys = points[:, 0], points[:, 1]
xs = xs[:, None].expand(num_points, num_gts)
ys = ys[:, None].expand(num_points, num_gts)
left = xs - gt_bboxes[..., 0]
right = gt_bboxes[..., 2] - xs
top = ys - gt_bboxes[..., 1]
bottom = gt_bboxes[..., 3] - ys
bbox_targets = torch.stack((left, top, right, bottom), -1)
if self.center_sampling:
# condition1: inside a `center bbox`
radius = self.center_sample_radius
# if gt_mask not None, use gt mask's centroid to determine
# the center region rather than gt_bbox center
if gt_masks is None:
center_xs = (gt_bboxes[..., 0] + gt_bboxes[..., 2]) / 2
center_ys = (gt_bboxes[..., 1] + gt_bboxes[..., 3]) / 2
else:
h, w = gt_masks.height, gt_masks.width
masks = gt_masks.to_tensor(
dtype=torch.bool, device=gt_bboxes.device)
yys = torch.arange(
0, h, dtype=torch.float32, device=masks.device)
xxs = torch.arange(
0, w, dtype=torch.float32, device=masks.device)
# m00/m10/m01 represent the moments of a contour
# centroid is computed by m00/m10 and m00/m01
m00 = masks.sum(dim=-1).sum(dim=-1).clamp(min=1e-6)
m10 = (masks * xxs).sum(dim=-1).sum(dim=-1)
m01 = (masks * yys[:, None]).sum(dim=-1).sum(dim=-1)
center_xs = m10 / m00
center_ys = m01 / m00
center_xs = center_xs[None].expand(num_points, num_gts)
center_ys = center_ys[None].expand(num_points, num_gts)
center_gts = torch.zeros_like(gt_bboxes)
stride = center_xs.new_zeros(center_xs.shape)
# project the points on current lvl back to the `original` sizes
lvl_begin = 0
for lvl_idx, num_points_lvl in enumerate(num_points_per_lvl):
lvl_end = lvl_begin + num_points_lvl
stride[lvl_begin:lvl_end] = self.strides[lvl_idx] * radius
lvl_begin = lvl_end
x_mins = center_xs - stride
y_mins = center_ys - stride
x_maxs = center_xs + stride
y_maxs = center_ys + stride
center_gts[..., 0] = torch.where(x_mins > gt_bboxes[..., 0],
x_mins, gt_bboxes[..., 0])
center_gts[..., 1] = torch.where(y_mins > gt_bboxes[..., 1],
y_mins, gt_bboxes[..., 1])
center_gts[..., 2] = torch.where(x_maxs > gt_bboxes[..., 2],
gt_bboxes[..., 2], x_maxs)
center_gts[..., 3] = torch.where(y_maxs > gt_bboxes[..., 3],
gt_bboxes[..., 3], y_maxs)
cb_dist_left = xs - center_gts[..., 0]
cb_dist_right = center_gts[..., 2] - xs
cb_dist_top = ys - center_gts[..., 1]
cb_dist_bottom = center_gts[..., 3] - ys
center_bbox = torch.stack(
(cb_dist_left, cb_dist_top, cb_dist_right, cb_dist_bottom), -1)
inside_gt_bbox_mask = center_bbox.min(-1)[0] > 0
else:
# condition1: inside a gt bbox
inside_gt_bbox_mask = bbox_targets.min(-1)[0] > 0
# condition2: limit the regression range for each location
max_regress_distance = bbox_targets.max(-1)[0]
inside_regress_range = (
(max_regress_distance >= regress_ranges[..., 0])
& (max_regress_distance <= regress_ranges[..., 1]))
# if there are still more than one objects for a location,
# we choose the one with minimal area
areas[inside_gt_bbox_mask == 0] = INF
areas[inside_regress_range == 0] = INF
min_area, min_area_inds = areas.min(dim=1)
labels = gt_labels[min_area_inds]
labels[min_area == INF] = self.num_classes # set as BG
bbox_targets = bbox_targets[range(num_points), min_area_inds]
# return pos_inds & pos_gt_inds
bg_class_ind = self.num_classes
pos_inds = ((labels >= 0)
& (labels < bg_class_ind)).nonzero().reshape(-1)
pos_gt_inds = min_area_inds[labels < self.num_classes]
return labels, bbox_targets, pos_inds, pos_gt_inds
def get_positive_infos(self) -> InstanceList:
"""Get positive information from sampling results.
Returns:
list[:obj:`InstanceData`]: Positive information of each image,
usually including positive bboxes, positive labels, positive
priors, etc.
"""
assert len(self._raw_positive_infos) > 0
pos_gt_inds_list = self._raw_positive_infos['pos_gt_inds_list']
pos_inds_list = self._raw_positive_infos['pos_inds_list']
num_imgs = len(pos_gt_inds_list)
cls_score_list = []
centerness_list = []
param_pred_list = []
point_list = []
stride_list = []
for cls_score_per_lvl, centerness_per_lvl, param_pred_per_lvl,\
point_per_lvl, stride_per_lvl in \
zip(self._raw_positive_infos['cls_scores'],
self._raw_positive_infos['centernesses'],
self._raw_positive_infos['param_preds'],
self._raw_positive_infos['all_level_points'],
self._raw_positive_infos['all_level_strides']):
cls_score_per_lvl = \
cls_score_per_lvl.permute(
0, 2, 3, 1).reshape(num_imgs, -1, self.num_classes)
centerness_per_lvl = \
centerness_per_lvl.permute(
0, 2, 3, 1).reshape(num_imgs, -1, 1)
param_pred_per_lvl = \
param_pred_per_lvl.permute(
0, 2, 3, 1).reshape(num_imgs, -1, self.num_params)
point_per_lvl = point_per_lvl.unsqueeze(0).repeat(num_imgs, 1, 1)
stride_per_lvl = stride_per_lvl.unsqueeze(0).repeat(num_imgs, 1)
cls_score_list.append(cls_score_per_lvl)
centerness_list.append(centerness_per_lvl)
param_pred_list.append(param_pred_per_lvl)
point_list.append(point_per_lvl)
stride_list.append(stride_per_lvl)
cls_scores = torch.cat(cls_score_list, dim=1)
centernesses = torch.cat(centerness_list, dim=1)
param_preds = torch.cat(param_pred_list, dim=1)
all_points = torch.cat(point_list, dim=1)
all_strides = torch.cat(stride_list, dim=1)
positive_infos = []
for i, (pos_gt_inds,
pos_inds) in enumerate(zip(pos_gt_inds_list, pos_inds_list)):
pos_info = InstanceData()
pos_info.points = all_points[i][pos_inds]
pos_info.strides = all_strides[i][pos_inds]
pos_info.scores = cls_scores[i][pos_inds]
pos_info.centernesses = centernesses[i][pos_inds]
pos_info.param_preds = param_preds[i][pos_inds]
pos_info.pos_assigned_gt_inds = pos_gt_inds
pos_info.pos_inds = pos_inds
positive_infos.append(pos_info)
return positive_infos
def predict_by_feat(self,
cls_scores: List[Tensor],
bbox_preds: List[Tensor],
score_factors: Optional[List[Tensor]] = None,
param_preds: Optional[List[Tensor]] = None,
batch_img_metas: Optional[List[dict]] = None,
cfg: Optional[ConfigDict] = None,
rescale: bool = False,
with_nms: bool = True) -> InstanceList:
"""Transform a batch of output features extracted from the head into
bbox results.
Note: When score_factors is not None, the cls_scores are
usually multiplied by it then obtain the real score used in NMS,
such as CenterNess in FCOS, IoU branch in ATSS.
Args:
cls_scores (list[Tensor]): Classification scores for all
scale levels, each is a 4D-tensor, has shape
(batch_size, num_priors * num_classes, H, W).
bbox_preds (list[Tensor]): Box energies / deltas for all
scale levels, each is a 4D-tensor, has shape
(batch_size, num_priors * 4, H, W).
score_factors (list[Tensor], optional): Score factor for
all scale level, each is a 4D-tensor, has shape
(batch_size, num_priors * 1, H, W). Defaults to None.
param_preds (list[Tensor], optional): Params for all scale
level, each is a 4D-tensor, has shape
(batch_size, num_priors * num_params, H, W)
batch_img_metas (list[dict], Optional): Batch image meta info.
Defaults to None.
cfg (ConfigDict, optional): Test / postprocessing
configuration, if None, test_cfg would be used.
Defaults to None.
rescale (bool): If True, return boxes in original image space.
Defaults to False.
with_nms (bool): If True, do nms before return boxes.
Defaults to True.
Returns:
list[:obj:`InstanceData`]: Object detection results of each image
after the post process. Each item usually contains following keys.
- scores (Tensor): Classification scores, has a shape
(num_instance, )
- labels (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes (Tensor): Has a shape (num_instances, 4),
the last dimension 4 arrange as (x1, y1, x2, y2).
"""
assert len(cls_scores) == len(bbox_preds)
if score_factors is None:
# e.g. Retina, FreeAnchor, Foveabox, etc.
with_score_factors = False
else:
# e.g. FCOS, PAA, ATSS, AutoAssign, etc.
with_score_factors = True
assert len(cls_scores) == len(score_factors)
num_levels = len(cls_scores)
featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)]
all_level_points_strides = self.prior_generator.grid_priors(
featmap_sizes,
dtype=bbox_preds[0].dtype,
device=bbox_preds[0].device,
with_stride=True)
all_level_points = [i[:, :2] for i in all_level_points_strides]
all_level_strides = [i[:, 2] for i in all_level_points_strides]
result_list = []
for img_id in range(len(batch_img_metas)):
img_meta = batch_img_metas[img_id]
cls_score_list = select_single_mlvl(
cls_scores, img_id, detach=True)
bbox_pred_list = select_single_mlvl(
bbox_preds, img_id, detach=True)
if with_score_factors:
score_factor_list = select_single_mlvl(
score_factors, img_id, detach=True)
else:
score_factor_list = [None for _ in range(num_levels)]
param_pred_list = select_single_mlvl(
param_preds, img_id, detach=True)
results = self._predict_by_feat_single(
cls_score_list=cls_score_list,
bbox_pred_list=bbox_pred_list,
score_factor_list=score_factor_list,
param_pred_list=param_pred_list,
mlvl_points=all_level_points,
mlvl_strides=all_level_strides,
img_meta=img_meta,
cfg=cfg,
rescale=rescale,
with_nms=with_nms)
result_list.append(results)
return result_list
def _predict_by_feat_single(self,
cls_score_list: List[Tensor],
bbox_pred_list: List[Tensor],
score_factor_list: List[Tensor],
param_pred_list: List[Tensor],
mlvl_points: List[Tensor],
mlvl_strides: List[Tensor],
img_meta: dict,
cfg: ConfigDict,
rescale: bool = False,
with_nms: bool = True) -> InstanceData:
"""Transform a single image's features extracted from the head into
bbox results.
Args:
cls_score_list (list[Tensor]): Box scores from all scale
levels of a single image, each item has shape
(num_priors * num_classes, H, W).
bbox_pred_list (list[Tensor]): Box energies / deltas from
all scale levels of a single image, each item has shape
(num_priors * 4, H, W).
score_factor_list (list[Tensor]): Score factor from all scale
levels of a single image, each item has shape
(num_priors * 1, H, W).
param_pred_list (List[Tensor]): Param predition from all scale
levels of a single image, each item has shape
(num_priors * num_params, H, W).
mlvl_points (list[Tensor]): Each element in the list is
the priors of a single level in feature pyramid.
It has shape (num_priors, 2)
mlvl_strides (List[Tensor]): Each element in the list is
the stride of a single level in feature pyramid.
It has shape (num_priors, 1)
img_meta (dict): Image meta info.
cfg (mmengine.Config): Test / postprocessing configuration,
if None, test_cfg would be used.
rescale (bool): If True, return boxes in original image space.
Defaults to False.
with_nms (bool): If True, do nms before return boxes.
Defaults to True.
Returns:
:obj:`InstanceData`: Detection results of each image
after the post process.
Each item usually contains following keys.
- scores (Tensor): Classification scores, has a shape
(num_instance, )
- labels (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes (Tensor): Has a shape (num_instances, 4),
the last dimension 4 arrange as (x1, y1, x2, y2).
"""
if score_factor_list[0] is None:
# e.g. Retina, FreeAnchor, etc.
with_score_factors = False
else:
# e.g. FCOS, PAA, ATSS, etc.
with_score_factors = True
cfg = self.test_cfg if cfg is None else cfg
cfg = copy.deepcopy(cfg)
img_shape = img_meta['img_shape']
nms_pre = cfg.get('nms_pre', -1)
mlvl_bbox_preds = []
mlvl_param_preds = []
mlvl_valid_points = []
mlvl_valid_strides = []
mlvl_scores = []
mlvl_labels = []
if with_score_factors:
mlvl_score_factors = []
else:
mlvl_score_factors = None
for level_idx, (cls_score, bbox_pred, score_factor,
param_pred, points, strides) in \
enumerate(zip(cls_score_list, bbox_pred_list,
score_factor_list, param_pred_list,
mlvl_points, mlvl_strides)):
assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
dim = self.bbox_coder.encode_size
bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, dim)
if with_score_factors:
score_factor = score_factor.permute(1, 2,
0).reshape(-1).sigmoid()
cls_score = cls_score.permute(1, 2,
0).reshape(-1, self.cls_out_channels)
if self.use_sigmoid_cls:
scores = cls_score.sigmoid()
else:
# remind that we set FG labels to [0, num_class-1]
# since mmdet v2.0
# BG cat_id: num_class
scores = cls_score.softmax(-1)[:, :-1]
param_pred = param_pred.permute(1, 2,
0).reshape(-1, self.num_params)
# After https://github.com/open-mmlab/mmdetection/pull/6268/,
# this operation keeps fewer bboxes under the same `nms_pre`.
# There is no difference in performance for most models. If you
# find a slight drop in performance, you can set a larger
# `nms_pre` than before.
score_thr = cfg.get('score_thr', 0)
results = filter_scores_and_topk(
scores, score_thr, nms_pre,
dict(
bbox_pred=bbox_pred,
param_pred=param_pred,
points=points,
strides=strides))
scores, labels, keep_idxs, filtered_results = results
bbox_pred = filtered_results['bbox_pred']
param_pred = filtered_results['param_pred']
points = filtered_results['points']
strides = filtered_results['strides']
if with_score_factors:
score_factor = score_factor[keep_idxs]
mlvl_bbox_preds.append(bbox_pred)
mlvl_param_preds.append(param_pred)
mlvl_valid_points.append(points)
mlvl_valid_strides.append(strides)
mlvl_scores.append(scores)
mlvl_labels.append(labels)
if with_score_factors:
mlvl_score_factors.append(score_factor)
bbox_pred = torch.cat(mlvl_bbox_preds)
priors = cat_boxes(mlvl_valid_points)
bboxes = self.bbox_coder.decode(priors, bbox_pred, max_shape=img_shape)
results = InstanceData()
results.bboxes = bboxes
results.scores = torch.cat(mlvl_scores)
results.labels = torch.cat(mlvl_labels)
results.param_preds = torch.cat(mlvl_param_preds)
results.points = torch.cat(mlvl_valid_points)
results.strides = torch.cat(mlvl_valid_strides)
if with_score_factors:
results.score_factors = torch.cat(mlvl_score_factors)
return self._bbox_post_process(
results=results,
cfg=cfg,
rescale=rescale,
with_nms=with_nms,
img_meta=img_meta)
class MaskFeatModule(BaseModule):
"""CondInst mask feature map branch used in \
https://arxiv.org/abs/1904.02689.
Args:
in_channels (int): Number of channels in the input feature map.
feat_channels (int): Number of hidden channels of the mask feature
map branch.
start_level (int): The starting feature map level from RPN that
will be used to predict the mask feature map.
end_level (int): The ending feature map level from rpn that
will be used to predict the mask feature map.
out_channels (int): Number of output channels of the mask feature
map branch. This is the channel count of the mask
feature map that to be dynamically convolved with the predicted
kernel.
mask_stride (int): Downsample factor of the mask feature map output.
Defaults to 4.
num_stacked_convs (int): Number of convs in mask feature branch.
conv_cfg (dict): Config dict for convolution layer. Default: None.
norm_cfg (dict): Config dict for normalization layer. Default: None.
init_cfg (dict or list[dict], optional): Initialization config dict.
"""
def __init__(self,
in_channels: int,
feat_channels: int,
start_level: int,
end_level: int,
out_channels: int,
mask_stride: int = 4,
num_stacked_convs: int = 4,
conv_cfg: OptConfigType = None,
norm_cfg: OptConfigType = None,
init_cfg: MultiConfig = [
dict(type='Normal', layer='Conv2d', std=0.01)
],
**kwargs) -> None:
super().__init__(init_cfg=init_cfg)
self.in_channels = in_channels
self.feat_channels = feat_channels
self.start_level = start_level
self.end_level = end_level
self.mask_stride = mask_stride
self.num_stacked_convs = num_stacked_convs
assert start_level >= 0 and end_level >= start_level
self.out_channels = out_channels
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self._init_layers()
def _init_layers(self) -> None:
"""Initialize layers of the head."""
self.convs_all_levels = nn.ModuleList()
for i in range(self.start_level, self.end_level + 1):
convs_per_level = nn.Sequential()
convs_per_level.add_module(
f'conv{i}',
ConvModule(
self.in_channels,
self.feat_channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
inplace=False,
bias=False))
self.convs_all_levels.append(convs_per_level)
conv_branch = []
for _ in range(self.num_stacked_convs):
conv_branch.append(
ConvModule(
self.feat_channels,
self.feat_channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
bias=False))
self.conv_branch = nn.Sequential(*conv_branch)
self.conv_pred = nn.Conv2d(
self.feat_channels, self.out_channels, 1, stride=1)
def init_weights(self) -> None:
"""Initialize weights of the head."""
super().init_weights()
kaiming_init(self.convs_all_levels, a=1, distribution='uniform')
kaiming_init(self.conv_branch, a=1, distribution='uniform')
kaiming_init(self.conv_pred, a=1, distribution='uniform')
def forward(self, x: Tuple[Tensor]) -> Tensor:
"""Forward features from the upstream network.
Args:
x (tuple[Tensor]): Features from the upstream network, each is
a 4D-tensor.
Returns:
Tensor: The predicted mask feature map.
"""
inputs = x[self.start_level:self.end_level + 1]
assert len(inputs) == (self.end_level - self.start_level + 1)
feature_add_all_level = self.convs_all_levels[0](inputs[0])
target_h, target_w = feature_add_all_level.size()[2:]
for i in range(1, len(inputs)):
input_p = inputs[i]
x_p = self.convs_all_levels[i](input_p)
h, w = x_p.size()[2:]
factor_h = target_h // h
factor_w = target_w // w
assert factor_h == factor_w
feature_per_level = aligned_bilinear(x_p, factor_h)
feature_add_all_level = feature_add_all_level + \
feature_per_level
feature_add_all_level = self.conv_branch(feature_add_all_level)
feature_pred = self.conv_pred(feature_add_all_level)
return feature_pred
@MODELS.register_module()
class CondInstMaskHead(BaseMaskHead):
"""CondInst mask head used in https://arxiv.org/abs/1904.02689.
This head outputs the mask for CondInst.
Args:
mask_feature_head (dict): Config of CondInstMaskFeatHead.
num_layers (int): Number of dynamic conv layers.
feat_channels (int): Number of channels in the dynamic conv.
mask_out_stride (int): The stride of the mask feat.
size_of_interest (int): The size of the region used in rel coord.
max_masks_to_train (int): Maximum number of masks to train for
each image.
loss_segm (:obj:`ConfigDict` or dict, optional): Config of
segmentation loss.
train_cfg (:obj:`ConfigDict` or dict, optional): Training config
of head.
test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of
head.
"""
def __init__(self,
mask_feature_head: ConfigType,
num_layers: int = 3,
feat_channels: int = 8,
mask_out_stride: int = 4,
size_of_interest: int = 8,
max_masks_to_train: int = -1,
topk_masks_per_img: int = -1,
loss_mask: ConfigType = None,
train_cfg: OptConfigType = None,
test_cfg: OptConfigType = None) -> None:
super().__init__()
self.mask_feature_head = MaskFeatModule(**mask_feature_head)
self.mask_feat_stride = self.mask_feature_head.mask_stride
self.in_channels = self.mask_feature_head.out_channels
self.num_layers = num_layers
self.feat_channels = feat_channels
self.size_of_interest = size_of_interest
self.mask_out_stride = mask_out_stride
self.max_masks_to_train = max_masks_to_train
self.topk_masks_per_img = topk_masks_per_img
self.prior_generator = MlvlPointGenerator([self.mask_feat_stride])
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.loss_mask = MODELS.build(loss_mask)
self._init_layers()
def _init_layers(self) -> None:
"""Initialize layers of the head."""
weight_nums, bias_nums = [], []
for i in range(self.num_layers):
if i == 0:
weight_nums.append((self.in_channels + 2) * self.feat_channels)
bias_nums.append(self.feat_channels)
elif i == self.num_layers - 1:
weight_nums.append(self.feat_channels * 1)
bias_nums.append(1)
else:
weight_nums.append(self.feat_channels * self.feat_channels)
bias_nums.append(self.feat_channels)
self.weight_nums = weight_nums
self.bias_nums = bias_nums
self.num_params = sum(weight_nums) + sum(bias_nums)
def parse_dynamic_params(
self, params: Tensor) -> Tuple[List[Tensor], List[Tensor]]:
"""parse the dynamic params for dynamic conv."""
num_insts = params.size(0)
params_splits = list(
torch.split_with_sizes(
params, self.weight_nums + self.bias_nums, dim=1))
weight_splits = params_splits[:self.num_layers]
bias_splits = params_splits[self.num_layers:]
for i in range(self.num_layers):
if i < self.num_layers - 1:
weight_splits[i] = weight_splits[i].reshape(
num_insts * self.in_channels, -1, 1, 1)
bias_splits[i] = bias_splits[i].reshape(num_insts *
self.in_channels)
else:
# out_channels x in_channels x 1 x 1
weight_splits[i] = weight_splits[i].reshape(
num_insts * 1, -1, 1, 1)
bias_splits[i] = bias_splits[i].reshape(num_insts)
return weight_splits, bias_splits
def dynamic_conv_forward(self, features: Tensor, weights: List[Tensor],
biases: List[Tensor], num_insts: int) -> Tensor:
"""dynamic forward, each layer follow a relu."""
n_layers = len(weights)
x = features
for i, (w, b) in enumerate(zip(weights, biases)):
x = F.conv2d(x, w, bias=b, stride=1, padding=0, groups=num_insts)
if i < n_layers - 1:
x = F.relu(x)
return x
def forward(self, x: tuple, positive_infos: InstanceList) -> tuple:
"""Forward feature from the upstream network to get prototypes and
linearly combine the prototypes, using masks coefficients, into
instance masks. Finally, crop the instance masks with given bboxes.
Args:
x (Tuple[Tensor]): Feature from the upstream network, which is
a 4D-tensor.
positive_infos (List[:obj:``InstanceData``]): Positive information
that calculate from detect head.
Returns:
tuple: Predicted instance segmentation masks
"""
mask_feats = self.mask_feature_head(x)
return multi_apply(self.forward_single, mask_feats, positive_infos)
def forward_single(self, mask_feat: Tensor,
positive_info: InstanceData) -> Tensor:
"""Forward features of a each image."""
pos_param_preds = positive_info.get('param_preds')
pos_points = positive_info.get('points')
pos_strides = positive_info.get('strides')
num_inst = pos_param_preds.shape[0]
mask_feat = mask_feat[None].repeat(num_inst, 1, 1, 1)
_, _, H, W = mask_feat.size()
if num_inst == 0:
return (pos_param_preds.new_zeros((0, 1, H, W)), )
locations = self.prior_generator.single_level_grid_priors(
mask_feat.size()[2:], 0, device=mask_feat.device)
rel_coords = relative_coordinate_maps(locations, pos_points,
pos_strides,
self.size_of_interest,
mask_feat.size()[2:])
mask_head_inputs = torch.cat([rel_coords, mask_feat], dim=1)
mask_head_inputs = mask_head_inputs.reshape(1, -1, H, W)
weights, biases = self.parse_dynamic_params(pos_param_preds)
mask_preds = self.dynamic_conv_forward(mask_head_inputs, weights,
biases, num_inst)
mask_preds = mask_preds.reshape(-1, H, W)
mask_preds = aligned_bilinear(
mask_preds.unsqueeze(0),
int(self.mask_feat_stride / self.mask_out_stride)).squeeze(0)
return (mask_preds, )
def loss_by_feat(self, mask_preds: List[Tensor],
batch_gt_instances: InstanceList,
batch_img_metas: List[dict], positive_infos: InstanceList,
**kwargs) -> dict:
"""Calculate the loss based on the features extracted by the mask head.
Args:
mask_preds (list[Tensor]): List of predicted masks, each has
shape (num_classes, H, W).
batch_gt_instances (list[:obj:`InstanceData`]): Batch of
gt_instance. It usually includes ``bboxes``, ``masks``,
and ``labels`` attributes.
batch_img_metas (list[dict]): Meta information of multiple images.
positive_infos (List[:obj:``InstanceData``]): Information of
positive samples of each image that are assigned in detection
head.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
assert positive_infos is not None, \
'positive_infos should not be None in `CondInstMaskHead`'
losses = dict()
loss_mask = 0.
num_imgs = len(mask_preds)
total_pos = 0
for idx in range(num_imgs):
(mask_pred, pos_mask_targets, num_pos) = \
self._get_targets_single(
mask_preds[idx], batch_gt_instances[idx],
positive_infos[idx])
# mask loss
total_pos += num_pos
if num_pos == 0 or pos_mask_targets is None:
loss = mask_pred.new_zeros(1).mean()
else:
loss = self.loss_mask(
mask_pred, pos_mask_targets,
reduction_override='none').sum()
loss_mask += loss
if total_pos == 0:
total_pos += 1 # avoid nan
loss_mask = loss_mask / total_pos
losses.update(loss_mask=loss_mask)
return losses
def _get_targets_single(self, mask_preds: Tensor,
gt_instances: InstanceData,
positive_info: InstanceData):
"""Compute targets for predictions of single image.
Args:
mask_preds (Tensor): Predicted prototypes with shape
(num_classes, H, W).
gt_instances (:obj:`InstanceData`): Ground truth of instance
annotations. It should includes ``bboxes``, ``labels``,
and ``masks`` attributes.
positive_info (:obj:`InstanceData`): Information of positive
samples that are assigned in detection head. It usually
contains following keys.
- pos_assigned_gt_inds (Tensor): Assigner GT indexes of
positive proposals, has shape (num_pos, )
- pos_inds (Tensor): Positive index of image, has
shape (num_pos, ).
- param_pred (Tensor): Positive param preditions
with shape (num_pos, num_params).
Returns:
tuple: Usually returns a tuple containing learning targets.
- mask_preds (Tensor): Positive predicted mask with shape
(num_pos, mask_h, mask_w).
- pos_mask_targets (Tensor): Positive mask targets with shape
(num_pos, mask_h, mask_w).
- num_pos (int): Positive numbers.
"""
gt_bboxes = gt_instances.bboxes
device = gt_bboxes.device
gt_masks = gt_instances.masks.to_tensor(
dtype=torch.bool, device=device).float()
# process with mask targets
pos_assigned_gt_inds = positive_info.get('pos_assigned_gt_inds')
scores = positive_info.get('scores')
centernesses = positive_info.get('centernesses')
num_pos = pos_assigned_gt_inds.size(0)
if gt_masks.size(0) == 0 or num_pos == 0:
return mask_preds, None, 0
# Since we're producing (near) full image masks,
# it'd take too much vram to backprop on every single mask.
# Thus we select only a subset.
if (self.max_masks_to_train != -1) and \
(num_pos > self.max_masks_to_train):
perm = torch.randperm(num_pos)
select = perm[:self.max_masks_to_train]
mask_preds = mask_preds[select]
pos_assigned_gt_inds = pos_assigned_gt_inds[select]
num_pos = self.max_masks_to_train
elif self.topk_masks_per_img != -1:
unique_gt_inds = pos_assigned_gt_inds.unique()
num_inst_per_gt = max(
int(self.topk_masks_per_img / len(unique_gt_inds)), 1)
keep_mask_preds = []
keep_pos_assigned_gt_inds = []
for gt_ind in unique_gt_inds:
per_inst_pos_inds = (pos_assigned_gt_inds == gt_ind)
mask_preds_per_inst = mask_preds[per_inst_pos_inds]
gt_inds_per_inst = pos_assigned_gt_inds[per_inst_pos_inds]
if sum(per_inst_pos_inds) > num_inst_per_gt:
per_inst_scores = scores[per_inst_pos_inds].sigmoid().max(
dim=1)[0]
per_inst_centerness = centernesses[
per_inst_pos_inds].sigmoid().reshape(-1, )
select = (per_inst_scores * per_inst_centerness).topk(
k=num_inst_per_gt, dim=0)[1]
mask_preds_per_inst = mask_preds_per_inst[select]
gt_inds_per_inst = gt_inds_per_inst[select]
keep_mask_preds.append(mask_preds_per_inst)
keep_pos_assigned_gt_inds.append(gt_inds_per_inst)
mask_preds = torch.cat(keep_mask_preds)
pos_assigned_gt_inds = torch.cat(keep_pos_assigned_gt_inds)
num_pos = pos_assigned_gt_inds.size(0)
# Follow the origin implement
start = int(self.mask_out_stride // 2)
gt_masks = gt_masks[:, start::self.mask_out_stride,
start::self.mask_out_stride]
gt_masks = gt_masks.gt(0.5).float()
pos_mask_targets = gt_masks[pos_assigned_gt_inds]
return (mask_preds, pos_mask_targets, num_pos)
def predict_by_feat(self,
mask_preds: List[Tensor],
results_list: InstanceList,
batch_img_metas: List[dict],
rescale: bool = True,
**kwargs) -> InstanceList:
"""Transform a batch of output features extracted from the head into
mask results.
Args:
mask_preds (list[Tensor]): Predicted prototypes with shape
(num_classes, H, W).
results_list (List[:obj:``InstanceData``]): BBoxHead results.
batch_img_metas (list[dict]): Meta information of all images.
rescale (bool, optional): Whether to rescale the results.
Defaults to False.
Returns:
list[:obj:`InstanceData`]: Processed results of multiple
images.Each :obj:`InstanceData` usually contains
following keys.
- scores (Tensor): Classification scores, has shape
(num_instance,).
- labels (Tensor): Has shape (num_instances,).
- masks (Tensor): Processed mask results, has
shape (num_instances, h, w).
"""
assert len(mask_preds) == len(results_list) == len(batch_img_metas)
for img_id in range(len(batch_img_metas)):
img_meta = batch_img_metas[img_id]
results = results_list[img_id]
bboxes = results.bboxes
mask_pred = mask_preds[img_id]
if bboxes.shape[0] == 0 or mask_pred.shape[0] == 0:
results_list[img_id] = empty_instances(
[img_meta],
bboxes.device,
task_type='mask',
instance_results=[results])[0]
else:
im_mask = self._predict_by_feat_single(
mask_preds=mask_pred,
bboxes=bboxes,
img_meta=img_meta,
rescale=rescale)
results.masks = im_mask
return results_list
def _predict_by_feat_single(self,
mask_preds: Tensor,
bboxes: Tensor,
img_meta: dict,
rescale: bool,
cfg: OptConfigType = None):
"""Transform a single image's features extracted from the head into
mask results.
Args:
mask_preds (Tensor): Predicted prototypes, has shape [H, W, N].
img_meta (dict): Meta information of each image, e.g.,
image size, scaling factor, etc.
rescale (bool): If rescale is False, then returned masks will
fit the scale of imgs[0].
cfg (dict, optional): Config used in test phase.
Defaults to None.
Returns:
:obj:`InstanceData`: Processed results of single image.
it usually contains following keys.
- scores (Tensor): Classification scores, has shape
(num_instance,).
- labels (Tensor): Has shape (num_instances,).
- masks (Tensor): Processed mask results, has
shape (num_instances, h, w).
"""
cfg = self.test_cfg if cfg is None else cfg
scale_factor = bboxes.new_tensor(img_meta['scale_factor']).repeat(
(1, 2))
img_h, img_w = img_meta['img_shape'][:2]
ori_h, ori_w = img_meta['ori_shape'][:2]
mask_preds = mask_preds.sigmoid().unsqueeze(0)
mask_preds = aligned_bilinear(mask_preds, self.mask_out_stride)
mask_preds = mask_preds[:, :, :img_h, :img_w]
if rescale: # in-placed rescale the bboxes
scale_factor = bboxes.new_tensor(img_meta['scale_factor']).repeat(
(1, 2))
bboxes /= scale_factor
masks = F.interpolate(
mask_preds, (ori_h, ori_w),
mode='bilinear',
align_corners=False).squeeze(0) > cfg.mask_thr
else:
masks = mask_preds.squeeze(0) > cfg.mask_thr
return masks