# Copyright (c) OpenMMLab. All rights reserved. from typing import List, Optional, Tuple import mmcv import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from mmcv.cnn import ConvModule from mmengine.structures import InstanceData from torch import Tensor from mmdet.models.utils.misc import floordiv from mmdet.registry import MODELS from mmdet.utils import ConfigType, InstanceList, MultiConfig, OptConfigType from ..layers import mask_matrix_nms from ..utils import center_of_mass, generate_coordinate, multi_apply from .base_mask_head import BaseMaskHead @MODELS.register_module() class SOLOHead(BaseMaskHead): """SOLO mask head used in `SOLO: Segmenting Objects by Locations. `_ Args: num_classes (int): Number of categories excluding the background category. in_channels (int): Number of channels in the input feature map. feat_channels (int): Number of hidden channels. Used in child classes. Defaults to 256. stacked_convs (int): Number of stacking convs of the head. Defaults to 4. strides (tuple): Downsample factor of each feature map. scale_ranges (tuple[tuple[int, int]]): Area range of multiple level masks, in the format [(min1, max1), (min2, max2), ...]. A range of (16, 64) means the area range between (16, 64). pos_scale (float): Constant scale factor to control the center region. num_grids (list[int]): Divided image into a uniform grids, each feature map has a different grid value. The number of output channels is grid ** 2. Defaults to [40, 36, 24, 16, 12]. cls_down_index (int): The index of downsample operation in classification branch. Defaults to 0. loss_mask (dict): Config of mask loss. loss_cls (dict): Config of classification loss. norm_cfg (dict): Dictionary to construct and config norm layer. Defaults to norm_cfg=dict(type='GN', num_groups=32, requires_grad=True). train_cfg (dict): Training config of head. test_cfg (dict): Testing config of head. init_cfg (dict or list[dict], optional): Initialization config dict. """ def __init__( self, num_classes: int, in_channels: int, feat_channels: int = 256, stacked_convs: int = 4, strides: tuple = (4, 8, 16, 32, 64), scale_ranges: tuple = ((8, 32), (16, 64), (32, 128), (64, 256), (128, 512)), pos_scale: float = 0.2, num_grids: list = [40, 36, 24, 16, 12], cls_down_index: int = 0, loss_mask: ConfigType = dict( type='DiceLoss', use_sigmoid=True, loss_weight=3.0), loss_cls: ConfigType = dict( type='FocalLoss', use_sigmoid=True, gamma=2.0, alpha=0.25, loss_weight=1.0), norm_cfg: ConfigType = dict( type='GN', num_groups=32, requires_grad=True), train_cfg: OptConfigType = None, test_cfg: OptConfigType = None, init_cfg: MultiConfig = [ dict(type='Normal', layer='Conv2d', std=0.01), dict( type='Normal', std=0.01, bias_prob=0.01, override=dict(name='conv_mask_list')), dict( type='Normal', std=0.01, bias_prob=0.01, override=dict(name='conv_cls')) ] ) -> None: super().__init__(init_cfg=init_cfg) self.num_classes = num_classes self.cls_out_channels = self.num_classes self.in_channels = in_channels self.feat_channels = feat_channels self.stacked_convs = stacked_convs self.strides = strides self.num_grids = num_grids # number of FPN feats self.num_levels = len(strides) assert self.num_levels == len(scale_ranges) == len(num_grids) self.scale_ranges = scale_ranges self.pos_scale = pos_scale self.cls_down_index = cls_down_index self.loss_cls = MODELS.build(loss_cls) self.loss_mask = MODELS.build(loss_mask) self.norm_cfg = norm_cfg self.init_cfg = init_cfg self.train_cfg = train_cfg self.test_cfg = test_cfg self._init_layers() def _init_layers(self) -> None: """Initialize layers of the head.""" self.mask_convs = nn.ModuleList() self.cls_convs = nn.ModuleList() for i in range(self.stacked_convs): chn = self.in_channels + 2 if i == 0 else self.feat_channels self.mask_convs.append( ConvModule( chn, self.feat_channels, 3, stride=1, padding=1, norm_cfg=self.norm_cfg)) chn = self.in_channels if i == 0 else self.feat_channels self.cls_convs.append( ConvModule( chn, self.feat_channels, 3, stride=1, padding=1, norm_cfg=self.norm_cfg)) self.conv_mask_list = nn.ModuleList() for num_grid in self.num_grids: self.conv_mask_list.append( nn.Conv2d(self.feat_channels, num_grid**2, 1)) self.conv_cls = nn.Conv2d( self.feat_channels, self.cls_out_channels, 3, padding=1) def resize_feats(self, x: Tuple[Tensor]) -> List[Tensor]: """Downsample the first feat and upsample last feat in feats. Args: x (tuple[Tensor]): Features from the upstream network, each is a 4D-tensor. Returns: list[Tensor]: Features after resizing, each is a 4D-tensor. """ out = [] for i in range(len(x)): if i == 0: out.append( F.interpolate(x[0], scale_factor=0.5, mode='bilinear')) elif i == len(x) - 1: out.append( F.interpolate( x[i], size=x[i - 1].shape[-2:], mode='bilinear')) else: out.append(x[i]) return out def forward(self, x: Tuple[Tensor]) -> tuple: """Forward features from the upstream network. Args: x (tuple[Tensor]): Features from the upstream network, each is a 4D-tensor. Returns: tuple: A tuple of classification scores and mask prediction. - mlvl_mask_preds (list[Tensor]): Multi-level mask prediction. Each element in the list has shape (batch_size, num_grids**2 ,h ,w). - mlvl_cls_preds (list[Tensor]): Multi-level scores. Each element in the list has shape (batch_size, num_classes, num_grids ,num_grids). """ assert len(x) == self.num_levels feats = self.resize_feats(x) mlvl_mask_preds = [] mlvl_cls_preds = [] for i in range(self.num_levels): x = feats[i] mask_feat = x cls_feat = x # generate and concat the coordinate coord_feat = generate_coordinate(mask_feat.size(), mask_feat.device) mask_feat = torch.cat([mask_feat, coord_feat], 1) for mask_layer in (self.mask_convs): mask_feat = mask_layer(mask_feat) mask_feat = F.interpolate( mask_feat, scale_factor=2, mode='bilinear') mask_preds = self.conv_mask_list[i](mask_feat) # cls branch for j, cls_layer in enumerate(self.cls_convs): if j == self.cls_down_index: num_grid = self.num_grids[i] cls_feat = F.interpolate( cls_feat, size=num_grid, mode='bilinear') cls_feat = cls_layer(cls_feat) cls_pred = self.conv_cls(cls_feat) if not self.training: feat_wh = feats[0].size()[-2:] upsampled_size = (feat_wh[0] * 2, feat_wh[1] * 2) mask_preds = F.interpolate( mask_preds.sigmoid(), size=upsampled_size, mode='bilinear') cls_pred = cls_pred.sigmoid() # get local maximum local_max = F.max_pool2d(cls_pred, 2, stride=1, padding=1) keep_mask = local_max[:, :, :-1, :-1] == cls_pred cls_pred = cls_pred * keep_mask mlvl_mask_preds.append(mask_preds) mlvl_cls_preds.append(cls_pred) return mlvl_mask_preds, mlvl_cls_preds def loss_by_feat(self, mlvl_mask_preds: List[Tensor], mlvl_cls_preds: List[Tensor], batch_gt_instances: InstanceList, batch_img_metas: List[dict], **kwargs) -> dict: """Calculate the loss based on the features extracted by the mask head. Args: mlvl_mask_preds (list[Tensor]): Multi-level mask prediction. Each element in the list has shape (batch_size, num_grids**2 ,h ,w). batch_gt_instances (list[:obj:`InstanceData`]): Batch of gt_instance. It usually includes ``bboxes``, ``masks``, and ``labels`` attributes. batch_img_metas (list[dict]): Meta information of multiple images. Returns: dict[str, Tensor]: A dictionary of loss components. """ num_levels = self.num_levels num_imgs = len(batch_img_metas) featmap_sizes = [featmap.size()[-2:] for featmap in mlvl_mask_preds] # `BoolTensor` in `pos_masks` represent # whether the corresponding point is # positive pos_mask_targets, labels, pos_masks = multi_apply( self._get_targets_single, batch_gt_instances, featmap_sizes=featmap_sizes) # change from the outside list meaning multi images # to the outside list meaning multi levels mlvl_pos_mask_targets = [[] for _ in range(num_levels)] mlvl_pos_mask_preds = [[] for _ in range(num_levels)] mlvl_pos_masks = [[] for _ in range(num_levels)] mlvl_labels = [[] for _ in range(num_levels)] for img_id in range(num_imgs): assert num_levels == len(pos_mask_targets[img_id]) for lvl in range(num_levels): mlvl_pos_mask_targets[lvl].append( pos_mask_targets[img_id][lvl]) mlvl_pos_mask_preds[lvl].append( mlvl_mask_preds[lvl][img_id, pos_masks[img_id][lvl], ...]) mlvl_pos_masks[lvl].append(pos_masks[img_id][lvl].flatten()) mlvl_labels[lvl].append(labels[img_id][lvl].flatten()) # cat multiple image temp_mlvl_cls_preds = [] for lvl in range(num_levels): mlvl_pos_mask_targets[lvl] = torch.cat( mlvl_pos_mask_targets[lvl], dim=0) mlvl_pos_mask_preds[lvl] = torch.cat( mlvl_pos_mask_preds[lvl], dim=0) mlvl_pos_masks[lvl] = torch.cat(mlvl_pos_masks[lvl], dim=0) mlvl_labels[lvl] = torch.cat(mlvl_labels[lvl], dim=0) temp_mlvl_cls_preds.append(mlvl_cls_preds[lvl].permute( 0, 2, 3, 1).reshape(-1, self.cls_out_channels)) num_pos = sum(item.sum() for item in mlvl_pos_masks) # dice loss loss_mask = [] for pred, target in zip(mlvl_pos_mask_preds, mlvl_pos_mask_targets): if pred.size()[0] == 0: loss_mask.append(pred.sum().unsqueeze(0)) continue loss_mask.append( self.loss_mask(pred, target, reduction_override='none')) if num_pos > 0: loss_mask = torch.cat(loss_mask).sum() / num_pos else: loss_mask = torch.cat(loss_mask).mean() flatten_labels = torch.cat(mlvl_labels) flatten_cls_preds = torch.cat(temp_mlvl_cls_preds) loss_cls = self.loss_cls( flatten_cls_preds, flatten_labels, avg_factor=num_pos + 1) return dict(loss_mask=loss_mask, loss_cls=loss_cls) def _get_targets_single(self, gt_instances: InstanceData, featmap_sizes: Optional[list] = None) -> tuple: """Compute targets for predictions of single image. Args: gt_instances (:obj:`InstanceData`): Ground truth of instance annotations. It should includes ``bboxes``, ``labels``, and ``masks`` attributes. featmap_sizes (list[:obj:`torch.size`]): Size of each feature map from feature pyramid, each element means (feat_h, feat_w). Defaults to None. Returns: Tuple: Usually returns a tuple containing targets for predictions. - mlvl_pos_mask_targets (list[Tensor]): Each element represent the binary mask targets for positive points in this level, has shape (num_pos, out_h, out_w). - mlvl_labels (list[Tensor]): Each element is classification labels for all points in this level, has shape (num_grid, num_grid). - mlvl_pos_masks (list[Tensor]): Each element is a `BoolTensor` to represent whether the corresponding point in single level is positive, has shape (num_grid **2). """ gt_labels = gt_instances.labels device = gt_labels.device gt_bboxes = gt_instances.bboxes gt_areas = torch.sqrt((gt_bboxes[:, 2] - gt_bboxes[:, 0]) * (gt_bboxes[:, 3] - gt_bboxes[:, 1])) gt_masks = gt_instances.masks.to_tensor( dtype=torch.bool, device=device) mlvl_pos_mask_targets = [] mlvl_labels = [] mlvl_pos_masks = [] for (lower_bound, upper_bound), stride, featmap_size, num_grid \ in zip(self.scale_ranges, self.strides, featmap_sizes, self.num_grids): mask_target = torch.zeros( [num_grid**2, featmap_size[0], featmap_size[1]], dtype=torch.uint8, device=device) # FG cat_id: [0, num_classes -1], BG cat_id: num_classes labels = torch.zeros([num_grid, num_grid], dtype=torch.int64, device=device) + self.num_classes pos_mask = torch.zeros([num_grid**2], dtype=torch.bool, device=device) gt_inds = ((gt_areas >= lower_bound) & (gt_areas <= upper_bound)).nonzero().flatten() if len(gt_inds) == 0: mlvl_pos_mask_targets.append( mask_target.new_zeros(0, featmap_size[0], featmap_size[1])) mlvl_labels.append(labels) mlvl_pos_masks.append(pos_mask) continue hit_gt_bboxes = gt_bboxes[gt_inds] hit_gt_labels = gt_labels[gt_inds] hit_gt_masks = gt_masks[gt_inds, ...] pos_w_ranges = 0.5 * (hit_gt_bboxes[:, 2] - hit_gt_bboxes[:, 0]) * self.pos_scale pos_h_ranges = 0.5 * (hit_gt_bboxes[:, 3] - hit_gt_bboxes[:, 1]) * self.pos_scale # Make sure hit_gt_masks has a value valid_mask_flags = hit_gt_masks.sum(dim=-1).sum(dim=-1) > 0 output_stride = stride / 2 for gt_mask, gt_label, pos_h_range, pos_w_range, \ valid_mask_flag in \ zip(hit_gt_masks, hit_gt_labels, pos_h_ranges, pos_w_ranges, valid_mask_flags): if not valid_mask_flag: continue upsampled_size = (featmap_sizes[0][0] * 4, featmap_sizes[0][1] * 4) center_h, center_w = center_of_mass(gt_mask) coord_w = int( floordiv((center_w / upsampled_size[1]), (1. / num_grid), rounding_mode='trunc')) coord_h = int( floordiv((center_h / upsampled_size[0]), (1. / num_grid), rounding_mode='trunc')) # left, top, right, down top_box = max( 0, int( floordiv( (center_h - pos_h_range) / upsampled_size[0], (1. / num_grid), rounding_mode='trunc'))) down_box = min( num_grid - 1, int( floordiv( (center_h + pos_h_range) / upsampled_size[0], (1. / num_grid), rounding_mode='trunc'))) left_box = max( 0, int( floordiv( (center_w - pos_w_range) / upsampled_size[1], (1. / num_grid), rounding_mode='trunc'))) right_box = min( num_grid - 1, int( floordiv( (center_w + pos_w_range) / upsampled_size[1], (1. / num_grid), rounding_mode='trunc'))) top = max(top_box, coord_h - 1) down = min(down_box, coord_h + 1) left = max(coord_w - 1, left_box) right = min(right_box, coord_w + 1) labels[top:(down + 1), left:(right + 1)] = gt_label # ins gt_mask = np.uint8(gt_mask.cpu().numpy()) # Follow the original implementation, F.interpolate is # different from cv2 and opencv gt_mask = mmcv.imrescale(gt_mask, scale=1. / output_stride) gt_mask = torch.from_numpy(gt_mask).to(device=device) for i in range(top, down + 1): for j in range(left, right + 1): index = int(i * num_grid + j) mask_target[index, :gt_mask.shape[0], :gt_mask. shape[1]] = gt_mask pos_mask[index] = True mlvl_pos_mask_targets.append(mask_target[pos_mask]) mlvl_labels.append(labels) mlvl_pos_masks.append(pos_mask) return mlvl_pos_mask_targets, mlvl_labels, mlvl_pos_masks def predict_by_feat(self, mlvl_mask_preds: List[Tensor], mlvl_cls_scores: List[Tensor], batch_img_metas: List[dict], **kwargs) -> InstanceList: """Transform a batch of output features extracted from the head into mask results. Args: mlvl_mask_preds (list[Tensor]): Multi-level mask prediction. Each element in the list has shape (batch_size, num_grids**2 ,h ,w). mlvl_cls_scores (list[Tensor]): Multi-level scores. Each element in the list has shape (batch_size, num_classes, num_grids ,num_grids). batch_img_metas (list[dict]): Meta information of all images. Returns: list[:obj:`InstanceData`]: Processed results of multiple images.Each :obj:`InstanceData` usually contains following keys. - scores (Tensor): Classification scores, has shape (num_instance,). - labels (Tensor): Has shape (num_instances,). - masks (Tensor): Processed mask results, has shape (num_instances, h, w). """ mlvl_cls_scores = [ item.permute(0, 2, 3, 1) for item in mlvl_cls_scores ] assert len(mlvl_mask_preds) == len(mlvl_cls_scores) num_levels = len(mlvl_cls_scores) results_list = [] for img_id in range(len(batch_img_metas)): cls_pred_list = [ mlvl_cls_scores[lvl][img_id].view(-1, self.cls_out_channels) for lvl in range(num_levels) ] mask_pred_list = [ mlvl_mask_preds[lvl][img_id] for lvl in range(num_levels) ] cls_pred_list = torch.cat(cls_pred_list, dim=0) mask_pred_list = torch.cat(mask_pred_list, dim=0) img_meta = batch_img_metas[img_id] results = self._predict_by_feat_single( cls_pred_list, mask_pred_list, img_meta=img_meta) results_list.append(results) return results_list def _predict_by_feat_single(self, cls_scores: Tensor, mask_preds: Tensor, img_meta: dict, cfg: OptConfigType = None) -> InstanceData: """Transform a single image's features extracted from the head into mask results. Args: cls_scores (Tensor): Classification score of all points in single image, has shape (num_points, num_classes). mask_preds (Tensor): Mask prediction of all points in single image, has shape (num_points, feat_h, feat_w). img_meta (dict): Meta information of corresponding image. cfg (dict, optional): Config used in test phase. Defaults to None. Returns: :obj:`InstanceData`: Processed results of single image. it usually contains following keys. - scores (Tensor): Classification scores, has shape (num_instance,). - labels (Tensor): Has shape (num_instances,). - masks (Tensor): Processed mask results, has shape (num_instances, h, w). """ def empty_results(cls_scores, ori_shape): """Generate a empty results.""" results = InstanceData() results.scores = cls_scores.new_ones(0) results.masks = cls_scores.new_zeros(0, *ori_shape) results.labels = cls_scores.new_ones(0) results.bboxes = cls_scores.new_zeros(0, 4) return results cfg = self.test_cfg if cfg is None else cfg assert len(cls_scores) == len(mask_preds) featmap_size = mask_preds.size()[-2:] h, w = img_meta['img_shape'][:2] upsampled_size = (featmap_size[0] * 4, featmap_size[1] * 4) score_mask = (cls_scores > cfg.score_thr) cls_scores = cls_scores[score_mask] if len(cls_scores) == 0: return empty_results(cls_scores, img_meta['ori_shape'][:2]) inds = score_mask.nonzero() cls_labels = inds[:, 1] # Filter the mask mask with an area is smaller than # stride of corresponding feature level lvl_interval = cls_labels.new_tensor(self.num_grids).pow(2).cumsum(0) strides = cls_scores.new_ones(lvl_interval[-1]) strides[:lvl_interval[0]] *= self.strides[0] for lvl in range(1, self.num_levels): strides[lvl_interval[lvl - 1]:lvl_interval[lvl]] *= self.strides[lvl] strides = strides[inds[:, 0]] mask_preds = mask_preds[inds[:, 0]] masks = mask_preds > cfg.mask_thr sum_masks = masks.sum((1, 2)).float() keep = sum_masks > strides if keep.sum() == 0: return empty_results(cls_scores, img_meta['ori_shape'][:2]) masks = masks[keep] mask_preds = mask_preds[keep] sum_masks = sum_masks[keep] cls_scores = cls_scores[keep] cls_labels = cls_labels[keep] # maskness. mask_scores = (mask_preds * masks).sum((1, 2)) / sum_masks cls_scores *= mask_scores scores, labels, _, keep_inds = mask_matrix_nms( masks, cls_labels, cls_scores, mask_area=sum_masks, nms_pre=cfg.nms_pre, max_num=cfg.max_per_img, kernel=cfg.kernel, sigma=cfg.sigma, filter_thr=cfg.filter_thr) # mask_matrix_nms may return an empty Tensor if len(keep_inds) == 0: return empty_results(cls_scores, img_meta['ori_shape'][:2]) mask_preds = mask_preds[keep_inds] mask_preds = F.interpolate( mask_preds.unsqueeze(0), size=upsampled_size, mode='bilinear')[:, :, :h, :w] mask_preds = F.interpolate( mask_preds, size=img_meta['ori_shape'][:2], mode='bilinear').squeeze(0) masks = mask_preds > cfg.mask_thr results = InstanceData() results.masks = masks results.labels = labels results.scores = scores # create an empty bbox in InstanceData to avoid bugs when # calculating metrics. results.bboxes = results.scores.new_zeros(len(scores), 4) return results @MODELS.register_module() class DecoupledSOLOHead(SOLOHead): """Decoupled SOLO mask head used in `SOLO: Segmenting Objects by Locations. `_ Args: init_cfg (dict or list[dict], optional): Initialization config dict. """ def __init__(self, *args, init_cfg: MultiConfig = [ dict(type='Normal', layer='Conv2d', std=0.01), dict( type='Normal', std=0.01, bias_prob=0.01, override=dict(name='conv_mask_list_x')), dict( type='Normal', std=0.01, bias_prob=0.01, override=dict(name='conv_mask_list_y')), dict( type='Normal', std=0.01, bias_prob=0.01, override=dict(name='conv_cls')) ], **kwargs) -> None: super().__init__(*args, init_cfg=init_cfg, **kwargs) def _init_layers(self) -> None: self.mask_convs_x = nn.ModuleList() self.mask_convs_y = nn.ModuleList() self.cls_convs = nn.ModuleList() for i in range(self.stacked_convs): chn = self.in_channels + 1 if i == 0 else self.feat_channels self.mask_convs_x.append( ConvModule( chn, self.feat_channels, 3, stride=1, padding=1, norm_cfg=self.norm_cfg)) self.mask_convs_y.append( ConvModule( chn, self.feat_channels, 3, stride=1, padding=1, norm_cfg=self.norm_cfg)) chn = self.in_channels if i == 0 else self.feat_channels self.cls_convs.append( ConvModule( chn, self.feat_channels, 3, stride=1, padding=1, norm_cfg=self.norm_cfg)) self.conv_mask_list_x = nn.ModuleList() self.conv_mask_list_y = nn.ModuleList() for num_grid in self.num_grids: self.conv_mask_list_x.append( nn.Conv2d(self.feat_channels, num_grid, 3, padding=1)) self.conv_mask_list_y.append( nn.Conv2d(self.feat_channels, num_grid, 3, padding=1)) self.conv_cls = nn.Conv2d( self.feat_channels, self.cls_out_channels, 3, padding=1) def forward(self, x: Tuple[Tensor]) -> Tuple: """Forward features from the upstream network. Args: x (tuple[Tensor]): Features from the upstream network, each is a 4D-tensor. Returns: tuple: A tuple of classification scores and mask prediction. - mlvl_mask_preds_x (list[Tensor]): Multi-level mask prediction from x branch. Each element in the list has shape (batch_size, num_grids ,h ,w). - mlvl_mask_preds_y (list[Tensor]): Multi-level mask prediction from y branch. Each element in the list has shape (batch_size, num_grids ,h ,w). - mlvl_cls_preds (list[Tensor]): Multi-level scores. Each element in the list has shape (batch_size, num_classes, num_grids ,num_grids). """ assert len(x) == self.num_levels feats = self.resize_feats(x) mask_preds_x = [] mask_preds_y = [] cls_preds = [] for i in range(self.num_levels): x = feats[i] mask_feat = x cls_feat = x # generate and concat the coordinate coord_feat = generate_coordinate(mask_feat.size(), mask_feat.device) mask_feat_x = torch.cat([mask_feat, coord_feat[:, 0:1, ...]], 1) mask_feat_y = torch.cat([mask_feat, coord_feat[:, 1:2, ...]], 1) for mask_layer_x, mask_layer_y in \ zip(self.mask_convs_x, self.mask_convs_y): mask_feat_x = mask_layer_x(mask_feat_x) mask_feat_y = mask_layer_y(mask_feat_y) mask_feat_x = F.interpolate( mask_feat_x, scale_factor=2, mode='bilinear') mask_feat_y = F.interpolate( mask_feat_y, scale_factor=2, mode='bilinear') mask_pred_x = self.conv_mask_list_x[i](mask_feat_x) mask_pred_y = self.conv_mask_list_y[i](mask_feat_y) # cls branch for j, cls_layer in enumerate(self.cls_convs): if j == self.cls_down_index: num_grid = self.num_grids[i] cls_feat = F.interpolate( cls_feat, size=num_grid, mode='bilinear') cls_feat = cls_layer(cls_feat) cls_pred = self.conv_cls(cls_feat) if not self.training: feat_wh = feats[0].size()[-2:] upsampled_size = (feat_wh[0] * 2, feat_wh[1] * 2) mask_pred_x = F.interpolate( mask_pred_x.sigmoid(), size=upsampled_size, mode='bilinear') mask_pred_y = F.interpolate( mask_pred_y.sigmoid(), size=upsampled_size, mode='bilinear') cls_pred = cls_pred.sigmoid() # get local maximum local_max = F.max_pool2d(cls_pred, 2, stride=1, padding=1) keep_mask = local_max[:, :, :-1, :-1] == cls_pred cls_pred = cls_pred * keep_mask mask_preds_x.append(mask_pred_x) mask_preds_y.append(mask_pred_y) cls_preds.append(cls_pred) return mask_preds_x, mask_preds_y, cls_preds def loss_by_feat(self, mlvl_mask_preds_x: List[Tensor], mlvl_mask_preds_y: List[Tensor], mlvl_cls_preds: List[Tensor], batch_gt_instances: InstanceList, batch_img_metas: List[dict], **kwargs) -> dict: """Calculate the loss based on the features extracted by the mask head. Args: mlvl_mask_preds_x (list[Tensor]): Multi-level mask prediction from x branch. Each element in the list has shape (batch_size, num_grids ,h ,w). mlvl_mask_preds_y (list[Tensor]): Multi-level mask prediction from y branch. Each element in the list has shape (batch_size, num_grids ,h ,w). mlvl_cls_preds (list[Tensor]): Multi-level scores. Each element in the list has shape (batch_size, num_classes, num_grids ,num_grids). batch_gt_instances (list[:obj:`InstanceData`]): Batch of gt_instance. It usually includes ``bboxes``, ``masks``, and ``labels`` attributes. batch_img_metas (list[dict]): Meta information of multiple images. Returns: dict[str, Tensor]: A dictionary of loss components. """ num_levels = self.num_levels num_imgs = len(batch_img_metas) featmap_sizes = [featmap.size()[-2:] for featmap in mlvl_mask_preds_x] pos_mask_targets, labels, xy_pos_indexes = multi_apply( self._get_targets_single, batch_gt_instances, featmap_sizes=featmap_sizes) # change from the outside list meaning multi images # to the outside list meaning multi levels mlvl_pos_mask_targets = [[] for _ in range(num_levels)] mlvl_pos_mask_preds_x = [[] for _ in range(num_levels)] mlvl_pos_mask_preds_y = [[] for _ in range(num_levels)] mlvl_labels = [[] for _ in range(num_levels)] for img_id in range(num_imgs): for lvl in range(num_levels): mlvl_pos_mask_targets[lvl].append( pos_mask_targets[img_id][lvl]) mlvl_pos_mask_preds_x[lvl].append( mlvl_mask_preds_x[lvl][img_id, xy_pos_indexes[img_id][lvl][:, 1]]) mlvl_pos_mask_preds_y[lvl].append( mlvl_mask_preds_y[lvl][img_id, xy_pos_indexes[img_id][lvl][:, 0]]) mlvl_labels[lvl].append(labels[img_id][lvl].flatten()) # cat multiple image temp_mlvl_cls_preds = [] for lvl in range(num_levels): mlvl_pos_mask_targets[lvl] = torch.cat( mlvl_pos_mask_targets[lvl], dim=0) mlvl_pos_mask_preds_x[lvl] = torch.cat( mlvl_pos_mask_preds_x[lvl], dim=0) mlvl_pos_mask_preds_y[lvl] = torch.cat( mlvl_pos_mask_preds_y[lvl], dim=0) mlvl_labels[lvl] = torch.cat(mlvl_labels[lvl], dim=0) temp_mlvl_cls_preds.append(mlvl_cls_preds[lvl].permute( 0, 2, 3, 1).reshape(-1, self.cls_out_channels)) num_pos = 0. # dice loss loss_mask = [] for pred_x, pred_y, target in \ zip(mlvl_pos_mask_preds_x, mlvl_pos_mask_preds_y, mlvl_pos_mask_targets): num_masks = pred_x.size(0) if num_masks == 0: # make sure can get grad loss_mask.append((pred_x.sum() + pred_y.sum()).unsqueeze(0)) continue num_pos += num_masks pred_mask = pred_y.sigmoid() * pred_x.sigmoid() loss_mask.append( self.loss_mask(pred_mask, target, reduction_override='none')) if num_pos > 0: loss_mask = torch.cat(loss_mask).sum() / num_pos else: loss_mask = torch.cat(loss_mask).mean() # cate flatten_labels = torch.cat(mlvl_labels) flatten_cls_preds = torch.cat(temp_mlvl_cls_preds) loss_cls = self.loss_cls( flatten_cls_preds, flatten_labels, avg_factor=num_pos + 1) return dict(loss_mask=loss_mask, loss_cls=loss_cls) def _get_targets_single(self, gt_instances: InstanceData, featmap_sizes: Optional[list] = None) -> tuple: """Compute targets for predictions of single image. Args: gt_instances (:obj:`InstanceData`): Ground truth of instance annotations. It should includes ``bboxes``, ``labels``, and ``masks`` attributes. featmap_sizes (list[:obj:`torch.size`]): Size of each feature map from feature pyramid, each element means (feat_h, feat_w). Defaults to None. Returns: Tuple: Usually returns a tuple containing targets for predictions. - mlvl_pos_mask_targets (list[Tensor]): Each element represent the binary mask targets for positive points in this level, has shape (num_pos, out_h, out_w). - mlvl_labels (list[Tensor]): Each element is classification labels for all points in this level, has shape (num_grid, num_grid). - mlvl_xy_pos_indexes (list[Tensor]): Each element in the list contains the index of positive samples in corresponding level, has shape (num_pos, 2), last dimension 2 present (index_x, index_y). """ mlvl_pos_mask_targets, mlvl_labels, mlvl_pos_masks = \ super()._get_targets_single(gt_instances, featmap_sizes=featmap_sizes) mlvl_xy_pos_indexes = [(item - self.num_classes).nonzero() for item in mlvl_labels] return mlvl_pos_mask_targets, mlvl_labels, mlvl_xy_pos_indexes def predict_by_feat(self, mlvl_mask_preds_x: List[Tensor], mlvl_mask_preds_y: List[Tensor], mlvl_cls_scores: List[Tensor], batch_img_metas: List[dict], **kwargs) -> InstanceList: """Transform a batch of output features extracted from the head into mask results. Args: mlvl_mask_preds_x (list[Tensor]): Multi-level mask prediction from x branch. Each element in the list has shape (batch_size, num_grids ,h ,w). mlvl_mask_preds_y (list[Tensor]): Multi-level mask prediction from y branch. Each element in the list has shape (batch_size, num_grids ,h ,w). mlvl_cls_scores (list[Tensor]): Multi-level scores. Each element in the list has shape (batch_size, num_classes ,num_grids ,num_grids). batch_img_metas (list[dict]): Meta information of all images. Returns: list[:obj:`InstanceData`]: Processed results of multiple images.Each :obj:`InstanceData` usually contains following keys. - scores (Tensor): Classification scores, has shape (num_instance,). - labels (Tensor): Has shape (num_instances,). - masks (Tensor): Processed mask results, has shape (num_instances, h, w). """ mlvl_cls_scores = [ item.permute(0, 2, 3, 1) for item in mlvl_cls_scores ] assert len(mlvl_mask_preds_x) == len(mlvl_cls_scores) num_levels = len(mlvl_cls_scores) results_list = [] for img_id in range(len(batch_img_metas)): cls_pred_list = [ mlvl_cls_scores[i][img_id].view( -1, self.cls_out_channels).detach() for i in range(num_levels) ] mask_pred_list_x = [ mlvl_mask_preds_x[i][img_id] for i in range(num_levels) ] mask_pred_list_y = [ mlvl_mask_preds_y[i][img_id] for i in range(num_levels) ] cls_pred_list = torch.cat(cls_pred_list, dim=0) mask_pred_list_x = torch.cat(mask_pred_list_x, dim=0) mask_pred_list_y = torch.cat(mask_pred_list_y, dim=0) img_meta = batch_img_metas[img_id] results = self._predict_by_feat_single( cls_pred_list, mask_pred_list_x, mask_pred_list_y, img_meta=img_meta) results_list.append(results) return results_list def _predict_by_feat_single(self, cls_scores: Tensor, mask_preds_x: Tensor, mask_preds_y: Tensor, img_meta: dict, cfg: OptConfigType = None) -> InstanceData: """Transform a single image's features extracted from the head into mask results. Args: cls_scores (Tensor): Classification score of all points in single image, has shape (num_points, num_classes). mask_preds_x (Tensor): Mask prediction of x branch of all points in single image, has shape (sum_num_grids, feat_h, feat_w). mask_preds_y (Tensor): Mask prediction of y branch of all points in single image, has shape (sum_num_grids, feat_h, feat_w). img_meta (dict): Meta information of corresponding image. cfg (dict): Config used in test phase. Returns: :obj:`InstanceData`: Processed results of single image. it usually contains following keys. - scores (Tensor): Classification scores, has shape (num_instance,). - labels (Tensor): Has shape (num_instances,). - masks (Tensor): Processed mask results, has shape (num_instances, h, w). """ def empty_results(cls_scores, ori_shape): """Generate a empty results.""" results = InstanceData() results.scores = cls_scores.new_ones(0) results.masks = cls_scores.new_zeros(0, *ori_shape) results.labels = cls_scores.new_ones(0) results.bboxes = cls_scores.new_zeros(0, 4) return results cfg = self.test_cfg if cfg is None else cfg featmap_size = mask_preds_x.size()[-2:] h, w = img_meta['img_shape'][:2] upsampled_size = (featmap_size[0] * 4, featmap_size[1] * 4) score_mask = (cls_scores > cfg.score_thr) cls_scores = cls_scores[score_mask] inds = score_mask.nonzero() lvl_interval = inds.new_tensor(self.num_grids).pow(2).cumsum(0) num_all_points = lvl_interval[-1] lvl_start_index = inds.new_ones(num_all_points) num_grids = inds.new_ones(num_all_points) seg_size = inds.new_tensor(self.num_grids).cumsum(0) mask_lvl_start_index = inds.new_ones(num_all_points) strides = inds.new_ones(num_all_points) lvl_start_index[:lvl_interval[0]] *= 0 mask_lvl_start_index[:lvl_interval[0]] *= 0 num_grids[:lvl_interval[0]] *= self.num_grids[0] strides[:lvl_interval[0]] *= self.strides[0] for lvl in range(1, self.num_levels): lvl_start_index[lvl_interval[lvl - 1]:lvl_interval[lvl]] *= \ lvl_interval[lvl - 1] mask_lvl_start_index[lvl_interval[lvl - 1]:lvl_interval[lvl]] *= \ seg_size[lvl - 1] num_grids[lvl_interval[lvl - 1]:lvl_interval[lvl]] *= \ self.num_grids[lvl] strides[lvl_interval[lvl - 1]:lvl_interval[lvl]] *= \ self.strides[lvl] lvl_start_index = lvl_start_index[inds[:, 0]] mask_lvl_start_index = mask_lvl_start_index[inds[:, 0]] num_grids = num_grids[inds[:, 0]] strides = strides[inds[:, 0]] y_lvl_offset = (inds[:, 0] - lvl_start_index) // num_grids x_lvl_offset = (inds[:, 0] - lvl_start_index) % num_grids y_inds = mask_lvl_start_index + y_lvl_offset x_inds = mask_lvl_start_index + x_lvl_offset cls_labels = inds[:, 1] mask_preds = mask_preds_x[x_inds, ...] * mask_preds_y[y_inds, ...] masks = mask_preds > cfg.mask_thr sum_masks = masks.sum((1, 2)).float() keep = sum_masks > strides if keep.sum() == 0: return empty_results(cls_scores, img_meta['ori_shape'][:2]) masks = masks[keep] mask_preds = mask_preds[keep] sum_masks = sum_masks[keep] cls_scores = cls_scores[keep] cls_labels = cls_labels[keep] # maskness. mask_scores = (mask_preds * masks).sum((1, 2)) / sum_masks cls_scores *= mask_scores scores, labels, _, keep_inds = mask_matrix_nms( masks, cls_labels, cls_scores, mask_area=sum_masks, nms_pre=cfg.nms_pre, max_num=cfg.max_per_img, kernel=cfg.kernel, sigma=cfg.sigma, filter_thr=cfg.filter_thr) # mask_matrix_nms may return an empty Tensor if len(keep_inds) == 0: return empty_results(cls_scores, img_meta['ori_shape'][:2]) mask_preds = mask_preds[keep_inds] mask_preds = F.interpolate( mask_preds.unsqueeze(0), size=upsampled_size, mode='bilinear')[:, :, :h, :w] mask_preds = F.interpolate( mask_preds, size=img_meta['ori_shape'][:2], mode='bilinear').squeeze(0) masks = mask_preds > cfg.mask_thr results = InstanceData() results.masks = masks results.labels = labels results.scores = scores # create an empty bbox in InstanceData to avoid bugs when # calculating metrics. results.bboxes = results.scores.new_zeros(len(scores), 4) return results @MODELS.register_module() class DecoupledSOLOLightHead(DecoupledSOLOHead): """Decoupled Light SOLO mask head used in `SOLO: Segmenting Objects by Locations `_ Args: with_dcn (bool): Whether use dcn in mask_convs and cls_convs, Defaults to False. init_cfg (dict or list[dict], optional): Initialization config dict. """ def __init__(self, *args, dcn_cfg: OptConfigType = None, init_cfg: MultiConfig = [ dict(type='Normal', layer='Conv2d', std=0.01), dict( type='Normal', std=0.01, bias_prob=0.01, override=dict(name='conv_mask_list_x')), dict( type='Normal', std=0.01, bias_prob=0.01, override=dict(name='conv_mask_list_y')), dict( type='Normal', std=0.01, bias_prob=0.01, override=dict(name='conv_cls')) ], **kwargs) -> None: assert dcn_cfg is None or isinstance(dcn_cfg, dict) self.dcn_cfg = dcn_cfg super().__init__(*args, init_cfg=init_cfg, **kwargs) def _init_layers(self) -> None: self.mask_convs = nn.ModuleList() self.cls_convs = nn.ModuleList() for i in range(self.stacked_convs): if self.dcn_cfg is not None \ and i == self.stacked_convs - 1: conv_cfg = self.dcn_cfg else: conv_cfg = None chn = self.in_channels + 2 if i == 0 else self.feat_channels self.mask_convs.append( ConvModule( chn, self.feat_channels, 3, stride=1, padding=1, conv_cfg=conv_cfg, norm_cfg=self.norm_cfg)) chn = self.in_channels if i == 0 else self.feat_channels self.cls_convs.append( ConvModule( chn, self.feat_channels, 3, stride=1, padding=1, conv_cfg=conv_cfg, norm_cfg=self.norm_cfg)) self.conv_mask_list_x = nn.ModuleList() self.conv_mask_list_y = nn.ModuleList() for num_grid in self.num_grids: self.conv_mask_list_x.append( nn.Conv2d(self.feat_channels, num_grid, 3, padding=1)) self.conv_mask_list_y.append( nn.Conv2d(self.feat_channels, num_grid, 3, padding=1)) self.conv_cls = nn.Conv2d( self.feat_channels, self.cls_out_channels, 3, padding=1) def forward(self, x: Tuple[Tensor]) -> Tuple: """Forward features from the upstream network. Args: x (tuple[Tensor]): Features from the upstream network, each is a 4D-tensor. Returns: tuple: A tuple of classification scores and mask prediction. - mlvl_mask_preds_x (list[Tensor]): Multi-level mask prediction from x branch. Each element in the list has shape (batch_size, num_grids ,h ,w). - mlvl_mask_preds_y (list[Tensor]): Multi-level mask prediction from y branch. Each element in the list has shape (batch_size, num_grids ,h ,w). - mlvl_cls_preds (list[Tensor]): Multi-level scores. Each element in the list has shape (batch_size, num_classes, num_grids ,num_grids). """ assert len(x) == self.num_levels feats = self.resize_feats(x) mask_preds_x = [] mask_preds_y = [] cls_preds = [] for i in range(self.num_levels): x = feats[i] mask_feat = x cls_feat = x # generate and concat the coordinate coord_feat = generate_coordinate(mask_feat.size(), mask_feat.device) mask_feat = torch.cat([mask_feat, coord_feat], 1) for mask_layer in self.mask_convs: mask_feat = mask_layer(mask_feat) mask_feat = F.interpolate( mask_feat, scale_factor=2, mode='bilinear') mask_pred_x = self.conv_mask_list_x[i](mask_feat) mask_pred_y = self.conv_mask_list_y[i](mask_feat) # cls branch for j, cls_layer in enumerate(self.cls_convs): if j == self.cls_down_index: num_grid = self.num_grids[i] cls_feat = F.interpolate( cls_feat, size=num_grid, mode='bilinear') cls_feat = cls_layer(cls_feat) cls_pred = self.conv_cls(cls_feat) if not self.training: feat_wh = feats[0].size()[-2:] upsampled_size = (feat_wh[0] * 2, feat_wh[1] * 2) mask_pred_x = F.interpolate( mask_pred_x.sigmoid(), size=upsampled_size, mode='bilinear') mask_pred_y = F.interpolate( mask_pred_y.sigmoid(), size=upsampled_size, mode='bilinear') cls_pred = cls_pred.sigmoid() # get local maximum local_max = F.max_pool2d(cls_pred, 2, stride=1, padding=1) keep_mask = local_max[:, :, :-1, :-1] == cls_pred cls_pred = cls_pred * keep_mask mask_preds_x.append(mask_pred_x) mask_preds_y.append(mask_pred_y) cls_preds.append(cls_pred) return mask_preds_x, mask_preds_y, cls_preds