File size: 2,836 Bytes
f549064
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple, Union

from torch import Tensor

from mmdet.registry import MODELS
from .convfc_bbox_head import ConvFCBBoxHead


@MODELS.register_module()
class SCNetBBoxHead(ConvFCBBoxHead):
    """BBox head for `SCNet <https://arxiv.org/abs/2012.10150>`_.

    This inherits ``ConvFCBBoxHead`` with modified forward() function, allow us
    to get intermediate shared feature.
    """

    def _forward_shared(self, x: Tensor) -> Tensor:
        """Forward function for shared part.

        Args:
            x (Tensor): Input feature.

        Returns:
            Tensor: Shared feature.
        """
        if self.num_shared_convs > 0:
            for conv in self.shared_convs:
                x = conv(x)

        if self.num_shared_fcs > 0:
            if self.with_avg_pool:
                x = self.avg_pool(x)

            x = x.flatten(1)

            for fc in self.shared_fcs:
                x = self.relu(fc(x))

        return x

    def _forward_cls_reg(self, x: Tensor) -> Tuple[Tensor]:
        """Forward function for classification and regression parts.

        Args:
            x (Tensor): Input feature.

        Returns:
            tuple[Tensor]:

                - cls_score (Tensor): classification prediction.
                - bbox_pred (Tensor): bbox prediction.
        """
        x_cls = x
        x_reg = x

        for conv in self.cls_convs:
            x_cls = conv(x_cls)
        if x_cls.dim() > 2:
            if self.with_avg_pool:
                x_cls = self.avg_pool(x_cls)
            x_cls = x_cls.flatten(1)
        for fc in self.cls_fcs:
            x_cls = self.relu(fc(x_cls))

        for conv in self.reg_convs:
            x_reg = conv(x_reg)
        if x_reg.dim() > 2:
            if self.with_avg_pool:
                x_reg = self.avg_pool(x_reg)
            x_reg = x_reg.flatten(1)
        for fc in self.reg_fcs:
            x_reg = self.relu(fc(x_reg))

        cls_score = self.fc_cls(x_cls) if self.with_cls else None
        bbox_pred = self.fc_reg(x_reg) if self.with_reg else None

        return cls_score, bbox_pred

    def forward(
            self,
            x: Tensor,
            return_shared_feat: bool = False) -> Union[Tensor, Tuple[Tensor]]:
        """Forward function.

        Args:
            x (Tensor): input features
            return_shared_feat (bool): If True, return cls-reg-shared feature.

        Return:
            out (tuple[Tensor]): contain ``cls_score`` and ``bbox_pred``,
                if  ``return_shared_feat`` is True, append ``x_shared`` to the
                returned tuple.
        """
        x_shared = self._forward_shared(x)
        out = self._forward_cls_reg(x_shared)

        if return_shared_feat:
            out += (x_shared, )

        return out