ai-photo-gallery / mmdet /models /roi_heads /trident_roi_head.py
KyanChen's picture
init
f549064
raw
history blame
No virus
4.3 kB
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple
import torch
from mmcv.ops import batched_nms
from mmengine.structures import InstanceData
from torch import Tensor
from mmdet.registry import MODELS
from mmdet.structures import SampleList
from mmdet.utils import InstanceList
from .standard_roi_head import StandardRoIHead
@MODELS.register_module()
class TridentRoIHead(StandardRoIHead):
"""Trident roi head.
Args:
num_branch (int): Number of branches in TridentNet.
test_branch_idx (int): In inference, all 3 branches will be used
if `test_branch_idx==-1`, otherwise only branch with index
`test_branch_idx` will be used.
"""
def __init__(self, num_branch: int, test_branch_idx: int,
**kwargs) -> None:
self.num_branch = num_branch
self.test_branch_idx = test_branch_idx
super().__init__(**kwargs)
def merge_trident_bboxes(self,
trident_results: InstanceList) -> InstanceData:
"""Merge bbox predictions of each branch.
Args:
trident_results (List[:obj:`InstanceData`]): A list of InstanceData
predicted from every branch.
Returns:
:obj:`InstanceData`: merged InstanceData.
"""
bboxes = torch.cat([res.bboxes for res in trident_results])
scores = torch.cat([res.scores for res in trident_results])
labels = torch.cat([res.labels for res in trident_results])
nms_cfg = self.test_cfg['nms']
results = InstanceData()
if bboxes.numel() == 0:
results.bboxes = bboxes
results.scores = scores
results.labels = labels
else:
det_bboxes, keep = batched_nms(bboxes, scores, labels, nms_cfg)
results.bboxes = det_bboxes[:, :-1]
results.scores = det_bboxes[:, -1]
results.labels = labels[keep]
if self.test_cfg['max_per_img'] > 0:
results = results[:self.test_cfg['max_per_img']]
return results
def predict(self,
x: Tuple[Tensor],
rpn_results_list: InstanceList,
batch_data_samples: SampleList,
rescale: bool = False) -> InstanceList:
"""Perform forward propagation of the roi head and predict detection
results on the features of the upstream network.
- Compute prediction bbox and label per branch.
- Merge predictions of each branch according to scores of
bboxes, i.e., bboxes with higher score are kept to give
top-k prediction.
Args:
x (tuple[Tensor]): Features from upstream network. Each
has shape (N, C, H, W).
rpn_results_list (list[:obj:`InstanceData`]): list of region
proposals.
batch_data_samples (List[:obj:`DetDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
rescale (bool): Whether to rescale the results to
the original image. Defaults to True.
Returns:
list[obj:`InstanceData`]: Detection results of each image.
Each item usually contains following keys.
- scores (Tensor): Classification scores, has a shape
(num_instance, )
- labels (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes (Tensor): Has a shape (num_instances, 4),
the last dimension 4 arrange as (x1, y1, x2, y2).
"""
results_list = super().predict(
x=x,
rpn_results_list=rpn_results_list,
batch_data_samples=batch_data_samples,
rescale=rescale)
num_branch = self.num_branch \
if self.training or self.test_branch_idx == -1 else 1
merged_results_list = []
for i in range(len(batch_data_samples) // num_branch):
merged_results_list.append(
self.merge_trident_bboxes(results_list[i * num_branch:(i + 1) *
num_branch]))
return merged_results_list