KyanChen's picture
init
f549064
raw
history blame
9.05 kB
# Copyright (c) OpenMMLab. All rights reserved.
import sys
import warnings
from inspect import signature
import torch
from mmcv.ops import batched_nms
from mmengine.structures import InstanceData
from mmdet.structures.bbox import bbox_mapping_back
from ..test_time_augs import merge_aug_proposals
if sys.version_info >= (3, 7):
from mmdet.utils.contextmanagers import completed
class BBoxTestMixin(object):
"""Mixin class for testing det bboxes via DenseHead."""
def simple_test_bboxes(self, feats, img_metas, rescale=False):
"""Test det bboxes without test-time augmentation, can be applied in
DenseHead except for ``RPNHead`` and its variants, e.g., ``GARPNHead``,
etc.
Args:
feats (tuple[torch.Tensor]): Multi-level features from the
upstream network, each is a 4D-tensor.
img_metas (list[dict]): List of image information.
rescale (bool, optional): Whether to rescale the results.
Defaults to False.
Returns:
list[obj:`InstanceData`]: Detection results of each
image after the post process. \
Each item usually contains following keys. \
- scores (Tensor): Classification scores, has a shape
(num_instance,)
- labels (Tensor): Labels of bboxes, has a shape
(num_instances,).
- bboxes (Tensor): Has a shape (num_instances, 4),
the last dimension 4 arrange as (x1, y1, x2, y2).
"""
warnings.warn('You are calling `simple_test_bboxes` in '
'`dense_test_mixins`, but the `dense_test_mixins`'
'will be deprecated soon. Please use '
'`simple_test` instead.')
outs = self.forward(feats)
results_list = self.get_results(
*outs, img_metas=img_metas, rescale=rescale)
return results_list
def aug_test_bboxes(self, feats, img_metas, rescale=False):
"""Test det bboxes with test time augmentation, can be applied in
DenseHead except for ``RPNHead`` and its variants, e.g., ``GARPNHead``,
etc.
Args:
feats (list[Tensor]): the outer list indicates test-time
augmentations and inner Tensor should have a shape NxCxHxW,
which contains features for all images in the batch.
img_metas (list[list[dict]]): the outer list indicates test-time
augs (multiscale, flip, etc.) and the inner list indicates
images in a batch. each dict has image information.
rescale (bool, optional): Whether to rescale the results.
Defaults to False.
Returns:
list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple.
The first item is ``bboxes`` with shape (n, 5),
where 5 represent (tl_x, tl_y, br_x, br_y, score).
The shape of the second tensor in the tuple is ``labels``
with shape (n,). The length of list should always be 1.
"""
warnings.warn('You are calling `aug_test_bboxes` in '
'`dense_test_mixins`, but the `dense_test_mixins`'
'will be deprecated soon. Please use '
'`aug_test` instead.')
# check with_nms argument
gb_sig = signature(self.get_results)
gb_args = [p.name for p in gb_sig.parameters.values()]
gbs_sig = signature(self._get_results_single)
gbs_args = [p.name for p in gbs_sig.parameters.values()]
assert ('with_nms' in gb_args) and ('with_nms' in gbs_args), \
f'{self.__class__.__name__}' \
' does not support test-time augmentation'
aug_bboxes = []
aug_scores = []
aug_labels = []
for x, img_meta in zip(feats, img_metas):
# only one image in the batch
outs = self.forward(x)
bbox_outputs = self.get_results(
*outs,
img_metas=img_meta,
cfg=self.test_cfg,
rescale=False,
with_nms=False)[0]
aug_bboxes.append(bbox_outputs.bboxes)
aug_scores.append(bbox_outputs.scores)
if len(bbox_outputs) >= 3:
aug_labels.append(bbox_outputs.labels)
# after merging, bboxes will be rescaled to the original image size
merged_bboxes, merged_scores = self.merge_aug_bboxes(
aug_bboxes, aug_scores, img_metas)
merged_labels = torch.cat(aug_labels, dim=0) if aug_labels else None
if merged_bboxes.numel() == 0:
det_bboxes = torch.cat([merged_bboxes, merged_scores[:, None]], -1)
return [
(det_bboxes, merged_labels),
]
det_bboxes, keep_idxs = batched_nms(merged_bboxes, merged_scores,
merged_labels, self.test_cfg.nms)
det_bboxes = det_bboxes[:self.test_cfg.max_per_img]
det_labels = merged_labels[keep_idxs][:self.test_cfg.max_per_img]
if rescale:
_det_bboxes = det_bboxes
else:
_det_bboxes = det_bboxes.clone()
_det_bboxes[:, :4] *= det_bboxes.new_tensor(
img_metas[0][0]['scale_factor'])
results = InstanceData()
results.bboxes = _det_bboxes[:, :4]
results.scores = _det_bboxes[:, 4]
results.labels = det_labels
return [results]
def aug_test_rpn(self, feats, img_metas):
"""Test with augmentation for only for ``RPNHead`` and its variants,
e.g., ``GARPNHead``, etc.
Args:
feats (tuple[Tensor]): Features from the upstream network, each is
a 4D-tensor.
img_metas (list[dict]): Meta info of each image.
Returns:
list[Tensor]: Proposals of each image, each item has shape (n, 5),
where 5 represent (tl_x, tl_y, br_x, br_y, score).
"""
samples_per_gpu = len(img_metas[0])
aug_proposals = [[] for _ in range(samples_per_gpu)]
for x, img_meta in zip(feats, img_metas):
results_list = self.simple_test_rpn(x, img_meta)
for i, results in enumerate(results_list):
proposals = torch.cat(
[results.bboxes, results.scores[:, None]], dim=-1)
aug_proposals[i].append(proposals)
# reorganize the order of 'img_metas' to match the dimensions
# of 'aug_proposals'
aug_img_metas = []
for i in range(samples_per_gpu):
aug_img_meta = []
for j in range(len(img_metas)):
aug_img_meta.append(img_metas[j][i])
aug_img_metas.append(aug_img_meta)
# after merging, proposals will be rescaled to the original image size
merged_proposals = []
for proposals, aug_img_meta in zip(aug_proposals, aug_img_metas):
merged_proposal = merge_aug_proposals(proposals, aug_img_meta,
self.test_cfg)
results = InstanceData()
results.bboxes = merged_proposal[:, :4]
results.scores = merged_proposal[:, 4]
merged_proposals.append(results)
return merged_proposals
if sys.version_info >= (3, 7):
async def async_simple_test_rpn(self, x, img_metas):
sleep_interval = self.test_cfg.pop('async_sleep_interval', 0.025)
async with completed(
__name__, 'rpn_head_forward',
sleep_interval=sleep_interval):
rpn_outs = self(x)
proposal_list = self.get_results(*rpn_outs, img_metas=img_metas)
return proposal_list
def merge_aug_bboxes(self, aug_bboxes, aug_scores, img_metas):
"""Merge augmented detection bboxes and scores.
Args:
aug_bboxes (list[Tensor]): shape (n, 4*#class)
aug_scores (list[Tensor] or None): shape (n, #class)
img_shapes (list[Tensor]): shape (3, ).
Returns:
tuple[Tensor]: ``bboxes`` with shape (n,4), where
4 represent (tl_x, tl_y, br_x, br_y)
and ``scores`` with shape (n,).
"""
recovered_bboxes = []
for bboxes, img_info in zip(aug_bboxes, img_metas):
img_shape = img_info[0]['img_shape']
scale_factor = img_info[0]['scale_factor']
flip = img_info[0]['flip']
flip_direction = img_info[0]['flip_direction']
bboxes = bbox_mapping_back(bboxes, img_shape, scale_factor, flip,
flip_direction)
recovered_bboxes.append(bboxes)
bboxes = torch.cat(recovered_bboxes, dim=0)
if aug_scores is None:
return bboxes
else:
scores = torch.cat(aug_scores, dim=0)
return bboxes, scores