KyanChen's picture
init
f549064
# Copyright (c) OpenMMLab. All rights reserved.
# Modified from https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend/point_head/point_head.py # noqa
from typing import List, Tuple
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmcv.ops import point_sample, rel_roi_point_to_rel_img_point
from mmengine.model import BaseModule
from mmengine.structures import InstanceData
from torch import Tensor
from mmdet.models.task_modules.samplers import SamplingResult
from mmdet.models.utils import (get_uncertain_point_coords_with_randomness,
get_uncertainty)
from mmdet.registry import MODELS
from mmdet.structures.bbox import bbox2roi
from mmdet.utils import ConfigType, InstanceList, MultiConfig, OptConfigType
@MODELS.register_module()
class MaskPointHead(BaseModule):
"""A mask point head use in PointRend.
``MaskPointHead`` use shared multi-layer perceptron (equivalent to
nn.Conv1d) to predict the logit of input points. The fine-grained feature
and coarse feature will be concatenate together for predication.
Args:
num_fcs (int): Number of fc layers in the head. Defaults to 3.
in_channels (int): Number of input channels. Defaults to 256.
fc_channels (int): Number of fc channels. Defaults to 256.
num_classes (int): Number of classes for logits. Defaults to 80.
class_agnostic (bool): Whether use class agnostic classification.
If so, the output channels of logits will be 1. Defaults to False.
coarse_pred_each_layer (bool): Whether concatenate coarse feature with
the output of each fc layer. Defaults to True.
conv_cfg (:obj:`ConfigDict` or dict): Dictionary to construct
and config conv layer. Defaults to dict(type='Conv1d')).
norm_cfg (:obj:`ConfigDict` or dict, optional): Dictionary to construct
and config norm layer. Defaults to None.
loss_point (:obj:`ConfigDict` or dict): Dictionary to construct and
config loss layer of point head. Defaults to
dict(type='CrossEntropyLoss', use_mask=True, loss_weight=1.0).
init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \
dict], optional): Initialization config dict.
"""
def __init__(
self,
num_classes: int,
num_fcs: int = 3,
in_channels: int = 256,
fc_channels: int = 256,
class_agnostic: bool = False,
coarse_pred_each_layer: bool = True,
conv_cfg: ConfigType = dict(type='Conv1d'),
norm_cfg: OptConfigType = None,
act_cfg: ConfigType = dict(type='ReLU'),
loss_point: ConfigType = dict(
type='CrossEntropyLoss', use_mask=True, loss_weight=1.0),
init_cfg: MultiConfig = dict(
type='Normal', std=0.001, override=dict(name='fc_logits'))
) -> None:
super().__init__(init_cfg=init_cfg)
self.num_fcs = num_fcs
self.in_channels = in_channels
self.fc_channels = fc_channels
self.num_classes = num_classes
self.class_agnostic = class_agnostic
self.coarse_pred_each_layer = coarse_pred_each_layer
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.loss_point = MODELS.build(loss_point)
fc_in_channels = in_channels + num_classes
self.fcs = nn.ModuleList()
for _ in range(num_fcs):
fc = ConvModule(
fc_in_channels,
fc_channels,
kernel_size=1,
stride=1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.fcs.append(fc)
fc_in_channels = fc_channels
fc_in_channels += num_classes if self.coarse_pred_each_layer else 0
out_channels = 1 if self.class_agnostic else self.num_classes
self.fc_logits = nn.Conv1d(
fc_in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, fine_grained_feats: Tensor,
coarse_feats: Tensor) -> Tensor:
"""Classify each point base on fine grained and coarse feats.
Args:
fine_grained_feats (Tensor): Fine grained feature sampled from FPN,
shape (num_rois, in_channels, num_points).
coarse_feats (Tensor): Coarse feature sampled from CoarseMaskHead,
shape (num_rois, num_classes, num_points).
Returns:
Tensor: Point classification results,
shape (num_rois, num_class, num_points).
"""
x = torch.cat([fine_grained_feats, coarse_feats], dim=1)
for fc in self.fcs:
x = fc(x)
if self.coarse_pred_each_layer:
x = torch.cat((x, coarse_feats), dim=1)
return self.fc_logits(x)
def get_targets(self, rois: Tensor, rel_roi_points: Tensor,
sampling_results: List[SamplingResult],
batch_gt_instances: InstanceList,
cfg: ConfigType) -> Tensor:
"""Get training targets of MaskPointHead for all images.
Args:
rois (Tensor): Region of Interest, shape (num_rois, 5).
rel_roi_points (Tensor): Points coordinates relative to RoI, shape
(num_rois, num_points, 2).
sampling_results (:obj:`SamplingResult`): Sampling result after
sampling and assignment.
batch_gt_instances (list[:obj:`InstanceData`]): Batch of
gt_instance. It usually includes ``bboxes``, ``labels``, and
``masks`` attributes.
cfg (obj:`ConfigDict` or dict): Training cfg.
Returns:
Tensor: Point target, shape (num_rois, num_points).
"""
num_imgs = len(sampling_results)
rois_list = []
rel_roi_points_list = []
for batch_ind in range(num_imgs):
inds = (rois[:, 0] == batch_ind)
rois_list.append(rois[inds])
rel_roi_points_list.append(rel_roi_points[inds])
pos_assigned_gt_inds_list = [
res.pos_assigned_gt_inds for res in sampling_results
]
cfg_list = [cfg for _ in range(num_imgs)]
point_targets = map(self._get_targets_single, rois_list,
rel_roi_points_list, pos_assigned_gt_inds_list,
batch_gt_instances, cfg_list)
point_targets = list(point_targets)
if len(point_targets) > 0:
point_targets = torch.cat(point_targets)
return point_targets
def _get_targets_single(self, rois: Tensor, rel_roi_points: Tensor,
pos_assigned_gt_inds: Tensor,
gt_instances: InstanceData,
cfg: ConfigType) -> Tensor:
"""Get training target of MaskPointHead for each image."""
num_pos = rois.size(0)
num_points = cfg.num_points
if num_pos > 0:
gt_masks_th = (
gt_instances.masks.to_tensor(rois.dtype,
rois.device).index_select(
0, pos_assigned_gt_inds))
gt_masks_th = gt_masks_th.unsqueeze(1)
rel_img_points = rel_roi_point_to_rel_img_point(
rois, rel_roi_points, gt_masks_th)
point_targets = point_sample(gt_masks_th,
rel_img_points).squeeze(1)
else:
point_targets = rois.new_zeros((0, num_points))
return point_targets
def loss_and_target(self, point_pred: Tensor, rel_roi_points: Tensor,
sampling_results: List[SamplingResult],
batch_gt_instances: InstanceList,
cfg: ConfigType) -> dict:
"""Calculate loss for MaskPointHead.
Args:
point_pred (Tensor): Point predication result, shape
(num_rois, num_classes, num_points).
rel_roi_points (Tensor): Points coordinates relative to RoI, shape
(num_rois, num_points, 2).
sampling_results (:obj:`SamplingResult`): Sampling result after
sampling and assignment.
batch_gt_instances (list[:obj:`InstanceData`]): Batch of
gt_instance. It usually includes ``bboxes``, ``labels``, and
``masks`` attributes.
cfg (obj:`ConfigDict` or dict): Training cfg.
Returns:
dict: a dictionary of point loss and point target.
"""
rois = bbox2roi([res.pos_bboxes for res in sampling_results])
pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results])
point_target = self.get_targets(rois, rel_roi_points, sampling_results,
batch_gt_instances, cfg)
if self.class_agnostic:
loss_point = self.loss_point(point_pred, point_target,
torch.zeros_like(pos_labels))
else:
loss_point = self.loss_point(point_pred, point_target, pos_labels)
return dict(loss_point=loss_point, point_target=point_target)
def get_roi_rel_points_train(self, mask_preds: Tensor, labels: Tensor,
cfg: ConfigType) -> Tensor:
"""Get ``num_points`` most uncertain points with random points during
train.
Sample points in [0, 1] x [0, 1] coordinate space based on their
uncertainty. The uncertainties are calculated for each point using
'_get_uncertainty()' function that takes point's logit prediction as
input.
Args:
mask_preds (Tensor): A tensor of shape (num_rois, num_classes,
mask_height, mask_width) for class-specific or class-agnostic
prediction.
labels (Tensor): The ground truth class for each instance.
cfg (:obj:`ConfigDict` or dict): Training config of point head.
Returns:
point_coords (Tensor): A tensor of shape (num_rois, num_points, 2)
that contains the coordinates sampled points.
"""
point_coords = get_uncertain_point_coords_with_randomness(
mask_preds, labels, cfg.num_points, cfg.oversample_ratio,
cfg.importance_sample_ratio)
return point_coords
def get_roi_rel_points_test(self, mask_preds: Tensor, label_preds: Tensor,
cfg: ConfigType) -> Tuple[Tensor, Tensor]:
"""Get ``num_points`` most uncertain points during test.
Args:
mask_preds (Tensor): A tensor of shape (num_rois, num_classes,
mask_height, mask_width) for class-specific or class-agnostic
prediction.
label_preds (Tensor): The predication class for each instance.
cfg (:obj:`ConfigDict` or dict): Testing config of point head.
Returns:
tuple:
- point_indices (Tensor): A tensor of shape (num_rois, num_points)
that contains indices from [0, mask_height x mask_width) of the
most uncertain points.
- point_coords (Tensor): A tensor of shape (num_rois, num_points,
2) that contains [0, 1] x [0, 1] normalized coordinates of the
most uncertain points from the [mask_height, mask_width] grid.
"""
num_points = cfg.subdivision_num_points
uncertainty_map = get_uncertainty(mask_preds, label_preds)
num_rois, _, mask_height, mask_width = uncertainty_map.shape
# During ONNX exporting, the type of each elements of 'shape' is
# `Tensor(float)`, while it is `float` during PyTorch inference.
if isinstance(mask_height, torch.Tensor):
h_step = 1.0 / mask_height.float()
w_step = 1.0 / mask_width.float()
else:
h_step = 1.0 / mask_height
w_step = 1.0 / mask_width
# cast to int to avoid dynamic K for TopK op in ONNX
mask_size = int(mask_height * mask_width)
uncertainty_map = uncertainty_map.view(num_rois, mask_size)
num_points = min(mask_size, num_points)
point_indices = uncertainty_map.topk(num_points, dim=1)[1]
xs = w_step / 2.0 + (point_indices % mask_width).float() * w_step
ys = h_step / 2.0 + (point_indices // mask_width).float() * h_step
point_coords = torch.stack([xs, ys], dim=2)
return point_indices, point_coords