KyanChen's picture
init
f549064
raw
history blame
No virus
32.3 kB
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.config import ConfigDict
from mmengine.model import BaseModule
from mmengine.structures import InstanceData
from torch import Tensor
from torch.nn.modules.utils import _pair
from mmdet.models.layers import multiclass_nms
from mmdet.models.losses import accuracy
from mmdet.models.task_modules.samplers import SamplingResult
from mmdet.models.utils import empty_instances, multi_apply
from mmdet.registry import MODELS, TASK_UTILS
from mmdet.structures.bbox import get_box_tensor, scale_boxes
from mmdet.utils import ConfigType, InstanceList, OptMultiConfig
@MODELS.register_module()
class BBoxHead(BaseModule):
"""Simplest RoI head, with only two fc layers for classification and
regression respectively."""
def __init__(self,
with_avg_pool: bool = False,
with_cls: bool = True,
with_reg: bool = True,
roi_feat_size: int = 7,
in_channels: int = 256,
num_classes: int = 80,
bbox_coder: ConfigType = dict(
type='DeltaXYWHBBoxCoder',
clip_border=True,
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2]),
predict_box_type: str = 'hbox',
reg_class_agnostic: bool = False,
reg_decoded_bbox: bool = False,
reg_predictor_cfg: ConfigType = dict(type='Linear'),
cls_predictor_cfg: ConfigType = dict(type='Linear'),
loss_cls: ConfigType = dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0),
loss_bbox: ConfigType = dict(
type='SmoothL1Loss', beta=1.0, loss_weight=1.0),
init_cfg: OptMultiConfig = None) -> None:
super().__init__(init_cfg=init_cfg)
assert with_cls or with_reg
self.with_avg_pool = with_avg_pool
self.with_cls = with_cls
self.with_reg = with_reg
self.roi_feat_size = _pair(roi_feat_size)
self.roi_feat_area = self.roi_feat_size[0] * self.roi_feat_size[1]
self.in_channels = in_channels
self.num_classes = num_classes
self.predict_box_type = predict_box_type
self.reg_class_agnostic = reg_class_agnostic
self.reg_decoded_bbox = reg_decoded_bbox
self.reg_predictor_cfg = reg_predictor_cfg
self.cls_predictor_cfg = cls_predictor_cfg
self.bbox_coder = TASK_UTILS.build(bbox_coder)
self.loss_cls = MODELS.build(loss_cls)
self.loss_bbox = MODELS.build(loss_bbox)
in_channels = self.in_channels
if self.with_avg_pool:
self.avg_pool = nn.AvgPool2d(self.roi_feat_size)
else:
in_channels *= self.roi_feat_area
if self.with_cls:
# need to add background class
if self.custom_cls_channels:
cls_channels = self.loss_cls.get_cls_channels(self.num_classes)
else:
cls_channels = num_classes + 1
cls_predictor_cfg_ = self.cls_predictor_cfg.copy()
cls_predictor_cfg_.update(
in_features=in_channels, out_features=cls_channels)
self.fc_cls = MODELS.build(cls_predictor_cfg_)
if self.with_reg:
box_dim = self.bbox_coder.encode_size
out_dim_reg = box_dim if reg_class_agnostic else \
box_dim * num_classes
reg_predictor_cfg_ = self.reg_predictor_cfg.copy()
if isinstance(reg_predictor_cfg_, (dict, ConfigDict)):
reg_predictor_cfg_.update(
in_features=in_channels, out_features=out_dim_reg)
self.fc_reg = MODELS.build(reg_predictor_cfg_)
self.debug_imgs = None
if init_cfg is None:
self.init_cfg = []
if self.with_cls:
self.init_cfg += [
dict(
type='Normal', std=0.01, override=dict(name='fc_cls'))
]
if self.with_reg:
self.init_cfg += [
dict(
type='Normal', std=0.001, override=dict(name='fc_reg'))
]
# TODO: Create a SeasawBBoxHead to simplified logic in BBoxHead
@property
def custom_cls_channels(self) -> bool:
"""get custom_cls_channels from loss_cls."""
return getattr(self.loss_cls, 'custom_cls_channels', False)
# TODO: Create a SeasawBBoxHead to simplified logic in BBoxHead
@property
def custom_activation(self) -> bool:
"""get custom_activation from loss_cls."""
return getattr(self.loss_cls, 'custom_activation', False)
# TODO: Create a SeasawBBoxHead to simplified logic in BBoxHead
@property
def custom_accuracy(self) -> bool:
"""get custom_accuracy from loss_cls."""
return getattr(self.loss_cls, 'custom_accuracy', False)
def forward(self, x: Tuple[Tensor]) -> tuple:
"""Forward features from the upstream network.
Args:
x (tuple[Tensor]): Features from the upstream network, each is
a 4D-tensor.
Returns:
tuple: A tuple of classification scores and bbox prediction.
- cls_score (Tensor): Classification scores for all
scale levels, each is a 4D-tensor, the channels number
is num_base_priors * num_classes.
- bbox_pred (Tensor): Box energies / deltas for all
scale levels, each is a 4D-tensor, the channels number
is num_base_priors * 4.
"""
if self.with_avg_pool:
if x.numel() > 0:
x = self.avg_pool(x)
x = x.view(x.size(0), -1)
else:
# avg_pool does not support empty tensor,
# so use torch.mean instead it
x = torch.mean(x, dim=(-1, -2))
cls_score = self.fc_cls(x) if self.with_cls else None
bbox_pred = self.fc_reg(x) if self.with_reg else None
return cls_score, bbox_pred
def _get_targets_single(self, pos_priors: Tensor, neg_priors: Tensor,
pos_gt_bboxes: Tensor, pos_gt_labels: Tensor,
cfg: ConfigDict) -> tuple:
"""Calculate the ground truth for proposals in the single image
according to the sampling results.
Args:
pos_priors (Tensor): Contains all the positive boxes,
has shape (num_pos, 4), the last dimension 4
represents [tl_x, tl_y, br_x, br_y].
neg_priors (Tensor): Contains all the negative boxes,
has shape (num_neg, 4), the last dimension 4
represents [tl_x, tl_y, br_x, br_y].
pos_gt_bboxes (Tensor): Contains gt_boxes for
all positive samples, has shape (num_pos, 4),
the last dimension 4
represents [tl_x, tl_y, br_x, br_y].
pos_gt_labels (Tensor): Contains gt_labels for
all positive samples, has shape (num_pos, ).
cfg (obj:`ConfigDict`): `train_cfg` of R-CNN.
Returns:
Tuple[Tensor]: Ground truth for proposals
in a single image. Containing the following Tensors:
- labels(Tensor): Gt_labels for all proposals, has
shape (num_proposals,).
- label_weights(Tensor): Labels_weights for all
proposals, has shape (num_proposals,).
- bbox_targets(Tensor):Regression target for all
proposals, has shape (num_proposals, 4), the
last dimension 4 represents [tl_x, tl_y, br_x, br_y].
- bbox_weights(Tensor):Regression weights for all
proposals, has shape (num_proposals, 4).
"""
num_pos = pos_priors.size(0)
num_neg = neg_priors.size(0)
num_samples = num_pos + num_neg
# original implementation uses new_zeros since BG are set to be 0
# now use empty & fill because BG cat_id = num_classes,
# FG cat_id = [0, num_classes-1]
labels = pos_priors.new_full((num_samples, ),
self.num_classes,
dtype=torch.long)
reg_dim = pos_gt_bboxes.size(-1) if self.reg_decoded_bbox \
else self.bbox_coder.encode_size
label_weights = pos_priors.new_zeros(num_samples)
bbox_targets = pos_priors.new_zeros(num_samples, reg_dim)
bbox_weights = pos_priors.new_zeros(num_samples, reg_dim)
if num_pos > 0:
labels[:num_pos] = pos_gt_labels
pos_weight = 1.0 if cfg.pos_weight <= 0 else cfg.pos_weight
label_weights[:num_pos] = pos_weight
if not self.reg_decoded_bbox:
pos_bbox_targets = self.bbox_coder.encode(
pos_priors, pos_gt_bboxes)
else:
# When the regression loss (e.g. `IouLoss`, `GIouLoss`)
# is applied directly on the decoded bounding boxes, both
# the predicted boxes and regression targets should be with
# absolute coordinate format.
pos_bbox_targets = get_box_tensor(pos_gt_bboxes)
bbox_targets[:num_pos, :] = pos_bbox_targets
bbox_weights[:num_pos, :] = 1
if num_neg > 0:
label_weights[-num_neg:] = 1.0
return labels, label_weights, bbox_targets, bbox_weights
def get_targets(self,
sampling_results: List[SamplingResult],
rcnn_train_cfg: ConfigDict,
concat: bool = True) -> tuple:
"""Calculate the ground truth for all samples in a batch according to
the sampling_results.
Almost the same as the implementation in bbox_head, we passed
additional parameters pos_inds_list and neg_inds_list to
`_get_targets_single` function.
Args:
sampling_results (List[obj:SamplingResult]): Assign results of
all images in a batch after sampling.
rcnn_train_cfg (obj:ConfigDict): `train_cfg` of RCNN.
concat (bool): Whether to concatenate the results of all
the images in a single batch.
Returns:
Tuple[Tensor]: Ground truth for proposals in a single image.
Containing the following list of Tensors:
- labels (list[Tensor],Tensor): Gt_labels for all
proposals in a batch, each tensor in list has
shape (num_proposals,) when `concat=False`, otherwise
just a single tensor has shape (num_all_proposals,).
- label_weights (list[Tensor]): Labels_weights for
all proposals in a batch, each tensor in list has
shape (num_proposals,) when `concat=False`, otherwise
just a single tensor has shape (num_all_proposals,).
- bbox_targets (list[Tensor],Tensor): Regression target
for all proposals in a batch, each tensor in list
has shape (num_proposals, 4) when `concat=False`,
otherwise just a single tensor has shape
(num_all_proposals, 4), the last dimension 4 represents
[tl_x, tl_y, br_x, br_y].
- bbox_weights (list[tensor],Tensor): Regression weights for
all proposals in a batch, each tensor in list has shape
(num_proposals, 4) when `concat=False`, otherwise just a
single tensor has shape (num_all_proposals, 4).
"""
pos_priors_list = [res.pos_priors for res in sampling_results]
neg_priors_list = [res.neg_priors for res in sampling_results]
pos_gt_bboxes_list = [res.pos_gt_bboxes for res in sampling_results]
pos_gt_labels_list = [res.pos_gt_labels for res in sampling_results]
labels, label_weights, bbox_targets, bbox_weights = multi_apply(
self._get_targets_single,
pos_priors_list,
neg_priors_list,
pos_gt_bboxes_list,
pos_gt_labels_list,
cfg=rcnn_train_cfg)
if concat:
labels = torch.cat(labels, 0)
label_weights = torch.cat(label_weights, 0)
bbox_targets = torch.cat(bbox_targets, 0)
bbox_weights = torch.cat(bbox_weights, 0)
return labels, label_weights, bbox_targets, bbox_weights
def loss_and_target(self,
cls_score: Tensor,
bbox_pred: Tensor,
rois: Tensor,
sampling_results: List[SamplingResult],
rcnn_train_cfg: ConfigDict,
concat: bool = True,
reduction_override: Optional[str] = None) -> dict:
"""Calculate the loss based on the features extracted by the bbox head.
Args:
cls_score (Tensor): Classification prediction
results of all class, has shape
(batch_size * num_proposals_single_image, num_classes)
bbox_pred (Tensor): Regression prediction results,
has shape
(batch_size * num_proposals_single_image, 4), the last
dimension 4 represents [tl_x, tl_y, br_x, br_y].
rois (Tensor): RoIs with the shape
(batch_size * num_proposals_single_image, 5) where the first
column indicates batch id of each RoI.
sampling_results (List[obj:SamplingResult]): Assign results of
all images in a batch after sampling.
rcnn_train_cfg (obj:ConfigDict): `train_cfg` of RCNN.
concat (bool): Whether to concatenate the results of all
the images in a single batch. Defaults to True.
reduction_override (str, optional): The reduction
method used to override the original reduction
method of the loss. Options are "none",
"mean" and "sum". Defaults to None,
Returns:
dict: A dictionary of loss and targets components.
The targets are only used for cascade rcnn.
"""
cls_reg_targets = self.get_targets(
sampling_results, rcnn_train_cfg, concat=concat)
losses = self.loss(
cls_score,
bbox_pred,
rois,
*cls_reg_targets,
reduction_override=reduction_override)
# cls_reg_targets is only for cascade rcnn
return dict(loss_bbox=losses, bbox_targets=cls_reg_targets)
def loss(self,
cls_score: Tensor,
bbox_pred: Tensor,
rois: Tensor,
labels: Tensor,
label_weights: Tensor,
bbox_targets: Tensor,
bbox_weights: Tensor,
reduction_override: Optional[str] = None) -> dict:
"""Calculate the loss based on the network predictions and targets.
Args:
cls_score (Tensor): Classification prediction
results of all class, has shape
(batch_size * num_proposals_single_image, num_classes)
bbox_pred (Tensor): Regression prediction results,
has shape
(batch_size * num_proposals_single_image, 4), the last
dimension 4 represents [tl_x, tl_y, br_x, br_y].
rois (Tensor): RoIs with the shape
(batch_size * num_proposals_single_image, 5) where the first
column indicates batch id of each RoI.
labels (Tensor): Gt_labels for all proposals in a batch, has
shape (batch_size * num_proposals_single_image, ).
label_weights (Tensor): Labels_weights for all proposals in a
batch, has shape (batch_size * num_proposals_single_image, ).
bbox_targets (Tensor): Regression target for all proposals in a
batch, has shape (batch_size * num_proposals_single_image, 4),
the last dimension 4 represents [tl_x, tl_y, br_x, br_y].
bbox_weights (Tensor): Regression weights for all proposals in a
batch, has shape (batch_size * num_proposals_single_image, 4).
reduction_override (str, optional): The reduction
method used to override the original reduction
method of the loss. Options are "none",
"mean" and "sum". Defaults to None,
Returns:
dict: A dictionary of loss.
"""
losses = dict()
if cls_score is not None:
avg_factor = max(torch.sum(label_weights > 0).float().item(), 1.)
if cls_score.numel() > 0:
loss_cls_ = self.loss_cls(
cls_score,
labels,
label_weights,
avg_factor=avg_factor,
reduction_override=reduction_override)
if isinstance(loss_cls_, dict):
losses.update(loss_cls_)
else:
losses['loss_cls'] = loss_cls_
if self.custom_activation:
acc_ = self.loss_cls.get_accuracy(cls_score, labels)
losses.update(acc_)
else:
losses['acc'] = accuracy(cls_score, labels)
if bbox_pred is not None:
bg_class_ind = self.num_classes
# 0~self.num_classes-1 are FG, self.num_classes is BG
pos_inds = (labels >= 0) & (labels < bg_class_ind)
# do not perform bounding box regression for BG anymore.
if pos_inds.any():
if self.reg_decoded_bbox:
# When the regression loss (e.g. `IouLoss`,
# `GIouLoss`, `DIouLoss`) is applied directly on
# the decoded bounding boxes, it decodes the
# already encoded coordinates to absolute format.
bbox_pred = self.bbox_coder.decode(rois[:, 1:], bbox_pred)
bbox_pred = get_box_tensor(bbox_pred)
if self.reg_class_agnostic:
pos_bbox_pred = bbox_pred.view(
bbox_pred.size(0), -1)[pos_inds.type(torch.bool)]
else:
pos_bbox_pred = bbox_pred.view(
bbox_pred.size(0), self.num_classes,
-1)[pos_inds.type(torch.bool),
labels[pos_inds.type(torch.bool)]]
losses['loss_bbox'] = self.loss_bbox(
pos_bbox_pred,
bbox_targets[pos_inds.type(torch.bool)],
bbox_weights[pos_inds.type(torch.bool)],
avg_factor=bbox_targets.size(0),
reduction_override=reduction_override)
else:
losses['loss_bbox'] = bbox_pred[pos_inds].sum()
return losses
def predict_by_feat(self,
rois: Tuple[Tensor],
cls_scores: Tuple[Tensor],
bbox_preds: Tuple[Tensor],
batch_img_metas: List[dict],
rcnn_test_cfg: Optional[ConfigDict] = None,
rescale: bool = False) -> InstanceList:
"""Transform a batch of output features extracted from the head into
bbox results.
Args:
rois (tuple[Tensor]): Tuple of boxes to be transformed.
Each has shape (num_boxes, 5). last dimension 5 arrange as
(batch_index, x1, y1, x2, y2).
cls_scores (tuple[Tensor]): Tuple of box scores, each has shape
(num_boxes, num_classes + 1).
bbox_preds (tuple[Tensor]): Tuple of box energies / deltas, each
has shape (num_boxes, num_classes * 4).
batch_img_metas (list[dict]): List of image information.
rcnn_test_cfg (obj:`ConfigDict`, optional): `test_cfg` of R-CNN.
Defaults to None.
rescale (bool): If True, return boxes in original image space.
Defaults to False.
Returns:
list[:obj:`InstanceData`]: Instance segmentation
results of each image after the post process.
Each item usually contains following keys.
- scores (Tensor): Classification scores, has a shape
(num_instance, )
- labels (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes (Tensor): Has a shape (num_instances, 4),
the last dimension 4 arrange as (x1, y1, x2, y2).
"""
assert len(cls_scores) == len(bbox_preds)
result_list = []
for img_id in range(len(batch_img_metas)):
img_meta = batch_img_metas[img_id]
results = self._predict_by_feat_single(
roi=rois[img_id],
cls_score=cls_scores[img_id],
bbox_pred=bbox_preds[img_id],
img_meta=img_meta,
rescale=rescale,
rcnn_test_cfg=rcnn_test_cfg)
result_list.append(results)
return result_list
def _predict_by_feat_single(
self,
roi: Tensor,
cls_score: Tensor,
bbox_pred: Tensor,
img_meta: dict,
rescale: bool = False,
rcnn_test_cfg: Optional[ConfigDict] = None) -> InstanceData:
"""Transform a single image's features extracted from the head into
bbox results.
Args:
roi (Tensor): Boxes to be transformed. Has shape (num_boxes, 5).
last dimension 5 arrange as (batch_index, x1, y1, x2, y2).
cls_score (Tensor): Box scores, has shape
(num_boxes, num_classes + 1).
bbox_pred (Tensor): Box energies / deltas.
has shape (num_boxes, num_classes * 4).
img_meta (dict): image information.
rescale (bool): If True, return boxes in original image space.
Defaults to False.
rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of Bbox Head.
Defaults to None
Returns:
:obj:`InstanceData`: Detection results of each image\
Each item usually contains following keys.
- scores (Tensor): Classification scores, has a shape
(num_instance, )
- labels (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes (Tensor): Has a shape (num_instances, 4),
the last dimension 4 arrange as (x1, y1, x2, y2).
"""
results = InstanceData()
if roi.shape[0] == 0:
return empty_instances([img_meta],
roi.device,
task_type='bbox',
instance_results=[results],
box_type=self.predict_box_type,
use_box_type=False,
num_classes=self.num_classes,
score_per_cls=rcnn_test_cfg is None)[0]
# some loss (Seesaw loss..) may have custom activation
if self.custom_cls_channels:
scores = self.loss_cls.get_activation(cls_score)
else:
scores = F.softmax(
cls_score, dim=-1) if cls_score is not None else None
img_shape = img_meta['img_shape']
num_rois = roi.size(0)
# bbox_pred would be None in some detector when with_reg is False,
# e.g. Grid R-CNN.
if bbox_pred is not None:
num_classes = 1 if self.reg_class_agnostic else self.num_classes
roi = roi.repeat_interleave(num_classes, dim=0)
bbox_pred = bbox_pred.view(-1, self.bbox_coder.encode_size)
bboxes = self.bbox_coder.decode(
roi[..., 1:], bbox_pred, max_shape=img_shape)
else:
bboxes = roi[:, 1:].clone()
if img_shape is not None and bboxes.size(-1) == 4:
bboxes[:, [0, 2]].clamp_(min=0, max=img_shape[1])
bboxes[:, [1, 3]].clamp_(min=0, max=img_shape[0])
if rescale and bboxes.size(0) > 0:
assert img_meta.get('scale_factor') is not None
scale_factor = [1 / s for s in img_meta['scale_factor']]
bboxes = scale_boxes(bboxes, scale_factor)
# Get the inside tensor when `bboxes` is a box type
bboxes = get_box_tensor(bboxes)
box_dim = bboxes.size(-1)
bboxes = bboxes.view(num_rois, -1)
if rcnn_test_cfg is None:
# This means that it is aug test.
# It needs to return the raw results without nms.
results.bboxes = bboxes
results.scores = scores
else:
det_bboxes, det_labels = multiclass_nms(
bboxes,
scores,
rcnn_test_cfg.score_thr,
rcnn_test_cfg.nms,
rcnn_test_cfg.max_per_img,
box_dim=box_dim)
results.bboxes = det_bboxes[:, :-1]
results.scores = det_bboxes[:, -1]
results.labels = det_labels
return results
def refine_bboxes(self, sampling_results: Union[List[SamplingResult],
InstanceList],
bbox_results: dict,
batch_img_metas: List[dict]) -> InstanceList:
"""Refine bboxes during training.
Args:
sampling_results (List[:obj:`SamplingResult`] or
List[:obj:`InstanceData`]): Sampling results.
:obj:`SamplingResult` is the real sampling results
calculate from bbox_head, while :obj:`InstanceData` is
fake sampling results, e.g., in Sparse R-CNN or QueryInst, etc.
bbox_results (dict): Usually is a dictionary with keys:
- `cls_score` (Tensor): Classification scores.
- `bbox_pred` (Tensor): Box energies / deltas.
- `rois` (Tensor): RoIs with the shape (n, 5) where the first
column indicates batch id of each RoI.
- `bbox_targets` (tuple): Ground truth for proposals in a
single image. Containing the following list of Tensors:
(labels, label_weights, bbox_targets, bbox_weights)
batch_img_metas (List[dict]): List of image information.
Returns:
list[:obj:`InstanceData`]: Refined bboxes of each image.
Example:
>>> # xdoctest: +REQUIRES(module:kwarray)
>>> import numpy as np
>>> from mmdet.models.task_modules.samplers.
... sampling_result import random_boxes
>>> from mmdet.models.task_modules.samplers import SamplingResult
>>> self = BBoxHead(reg_class_agnostic=True)
>>> n_roi = 2
>>> n_img = 4
>>> scale = 512
>>> rng = np.random.RandomState(0)
... batch_img_metas = [{'img_shape': (scale, scale)}
>>> for _ in range(n_img)]
>>> sampling_results = [SamplingResult.random(rng=10)
... for _ in range(n_img)]
>>> # Create rois in the expected format
>>> roi_boxes = random_boxes(n_roi, scale=scale, rng=rng)
>>> img_ids = torch.randint(0, n_img, (n_roi,))
>>> img_ids = img_ids.float()
>>> rois = torch.cat([img_ids[:, None], roi_boxes], dim=1)
>>> # Create other args
>>> labels = torch.randint(0, 81, (scale,)).long()
>>> bbox_preds = random_boxes(n_roi, scale=scale, rng=rng)
>>> cls_score = torch.randn((scale, 81))
... # For each image, pretend random positive boxes are gts
>>> bbox_targets = (labels, None, None, None)
... bbox_results = dict(rois=rois, bbox_pred=bbox_preds,
... cls_score=cls_score,
... bbox_targets=bbox_targets)
>>> bboxes_list = self.refine_bboxes(sampling_results,
... bbox_results,
... batch_img_metas)
>>> print(bboxes_list)
"""
pos_is_gts = [res.pos_is_gt for res in sampling_results]
# bbox_targets is a tuple
labels = bbox_results['bbox_targets'][0]
cls_scores = bbox_results['cls_score']
rois = bbox_results['rois']
bbox_preds = bbox_results['bbox_pred']
if self.custom_activation:
# TODO: Create a SeasawBBoxHead to simplified logic in BBoxHead
cls_scores = self.loss_cls.get_activation(cls_scores)
if cls_scores.numel() == 0:
return None
if cls_scores.shape[-1] == self.num_classes + 1:
# remove background class
cls_scores = cls_scores[:, :-1]
elif cls_scores.shape[-1] != self.num_classes:
raise ValueError('The last dim of `cls_scores` should equal to '
'`num_classes` or `num_classes + 1`,'
f'but got {cls_scores.shape[-1]}.')
labels = torch.where(labels == self.num_classes, cls_scores.argmax(1),
labels)
img_ids = rois[:, 0].long().unique(sorted=True)
assert img_ids.numel() <= len(batch_img_metas)
results_list = []
for i in range(len(batch_img_metas)):
inds = torch.nonzero(
rois[:, 0] == i, as_tuple=False).squeeze(dim=1)
num_rois = inds.numel()
bboxes_ = rois[inds, 1:]
label_ = labels[inds]
bbox_pred_ = bbox_preds[inds]
img_meta_ = batch_img_metas[i]
pos_is_gts_ = pos_is_gts[i]
bboxes = self.regress_by_class(bboxes_, label_, bbox_pred_,
img_meta_)
# filter gt bboxes
pos_keep = 1 - pos_is_gts_
keep_inds = pos_is_gts_.new_ones(num_rois)
keep_inds[:len(pos_is_gts_)] = pos_keep
results = InstanceData(bboxes=bboxes[keep_inds.type(torch.bool)])
results_list.append(results)
return results_list
def regress_by_class(self, priors: Tensor, label: Tensor,
bbox_pred: Tensor, img_meta: dict) -> Tensor:
"""Regress the bbox for the predicted class. Used in Cascade R-CNN.
Args:
priors (Tensor): Priors from `rpn_head` or last stage
`bbox_head`, has shape (num_proposals, 4).
label (Tensor): Only used when `self.reg_class_agnostic`
is False, has shape (num_proposals, ).
bbox_pred (Tensor): Regression prediction of
current stage `bbox_head`. When `self.reg_class_agnostic`
is False, it has shape (n, num_classes * 4), otherwise
it has shape (n, 4).
img_meta (dict): Image meta info.
Returns:
Tensor: Regressed bboxes, the same shape as input rois.
"""
reg_dim = self.bbox_coder.encode_size
if not self.reg_class_agnostic:
label = label * reg_dim
inds = torch.stack([label + i for i in range(reg_dim)], 1)
bbox_pred = torch.gather(bbox_pred, 1, inds)
assert bbox_pred.size()[1] == reg_dim
max_shape = img_meta['img_shape']
regressed_bboxes = self.bbox_coder.decode(
priors, bbox_pred, max_shape=max_shape)
return regressed_bboxes