Spaces:
Runtime error
Runtime error
File size: 2,836 Bytes
f549064 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple, Union
from torch import Tensor
from mmdet.registry import MODELS
from .convfc_bbox_head import ConvFCBBoxHead
@MODELS.register_module()
class SCNetBBoxHead(ConvFCBBoxHead):
"""BBox head for `SCNet <https://arxiv.org/abs/2012.10150>`_.
This inherits ``ConvFCBBoxHead`` with modified forward() function, allow us
to get intermediate shared feature.
"""
def _forward_shared(self, x: Tensor) -> Tensor:
"""Forward function for shared part.
Args:
x (Tensor): Input feature.
Returns:
Tensor: Shared feature.
"""
if self.num_shared_convs > 0:
for conv in self.shared_convs:
x = conv(x)
if self.num_shared_fcs > 0:
if self.with_avg_pool:
x = self.avg_pool(x)
x = x.flatten(1)
for fc in self.shared_fcs:
x = self.relu(fc(x))
return x
def _forward_cls_reg(self, x: Tensor) -> Tuple[Tensor]:
"""Forward function for classification and regression parts.
Args:
x (Tensor): Input feature.
Returns:
tuple[Tensor]:
- cls_score (Tensor): classification prediction.
- bbox_pred (Tensor): bbox prediction.
"""
x_cls = x
x_reg = x
for conv in self.cls_convs:
x_cls = conv(x_cls)
if x_cls.dim() > 2:
if self.with_avg_pool:
x_cls = self.avg_pool(x_cls)
x_cls = x_cls.flatten(1)
for fc in self.cls_fcs:
x_cls = self.relu(fc(x_cls))
for conv in self.reg_convs:
x_reg = conv(x_reg)
if x_reg.dim() > 2:
if self.with_avg_pool:
x_reg = self.avg_pool(x_reg)
x_reg = x_reg.flatten(1)
for fc in self.reg_fcs:
x_reg = self.relu(fc(x_reg))
cls_score = self.fc_cls(x_cls) if self.with_cls else None
bbox_pred = self.fc_reg(x_reg) if self.with_reg else None
return cls_score, bbox_pred
def forward(
self,
x: Tensor,
return_shared_feat: bool = False) -> Union[Tensor, Tuple[Tensor]]:
"""Forward function.
Args:
x (Tensor): input features
return_shared_feat (bool): If True, return cls-reg-shared feature.
Return:
out (tuple[Tensor]): contain ``cls_score`` and ``bbox_pred``,
if ``return_shared_feat`` is True, append ``x_shared`` to the
returned tuple.
"""
x_shared = self._forward_shared(x)
out = self._forward_cls_reg(x_shared)
if return_shared_feat:
out += (x_shared, )
return out
|