KyanChen's picture
init
f549064
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Tuple
import torch
from mmcv.ops import batched_nms
from mmengine.model import BaseTTAModel
from mmengine.registry import MODELS
from mmengine.structures import InstanceData
from torch import Tensor
from mmdet.structures import DetDataSample
from mmdet.structures.bbox import bbox_flip
@MODELS.register_module()
class DetTTAModel(BaseTTAModel):
"""Merge augmented detection results, only bboxes corresponding score under
flipping and multi-scale resizing can be processed now.
Examples:
>>> tta_model = dict(
>>> type='DetTTAModel',
>>> tta_cfg=dict(nms=dict(
>>> type='nms',
>>> iou_threshold=0.5),
>>> max_per_img=100))
>>>
>>> tta_pipeline = [
>>> dict(type='LoadImageFromFile',
>>> file_client_args=dict(backend='disk')),
>>> dict(
>>> type='TestTimeAug',
>>> transforms=[[
>>> dict(type='Resize',
>>> scale=(1333, 800),
>>> keep_ratio=True),
>>> ], [
>>> dict(type='RandomFlip', prob=1.),
>>> dict(type='RandomFlip', prob=0.)
>>> ], [
>>> dict(
>>> type='PackDetInputs',
>>> meta_keys=('img_id', 'img_path', 'ori_shape',
>>> 'img_shape', 'scale_factor', 'flip',
>>> 'flip_direction'))
>>> ]])]
"""
def __init__(self, tta_cfg=None, **kwargs):
super().__init__(**kwargs)
self.tta_cfg = tta_cfg
def merge_aug_bboxes(self, aug_bboxes: List[Tensor],
aug_scores: List[Tensor],
img_metas: List[str]) -> Tuple[Tensor, Tensor]:
"""Merge augmented detection bboxes and scores.
Args:
aug_bboxes (list[Tensor]): shape (n, 4*#class)
aug_scores (list[Tensor] or None): shape (n, #class)
Returns:
tuple[Tensor]: ``bboxes`` with shape (n,4), where
4 represent (tl_x, tl_y, br_x, br_y)
and ``scores`` with shape (n,).
"""
recovered_bboxes = []
for bboxes, img_info in zip(aug_bboxes, img_metas):
ori_shape = img_info['ori_shape']
flip = img_info['flip']
flip_direction = img_info['flip_direction']
if flip:
bboxes = bbox_flip(
bboxes=bboxes,
img_shape=ori_shape,
direction=flip_direction)
recovered_bboxes.append(bboxes)
bboxes = torch.cat(recovered_bboxes, dim=0)
if aug_scores is None:
return bboxes
else:
scores = torch.cat(aug_scores, dim=0)
return bboxes, scores
def merge_preds(self, data_samples_list: List[List[DetDataSample]]):
"""Merge batch predictions of enhanced data.
Args:
data_samples_list (List[List[DetDataSample]]): List of predictions
of all enhanced data. The outer list indicates images, and the
inner list corresponds to the different views of one image.
Each element of the inner list is a ``DetDataSample``.
Returns:
List[DetDataSample]: Merged batch prediction.
"""
merged_data_samples = []
for data_samples in data_samples_list:
merged_data_samples.append(self._merge_single_sample(data_samples))
return merged_data_samples
def _merge_single_sample(
self, data_samples: List[DetDataSample]) -> DetDataSample:
"""Merge predictions which come form the different views of one image
to one prediction.
Args:
data_samples (List[DetDataSample]): List of predictions
of enhanced data which come form one image.
Returns:
List[DetDataSample]: Merged prediction.
"""
aug_bboxes = []
aug_scores = []
aug_labels = []
img_metas = []
# TODO: support instance segmentation TTA
assert data_samples[0].pred_instances.get('masks', None) is None, \
'TTA of instance segmentation does not support now.'
for data_sample in data_samples:
aug_bboxes.append(data_sample.pred_instances.bboxes)
aug_scores.append(data_sample.pred_instances.scores)
aug_labels.append(data_sample.pred_instances.labels)
img_metas.append(data_sample.metainfo)
merged_bboxes, merged_scores = self.merge_aug_bboxes(
aug_bboxes, aug_scores, img_metas)
merged_labels = torch.cat(aug_labels, dim=0)
if merged_bboxes.numel() == 0:
return data_samples[0]
det_bboxes, keep_idxs = batched_nms(merged_bboxes, merged_scores,
merged_labels, self.tta_cfg.nms)
det_bboxes = det_bboxes[:self.tta_cfg.max_per_img]
det_labels = merged_labels[keep_idxs][:self.tta_cfg.max_per_img]
results = InstanceData()
_det_bboxes = det_bboxes.clone()
results.bboxes = _det_bboxes[:, :-1]
results.scores = _det_bboxes[:, -1]
results.labels = det_labels
det_results = data_samples[0]
det_results.pred_instances = results
return det_results