Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Tuple | |
from torch import Tensor | |
from mmdet.registry import MODELS | |
from .standard_roi_head import StandardRoIHead | |
class DoubleHeadRoIHead(StandardRoIHead): | |
"""RoI head for `Double Head RCNN <https://arxiv.org/abs/1904.06493>`_. | |
Args: | |
reg_roi_scale_factor (float): The scale factor to extend the rois | |
used to extract the regression features. | |
""" | |
def __init__(self, reg_roi_scale_factor: float, **kwargs): | |
super().__init__(**kwargs) | |
self.reg_roi_scale_factor = reg_roi_scale_factor | |
def _bbox_forward(self, x: Tuple[Tensor], rois: Tensor) -> dict: | |
"""Box head forward function used in both training and testing. | |
Args: | |
x (tuple[Tensor]): List of multi-level img features. | |
rois (Tensor): RoIs with the shape (n, 5) where the first | |
column indicates batch id of each RoI. | |
Returns: | |
dict[str, Tensor]: Usually returns a dictionary with keys: | |
- `cls_score` (Tensor): Classification scores. | |
- `bbox_pred` (Tensor): Box energies / deltas. | |
- `bbox_feats` (Tensor): Extract bbox RoI features. | |
""" | |
bbox_cls_feats = self.bbox_roi_extractor( | |
x[:self.bbox_roi_extractor.num_inputs], rois) | |
bbox_reg_feats = self.bbox_roi_extractor( | |
x[:self.bbox_roi_extractor.num_inputs], | |
rois, | |
roi_scale_factor=self.reg_roi_scale_factor) | |
if self.with_shared_head: | |
bbox_cls_feats = self.shared_head(bbox_cls_feats) | |
bbox_reg_feats = self.shared_head(bbox_reg_feats) | |
cls_score, bbox_pred = self.bbox_head(bbox_cls_feats, bbox_reg_feats) | |
bbox_results = dict( | |
cls_score=cls_score, | |
bbox_pred=bbox_pred, | |
bbox_feats=bbox_cls_feats) | |
return bbox_results | |