Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import copy | |
from typing import Dict, List, Optional, Tuple | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from mmcv.cnn import ConvModule, Scale | |
from mmengine.config import ConfigDict | |
from mmengine.model import BaseModule, kaiming_init | |
from mmengine.structures import InstanceData | |
from torch import Tensor | |
from mmdet.registry import MODELS | |
from mmdet.structures.bbox import cat_boxes | |
from mmdet.utils import (ConfigType, InstanceList, MultiConfig, OptConfigType, | |
OptInstanceList, reduce_mean) | |
from ..task_modules.prior_generators import MlvlPointGenerator | |
from ..utils import (aligned_bilinear, filter_scores_and_topk, multi_apply, | |
relative_coordinate_maps, select_single_mlvl) | |
from ..utils.misc import empty_instances | |
from .base_mask_head import BaseMaskHead | |
from .fcos_head import FCOSHead | |
INF = 1e8 | |
class CondInstBboxHead(FCOSHead): | |
"""CondInst box head used in https://arxiv.org/abs/1904.02689. | |
Note that CondInst Bbox Head is a extension of FCOS head. | |
Two differences are described as follows: | |
1. CondInst box head predicts a set of params for each instance. | |
2. CondInst box head return the pos_gt_inds and pos_inds. | |
Args: | |
num_params (int): Number of params for instance segmentation. | |
""" | |
def __init__(self, *args, num_params: int = 169, **kwargs) -> None: | |
self.num_params = num_params | |
super().__init__(*args, **kwargs) | |
def _init_layers(self) -> None: | |
"""Initialize layers of the head.""" | |
super()._init_layers() | |
self.controller = nn.Conv2d( | |
self.feat_channels, self.num_params, 3, padding=1) | |
def forward_single(self, x: Tensor, scale: Scale, | |
stride: int) -> Tuple[Tensor, Tensor, Tensor, Tensor]: | |
"""Forward features of a single scale level. | |
Args: | |
x (Tensor): FPN feature maps of the specified stride. | |
scale (:obj:`mmcv.cnn.Scale`): Learnable scale module to resize | |
the bbox prediction. | |
stride (int): The corresponding stride for feature maps, only | |
used to normalize the bbox prediction when self.norm_on_bbox | |
is True. | |
Returns: | |
tuple: scores for each class, bbox predictions, centerness | |
predictions and param predictions of input feature maps. | |
""" | |
cls_score, bbox_pred, cls_feat, reg_feat = \ | |
super(FCOSHead, self).forward_single(x) | |
if self.centerness_on_reg: | |
centerness = self.conv_centerness(reg_feat) | |
else: | |
centerness = self.conv_centerness(cls_feat) | |
# scale the bbox_pred of different level | |
# float to avoid overflow when enabling FP16 | |
bbox_pred = scale(bbox_pred).float() | |
if self.norm_on_bbox: | |
# bbox_pred needed for gradient computation has been modified | |
# by F.relu(bbox_pred) when run with PyTorch 1.10. So replace | |
# F.relu(bbox_pred) with bbox_pred.clamp(min=0) | |
bbox_pred = bbox_pred.clamp(min=0) | |
if not self.training: | |
bbox_pred *= stride | |
else: | |
bbox_pred = bbox_pred.exp() | |
param_pred = self.controller(reg_feat) | |
return cls_score, bbox_pred, centerness, param_pred | |
def loss_by_feat( | |
self, | |
cls_scores: List[Tensor], | |
bbox_preds: List[Tensor], | |
centernesses: List[Tensor], | |
param_preds: List[Tensor], | |
batch_gt_instances: InstanceList, | |
batch_img_metas: List[dict], | |
batch_gt_instances_ignore: OptInstanceList = None | |
) -> Dict[str, Tensor]: | |
"""Calculate the loss based on the features extracted by the detection | |
head. | |
Args: | |
cls_scores (list[Tensor]): Box scores for each scale level, | |
each is a 4D-tensor, the channel number is | |
num_points * num_classes. | |
bbox_preds (list[Tensor]): Box energies / deltas for each scale | |
level, each is a 4D-tensor, the channel number is | |
num_points * 4. | |
centernesses (list[Tensor]): centerness for each scale level, each | |
is a 4D-tensor, the channel number is num_points * 1. | |
param_preds (List[Tensor]): param_pred for each scale level, each | |
is a 4D-tensor, the channel number is num_params. | |
batch_gt_instances (list[:obj:`InstanceData`]): Batch of | |
gt_instance. It usually includes ``bboxes`` and ``labels`` | |
attributes. | |
batch_img_metas (list[dict]): Meta information of each image, e.g., | |
image size, scaling factor, etc. | |
batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional): | |
Batch of gt_instances_ignore. It includes ``bboxes`` attribute | |
data that is ignored during training and testing. | |
Defaults to None. | |
Returns: | |
dict[str, Tensor]: A dictionary of loss components. | |
""" | |
assert len(cls_scores) == len(bbox_preds) == len(centernesses) | |
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] | |
# Need stride for rel coord compute | |
all_level_points_strides = self.prior_generator.grid_priors( | |
featmap_sizes, | |
dtype=bbox_preds[0].dtype, | |
device=bbox_preds[0].device, | |
with_stride=True) | |
all_level_points = [i[:, :2] for i in all_level_points_strides] | |
all_level_strides = [i[:, 2] for i in all_level_points_strides] | |
labels, bbox_targets, pos_inds_list, pos_gt_inds_list = \ | |
self.get_targets(all_level_points, batch_gt_instances) | |
num_imgs = cls_scores[0].size(0) | |
# flatten cls_scores, bbox_preds and centerness | |
flatten_cls_scores = [ | |
cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels) | |
for cls_score in cls_scores | |
] | |
flatten_bbox_preds = [ | |
bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4) | |
for bbox_pred in bbox_preds | |
] | |
flatten_centerness = [ | |
centerness.permute(0, 2, 3, 1).reshape(-1) | |
for centerness in centernesses | |
] | |
flatten_cls_scores = torch.cat(flatten_cls_scores) | |
flatten_bbox_preds = torch.cat(flatten_bbox_preds) | |
flatten_centerness = torch.cat(flatten_centerness) | |
flatten_labels = torch.cat(labels) | |
flatten_bbox_targets = torch.cat(bbox_targets) | |
# repeat points to align with bbox_preds | |
flatten_points = torch.cat( | |
[points.repeat(num_imgs, 1) for points in all_level_points]) | |
# FG cat_id: [0, num_classes -1], BG cat_id: num_classes | |
bg_class_ind = self.num_classes | |
pos_inds = ((flatten_labels >= 0) | |
& (flatten_labels < bg_class_ind)).nonzero().reshape(-1) | |
num_pos = torch.tensor( | |
len(pos_inds), dtype=torch.float, device=bbox_preds[0].device) | |
num_pos = max(reduce_mean(num_pos), 1.0) | |
loss_cls = self.loss_cls( | |
flatten_cls_scores, flatten_labels, avg_factor=num_pos) | |
pos_bbox_preds = flatten_bbox_preds[pos_inds] | |
pos_centerness = flatten_centerness[pos_inds] | |
pos_bbox_targets = flatten_bbox_targets[pos_inds] | |
pos_centerness_targets = self.centerness_target(pos_bbox_targets) | |
# centerness weighted iou loss | |
centerness_denorm = max( | |
reduce_mean(pos_centerness_targets.sum().detach()), 1e-6) | |
if len(pos_inds) > 0: | |
pos_points = flatten_points[pos_inds] | |
pos_decoded_bbox_preds = self.bbox_coder.decode( | |
pos_points, pos_bbox_preds) | |
pos_decoded_target_preds = self.bbox_coder.decode( | |
pos_points, pos_bbox_targets) | |
loss_bbox = self.loss_bbox( | |
pos_decoded_bbox_preds, | |
pos_decoded_target_preds, | |
weight=pos_centerness_targets, | |
avg_factor=centerness_denorm) | |
loss_centerness = self.loss_centerness( | |
pos_centerness, pos_centerness_targets, avg_factor=num_pos) | |
else: | |
loss_bbox = pos_bbox_preds.sum() | |
loss_centerness = pos_centerness.sum() | |
self._raw_positive_infos.update(cls_scores=cls_scores) | |
self._raw_positive_infos.update(centernesses=centernesses) | |
self._raw_positive_infos.update(param_preds=param_preds) | |
self._raw_positive_infos.update(all_level_points=all_level_points) | |
self._raw_positive_infos.update(all_level_strides=all_level_strides) | |
self._raw_positive_infos.update(pos_gt_inds_list=pos_gt_inds_list) | |
self._raw_positive_infos.update(pos_inds_list=pos_inds_list) | |
return dict( | |
loss_cls=loss_cls, | |
loss_bbox=loss_bbox, | |
loss_centerness=loss_centerness) | |
def get_targets( | |
self, points: List[Tensor], batch_gt_instances: InstanceList | |
) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]: | |
"""Compute regression, classification and centerness targets for points | |
in multiple images. | |
Args: | |
points (list[Tensor]): Points of each fpn level, each has shape | |
(num_points, 2). | |
batch_gt_instances (list[:obj:`InstanceData`]): Batch of | |
gt_instance. It usually includes ``bboxes`` and ``labels`` | |
attributes. | |
Returns: | |
tuple: Targets of each level. | |
- concat_lvl_labels (list[Tensor]): Labels of each level. | |
- concat_lvl_bbox_targets (list[Tensor]): BBox targets of each \ | |
level. | |
- pos_inds_list (list[Tensor]): pos_inds of each image. | |
- pos_gt_inds_list (List[Tensor]): pos_gt_inds of each image. | |
""" | |
assert len(points) == len(self.regress_ranges) | |
num_levels = len(points) | |
# expand regress ranges to align with points | |
expanded_regress_ranges = [ | |
points[i].new_tensor(self.regress_ranges[i])[None].expand_as( | |
points[i]) for i in range(num_levels) | |
] | |
# concat all levels points and regress ranges | |
concat_regress_ranges = torch.cat(expanded_regress_ranges, dim=0) | |
concat_points = torch.cat(points, dim=0) | |
# the number of points per img, per lvl | |
num_points = [center.size(0) for center in points] | |
# get labels and bbox_targets of each image | |
labels_list, bbox_targets_list, pos_inds_list, pos_gt_inds_list = \ | |
multi_apply( | |
self._get_targets_single, | |
batch_gt_instances, | |
points=concat_points, | |
regress_ranges=concat_regress_ranges, | |
num_points_per_lvl=num_points) | |
# split to per img, per level | |
labels_list = [labels.split(num_points, 0) for labels in labels_list] | |
bbox_targets_list = [ | |
bbox_targets.split(num_points, 0) | |
for bbox_targets in bbox_targets_list | |
] | |
# concat per level image | |
concat_lvl_labels = [] | |
concat_lvl_bbox_targets = [] | |
for i in range(num_levels): | |
concat_lvl_labels.append( | |
torch.cat([labels[i] for labels in labels_list])) | |
bbox_targets = torch.cat( | |
[bbox_targets[i] for bbox_targets in bbox_targets_list]) | |
if self.norm_on_bbox: | |
bbox_targets = bbox_targets / self.strides[i] | |
concat_lvl_bbox_targets.append(bbox_targets) | |
return (concat_lvl_labels, concat_lvl_bbox_targets, pos_inds_list, | |
pos_gt_inds_list) | |
def _get_targets_single( | |
self, gt_instances: InstanceData, points: Tensor, | |
regress_ranges: Tensor, num_points_per_lvl: List[int] | |
) -> Tuple[Tensor, Tensor, Tensor, Tensor]: | |
"""Compute regression and classification targets for a single image.""" | |
num_points = points.size(0) | |
num_gts = len(gt_instances) | |
gt_bboxes = gt_instances.bboxes | |
gt_labels = gt_instances.labels | |
gt_masks = gt_instances.get('masks', None) | |
if num_gts == 0: | |
return gt_labels.new_full((num_points,), self.num_classes), \ | |
gt_bboxes.new_zeros((num_points, 4)), \ | |
gt_bboxes.new_zeros((0,), dtype=torch.int64), \ | |
gt_bboxes.new_zeros((0,), dtype=torch.int64) | |
areas = (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * ( | |
gt_bboxes[:, 3] - gt_bboxes[:, 1]) | |
# TODO: figure out why these two are different | |
# areas = areas[None].expand(num_points, num_gts) | |
areas = areas[None].repeat(num_points, 1) | |
regress_ranges = regress_ranges[:, None, :].expand( | |
num_points, num_gts, 2) | |
gt_bboxes = gt_bboxes[None].expand(num_points, num_gts, 4) | |
xs, ys = points[:, 0], points[:, 1] | |
xs = xs[:, None].expand(num_points, num_gts) | |
ys = ys[:, None].expand(num_points, num_gts) | |
left = xs - gt_bboxes[..., 0] | |
right = gt_bboxes[..., 2] - xs | |
top = ys - gt_bboxes[..., 1] | |
bottom = gt_bboxes[..., 3] - ys | |
bbox_targets = torch.stack((left, top, right, bottom), -1) | |
if self.center_sampling: | |
# condition1: inside a `center bbox` | |
radius = self.center_sample_radius | |
# if gt_mask not None, use gt mask's centroid to determine | |
# the center region rather than gt_bbox center | |
if gt_masks is None: | |
center_xs = (gt_bboxes[..., 0] + gt_bboxes[..., 2]) / 2 | |
center_ys = (gt_bboxes[..., 1] + gt_bboxes[..., 3]) / 2 | |
else: | |
h, w = gt_masks.height, gt_masks.width | |
masks = gt_masks.to_tensor( | |
dtype=torch.bool, device=gt_bboxes.device) | |
yys = torch.arange( | |
0, h, dtype=torch.float32, device=masks.device) | |
xxs = torch.arange( | |
0, w, dtype=torch.float32, device=masks.device) | |
# m00/m10/m01 represent the moments of a contour | |
# centroid is computed by m00/m10 and m00/m01 | |
m00 = masks.sum(dim=-1).sum(dim=-1).clamp(min=1e-6) | |
m10 = (masks * xxs).sum(dim=-1).sum(dim=-1) | |
m01 = (masks * yys[:, None]).sum(dim=-1).sum(dim=-1) | |
center_xs = m10 / m00 | |
center_ys = m01 / m00 | |
center_xs = center_xs[None].expand(num_points, num_gts) | |
center_ys = center_ys[None].expand(num_points, num_gts) | |
center_gts = torch.zeros_like(gt_bboxes) | |
stride = center_xs.new_zeros(center_xs.shape) | |
# project the points on current lvl back to the `original` sizes | |
lvl_begin = 0 | |
for lvl_idx, num_points_lvl in enumerate(num_points_per_lvl): | |
lvl_end = lvl_begin + num_points_lvl | |
stride[lvl_begin:lvl_end] = self.strides[lvl_idx] * radius | |
lvl_begin = lvl_end | |
x_mins = center_xs - stride | |
y_mins = center_ys - stride | |
x_maxs = center_xs + stride | |
y_maxs = center_ys + stride | |
center_gts[..., 0] = torch.where(x_mins > gt_bboxes[..., 0], | |
x_mins, gt_bboxes[..., 0]) | |
center_gts[..., 1] = torch.where(y_mins > gt_bboxes[..., 1], | |
y_mins, gt_bboxes[..., 1]) | |
center_gts[..., 2] = torch.where(x_maxs > gt_bboxes[..., 2], | |
gt_bboxes[..., 2], x_maxs) | |
center_gts[..., 3] = torch.where(y_maxs > gt_bboxes[..., 3], | |
gt_bboxes[..., 3], y_maxs) | |
cb_dist_left = xs - center_gts[..., 0] | |
cb_dist_right = center_gts[..., 2] - xs | |
cb_dist_top = ys - center_gts[..., 1] | |
cb_dist_bottom = center_gts[..., 3] - ys | |
center_bbox = torch.stack( | |
(cb_dist_left, cb_dist_top, cb_dist_right, cb_dist_bottom), -1) | |
inside_gt_bbox_mask = center_bbox.min(-1)[0] > 0 | |
else: | |
# condition1: inside a gt bbox | |
inside_gt_bbox_mask = bbox_targets.min(-1)[0] > 0 | |
# condition2: limit the regression range for each location | |
max_regress_distance = bbox_targets.max(-1)[0] | |
inside_regress_range = ( | |
(max_regress_distance >= regress_ranges[..., 0]) | |
& (max_regress_distance <= regress_ranges[..., 1])) | |
# if there are still more than one objects for a location, | |
# we choose the one with minimal area | |
areas[inside_gt_bbox_mask == 0] = INF | |
areas[inside_regress_range == 0] = INF | |
min_area, min_area_inds = areas.min(dim=1) | |
labels = gt_labels[min_area_inds] | |
labels[min_area == INF] = self.num_classes # set as BG | |
bbox_targets = bbox_targets[range(num_points), min_area_inds] | |
# return pos_inds & pos_gt_inds | |
bg_class_ind = self.num_classes | |
pos_inds = ((labels >= 0) | |
& (labels < bg_class_ind)).nonzero().reshape(-1) | |
pos_gt_inds = min_area_inds[labels < self.num_classes] | |
return labels, bbox_targets, pos_inds, pos_gt_inds | |
def get_positive_infos(self) -> InstanceList: | |
"""Get positive information from sampling results. | |
Returns: | |
list[:obj:`InstanceData`]: Positive information of each image, | |
usually including positive bboxes, positive labels, positive | |
priors, etc. | |
""" | |
assert len(self._raw_positive_infos) > 0 | |
pos_gt_inds_list = self._raw_positive_infos['pos_gt_inds_list'] | |
pos_inds_list = self._raw_positive_infos['pos_inds_list'] | |
num_imgs = len(pos_gt_inds_list) | |
cls_score_list = [] | |
centerness_list = [] | |
param_pred_list = [] | |
point_list = [] | |
stride_list = [] | |
for cls_score_per_lvl, centerness_per_lvl, param_pred_per_lvl,\ | |
point_per_lvl, stride_per_lvl in \ | |
zip(self._raw_positive_infos['cls_scores'], | |
self._raw_positive_infos['centernesses'], | |
self._raw_positive_infos['param_preds'], | |
self._raw_positive_infos['all_level_points'], | |
self._raw_positive_infos['all_level_strides']): | |
cls_score_per_lvl = \ | |
cls_score_per_lvl.permute( | |
0, 2, 3, 1).reshape(num_imgs, -1, self.num_classes) | |
centerness_per_lvl = \ | |
centerness_per_lvl.permute( | |
0, 2, 3, 1).reshape(num_imgs, -1, 1) | |
param_pred_per_lvl = \ | |
param_pred_per_lvl.permute( | |
0, 2, 3, 1).reshape(num_imgs, -1, self.num_params) | |
point_per_lvl = point_per_lvl.unsqueeze(0).repeat(num_imgs, 1, 1) | |
stride_per_lvl = stride_per_lvl.unsqueeze(0).repeat(num_imgs, 1) | |
cls_score_list.append(cls_score_per_lvl) | |
centerness_list.append(centerness_per_lvl) | |
param_pred_list.append(param_pred_per_lvl) | |
point_list.append(point_per_lvl) | |
stride_list.append(stride_per_lvl) | |
cls_scores = torch.cat(cls_score_list, dim=1) | |
centernesses = torch.cat(centerness_list, dim=1) | |
param_preds = torch.cat(param_pred_list, dim=1) | |
all_points = torch.cat(point_list, dim=1) | |
all_strides = torch.cat(stride_list, dim=1) | |
positive_infos = [] | |
for i, (pos_gt_inds, | |
pos_inds) in enumerate(zip(pos_gt_inds_list, pos_inds_list)): | |
pos_info = InstanceData() | |
pos_info.points = all_points[i][pos_inds] | |
pos_info.strides = all_strides[i][pos_inds] | |
pos_info.scores = cls_scores[i][pos_inds] | |
pos_info.centernesses = centernesses[i][pos_inds] | |
pos_info.param_preds = param_preds[i][pos_inds] | |
pos_info.pos_assigned_gt_inds = pos_gt_inds | |
pos_info.pos_inds = pos_inds | |
positive_infos.append(pos_info) | |
return positive_infos | |
def predict_by_feat(self, | |
cls_scores: List[Tensor], | |
bbox_preds: List[Tensor], | |
score_factors: Optional[List[Tensor]] = None, | |
param_preds: Optional[List[Tensor]] = None, | |
batch_img_metas: Optional[List[dict]] = None, | |
cfg: Optional[ConfigDict] = None, | |
rescale: bool = False, | |
with_nms: bool = True) -> InstanceList: | |
"""Transform a batch of output features extracted from the head into | |
bbox results. | |
Note: When score_factors is not None, the cls_scores are | |
usually multiplied by it then obtain the real score used in NMS, | |
such as CenterNess in FCOS, IoU branch in ATSS. | |
Args: | |
cls_scores (list[Tensor]): Classification scores for all | |
scale levels, each is a 4D-tensor, has shape | |
(batch_size, num_priors * num_classes, H, W). | |
bbox_preds (list[Tensor]): Box energies / deltas for all | |
scale levels, each is a 4D-tensor, has shape | |
(batch_size, num_priors * 4, H, W). | |
score_factors (list[Tensor], optional): Score factor for | |
all scale level, each is a 4D-tensor, has shape | |
(batch_size, num_priors * 1, H, W). Defaults to None. | |
param_preds (list[Tensor], optional): Params for all scale | |
level, each is a 4D-tensor, has shape | |
(batch_size, num_priors * num_params, H, W) | |
batch_img_metas (list[dict], Optional): Batch image meta info. | |
Defaults to None. | |
cfg (ConfigDict, optional): Test / postprocessing | |
configuration, if None, test_cfg would be used. | |
Defaults to None. | |
rescale (bool): If True, return boxes in original image space. | |
Defaults to False. | |
with_nms (bool): If True, do nms before return boxes. | |
Defaults to True. | |
Returns: | |
list[:obj:`InstanceData`]: Object detection results of each image | |
after the post process. Each item usually contains following keys. | |
- scores (Tensor): Classification scores, has a shape | |
(num_instance, ) | |
- labels (Tensor): Labels of bboxes, has a shape | |
(num_instances, ). | |
- bboxes (Tensor): Has a shape (num_instances, 4), | |
the last dimension 4 arrange as (x1, y1, x2, y2). | |
""" | |
assert len(cls_scores) == len(bbox_preds) | |
if score_factors is None: | |
# e.g. Retina, FreeAnchor, Foveabox, etc. | |
with_score_factors = False | |
else: | |
# e.g. FCOS, PAA, ATSS, AutoAssign, etc. | |
with_score_factors = True | |
assert len(cls_scores) == len(score_factors) | |
num_levels = len(cls_scores) | |
featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)] | |
all_level_points_strides = self.prior_generator.grid_priors( | |
featmap_sizes, | |
dtype=bbox_preds[0].dtype, | |
device=bbox_preds[0].device, | |
with_stride=True) | |
all_level_points = [i[:, :2] for i in all_level_points_strides] | |
all_level_strides = [i[:, 2] for i in all_level_points_strides] | |
result_list = [] | |
for img_id in range(len(batch_img_metas)): | |
img_meta = batch_img_metas[img_id] | |
cls_score_list = select_single_mlvl( | |
cls_scores, img_id, detach=True) | |
bbox_pred_list = select_single_mlvl( | |
bbox_preds, img_id, detach=True) | |
if with_score_factors: | |
score_factor_list = select_single_mlvl( | |
score_factors, img_id, detach=True) | |
else: | |
score_factor_list = [None for _ in range(num_levels)] | |
param_pred_list = select_single_mlvl( | |
param_preds, img_id, detach=True) | |
results = self._predict_by_feat_single( | |
cls_score_list=cls_score_list, | |
bbox_pred_list=bbox_pred_list, | |
score_factor_list=score_factor_list, | |
param_pred_list=param_pred_list, | |
mlvl_points=all_level_points, | |
mlvl_strides=all_level_strides, | |
img_meta=img_meta, | |
cfg=cfg, | |
rescale=rescale, | |
with_nms=with_nms) | |
result_list.append(results) | |
return result_list | |
def _predict_by_feat_single(self, | |
cls_score_list: List[Tensor], | |
bbox_pred_list: List[Tensor], | |
score_factor_list: List[Tensor], | |
param_pred_list: List[Tensor], | |
mlvl_points: List[Tensor], | |
mlvl_strides: List[Tensor], | |
img_meta: dict, | |
cfg: ConfigDict, | |
rescale: bool = False, | |
with_nms: bool = True) -> InstanceData: | |
"""Transform a single image's features extracted from the head into | |
bbox results. | |
Args: | |
cls_score_list (list[Tensor]): Box scores from all scale | |
levels of a single image, each item has shape | |
(num_priors * num_classes, H, W). | |
bbox_pred_list (list[Tensor]): Box energies / deltas from | |
all scale levels of a single image, each item has shape | |
(num_priors * 4, H, W). | |
score_factor_list (list[Tensor]): Score factor from all scale | |
levels of a single image, each item has shape | |
(num_priors * 1, H, W). | |
param_pred_list (List[Tensor]): Param predition from all scale | |
levels of a single image, each item has shape | |
(num_priors * num_params, H, W). | |
mlvl_points (list[Tensor]): Each element in the list is | |
the priors of a single level in feature pyramid. | |
It has shape (num_priors, 2) | |
mlvl_strides (List[Tensor]): Each element in the list is | |
the stride of a single level in feature pyramid. | |
It has shape (num_priors, 1) | |
img_meta (dict): Image meta info. | |
cfg (mmengine.Config): Test / postprocessing configuration, | |
if None, test_cfg would be used. | |
rescale (bool): If True, return boxes in original image space. | |
Defaults to False. | |
with_nms (bool): If True, do nms before return boxes. | |
Defaults to True. | |
Returns: | |
:obj:`InstanceData`: Detection results of each image | |
after the post process. | |
Each item usually contains following keys. | |
- scores (Tensor): Classification scores, has a shape | |
(num_instance, ) | |
- labels (Tensor): Labels of bboxes, has a shape | |
(num_instances, ). | |
- bboxes (Tensor): Has a shape (num_instances, 4), | |
the last dimension 4 arrange as (x1, y1, x2, y2). | |
""" | |
if score_factor_list[0] is None: | |
# e.g. Retina, FreeAnchor, etc. | |
with_score_factors = False | |
else: | |
# e.g. FCOS, PAA, ATSS, etc. | |
with_score_factors = True | |
cfg = self.test_cfg if cfg is None else cfg | |
cfg = copy.deepcopy(cfg) | |
img_shape = img_meta['img_shape'] | |
nms_pre = cfg.get('nms_pre', -1) | |
mlvl_bbox_preds = [] | |
mlvl_param_preds = [] | |
mlvl_valid_points = [] | |
mlvl_valid_strides = [] | |
mlvl_scores = [] | |
mlvl_labels = [] | |
if with_score_factors: | |
mlvl_score_factors = [] | |
else: | |
mlvl_score_factors = None | |
for level_idx, (cls_score, bbox_pred, score_factor, | |
param_pred, points, strides) in \ | |
enumerate(zip(cls_score_list, bbox_pred_list, | |
score_factor_list, param_pred_list, | |
mlvl_points, mlvl_strides)): | |
assert cls_score.size()[-2:] == bbox_pred.size()[-2:] | |
dim = self.bbox_coder.encode_size | |
bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, dim) | |
if with_score_factors: | |
score_factor = score_factor.permute(1, 2, | |
0).reshape(-1).sigmoid() | |
cls_score = cls_score.permute(1, 2, | |
0).reshape(-1, self.cls_out_channels) | |
if self.use_sigmoid_cls: | |
scores = cls_score.sigmoid() | |
else: | |
# remind that we set FG labels to [0, num_class-1] | |
# since mmdet v2.0 | |
# BG cat_id: num_class | |
scores = cls_score.softmax(-1)[:, :-1] | |
param_pred = param_pred.permute(1, 2, | |
0).reshape(-1, self.num_params) | |
# After https://github.com/open-mmlab/mmdetection/pull/6268/, | |
# this operation keeps fewer bboxes under the same `nms_pre`. | |
# There is no difference in performance for most models. If you | |
# find a slight drop in performance, you can set a larger | |
# `nms_pre` than before. | |
score_thr = cfg.get('score_thr', 0) | |
results = filter_scores_and_topk( | |
scores, score_thr, nms_pre, | |
dict( | |
bbox_pred=bbox_pred, | |
param_pred=param_pred, | |
points=points, | |
strides=strides)) | |
scores, labels, keep_idxs, filtered_results = results | |
bbox_pred = filtered_results['bbox_pred'] | |
param_pred = filtered_results['param_pred'] | |
points = filtered_results['points'] | |
strides = filtered_results['strides'] | |
if with_score_factors: | |
score_factor = score_factor[keep_idxs] | |
mlvl_bbox_preds.append(bbox_pred) | |
mlvl_param_preds.append(param_pred) | |
mlvl_valid_points.append(points) | |
mlvl_valid_strides.append(strides) | |
mlvl_scores.append(scores) | |
mlvl_labels.append(labels) | |
if with_score_factors: | |
mlvl_score_factors.append(score_factor) | |
bbox_pred = torch.cat(mlvl_bbox_preds) | |
priors = cat_boxes(mlvl_valid_points) | |
bboxes = self.bbox_coder.decode(priors, bbox_pred, max_shape=img_shape) | |
results = InstanceData() | |
results.bboxes = bboxes | |
results.scores = torch.cat(mlvl_scores) | |
results.labels = torch.cat(mlvl_labels) | |
results.param_preds = torch.cat(mlvl_param_preds) | |
results.points = torch.cat(mlvl_valid_points) | |
results.strides = torch.cat(mlvl_valid_strides) | |
if with_score_factors: | |
results.score_factors = torch.cat(mlvl_score_factors) | |
return self._bbox_post_process( | |
results=results, | |
cfg=cfg, | |
rescale=rescale, | |
with_nms=with_nms, | |
img_meta=img_meta) | |
class MaskFeatModule(BaseModule): | |
"""CondInst mask feature map branch used in \ | |
https://arxiv.org/abs/1904.02689. | |
Args: | |
in_channels (int): Number of channels in the input feature map. | |
feat_channels (int): Number of hidden channels of the mask feature | |
map branch. | |
start_level (int): The starting feature map level from RPN that | |
will be used to predict the mask feature map. | |
end_level (int): The ending feature map level from rpn that | |
will be used to predict the mask feature map. | |
out_channels (int): Number of output channels of the mask feature | |
map branch. This is the channel count of the mask | |
feature map that to be dynamically convolved with the predicted | |
kernel. | |
mask_stride (int): Downsample factor of the mask feature map output. | |
Defaults to 4. | |
num_stacked_convs (int): Number of convs in mask feature branch. | |
conv_cfg (dict): Config dict for convolution layer. Default: None. | |
norm_cfg (dict): Config dict for normalization layer. Default: None. | |
init_cfg (dict or list[dict], optional): Initialization config dict. | |
""" | |
def __init__(self, | |
in_channels: int, | |
feat_channels: int, | |
start_level: int, | |
end_level: int, | |
out_channels: int, | |
mask_stride: int = 4, | |
num_stacked_convs: int = 4, | |
conv_cfg: OptConfigType = None, | |
norm_cfg: OptConfigType = None, | |
init_cfg: MultiConfig = [ | |
dict(type='Normal', layer='Conv2d', std=0.01) | |
], | |
**kwargs) -> None: | |
super().__init__(init_cfg=init_cfg) | |
self.in_channels = in_channels | |
self.feat_channels = feat_channels | |
self.start_level = start_level | |
self.end_level = end_level | |
self.mask_stride = mask_stride | |
self.num_stacked_convs = num_stacked_convs | |
assert start_level >= 0 and end_level >= start_level | |
self.out_channels = out_channels | |
self.conv_cfg = conv_cfg | |
self.norm_cfg = norm_cfg | |
self._init_layers() | |
def _init_layers(self) -> None: | |
"""Initialize layers of the head.""" | |
self.convs_all_levels = nn.ModuleList() | |
for i in range(self.start_level, self.end_level + 1): | |
convs_per_level = nn.Sequential() | |
convs_per_level.add_module( | |
f'conv{i}', | |
ConvModule( | |
self.in_channels, | |
self.feat_channels, | |
3, | |
padding=1, | |
conv_cfg=self.conv_cfg, | |
norm_cfg=self.norm_cfg, | |
inplace=False, | |
bias=False)) | |
self.convs_all_levels.append(convs_per_level) | |
conv_branch = [] | |
for _ in range(self.num_stacked_convs): | |
conv_branch.append( | |
ConvModule( | |
self.feat_channels, | |
self.feat_channels, | |
3, | |
padding=1, | |
conv_cfg=self.conv_cfg, | |
norm_cfg=self.norm_cfg, | |
bias=False)) | |
self.conv_branch = nn.Sequential(*conv_branch) | |
self.conv_pred = nn.Conv2d( | |
self.feat_channels, self.out_channels, 1, stride=1) | |
def init_weights(self) -> None: | |
"""Initialize weights of the head.""" | |
super().init_weights() | |
kaiming_init(self.convs_all_levels, a=1, distribution='uniform') | |
kaiming_init(self.conv_branch, a=1, distribution='uniform') | |
kaiming_init(self.conv_pred, a=1, distribution='uniform') | |
def forward(self, x: Tuple[Tensor]) -> Tensor: | |
"""Forward features from the upstream network. | |
Args: | |
x (tuple[Tensor]): Features from the upstream network, each is | |
a 4D-tensor. | |
Returns: | |
Tensor: The predicted mask feature map. | |
""" | |
inputs = x[self.start_level:self.end_level + 1] | |
assert len(inputs) == (self.end_level - self.start_level + 1) | |
feature_add_all_level = self.convs_all_levels[0](inputs[0]) | |
target_h, target_w = feature_add_all_level.size()[2:] | |
for i in range(1, len(inputs)): | |
input_p = inputs[i] | |
x_p = self.convs_all_levels[i](input_p) | |
h, w = x_p.size()[2:] | |
factor_h = target_h // h | |
factor_w = target_w // w | |
assert factor_h == factor_w | |
feature_per_level = aligned_bilinear(x_p, factor_h) | |
feature_add_all_level = feature_add_all_level + \ | |
feature_per_level | |
feature_add_all_level = self.conv_branch(feature_add_all_level) | |
feature_pred = self.conv_pred(feature_add_all_level) | |
return feature_pred | |
class CondInstMaskHead(BaseMaskHead): | |
"""CondInst mask head used in https://arxiv.org/abs/1904.02689. | |
This head outputs the mask for CondInst. | |
Args: | |
mask_feature_head (dict): Config of CondInstMaskFeatHead. | |
num_layers (int): Number of dynamic conv layers. | |
feat_channels (int): Number of channels in the dynamic conv. | |
mask_out_stride (int): The stride of the mask feat. | |
size_of_interest (int): The size of the region used in rel coord. | |
max_masks_to_train (int): Maximum number of masks to train for | |
each image. | |
loss_segm (:obj:`ConfigDict` or dict, optional): Config of | |
segmentation loss. | |
train_cfg (:obj:`ConfigDict` or dict, optional): Training config | |
of head. | |
test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of | |
head. | |
""" | |
def __init__(self, | |
mask_feature_head: ConfigType, | |
num_layers: int = 3, | |
feat_channels: int = 8, | |
mask_out_stride: int = 4, | |
size_of_interest: int = 8, | |
max_masks_to_train: int = -1, | |
topk_masks_per_img: int = -1, | |
loss_mask: ConfigType = None, | |
train_cfg: OptConfigType = None, | |
test_cfg: OptConfigType = None) -> None: | |
super().__init__() | |
self.mask_feature_head = MaskFeatModule(**mask_feature_head) | |
self.mask_feat_stride = self.mask_feature_head.mask_stride | |
self.in_channels = self.mask_feature_head.out_channels | |
self.num_layers = num_layers | |
self.feat_channels = feat_channels | |
self.size_of_interest = size_of_interest | |
self.mask_out_stride = mask_out_stride | |
self.max_masks_to_train = max_masks_to_train | |
self.topk_masks_per_img = topk_masks_per_img | |
self.prior_generator = MlvlPointGenerator([self.mask_feat_stride]) | |
self.train_cfg = train_cfg | |
self.test_cfg = test_cfg | |
self.loss_mask = MODELS.build(loss_mask) | |
self._init_layers() | |
def _init_layers(self) -> None: | |
"""Initialize layers of the head.""" | |
weight_nums, bias_nums = [], [] | |
for i in range(self.num_layers): | |
if i == 0: | |
weight_nums.append((self.in_channels + 2) * self.feat_channels) | |
bias_nums.append(self.feat_channels) | |
elif i == self.num_layers - 1: | |
weight_nums.append(self.feat_channels * 1) | |
bias_nums.append(1) | |
else: | |
weight_nums.append(self.feat_channels * self.feat_channels) | |
bias_nums.append(self.feat_channels) | |
self.weight_nums = weight_nums | |
self.bias_nums = bias_nums | |
self.num_params = sum(weight_nums) + sum(bias_nums) | |
def parse_dynamic_params( | |
self, params: Tensor) -> Tuple[List[Tensor], List[Tensor]]: | |
"""parse the dynamic params for dynamic conv.""" | |
num_insts = params.size(0) | |
params_splits = list( | |
torch.split_with_sizes( | |
params, self.weight_nums + self.bias_nums, dim=1)) | |
weight_splits = params_splits[:self.num_layers] | |
bias_splits = params_splits[self.num_layers:] | |
for i in range(self.num_layers): | |
if i < self.num_layers - 1: | |
weight_splits[i] = weight_splits[i].reshape( | |
num_insts * self.in_channels, -1, 1, 1) | |
bias_splits[i] = bias_splits[i].reshape(num_insts * | |
self.in_channels) | |
else: | |
# out_channels x in_channels x 1 x 1 | |
weight_splits[i] = weight_splits[i].reshape( | |
num_insts * 1, -1, 1, 1) | |
bias_splits[i] = bias_splits[i].reshape(num_insts) | |
return weight_splits, bias_splits | |
def dynamic_conv_forward(self, features: Tensor, weights: List[Tensor], | |
biases: List[Tensor], num_insts: int) -> Tensor: | |
"""dynamic forward, each layer follow a relu.""" | |
n_layers = len(weights) | |
x = features | |
for i, (w, b) in enumerate(zip(weights, biases)): | |
x = F.conv2d(x, w, bias=b, stride=1, padding=0, groups=num_insts) | |
if i < n_layers - 1: | |
x = F.relu(x) | |
return x | |
def forward(self, x: tuple, positive_infos: InstanceList) -> tuple: | |
"""Forward feature from the upstream network to get prototypes and | |
linearly combine the prototypes, using masks coefficients, into | |
instance masks. Finally, crop the instance masks with given bboxes. | |
Args: | |
x (Tuple[Tensor]): Feature from the upstream network, which is | |
a 4D-tensor. | |
positive_infos (List[:obj:``InstanceData``]): Positive information | |
that calculate from detect head. | |
Returns: | |
tuple: Predicted instance segmentation masks | |
""" | |
mask_feats = self.mask_feature_head(x) | |
return multi_apply(self.forward_single, mask_feats, positive_infos) | |
def forward_single(self, mask_feat: Tensor, | |
positive_info: InstanceData) -> Tensor: | |
"""Forward features of a each image.""" | |
pos_param_preds = positive_info.get('param_preds') | |
pos_points = positive_info.get('points') | |
pos_strides = positive_info.get('strides') | |
num_inst = pos_param_preds.shape[0] | |
mask_feat = mask_feat[None].repeat(num_inst, 1, 1, 1) | |
_, _, H, W = mask_feat.size() | |
if num_inst == 0: | |
return (pos_param_preds.new_zeros((0, 1, H, W)), ) | |
locations = self.prior_generator.single_level_grid_priors( | |
mask_feat.size()[2:], 0, device=mask_feat.device) | |
rel_coords = relative_coordinate_maps(locations, pos_points, | |
pos_strides, | |
self.size_of_interest, | |
mask_feat.size()[2:]) | |
mask_head_inputs = torch.cat([rel_coords, mask_feat], dim=1) | |
mask_head_inputs = mask_head_inputs.reshape(1, -1, H, W) | |
weights, biases = self.parse_dynamic_params(pos_param_preds) | |
mask_preds = self.dynamic_conv_forward(mask_head_inputs, weights, | |
biases, num_inst) | |
mask_preds = mask_preds.reshape(-1, H, W) | |
mask_preds = aligned_bilinear( | |
mask_preds.unsqueeze(0), | |
int(self.mask_feat_stride / self.mask_out_stride)).squeeze(0) | |
return (mask_preds, ) | |
def loss_by_feat(self, mask_preds: List[Tensor], | |
batch_gt_instances: InstanceList, | |
batch_img_metas: List[dict], positive_infos: InstanceList, | |
**kwargs) -> dict: | |
"""Calculate the loss based on the features extracted by the mask head. | |
Args: | |
mask_preds (list[Tensor]): List of predicted masks, each has | |
shape (num_classes, H, W). | |
batch_gt_instances (list[:obj:`InstanceData`]): Batch of | |
gt_instance. It usually includes ``bboxes``, ``masks``, | |
and ``labels`` attributes. | |
batch_img_metas (list[dict]): Meta information of multiple images. | |
positive_infos (List[:obj:``InstanceData``]): Information of | |
positive samples of each image that are assigned in detection | |
head. | |
Returns: | |
dict[str, Tensor]: A dictionary of loss components. | |
""" | |
assert positive_infos is not None, \ | |
'positive_infos should not be None in `CondInstMaskHead`' | |
losses = dict() | |
loss_mask = 0. | |
num_imgs = len(mask_preds) | |
total_pos = 0 | |
for idx in range(num_imgs): | |
(mask_pred, pos_mask_targets, num_pos) = \ | |
self._get_targets_single( | |
mask_preds[idx], batch_gt_instances[idx], | |
positive_infos[idx]) | |
# mask loss | |
total_pos += num_pos | |
if num_pos == 0 or pos_mask_targets is None: | |
loss = mask_pred.new_zeros(1).mean() | |
else: | |
loss = self.loss_mask( | |
mask_pred, pos_mask_targets, | |
reduction_override='none').sum() | |
loss_mask += loss | |
if total_pos == 0: | |
total_pos += 1 # avoid nan | |
loss_mask = loss_mask / total_pos | |
losses.update(loss_mask=loss_mask) | |
return losses | |
def _get_targets_single(self, mask_preds: Tensor, | |
gt_instances: InstanceData, | |
positive_info: InstanceData): | |
"""Compute targets for predictions of single image. | |
Args: | |
mask_preds (Tensor): Predicted prototypes with shape | |
(num_classes, H, W). | |
gt_instances (:obj:`InstanceData`): Ground truth of instance | |
annotations. It should includes ``bboxes``, ``labels``, | |
and ``masks`` attributes. | |
positive_info (:obj:`InstanceData`): Information of positive | |
samples that are assigned in detection head. It usually | |
contains following keys. | |
- pos_assigned_gt_inds (Tensor): Assigner GT indexes of | |
positive proposals, has shape (num_pos, ) | |
- pos_inds (Tensor): Positive index of image, has | |
shape (num_pos, ). | |
- param_pred (Tensor): Positive param preditions | |
with shape (num_pos, num_params). | |
Returns: | |
tuple: Usually returns a tuple containing learning targets. | |
- mask_preds (Tensor): Positive predicted mask with shape | |
(num_pos, mask_h, mask_w). | |
- pos_mask_targets (Tensor): Positive mask targets with shape | |
(num_pos, mask_h, mask_w). | |
- num_pos (int): Positive numbers. | |
""" | |
gt_bboxes = gt_instances.bboxes | |
device = gt_bboxes.device | |
gt_masks = gt_instances.masks.to_tensor( | |
dtype=torch.bool, device=device).float() | |
# process with mask targets | |
pos_assigned_gt_inds = positive_info.get('pos_assigned_gt_inds') | |
scores = positive_info.get('scores') | |
centernesses = positive_info.get('centernesses') | |
num_pos = pos_assigned_gt_inds.size(0) | |
if gt_masks.size(0) == 0 or num_pos == 0: | |
return mask_preds, None, 0 | |
# Since we're producing (near) full image masks, | |
# it'd take too much vram to backprop on every single mask. | |
# Thus we select only a subset. | |
if (self.max_masks_to_train != -1) and \ | |
(num_pos > self.max_masks_to_train): | |
perm = torch.randperm(num_pos) | |
select = perm[:self.max_masks_to_train] | |
mask_preds = mask_preds[select] | |
pos_assigned_gt_inds = pos_assigned_gt_inds[select] | |
num_pos = self.max_masks_to_train | |
elif self.topk_masks_per_img != -1: | |
unique_gt_inds = pos_assigned_gt_inds.unique() | |
num_inst_per_gt = max( | |
int(self.topk_masks_per_img / len(unique_gt_inds)), 1) | |
keep_mask_preds = [] | |
keep_pos_assigned_gt_inds = [] | |
for gt_ind in unique_gt_inds: | |
per_inst_pos_inds = (pos_assigned_gt_inds == gt_ind) | |
mask_preds_per_inst = mask_preds[per_inst_pos_inds] | |
gt_inds_per_inst = pos_assigned_gt_inds[per_inst_pos_inds] | |
if sum(per_inst_pos_inds) > num_inst_per_gt: | |
per_inst_scores = scores[per_inst_pos_inds].sigmoid().max( | |
dim=1)[0] | |
per_inst_centerness = centernesses[ | |
per_inst_pos_inds].sigmoid().reshape(-1, ) | |
select = (per_inst_scores * per_inst_centerness).topk( | |
k=num_inst_per_gt, dim=0)[1] | |
mask_preds_per_inst = mask_preds_per_inst[select] | |
gt_inds_per_inst = gt_inds_per_inst[select] | |
keep_mask_preds.append(mask_preds_per_inst) | |
keep_pos_assigned_gt_inds.append(gt_inds_per_inst) | |
mask_preds = torch.cat(keep_mask_preds) | |
pos_assigned_gt_inds = torch.cat(keep_pos_assigned_gt_inds) | |
num_pos = pos_assigned_gt_inds.size(0) | |
# Follow the origin implement | |
start = int(self.mask_out_stride // 2) | |
gt_masks = gt_masks[:, start::self.mask_out_stride, | |
start::self.mask_out_stride] | |
gt_masks = gt_masks.gt(0.5).float() | |
pos_mask_targets = gt_masks[pos_assigned_gt_inds] | |
return (mask_preds, pos_mask_targets, num_pos) | |
def predict_by_feat(self, | |
mask_preds: List[Tensor], | |
results_list: InstanceList, | |
batch_img_metas: List[dict], | |
rescale: bool = True, | |
**kwargs) -> InstanceList: | |
"""Transform a batch of output features extracted from the head into | |
mask results. | |
Args: | |
mask_preds (list[Tensor]): Predicted prototypes with shape | |
(num_classes, H, W). | |
results_list (List[:obj:``InstanceData``]): BBoxHead results. | |
batch_img_metas (list[dict]): Meta information of all images. | |
rescale (bool, optional): Whether to rescale the results. | |
Defaults to False. | |
Returns: | |
list[:obj:`InstanceData`]: Processed results of multiple | |
images.Each :obj:`InstanceData` usually contains | |
following keys. | |
- scores (Tensor): Classification scores, has shape | |
(num_instance,). | |
- labels (Tensor): Has shape (num_instances,). | |
- masks (Tensor): Processed mask results, has | |
shape (num_instances, h, w). | |
""" | |
assert len(mask_preds) == len(results_list) == len(batch_img_metas) | |
for img_id in range(len(batch_img_metas)): | |
img_meta = batch_img_metas[img_id] | |
results = results_list[img_id] | |
bboxes = results.bboxes | |
mask_pred = mask_preds[img_id] | |
if bboxes.shape[0] == 0 or mask_pred.shape[0] == 0: | |
results_list[img_id] = empty_instances( | |
[img_meta], | |
bboxes.device, | |
task_type='mask', | |
instance_results=[results])[0] | |
else: | |
im_mask = self._predict_by_feat_single( | |
mask_preds=mask_pred, | |
bboxes=bboxes, | |
img_meta=img_meta, | |
rescale=rescale) | |
results.masks = im_mask | |
return results_list | |
def _predict_by_feat_single(self, | |
mask_preds: Tensor, | |
bboxes: Tensor, | |
img_meta: dict, | |
rescale: bool, | |
cfg: OptConfigType = None): | |
"""Transform a single image's features extracted from the head into | |
mask results. | |
Args: | |
mask_preds (Tensor): Predicted prototypes, has shape [H, W, N]. | |
img_meta (dict): Meta information of each image, e.g., | |
image size, scaling factor, etc. | |
rescale (bool): If rescale is False, then returned masks will | |
fit the scale of imgs[0]. | |
cfg (dict, optional): Config used in test phase. | |
Defaults to None. | |
Returns: | |
:obj:`InstanceData`: Processed results of single image. | |
it usually contains following keys. | |
- scores (Tensor): Classification scores, has shape | |
(num_instance,). | |
- labels (Tensor): Has shape (num_instances,). | |
- masks (Tensor): Processed mask results, has | |
shape (num_instances, h, w). | |
""" | |
cfg = self.test_cfg if cfg is None else cfg | |
scale_factor = bboxes.new_tensor(img_meta['scale_factor']).repeat( | |
(1, 2)) | |
img_h, img_w = img_meta['img_shape'][:2] | |
ori_h, ori_w = img_meta['ori_shape'][:2] | |
mask_preds = mask_preds.sigmoid().unsqueeze(0) | |
mask_preds = aligned_bilinear(mask_preds, self.mask_out_stride) | |
mask_preds = mask_preds[:, :, :img_h, :img_w] | |
if rescale: # in-placed rescale the bboxes | |
scale_factor = bboxes.new_tensor(img_meta['scale_factor']).repeat( | |
(1, 2)) | |
bboxes /= scale_factor | |
masks = F.interpolate( | |
mask_preds, (ori_h, ori_w), | |
mode='bilinear', | |
align_corners=False).squeeze(0) > cfg.mask_thr | |
else: | |
masks = mask_preds.squeeze(0) > cfg.mask_thr | |
return masks | |