KyanChen's picture
init
f549064
raw
history blame
No virus
52.1 kB
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Tuple
import mmcv
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmengine.structures import InstanceData
from torch import Tensor
from mmdet.models.utils.misc import floordiv
from mmdet.registry import MODELS
from mmdet.utils import ConfigType, InstanceList, MultiConfig, OptConfigType
from ..layers import mask_matrix_nms
from ..utils import center_of_mass, generate_coordinate, multi_apply
from .base_mask_head import BaseMaskHead
@MODELS.register_module()
class SOLOHead(BaseMaskHead):
"""SOLO mask head used in `SOLO: Segmenting Objects by Locations.
<https://arxiv.org/abs/1912.04488>`_
Args:
num_classes (int): Number of categories excluding the background
category.
in_channels (int): Number of channels in the input feature map.
feat_channels (int): Number of hidden channels. Used in child classes.
Defaults to 256.
stacked_convs (int): Number of stacking convs of the head.
Defaults to 4.
strides (tuple): Downsample factor of each feature map.
scale_ranges (tuple[tuple[int, int]]): Area range of multiple
level masks, in the format [(min1, max1), (min2, max2), ...].
A range of (16, 64) means the area range between (16, 64).
pos_scale (float): Constant scale factor to control the center region.
num_grids (list[int]): Divided image into a uniform grids, each
feature map has a different grid value. The number of output
channels is grid ** 2. Defaults to [40, 36, 24, 16, 12].
cls_down_index (int): The index of downsample operation in
classification branch. Defaults to 0.
loss_mask (dict): Config of mask loss.
loss_cls (dict): Config of classification loss.
norm_cfg (dict): Dictionary to construct and config norm layer.
Defaults to norm_cfg=dict(type='GN', num_groups=32,
requires_grad=True).
train_cfg (dict): Training config of head.
test_cfg (dict): Testing config of head.
init_cfg (dict or list[dict], optional): Initialization config dict.
"""
def __init__(
self,
num_classes: int,
in_channels: int,
feat_channels: int = 256,
stacked_convs: int = 4,
strides: tuple = (4, 8, 16, 32, 64),
scale_ranges: tuple = ((8, 32), (16, 64), (32, 128), (64, 256), (128,
512)),
pos_scale: float = 0.2,
num_grids: list = [40, 36, 24, 16, 12],
cls_down_index: int = 0,
loss_mask: ConfigType = dict(
type='DiceLoss', use_sigmoid=True, loss_weight=3.0),
loss_cls: ConfigType = dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
norm_cfg: ConfigType = dict(
type='GN', num_groups=32, requires_grad=True),
train_cfg: OptConfigType = None,
test_cfg: OptConfigType = None,
init_cfg: MultiConfig = [
dict(type='Normal', layer='Conv2d', std=0.01),
dict(
type='Normal',
std=0.01,
bias_prob=0.01,
override=dict(name='conv_mask_list')),
dict(
type='Normal',
std=0.01,
bias_prob=0.01,
override=dict(name='conv_cls'))
]
) -> None:
super().__init__(init_cfg=init_cfg)
self.num_classes = num_classes
self.cls_out_channels = self.num_classes
self.in_channels = in_channels
self.feat_channels = feat_channels
self.stacked_convs = stacked_convs
self.strides = strides
self.num_grids = num_grids
# number of FPN feats
self.num_levels = len(strides)
assert self.num_levels == len(scale_ranges) == len(num_grids)
self.scale_ranges = scale_ranges
self.pos_scale = pos_scale
self.cls_down_index = cls_down_index
self.loss_cls = MODELS.build(loss_cls)
self.loss_mask = MODELS.build(loss_mask)
self.norm_cfg = norm_cfg
self.init_cfg = init_cfg
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self._init_layers()
def _init_layers(self) -> None:
"""Initialize layers of the head."""
self.mask_convs = nn.ModuleList()
self.cls_convs = nn.ModuleList()
for i in range(self.stacked_convs):
chn = self.in_channels + 2 if i == 0 else self.feat_channels
self.mask_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
norm_cfg=self.norm_cfg))
chn = self.in_channels if i == 0 else self.feat_channels
self.cls_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
norm_cfg=self.norm_cfg))
self.conv_mask_list = nn.ModuleList()
for num_grid in self.num_grids:
self.conv_mask_list.append(
nn.Conv2d(self.feat_channels, num_grid**2, 1))
self.conv_cls = nn.Conv2d(
self.feat_channels, self.cls_out_channels, 3, padding=1)
def resize_feats(self, x: Tuple[Tensor]) -> List[Tensor]:
"""Downsample the first feat and upsample last feat in feats.
Args:
x (tuple[Tensor]): Features from the upstream network, each is
a 4D-tensor.
Returns:
list[Tensor]: Features after resizing, each is a 4D-tensor.
"""
out = []
for i in range(len(x)):
if i == 0:
out.append(
F.interpolate(x[0], scale_factor=0.5, mode='bilinear'))
elif i == len(x) - 1:
out.append(
F.interpolate(
x[i], size=x[i - 1].shape[-2:], mode='bilinear'))
else:
out.append(x[i])
return out
def forward(self, x: Tuple[Tensor]) -> tuple:
"""Forward features from the upstream network.
Args:
x (tuple[Tensor]): Features from the upstream network, each is
a 4D-tensor.
Returns:
tuple: A tuple of classification scores and mask prediction.
- mlvl_mask_preds (list[Tensor]): Multi-level mask prediction.
Each element in the list has shape
(batch_size, num_grids**2 ,h ,w).
- mlvl_cls_preds (list[Tensor]): Multi-level scores.
Each element in the list has shape
(batch_size, num_classes, num_grids ,num_grids).
"""
assert len(x) == self.num_levels
feats = self.resize_feats(x)
mlvl_mask_preds = []
mlvl_cls_preds = []
for i in range(self.num_levels):
x = feats[i]
mask_feat = x
cls_feat = x
# generate and concat the coordinate
coord_feat = generate_coordinate(mask_feat.size(),
mask_feat.device)
mask_feat = torch.cat([mask_feat, coord_feat], 1)
for mask_layer in (self.mask_convs):
mask_feat = mask_layer(mask_feat)
mask_feat = F.interpolate(
mask_feat, scale_factor=2, mode='bilinear')
mask_preds = self.conv_mask_list[i](mask_feat)
# cls branch
for j, cls_layer in enumerate(self.cls_convs):
if j == self.cls_down_index:
num_grid = self.num_grids[i]
cls_feat = F.interpolate(
cls_feat, size=num_grid, mode='bilinear')
cls_feat = cls_layer(cls_feat)
cls_pred = self.conv_cls(cls_feat)
if not self.training:
feat_wh = feats[0].size()[-2:]
upsampled_size = (feat_wh[0] * 2, feat_wh[1] * 2)
mask_preds = F.interpolate(
mask_preds.sigmoid(), size=upsampled_size, mode='bilinear')
cls_pred = cls_pred.sigmoid()
# get local maximum
local_max = F.max_pool2d(cls_pred, 2, stride=1, padding=1)
keep_mask = local_max[:, :, :-1, :-1] == cls_pred
cls_pred = cls_pred * keep_mask
mlvl_mask_preds.append(mask_preds)
mlvl_cls_preds.append(cls_pred)
return mlvl_mask_preds, mlvl_cls_preds
def loss_by_feat(self, mlvl_mask_preds: List[Tensor],
mlvl_cls_preds: List[Tensor],
batch_gt_instances: InstanceList,
batch_img_metas: List[dict], **kwargs) -> dict:
"""Calculate the loss based on the features extracted by the mask head.
Args:
mlvl_mask_preds (list[Tensor]): Multi-level mask prediction.
Each element in the list has shape
(batch_size, num_grids**2 ,h ,w).
batch_gt_instances (list[:obj:`InstanceData`]): Batch of
gt_instance. It usually includes ``bboxes``, ``masks``,
and ``labels`` attributes.
batch_img_metas (list[dict]): Meta information of multiple images.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
num_levels = self.num_levels
num_imgs = len(batch_img_metas)
featmap_sizes = [featmap.size()[-2:] for featmap in mlvl_mask_preds]
# `BoolTensor` in `pos_masks` represent
# whether the corresponding point is
# positive
pos_mask_targets, labels, pos_masks = multi_apply(
self._get_targets_single,
batch_gt_instances,
featmap_sizes=featmap_sizes)
# change from the outside list meaning multi images
# to the outside list meaning multi levels
mlvl_pos_mask_targets = [[] for _ in range(num_levels)]
mlvl_pos_mask_preds = [[] for _ in range(num_levels)]
mlvl_pos_masks = [[] for _ in range(num_levels)]
mlvl_labels = [[] for _ in range(num_levels)]
for img_id in range(num_imgs):
assert num_levels == len(pos_mask_targets[img_id])
for lvl in range(num_levels):
mlvl_pos_mask_targets[lvl].append(
pos_mask_targets[img_id][lvl])
mlvl_pos_mask_preds[lvl].append(
mlvl_mask_preds[lvl][img_id, pos_masks[img_id][lvl], ...])
mlvl_pos_masks[lvl].append(pos_masks[img_id][lvl].flatten())
mlvl_labels[lvl].append(labels[img_id][lvl].flatten())
# cat multiple image
temp_mlvl_cls_preds = []
for lvl in range(num_levels):
mlvl_pos_mask_targets[lvl] = torch.cat(
mlvl_pos_mask_targets[lvl], dim=0)
mlvl_pos_mask_preds[lvl] = torch.cat(
mlvl_pos_mask_preds[lvl], dim=0)
mlvl_pos_masks[lvl] = torch.cat(mlvl_pos_masks[lvl], dim=0)
mlvl_labels[lvl] = torch.cat(mlvl_labels[lvl], dim=0)
temp_mlvl_cls_preds.append(mlvl_cls_preds[lvl].permute(
0, 2, 3, 1).reshape(-1, self.cls_out_channels))
num_pos = sum(item.sum() for item in mlvl_pos_masks)
# dice loss
loss_mask = []
for pred, target in zip(mlvl_pos_mask_preds, mlvl_pos_mask_targets):
if pred.size()[0] == 0:
loss_mask.append(pred.sum().unsqueeze(0))
continue
loss_mask.append(
self.loss_mask(pred, target, reduction_override='none'))
if num_pos > 0:
loss_mask = torch.cat(loss_mask).sum() / num_pos
else:
loss_mask = torch.cat(loss_mask).mean()
flatten_labels = torch.cat(mlvl_labels)
flatten_cls_preds = torch.cat(temp_mlvl_cls_preds)
loss_cls = self.loss_cls(
flatten_cls_preds, flatten_labels, avg_factor=num_pos + 1)
return dict(loss_mask=loss_mask, loss_cls=loss_cls)
def _get_targets_single(self,
gt_instances: InstanceData,
featmap_sizes: Optional[list] = None) -> tuple:
"""Compute targets for predictions of single image.
Args:
gt_instances (:obj:`InstanceData`): Ground truth of instance
annotations. It should includes ``bboxes``, ``labels``,
and ``masks`` attributes.
featmap_sizes (list[:obj:`torch.size`]): Size of each
feature map from feature pyramid, each element
means (feat_h, feat_w). Defaults to None.
Returns:
Tuple: Usually returns a tuple containing targets for predictions.
- mlvl_pos_mask_targets (list[Tensor]): Each element represent
the binary mask targets for positive points in this
level, has shape (num_pos, out_h, out_w).
- mlvl_labels (list[Tensor]): Each element is
classification labels for all
points in this level, has shape
(num_grid, num_grid).
- mlvl_pos_masks (list[Tensor]): Each element is
a `BoolTensor` to represent whether the
corresponding point in single level
is positive, has shape (num_grid **2).
"""
gt_labels = gt_instances.labels
device = gt_labels.device
gt_bboxes = gt_instances.bboxes
gt_areas = torch.sqrt((gt_bboxes[:, 2] - gt_bboxes[:, 0]) *
(gt_bboxes[:, 3] - gt_bboxes[:, 1]))
gt_masks = gt_instances.masks.to_tensor(
dtype=torch.bool, device=device)
mlvl_pos_mask_targets = []
mlvl_labels = []
mlvl_pos_masks = []
for (lower_bound, upper_bound), stride, featmap_size, num_grid \
in zip(self.scale_ranges, self.strides,
featmap_sizes, self.num_grids):
mask_target = torch.zeros(
[num_grid**2, featmap_size[0], featmap_size[1]],
dtype=torch.uint8,
device=device)
# FG cat_id: [0, num_classes -1], BG cat_id: num_classes
labels = torch.zeros([num_grid, num_grid],
dtype=torch.int64,
device=device) + self.num_classes
pos_mask = torch.zeros([num_grid**2],
dtype=torch.bool,
device=device)
gt_inds = ((gt_areas >= lower_bound) &
(gt_areas <= upper_bound)).nonzero().flatten()
if len(gt_inds) == 0:
mlvl_pos_mask_targets.append(
mask_target.new_zeros(0, featmap_size[0], featmap_size[1]))
mlvl_labels.append(labels)
mlvl_pos_masks.append(pos_mask)
continue
hit_gt_bboxes = gt_bboxes[gt_inds]
hit_gt_labels = gt_labels[gt_inds]
hit_gt_masks = gt_masks[gt_inds, ...]
pos_w_ranges = 0.5 * (hit_gt_bboxes[:, 2] -
hit_gt_bboxes[:, 0]) * self.pos_scale
pos_h_ranges = 0.5 * (hit_gt_bboxes[:, 3] -
hit_gt_bboxes[:, 1]) * self.pos_scale
# Make sure hit_gt_masks has a value
valid_mask_flags = hit_gt_masks.sum(dim=-1).sum(dim=-1) > 0
output_stride = stride / 2
for gt_mask, gt_label, pos_h_range, pos_w_range, \
valid_mask_flag in \
zip(hit_gt_masks, hit_gt_labels, pos_h_ranges,
pos_w_ranges, valid_mask_flags):
if not valid_mask_flag:
continue
upsampled_size = (featmap_sizes[0][0] * 4,
featmap_sizes[0][1] * 4)
center_h, center_w = center_of_mass(gt_mask)
coord_w = int(
floordiv((center_w / upsampled_size[1]), (1. / num_grid),
rounding_mode='trunc'))
coord_h = int(
floordiv((center_h / upsampled_size[0]), (1. / num_grid),
rounding_mode='trunc'))
# left, top, right, down
top_box = max(
0,
int(
floordiv(
(center_h - pos_h_range) / upsampled_size[0],
(1. / num_grid),
rounding_mode='trunc')))
down_box = min(
num_grid - 1,
int(
floordiv(
(center_h + pos_h_range) / upsampled_size[0],
(1. / num_grid),
rounding_mode='trunc')))
left_box = max(
0,
int(
floordiv(
(center_w - pos_w_range) / upsampled_size[1],
(1. / num_grid),
rounding_mode='trunc')))
right_box = min(
num_grid - 1,
int(
floordiv(
(center_w + pos_w_range) / upsampled_size[1],
(1. / num_grid),
rounding_mode='trunc')))
top = max(top_box, coord_h - 1)
down = min(down_box, coord_h + 1)
left = max(coord_w - 1, left_box)
right = min(right_box, coord_w + 1)
labels[top:(down + 1), left:(right + 1)] = gt_label
# ins
gt_mask = np.uint8(gt_mask.cpu().numpy())
# Follow the original implementation, F.interpolate is
# different from cv2 and opencv
gt_mask = mmcv.imrescale(gt_mask, scale=1. / output_stride)
gt_mask = torch.from_numpy(gt_mask).to(device=device)
for i in range(top, down + 1):
for j in range(left, right + 1):
index = int(i * num_grid + j)
mask_target[index, :gt_mask.shape[0], :gt_mask.
shape[1]] = gt_mask
pos_mask[index] = True
mlvl_pos_mask_targets.append(mask_target[pos_mask])
mlvl_labels.append(labels)
mlvl_pos_masks.append(pos_mask)
return mlvl_pos_mask_targets, mlvl_labels, mlvl_pos_masks
def predict_by_feat(self, mlvl_mask_preds: List[Tensor],
mlvl_cls_scores: List[Tensor],
batch_img_metas: List[dict], **kwargs) -> InstanceList:
"""Transform a batch of output features extracted from the head into
mask results.
Args:
mlvl_mask_preds (list[Tensor]): Multi-level mask prediction.
Each element in the list has shape
(batch_size, num_grids**2 ,h ,w).
mlvl_cls_scores (list[Tensor]): Multi-level scores. Each element
in the list has shape
(batch_size, num_classes, num_grids ,num_grids).
batch_img_metas (list[dict]): Meta information of all images.
Returns:
list[:obj:`InstanceData`]: Processed results of multiple
images.Each :obj:`InstanceData` usually contains
following keys.
- scores (Tensor): Classification scores, has shape
(num_instance,).
- labels (Tensor): Has shape (num_instances,).
- masks (Tensor): Processed mask results, has
shape (num_instances, h, w).
"""
mlvl_cls_scores = [
item.permute(0, 2, 3, 1) for item in mlvl_cls_scores
]
assert len(mlvl_mask_preds) == len(mlvl_cls_scores)
num_levels = len(mlvl_cls_scores)
results_list = []
for img_id in range(len(batch_img_metas)):
cls_pred_list = [
mlvl_cls_scores[lvl][img_id].view(-1, self.cls_out_channels)
for lvl in range(num_levels)
]
mask_pred_list = [
mlvl_mask_preds[lvl][img_id] for lvl in range(num_levels)
]
cls_pred_list = torch.cat(cls_pred_list, dim=0)
mask_pred_list = torch.cat(mask_pred_list, dim=0)
img_meta = batch_img_metas[img_id]
results = self._predict_by_feat_single(
cls_pred_list, mask_pred_list, img_meta=img_meta)
results_list.append(results)
return results_list
def _predict_by_feat_single(self,
cls_scores: Tensor,
mask_preds: Tensor,
img_meta: dict,
cfg: OptConfigType = None) -> InstanceData:
"""Transform a single image's features extracted from the head into
mask results.
Args:
cls_scores (Tensor): Classification score of all points
in single image, has shape (num_points, num_classes).
mask_preds (Tensor): Mask prediction of all points in
single image, has shape (num_points, feat_h, feat_w).
img_meta (dict): Meta information of corresponding image.
cfg (dict, optional): Config used in test phase.
Defaults to None.
Returns:
:obj:`InstanceData`: Processed results of single image.
it usually contains following keys.
- scores (Tensor): Classification scores, has shape
(num_instance,).
- labels (Tensor): Has shape (num_instances,).
- masks (Tensor): Processed mask results, has
shape (num_instances, h, w).
"""
def empty_results(cls_scores, ori_shape):
"""Generate a empty results."""
results = InstanceData()
results.scores = cls_scores.new_ones(0)
results.masks = cls_scores.new_zeros(0, *ori_shape)
results.labels = cls_scores.new_ones(0)
results.bboxes = cls_scores.new_zeros(0, 4)
return results
cfg = self.test_cfg if cfg is None else cfg
assert len(cls_scores) == len(mask_preds)
featmap_size = mask_preds.size()[-2:]
h, w = img_meta['img_shape'][:2]
upsampled_size = (featmap_size[0] * 4, featmap_size[1] * 4)
score_mask = (cls_scores > cfg.score_thr)
cls_scores = cls_scores[score_mask]
if len(cls_scores) == 0:
return empty_results(cls_scores, img_meta['ori_shape'][:2])
inds = score_mask.nonzero()
cls_labels = inds[:, 1]
# Filter the mask mask with an area is smaller than
# stride of corresponding feature level
lvl_interval = cls_labels.new_tensor(self.num_grids).pow(2).cumsum(0)
strides = cls_scores.new_ones(lvl_interval[-1])
strides[:lvl_interval[0]] *= self.strides[0]
for lvl in range(1, self.num_levels):
strides[lvl_interval[lvl -
1]:lvl_interval[lvl]] *= self.strides[lvl]
strides = strides[inds[:, 0]]
mask_preds = mask_preds[inds[:, 0]]
masks = mask_preds > cfg.mask_thr
sum_masks = masks.sum((1, 2)).float()
keep = sum_masks > strides
if keep.sum() == 0:
return empty_results(cls_scores, img_meta['ori_shape'][:2])
masks = masks[keep]
mask_preds = mask_preds[keep]
sum_masks = sum_masks[keep]
cls_scores = cls_scores[keep]
cls_labels = cls_labels[keep]
# maskness.
mask_scores = (mask_preds * masks).sum((1, 2)) / sum_masks
cls_scores *= mask_scores
scores, labels, _, keep_inds = mask_matrix_nms(
masks,
cls_labels,
cls_scores,
mask_area=sum_masks,
nms_pre=cfg.nms_pre,
max_num=cfg.max_per_img,
kernel=cfg.kernel,
sigma=cfg.sigma,
filter_thr=cfg.filter_thr)
# mask_matrix_nms may return an empty Tensor
if len(keep_inds) == 0:
return empty_results(cls_scores, img_meta['ori_shape'][:2])
mask_preds = mask_preds[keep_inds]
mask_preds = F.interpolate(
mask_preds.unsqueeze(0), size=upsampled_size,
mode='bilinear')[:, :, :h, :w]
mask_preds = F.interpolate(
mask_preds, size=img_meta['ori_shape'][:2],
mode='bilinear').squeeze(0)
masks = mask_preds > cfg.mask_thr
results = InstanceData()
results.masks = masks
results.labels = labels
results.scores = scores
# create an empty bbox in InstanceData to avoid bugs when
# calculating metrics.
results.bboxes = results.scores.new_zeros(len(scores), 4)
return results
@MODELS.register_module()
class DecoupledSOLOHead(SOLOHead):
"""Decoupled SOLO mask head used in `SOLO: Segmenting Objects by Locations.
<https://arxiv.org/abs/1912.04488>`_
Args:
init_cfg (dict or list[dict], optional): Initialization config dict.
"""
def __init__(self,
*args,
init_cfg: MultiConfig = [
dict(type='Normal', layer='Conv2d', std=0.01),
dict(
type='Normal',
std=0.01,
bias_prob=0.01,
override=dict(name='conv_mask_list_x')),
dict(
type='Normal',
std=0.01,
bias_prob=0.01,
override=dict(name='conv_mask_list_y')),
dict(
type='Normal',
std=0.01,
bias_prob=0.01,
override=dict(name='conv_cls'))
],
**kwargs) -> None:
super().__init__(*args, init_cfg=init_cfg, **kwargs)
def _init_layers(self) -> None:
self.mask_convs_x = nn.ModuleList()
self.mask_convs_y = nn.ModuleList()
self.cls_convs = nn.ModuleList()
for i in range(self.stacked_convs):
chn = self.in_channels + 1 if i == 0 else self.feat_channels
self.mask_convs_x.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
norm_cfg=self.norm_cfg))
self.mask_convs_y.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
norm_cfg=self.norm_cfg))
chn = self.in_channels if i == 0 else self.feat_channels
self.cls_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
norm_cfg=self.norm_cfg))
self.conv_mask_list_x = nn.ModuleList()
self.conv_mask_list_y = nn.ModuleList()
for num_grid in self.num_grids:
self.conv_mask_list_x.append(
nn.Conv2d(self.feat_channels, num_grid, 3, padding=1))
self.conv_mask_list_y.append(
nn.Conv2d(self.feat_channels, num_grid, 3, padding=1))
self.conv_cls = nn.Conv2d(
self.feat_channels, self.cls_out_channels, 3, padding=1)
def forward(self, x: Tuple[Tensor]) -> Tuple:
"""Forward features from the upstream network.
Args:
x (tuple[Tensor]): Features from the upstream network, each is
a 4D-tensor.
Returns:
tuple: A tuple of classification scores and mask prediction.
- mlvl_mask_preds_x (list[Tensor]): Multi-level mask prediction
from x branch. Each element in the list has shape
(batch_size, num_grids ,h ,w).
- mlvl_mask_preds_y (list[Tensor]): Multi-level mask prediction
from y branch. Each element in the list has shape
(batch_size, num_grids ,h ,w).
- mlvl_cls_preds (list[Tensor]): Multi-level scores.
Each element in the list has shape
(batch_size, num_classes, num_grids ,num_grids).
"""
assert len(x) == self.num_levels
feats = self.resize_feats(x)
mask_preds_x = []
mask_preds_y = []
cls_preds = []
for i in range(self.num_levels):
x = feats[i]
mask_feat = x
cls_feat = x
# generate and concat the coordinate
coord_feat = generate_coordinate(mask_feat.size(),
mask_feat.device)
mask_feat_x = torch.cat([mask_feat, coord_feat[:, 0:1, ...]], 1)
mask_feat_y = torch.cat([mask_feat, coord_feat[:, 1:2, ...]], 1)
for mask_layer_x, mask_layer_y in \
zip(self.mask_convs_x, self.mask_convs_y):
mask_feat_x = mask_layer_x(mask_feat_x)
mask_feat_y = mask_layer_y(mask_feat_y)
mask_feat_x = F.interpolate(
mask_feat_x, scale_factor=2, mode='bilinear')
mask_feat_y = F.interpolate(
mask_feat_y, scale_factor=2, mode='bilinear')
mask_pred_x = self.conv_mask_list_x[i](mask_feat_x)
mask_pred_y = self.conv_mask_list_y[i](mask_feat_y)
# cls branch
for j, cls_layer in enumerate(self.cls_convs):
if j == self.cls_down_index:
num_grid = self.num_grids[i]
cls_feat = F.interpolate(
cls_feat, size=num_grid, mode='bilinear')
cls_feat = cls_layer(cls_feat)
cls_pred = self.conv_cls(cls_feat)
if not self.training:
feat_wh = feats[0].size()[-2:]
upsampled_size = (feat_wh[0] * 2, feat_wh[1] * 2)
mask_pred_x = F.interpolate(
mask_pred_x.sigmoid(),
size=upsampled_size,
mode='bilinear')
mask_pred_y = F.interpolate(
mask_pred_y.sigmoid(),
size=upsampled_size,
mode='bilinear')
cls_pred = cls_pred.sigmoid()
# get local maximum
local_max = F.max_pool2d(cls_pred, 2, stride=1, padding=1)
keep_mask = local_max[:, :, :-1, :-1] == cls_pred
cls_pred = cls_pred * keep_mask
mask_preds_x.append(mask_pred_x)
mask_preds_y.append(mask_pred_y)
cls_preds.append(cls_pred)
return mask_preds_x, mask_preds_y, cls_preds
def loss_by_feat(self, mlvl_mask_preds_x: List[Tensor],
mlvl_mask_preds_y: List[Tensor],
mlvl_cls_preds: List[Tensor],
batch_gt_instances: InstanceList,
batch_img_metas: List[dict], **kwargs) -> dict:
"""Calculate the loss based on the features extracted by the mask head.
Args:
mlvl_mask_preds_x (list[Tensor]): Multi-level mask prediction
from x branch. Each element in the list has shape
(batch_size, num_grids ,h ,w).
mlvl_mask_preds_y (list[Tensor]): Multi-level mask prediction
from y branch. Each element in the list has shape
(batch_size, num_grids ,h ,w).
mlvl_cls_preds (list[Tensor]): Multi-level scores. Each element
in the list has shape
(batch_size, num_classes, num_grids ,num_grids).
batch_gt_instances (list[:obj:`InstanceData`]): Batch of
gt_instance. It usually includes ``bboxes``, ``masks``,
and ``labels`` attributes.
batch_img_metas (list[dict]): Meta information of multiple images.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
num_levels = self.num_levels
num_imgs = len(batch_img_metas)
featmap_sizes = [featmap.size()[-2:] for featmap in mlvl_mask_preds_x]
pos_mask_targets, labels, xy_pos_indexes = multi_apply(
self._get_targets_single,
batch_gt_instances,
featmap_sizes=featmap_sizes)
# change from the outside list meaning multi images
# to the outside list meaning multi levels
mlvl_pos_mask_targets = [[] for _ in range(num_levels)]
mlvl_pos_mask_preds_x = [[] for _ in range(num_levels)]
mlvl_pos_mask_preds_y = [[] for _ in range(num_levels)]
mlvl_labels = [[] for _ in range(num_levels)]
for img_id in range(num_imgs):
for lvl in range(num_levels):
mlvl_pos_mask_targets[lvl].append(
pos_mask_targets[img_id][lvl])
mlvl_pos_mask_preds_x[lvl].append(
mlvl_mask_preds_x[lvl][img_id,
xy_pos_indexes[img_id][lvl][:, 1]])
mlvl_pos_mask_preds_y[lvl].append(
mlvl_mask_preds_y[lvl][img_id,
xy_pos_indexes[img_id][lvl][:, 0]])
mlvl_labels[lvl].append(labels[img_id][lvl].flatten())
# cat multiple image
temp_mlvl_cls_preds = []
for lvl in range(num_levels):
mlvl_pos_mask_targets[lvl] = torch.cat(
mlvl_pos_mask_targets[lvl], dim=0)
mlvl_pos_mask_preds_x[lvl] = torch.cat(
mlvl_pos_mask_preds_x[lvl], dim=0)
mlvl_pos_mask_preds_y[lvl] = torch.cat(
mlvl_pos_mask_preds_y[lvl], dim=0)
mlvl_labels[lvl] = torch.cat(mlvl_labels[lvl], dim=0)
temp_mlvl_cls_preds.append(mlvl_cls_preds[lvl].permute(
0, 2, 3, 1).reshape(-1, self.cls_out_channels))
num_pos = 0.
# dice loss
loss_mask = []
for pred_x, pred_y, target in \
zip(mlvl_pos_mask_preds_x,
mlvl_pos_mask_preds_y, mlvl_pos_mask_targets):
num_masks = pred_x.size(0)
if num_masks == 0:
# make sure can get grad
loss_mask.append((pred_x.sum() + pred_y.sum()).unsqueeze(0))
continue
num_pos += num_masks
pred_mask = pred_y.sigmoid() * pred_x.sigmoid()
loss_mask.append(
self.loss_mask(pred_mask, target, reduction_override='none'))
if num_pos > 0:
loss_mask = torch.cat(loss_mask).sum() / num_pos
else:
loss_mask = torch.cat(loss_mask).mean()
# cate
flatten_labels = torch.cat(mlvl_labels)
flatten_cls_preds = torch.cat(temp_mlvl_cls_preds)
loss_cls = self.loss_cls(
flatten_cls_preds, flatten_labels, avg_factor=num_pos + 1)
return dict(loss_mask=loss_mask, loss_cls=loss_cls)
def _get_targets_single(self,
gt_instances: InstanceData,
featmap_sizes: Optional[list] = None) -> tuple:
"""Compute targets for predictions of single image.
Args:
gt_instances (:obj:`InstanceData`): Ground truth of instance
annotations. It should includes ``bboxes``, ``labels``,
and ``masks`` attributes.
featmap_sizes (list[:obj:`torch.size`]): Size of each
feature map from feature pyramid, each element
means (feat_h, feat_w). Defaults to None.
Returns:
Tuple: Usually returns a tuple containing targets for predictions.
- mlvl_pos_mask_targets (list[Tensor]): Each element represent
the binary mask targets for positive points in this
level, has shape (num_pos, out_h, out_w).
- mlvl_labels (list[Tensor]): Each element is
classification labels for all
points in this level, has shape
(num_grid, num_grid).
- mlvl_xy_pos_indexes (list[Tensor]): Each element
in the list contains the index of positive samples in
corresponding level, has shape (num_pos, 2), last
dimension 2 present (index_x, index_y).
"""
mlvl_pos_mask_targets, mlvl_labels, mlvl_pos_masks = \
super()._get_targets_single(gt_instances,
featmap_sizes=featmap_sizes)
mlvl_xy_pos_indexes = [(item - self.num_classes).nonzero()
for item in mlvl_labels]
return mlvl_pos_mask_targets, mlvl_labels, mlvl_xy_pos_indexes
def predict_by_feat(self, mlvl_mask_preds_x: List[Tensor],
mlvl_mask_preds_y: List[Tensor],
mlvl_cls_scores: List[Tensor],
batch_img_metas: List[dict], **kwargs) -> InstanceList:
"""Transform a batch of output features extracted from the head into
mask results.
Args:
mlvl_mask_preds_x (list[Tensor]): Multi-level mask prediction
from x branch. Each element in the list has shape
(batch_size, num_grids ,h ,w).
mlvl_mask_preds_y (list[Tensor]): Multi-level mask prediction
from y branch. Each element in the list has shape
(batch_size, num_grids ,h ,w).
mlvl_cls_scores (list[Tensor]): Multi-level scores. Each element
in the list has shape
(batch_size, num_classes ,num_grids ,num_grids).
batch_img_metas (list[dict]): Meta information of all images.
Returns:
list[:obj:`InstanceData`]: Processed results of multiple
images.Each :obj:`InstanceData` usually contains
following keys.
- scores (Tensor): Classification scores, has shape
(num_instance,).
- labels (Tensor): Has shape (num_instances,).
- masks (Tensor): Processed mask results, has
shape (num_instances, h, w).
"""
mlvl_cls_scores = [
item.permute(0, 2, 3, 1) for item in mlvl_cls_scores
]
assert len(mlvl_mask_preds_x) == len(mlvl_cls_scores)
num_levels = len(mlvl_cls_scores)
results_list = []
for img_id in range(len(batch_img_metas)):
cls_pred_list = [
mlvl_cls_scores[i][img_id].view(
-1, self.cls_out_channels).detach()
for i in range(num_levels)
]
mask_pred_list_x = [
mlvl_mask_preds_x[i][img_id] for i in range(num_levels)
]
mask_pred_list_y = [
mlvl_mask_preds_y[i][img_id] for i in range(num_levels)
]
cls_pred_list = torch.cat(cls_pred_list, dim=0)
mask_pred_list_x = torch.cat(mask_pred_list_x, dim=0)
mask_pred_list_y = torch.cat(mask_pred_list_y, dim=0)
img_meta = batch_img_metas[img_id]
results = self._predict_by_feat_single(
cls_pred_list,
mask_pred_list_x,
mask_pred_list_y,
img_meta=img_meta)
results_list.append(results)
return results_list
def _predict_by_feat_single(self,
cls_scores: Tensor,
mask_preds_x: Tensor,
mask_preds_y: Tensor,
img_meta: dict,
cfg: OptConfigType = None) -> InstanceData:
"""Transform a single image's features extracted from the head into
mask results.
Args:
cls_scores (Tensor): Classification score of all points
in single image, has shape (num_points, num_classes).
mask_preds_x (Tensor): Mask prediction of x branch of
all points in single image, has shape
(sum_num_grids, feat_h, feat_w).
mask_preds_y (Tensor): Mask prediction of y branch of
all points in single image, has shape
(sum_num_grids, feat_h, feat_w).
img_meta (dict): Meta information of corresponding image.
cfg (dict): Config used in test phase.
Returns:
:obj:`InstanceData`: Processed results of single image.
it usually contains following keys.
- scores (Tensor): Classification scores, has shape
(num_instance,).
- labels (Tensor): Has shape (num_instances,).
- masks (Tensor): Processed mask results, has
shape (num_instances, h, w).
"""
def empty_results(cls_scores, ori_shape):
"""Generate a empty results."""
results = InstanceData()
results.scores = cls_scores.new_ones(0)
results.masks = cls_scores.new_zeros(0, *ori_shape)
results.labels = cls_scores.new_ones(0)
results.bboxes = cls_scores.new_zeros(0, 4)
return results
cfg = self.test_cfg if cfg is None else cfg
featmap_size = mask_preds_x.size()[-2:]
h, w = img_meta['img_shape'][:2]
upsampled_size = (featmap_size[0] * 4, featmap_size[1] * 4)
score_mask = (cls_scores > cfg.score_thr)
cls_scores = cls_scores[score_mask]
inds = score_mask.nonzero()
lvl_interval = inds.new_tensor(self.num_grids).pow(2).cumsum(0)
num_all_points = lvl_interval[-1]
lvl_start_index = inds.new_ones(num_all_points)
num_grids = inds.new_ones(num_all_points)
seg_size = inds.new_tensor(self.num_grids).cumsum(0)
mask_lvl_start_index = inds.new_ones(num_all_points)
strides = inds.new_ones(num_all_points)
lvl_start_index[:lvl_interval[0]] *= 0
mask_lvl_start_index[:lvl_interval[0]] *= 0
num_grids[:lvl_interval[0]] *= self.num_grids[0]
strides[:lvl_interval[0]] *= self.strides[0]
for lvl in range(1, self.num_levels):
lvl_start_index[lvl_interval[lvl - 1]:lvl_interval[lvl]] *= \
lvl_interval[lvl - 1]
mask_lvl_start_index[lvl_interval[lvl - 1]:lvl_interval[lvl]] *= \
seg_size[lvl - 1]
num_grids[lvl_interval[lvl - 1]:lvl_interval[lvl]] *= \
self.num_grids[lvl]
strides[lvl_interval[lvl - 1]:lvl_interval[lvl]] *= \
self.strides[lvl]
lvl_start_index = lvl_start_index[inds[:, 0]]
mask_lvl_start_index = mask_lvl_start_index[inds[:, 0]]
num_grids = num_grids[inds[:, 0]]
strides = strides[inds[:, 0]]
y_lvl_offset = (inds[:, 0] - lvl_start_index) // num_grids
x_lvl_offset = (inds[:, 0] - lvl_start_index) % num_grids
y_inds = mask_lvl_start_index + y_lvl_offset
x_inds = mask_lvl_start_index + x_lvl_offset
cls_labels = inds[:, 1]
mask_preds = mask_preds_x[x_inds, ...] * mask_preds_y[y_inds, ...]
masks = mask_preds > cfg.mask_thr
sum_masks = masks.sum((1, 2)).float()
keep = sum_masks > strides
if keep.sum() == 0:
return empty_results(cls_scores, img_meta['ori_shape'][:2])
masks = masks[keep]
mask_preds = mask_preds[keep]
sum_masks = sum_masks[keep]
cls_scores = cls_scores[keep]
cls_labels = cls_labels[keep]
# maskness.
mask_scores = (mask_preds * masks).sum((1, 2)) / sum_masks
cls_scores *= mask_scores
scores, labels, _, keep_inds = mask_matrix_nms(
masks,
cls_labels,
cls_scores,
mask_area=sum_masks,
nms_pre=cfg.nms_pre,
max_num=cfg.max_per_img,
kernel=cfg.kernel,
sigma=cfg.sigma,
filter_thr=cfg.filter_thr)
# mask_matrix_nms may return an empty Tensor
if len(keep_inds) == 0:
return empty_results(cls_scores, img_meta['ori_shape'][:2])
mask_preds = mask_preds[keep_inds]
mask_preds = F.interpolate(
mask_preds.unsqueeze(0), size=upsampled_size,
mode='bilinear')[:, :, :h, :w]
mask_preds = F.interpolate(
mask_preds, size=img_meta['ori_shape'][:2],
mode='bilinear').squeeze(0)
masks = mask_preds > cfg.mask_thr
results = InstanceData()
results.masks = masks
results.labels = labels
results.scores = scores
# create an empty bbox in InstanceData to avoid bugs when
# calculating metrics.
results.bboxes = results.scores.new_zeros(len(scores), 4)
return results
@MODELS.register_module()
class DecoupledSOLOLightHead(DecoupledSOLOHead):
"""Decoupled Light SOLO mask head used in `SOLO: Segmenting Objects by
Locations <https://arxiv.org/abs/1912.04488>`_
Args:
with_dcn (bool): Whether use dcn in mask_convs and cls_convs,
Defaults to False.
init_cfg (dict or list[dict], optional): Initialization config dict.
"""
def __init__(self,
*args,
dcn_cfg: OptConfigType = None,
init_cfg: MultiConfig = [
dict(type='Normal', layer='Conv2d', std=0.01),
dict(
type='Normal',
std=0.01,
bias_prob=0.01,
override=dict(name='conv_mask_list_x')),
dict(
type='Normal',
std=0.01,
bias_prob=0.01,
override=dict(name='conv_mask_list_y')),
dict(
type='Normal',
std=0.01,
bias_prob=0.01,
override=dict(name='conv_cls'))
],
**kwargs) -> None:
assert dcn_cfg is None or isinstance(dcn_cfg, dict)
self.dcn_cfg = dcn_cfg
super().__init__(*args, init_cfg=init_cfg, **kwargs)
def _init_layers(self) -> None:
self.mask_convs = nn.ModuleList()
self.cls_convs = nn.ModuleList()
for i in range(self.stacked_convs):
if self.dcn_cfg is not None \
and i == self.stacked_convs - 1:
conv_cfg = self.dcn_cfg
else:
conv_cfg = None
chn = self.in_channels + 2 if i == 0 else self.feat_channels
self.mask_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=self.norm_cfg))
chn = self.in_channels if i == 0 else self.feat_channels
self.cls_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=self.norm_cfg))
self.conv_mask_list_x = nn.ModuleList()
self.conv_mask_list_y = nn.ModuleList()
for num_grid in self.num_grids:
self.conv_mask_list_x.append(
nn.Conv2d(self.feat_channels, num_grid, 3, padding=1))
self.conv_mask_list_y.append(
nn.Conv2d(self.feat_channels, num_grid, 3, padding=1))
self.conv_cls = nn.Conv2d(
self.feat_channels, self.cls_out_channels, 3, padding=1)
def forward(self, x: Tuple[Tensor]) -> Tuple:
"""Forward features from the upstream network.
Args:
x (tuple[Tensor]): Features from the upstream network, each is
a 4D-tensor.
Returns:
tuple: A tuple of classification scores and mask prediction.
- mlvl_mask_preds_x (list[Tensor]): Multi-level mask prediction
from x branch. Each element in the list has shape
(batch_size, num_grids ,h ,w).
- mlvl_mask_preds_y (list[Tensor]): Multi-level mask prediction
from y branch. Each element in the list has shape
(batch_size, num_grids ,h ,w).
- mlvl_cls_preds (list[Tensor]): Multi-level scores.
Each element in the list has shape
(batch_size, num_classes, num_grids ,num_grids).
"""
assert len(x) == self.num_levels
feats = self.resize_feats(x)
mask_preds_x = []
mask_preds_y = []
cls_preds = []
for i in range(self.num_levels):
x = feats[i]
mask_feat = x
cls_feat = x
# generate and concat the coordinate
coord_feat = generate_coordinate(mask_feat.size(),
mask_feat.device)
mask_feat = torch.cat([mask_feat, coord_feat], 1)
for mask_layer in self.mask_convs:
mask_feat = mask_layer(mask_feat)
mask_feat = F.interpolate(
mask_feat, scale_factor=2, mode='bilinear')
mask_pred_x = self.conv_mask_list_x[i](mask_feat)
mask_pred_y = self.conv_mask_list_y[i](mask_feat)
# cls branch
for j, cls_layer in enumerate(self.cls_convs):
if j == self.cls_down_index:
num_grid = self.num_grids[i]
cls_feat = F.interpolate(
cls_feat, size=num_grid, mode='bilinear')
cls_feat = cls_layer(cls_feat)
cls_pred = self.conv_cls(cls_feat)
if not self.training:
feat_wh = feats[0].size()[-2:]
upsampled_size = (feat_wh[0] * 2, feat_wh[1] * 2)
mask_pred_x = F.interpolate(
mask_pred_x.sigmoid(),
size=upsampled_size,
mode='bilinear')
mask_pred_y = F.interpolate(
mask_pred_y.sigmoid(),
size=upsampled_size,
mode='bilinear')
cls_pred = cls_pred.sigmoid()
# get local maximum
local_max = F.max_pool2d(cls_pred, 2, stride=1, padding=1)
keep_mask = local_max[:, :, :-1, :-1] == cls_pred
cls_pred = cls_pred * keep_mask
mask_preds_x.append(mask_pred_x)
mask_preds_y.append(mask_pred_y)
cls_preds.append(cls_pred)
return mask_preds_x, mask_preds_y, cls_preds