# Copyright (c) OpenMMLab. All rights reserved. from typing import Tuple import torch from mmcv.ops import batched_nms from mmengine.structures import InstanceData from torch import Tensor from mmdet.registry import MODELS from mmdet.structures import SampleList from mmdet.utils import InstanceList from .standard_roi_head import StandardRoIHead @MODELS.register_module() class TridentRoIHead(StandardRoIHead): """Trident roi head. Args: num_branch (int): Number of branches in TridentNet. test_branch_idx (int): In inference, all 3 branches will be used if `test_branch_idx==-1`, otherwise only branch with index `test_branch_idx` will be used. """ def __init__(self, num_branch: int, test_branch_idx: int, **kwargs) -> None: self.num_branch = num_branch self.test_branch_idx = test_branch_idx super().__init__(**kwargs) def merge_trident_bboxes(self, trident_results: InstanceList) -> InstanceData: """Merge bbox predictions of each branch. Args: trident_results (List[:obj:`InstanceData`]): A list of InstanceData predicted from every branch. Returns: :obj:`InstanceData`: merged InstanceData. """ bboxes = torch.cat([res.bboxes for res in trident_results]) scores = torch.cat([res.scores for res in trident_results]) labels = torch.cat([res.labels for res in trident_results]) nms_cfg = self.test_cfg['nms'] results = InstanceData() if bboxes.numel() == 0: results.bboxes = bboxes results.scores = scores results.labels = labels else: det_bboxes, keep = batched_nms(bboxes, scores, labels, nms_cfg) results.bboxes = det_bboxes[:, :-1] results.scores = det_bboxes[:, -1] results.labels = labels[keep] if self.test_cfg['max_per_img'] > 0: results = results[:self.test_cfg['max_per_img']] return results def predict(self, x: Tuple[Tensor], rpn_results_list: InstanceList, batch_data_samples: SampleList, rescale: bool = False) -> InstanceList: """Perform forward propagation of the roi head and predict detection results on the features of the upstream network. - Compute prediction bbox and label per branch. - Merge predictions of each branch according to scores of bboxes, i.e., bboxes with higher score are kept to give top-k prediction. Args: x (tuple[Tensor]): Features from upstream network. Each has shape (N, C, H, W). rpn_results_list (list[:obj:`InstanceData`]): list of region proposals. batch_data_samples (List[:obj:`DetDataSample`]): The Data Samples. It usually includes information such as `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. rescale (bool): Whether to rescale the results to the original image. Defaults to True. Returns: list[obj:`InstanceData`]: Detection results of each image. Each item usually contains following keys. - scores (Tensor): Classification scores, has a shape (num_instance, ) - labels (Tensor): Labels of bboxes, has a shape (num_instances, ). - bboxes (Tensor): Has a shape (num_instances, 4), the last dimension 4 arrange as (x1, y1, x2, y2). """ results_list = super().predict( x=x, rpn_results_list=rpn_results_list, batch_data_samples=batch_data_samples, rescale=rescale) num_branch = self.num_branch \ if self.training or self.test_branch_idx == -1 else 1 merged_results_list = [] for i in range(len(batch_data_samples) // num_branch): merged_results_list.append( self.merge_trident_bboxes(results_list[i * num_branch:(i + 1) * num_branch])) return merged_results_list