Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
# Modified from https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend # noqa | |
from typing import List, Tuple | |
import torch | |
import torch.nn.functional as F | |
from mmcv.ops import point_sample, rel_roi_point_to_rel_img_point | |
from torch import Tensor | |
from mmdet.registry import MODELS | |
from mmdet.structures.bbox import bbox2roi | |
from mmdet.utils import ConfigType, InstanceList | |
from ..task_modules.samplers import SamplingResult | |
from ..utils import empty_instances | |
from .standard_roi_head import StandardRoIHead | |
class PointRendRoIHead(StandardRoIHead): | |
"""`PointRend <https://arxiv.org/abs/1912.08193>`_.""" | |
def __init__(self, point_head: ConfigType, *args, **kwargs) -> None: | |
super().__init__(*args, **kwargs) | |
assert self.with_bbox and self.with_mask | |
self.init_point_head(point_head) | |
def init_point_head(self, point_head: ConfigType) -> None: | |
"""Initialize ``point_head``""" | |
self.point_head = MODELS.build(point_head) | |
def mask_loss(self, x: Tuple[Tensor], | |
sampling_results: List[SamplingResult], bbox_feats: Tensor, | |
batch_gt_instances: InstanceList) -> dict: | |
"""Run forward function and calculate loss for mask head and point head | |
in training.""" | |
mask_results = super().mask_loss( | |
x=x, | |
sampling_results=sampling_results, | |
bbox_feats=bbox_feats, | |
batch_gt_instances=batch_gt_instances) | |
mask_point_results = self._mask_point_loss( | |
x=x, | |
sampling_results=sampling_results, | |
mask_preds=mask_results['mask_preds'], | |
batch_gt_instances=batch_gt_instances) | |
mask_results['loss_mask'].update( | |
loss_point=mask_point_results['loss_point']) | |
return mask_results | |
def _mask_point_loss(self, x: Tuple[Tensor], | |
sampling_results: List[SamplingResult], | |
mask_preds: Tensor, | |
batch_gt_instances: InstanceList) -> dict: | |
"""Run forward function and calculate loss for point head in | |
training.""" | |
pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results]) | |
rel_roi_points = self.point_head.get_roi_rel_points_train( | |
mask_preds, pos_labels, cfg=self.train_cfg) | |
rois = bbox2roi([res.pos_bboxes for res in sampling_results]) | |
fine_grained_point_feats = self._get_fine_grained_point_feats( | |
x, rois, rel_roi_points) | |
coarse_point_feats = point_sample(mask_preds, rel_roi_points) | |
mask_point_pred = self.point_head(fine_grained_point_feats, | |
coarse_point_feats) | |
loss_and_target = self.point_head.loss_and_target( | |
point_pred=mask_point_pred, | |
rel_roi_points=rel_roi_points, | |
sampling_results=sampling_results, | |
batch_gt_instances=batch_gt_instances, | |
cfg=self.train_cfg) | |
return loss_and_target | |
def _mask_point_forward_test(self, x: Tuple[Tensor], rois: Tensor, | |
label_preds: Tensor, | |
mask_preds: Tensor) -> Tensor: | |
"""Mask refining process with point head in testing. | |
Args: | |
x (tuple[Tensor]): Feature maps of all scale level. | |
rois (Tensor): shape (num_rois, 5). | |
label_preds (Tensor): The predication class for each rois. | |
mask_preds (Tensor): The predication coarse masks of | |
shape (num_rois, num_classes, small_size, small_size). | |
Returns: | |
Tensor: The refined masks of shape (num_rois, num_classes, | |
large_size, large_size). | |
""" | |
refined_mask_pred = mask_preds.clone() | |
for subdivision_step in range(self.test_cfg.subdivision_steps): | |
refined_mask_pred = F.interpolate( | |
refined_mask_pred, | |
scale_factor=self.test_cfg.scale_factor, | |
mode='bilinear', | |
align_corners=False) | |
# If `subdivision_num_points` is larger or equal to the | |
# resolution of the next step, then we can skip this step | |
num_rois, channels, mask_height, mask_width = \ | |
refined_mask_pred.shape | |
if (self.test_cfg.subdivision_num_points >= | |
self.test_cfg.scale_factor**2 * mask_height * mask_width | |
and | |
subdivision_step < self.test_cfg.subdivision_steps - 1): | |
continue | |
point_indices, rel_roi_points = \ | |
self.point_head.get_roi_rel_points_test( | |
refined_mask_pred, label_preds, cfg=self.test_cfg) | |
fine_grained_point_feats = self._get_fine_grained_point_feats( | |
x=x, rois=rois, rel_roi_points=rel_roi_points) | |
coarse_point_feats = point_sample(mask_preds, rel_roi_points) | |
mask_point_pred = self.point_head(fine_grained_point_feats, | |
coarse_point_feats) | |
point_indices = point_indices.unsqueeze(1).expand(-1, channels, -1) | |
refined_mask_pred = refined_mask_pred.reshape( | |
num_rois, channels, mask_height * mask_width) | |
refined_mask_pred = refined_mask_pred.scatter_( | |
2, point_indices, mask_point_pred) | |
refined_mask_pred = refined_mask_pred.view(num_rois, channels, | |
mask_height, mask_width) | |
return refined_mask_pred | |
def _get_fine_grained_point_feats(self, x: Tuple[Tensor], rois: Tensor, | |
rel_roi_points: Tensor) -> Tensor: | |
"""Sample fine grained feats from each level feature map and | |
concatenate them together. | |
Args: | |
x (tuple[Tensor]): Feature maps of all scale level. | |
rois (Tensor): shape (num_rois, 5). | |
rel_roi_points (Tensor): A tensor of shape (num_rois, num_points, | |
2) that contains [0, 1] x [0, 1] normalized coordinates of the | |
most uncertain points from the [mask_height, mask_width] grid. | |
Returns: | |
Tensor: The fine grained features for each points, | |
has shape (num_rois, feats_channels, num_points). | |
""" | |
assert rois.shape[0] > 0, 'RoI is a empty tensor.' | |
num_imgs = x[0].shape[0] | |
fine_grained_feats = [] | |
for idx in range(self.mask_roi_extractor.num_inputs): | |
feats = x[idx] | |
spatial_scale = 1. / float( | |
self.mask_roi_extractor.featmap_strides[idx]) | |
point_feats = [] | |
for batch_ind in range(num_imgs): | |
# unravel batch dim | |
feat = feats[batch_ind].unsqueeze(0) | |
inds = (rois[:, 0].long() == batch_ind) | |
if inds.any(): | |
rel_img_points = rel_roi_point_to_rel_img_point( | |
rois=rois[inds], | |
rel_roi_points=rel_roi_points[inds], | |
img=feat.shape[2:], | |
spatial_scale=spatial_scale).unsqueeze(0) | |
point_feat = point_sample(feat, rel_img_points) | |
point_feat = point_feat.squeeze(0).transpose(0, 1) | |
point_feats.append(point_feat) | |
fine_grained_feats.append(torch.cat(point_feats, dim=0)) | |
return torch.cat(fine_grained_feats, dim=1) | |
def predict_mask(self, | |
x: Tuple[Tensor], | |
batch_img_metas: List[dict], | |
results_list: InstanceList, | |
rescale: bool = False) -> InstanceList: | |
"""Perform forward propagation of the mask head and predict detection | |
results on the features of the upstream network. | |
Args: | |
x (tuple[Tensor]): Feature maps of all scale level. | |
batch_img_metas (list[dict]): List of image information. | |
results_list (list[:obj:`InstanceData`]): Detection results of | |
each image. | |
rescale (bool): If True, return boxes in original image space. | |
Defaults to False. | |
Returns: | |
list[:obj:`InstanceData`]: Detection results of each image | |
after the post process. | |
Each item usually contains following keys. | |
- scores (Tensor): Classification scores, has a shape | |
(num_instance, ) | |
- labels (Tensor): Labels of bboxes, has a shape | |
(num_instances, ). | |
- bboxes (Tensor): Has a shape (num_instances, 4), | |
the last dimension 4 arrange as (x1, y1, x2, y2). | |
- masks (Tensor): Has a shape (num_instances, H, W). | |
""" | |
# don't need to consider aug_test. | |
bboxes = [res.bboxes for res in results_list] | |
mask_rois = bbox2roi(bboxes) | |
if mask_rois.shape[0] == 0: | |
results_list = empty_instances( | |
batch_img_metas, | |
mask_rois.device, | |
task_type='mask', | |
instance_results=results_list, | |
mask_thr_binary=self.test_cfg.mask_thr_binary) | |
return results_list | |
mask_results = self._mask_forward(x, mask_rois) | |
mask_preds = mask_results['mask_preds'] | |
# split batch mask prediction back to each image | |
num_mask_rois_per_img = [len(res) for res in results_list] | |
mask_preds = mask_preds.split(num_mask_rois_per_img, 0) | |
# refine mask_preds | |
mask_rois = mask_rois.split(num_mask_rois_per_img, 0) | |
mask_preds_refined = [] | |
for i in range(len(batch_img_metas)): | |
labels = results_list[i].labels | |
x_i = [xx[[i]] for xx in x] | |
mask_rois_i = mask_rois[i] | |
mask_rois_i[:, 0] = 0 | |
mask_pred_i = self._mask_point_forward_test( | |
x_i, mask_rois_i, labels, mask_preds[i]) | |
mask_preds_refined.append(mask_pred_i) | |
# TODO: Handle the case where rescale is false | |
results_list = self.mask_head.predict_by_feat( | |
mask_preds=mask_preds_refined, | |
results_list=results_list, | |
batch_img_metas=batch_img_metas, | |
rcnn_test_cfg=self.test_cfg, | |
rescale=rescale) | |
return results_list | |