File size: 6,975 Bytes
f549064
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.model import ModuleList
from torch import Tensor

from mmdet.registry import MODELS
from mmdet.structures import SampleList
from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig
from ..layers import ConvUpsample
from ..utils import interpolate_as
from .base_semantic_head import BaseSemanticHead


@MODELS.register_module()
class PanopticFPNHead(BaseSemanticHead):
    """PanopticFPNHead used in Panoptic FPN.

    In this head, the number of output channels is ``num_stuff_classes
    + 1``, including all stuff classes and one thing class. The stuff
    classes will be reset from ``0`` to ``num_stuff_classes - 1``, the
    thing classes will be merged to ``num_stuff_classes``-th channel.

    Arg:
        num_things_classes (int): Number of thing classes. Default: 80.
        num_stuff_classes (int): Number of stuff classes. Default: 53.
        in_channels (int): Number of channels in the input feature
            map.
        inner_channels (int): Number of channels in inner features.
        start_level (int): The start level of the input features
            used in PanopticFPN.
        end_level (int): The end level of the used features, the
            ``end_level``-th layer will not be used.
        conv_cfg (Optional[Union[ConfigDict, dict]]): Dictionary to construct
            and config conv layer.
        norm_cfg (Union[ConfigDict, dict]): Dictionary to construct and config
            norm layer. Use ``GN`` by default.
        init_cfg (Optional[Union[ConfigDict, dict]]): Initialization config
            dict.
        loss_seg (Union[ConfigDict, dict]): the loss of the semantic head.
    """

    def __init__(self,
                 num_things_classes: int = 80,
                 num_stuff_classes: int = 53,
                 in_channels: int = 256,
                 inner_channels: int = 128,
                 start_level: int = 0,
                 end_level: int = 4,
                 conv_cfg: OptConfigType = None,
                 norm_cfg: ConfigType = dict(
                     type='GN', num_groups=32, requires_grad=True),
                 loss_seg: ConfigType = dict(
                     type='CrossEntropyLoss', ignore_index=-1,
                     loss_weight=1.0),
                 init_cfg: OptMultiConfig = None) -> None:
        seg_rescale_factor = 1 / 2**(start_level + 2)
        super().__init__(
            num_classes=num_stuff_classes + 1,
            seg_rescale_factor=seg_rescale_factor,
            loss_seg=loss_seg,
            init_cfg=init_cfg)
        self.num_things_classes = num_things_classes
        self.num_stuff_classes = num_stuff_classes
        # Used feature layers are [start_level, end_level)
        self.start_level = start_level
        self.end_level = end_level
        self.num_stages = end_level - start_level
        self.inner_channels = inner_channels

        self.conv_upsample_layers = ModuleList()
        for i in range(start_level, end_level):
            self.conv_upsample_layers.append(
                ConvUpsample(
                    in_channels,
                    inner_channels,
                    num_layers=i if i > 0 else 1,
                    num_upsample=i if i > 0 else 0,
                    conv_cfg=conv_cfg,
                    norm_cfg=norm_cfg,
                ))
        self.conv_logits = nn.Conv2d(inner_channels, self.num_classes, 1)

    def _set_things_to_void(self, gt_semantic_seg: Tensor) -> Tensor:
        """Merge thing classes to one class.

        In PanopticFPN, the background labels will be reset from `0` to
        `self.num_stuff_classes-1`, the foreground labels will be merged to
        `self.num_stuff_classes`-th channel.
        """
        gt_semantic_seg = gt_semantic_seg.int()
        fg_mask = gt_semantic_seg < self.num_things_classes
        bg_mask = (gt_semantic_seg >= self.num_things_classes) * (
            gt_semantic_seg < self.num_things_classes + self.num_stuff_classes)

        new_gt_seg = torch.clone(gt_semantic_seg)
        new_gt_seg = torch.where(bg_mask,
                                 gt_semantic_seg - self.num_things_classes,
                                 new_gt_seg)
        new_gt_seg = torch.where(fg_mask,
                                 fg_mask.int() * self.num_stuff_classes,
                                 new_gt_seg)
        return new_gt_seg

    def loss(self, x: Union[Tensor, Tuple[Tensor]],
             batch_data_samples: SampleList) -> Dict[str, Tensor]:
        """
        Args:
            x (Union[Tensor, Tuple[Tensor]]): Feature maps.
            batch_data_samples (list[:obj:`DetDataSample`]): The batch
                data samples. It usually includes information such
                as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.

        Returns:
            Dict[str, Tensor]: The loss of semantic head.
        """
        seg_preds = self(x)['seg_preds']
        gt_semantic_segs = [
            data_sample.gt_sem_seg.sem_seg
            for data_sample in batch_data_samples
        ]

        gt_semantic_segs = torch.stack(gt_semantic_segs)
        if self.seg_rescale_factor != 1.0:
            gt_semantic_segs = F.interpolate(
                gt_semantic_segs.float(),
                scale_factor=self.seg_rescale_factor,
                mode='nearest').squeeze(1)

        # Things classes will be merged to one class in PanopticFPN.
        gt_semantic_segs = self._set_things_to_void(gt_semantic_segs)

        if seg_preds.shape[-2:] != gt_semantic_segs.shape[-2:]:
            seg_preds = interpolate_as(seg_preds, gt_semantic_segs)
        seg_preds = seg_preds.permute((0, 2, 3, 1))

        loss_seg = self.loss_seg(
            seg_preds.reshape(-1, self.num_classes),  # => [NxHxW, C]
            gt_semantic_segs.reshape(-1).long())

        return dict(loss_seg=loss_seg)

    def init_weights(self) -> None:
        """Initialize weights."""
        super().init_weights()
        nn.init.normal_(self.conv_logits.weight.data, 0, 0.01)
        self.conv_logits.bias.data.zero_()

    def forward(self, x: Tuple[Tensor]) -> Dict[str, Tensor]:
        """Forward.

        Args:
            x (Tuple[Tensor]): Multi scale Feature maps.

        Returns:
            dict[str, Tensor]: semantic segmentation predictions and
                feature maps.
        """
        # the number of subnets must be not more than
        # the length of features.
        assert self.num_stages <= len(x)

        feats = []
        for i, layer in enumerate(self.conv_upsample_layers):
            f = layer(x[self.start_level + i])
            feats.append(f)

        seg_feats = torch.sum(torch.stack(feats, dim=0), dim=0)
        seg_preds = self.conv_logits(seg_feats)
        out = dict(seg_preds=seg_preds, seg_feats=seg_feats)
        return out