File size: 12,664 Bytes
f549064
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
# Copyright (c) OpenMMLab. All rights reserved.
# Modified from https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend/point_head/point_head.py  # noqa

from typing import List, Tuple

import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmcv.ops import point_sample, rel_roi_point_to_rel_img_point
from mmengine.model import BaseModule
from mmengine.structures import InstanceData
from torch import Tensor

from mmdet.models.task_modules.samplers import SamplingResult
from mmdet.models.utils import (get_uncertain_point_coords_with_randomness,
                                get_uncertainty)
from mmdet.registry import MODELS
from mmdet.structures.bbox import bbox2roi
from mmdet.utils import ConfigType, InstanceList, MultiConfig, OptConfigType


@MODELS.register_module()
class MaskPointHead(BaseModule):
    """A mask point head use in PointRend.

    ``MaskPointHead`` use shared multi-layer perceptron (equivalent to
    nn.Conv1d) to predict the logit of input points. The fine-grained feature
    and coarse feature will be concatenate together for predication.

    Args:
        num_fcs (int): Number of fc layers in the head. Defaults to 3.
        in_channels (int): Number of input channels. Defaults to 256.
        fc_channels (int): Number of fc channels. Defaults to 256.
        num_classes (int): Number of classes for logits. Defaults to 80.
        class_agnostic (bool): Whether use class agnostic classification.
            If so, the output channels of logits will be 1. Defaults to False.
        coarse_pred_each_layer (bool): Whether concatenate coarse feature with
            the output of each fc layer. Defaults to True.
        conv_cfg (:obj:`ConfigDict` or dict): Dictionary to construct
            and config conv layer. Defaults to dict(type='Conv1d')).
        norm_cfg (:obj:`ConfigDict` or dict, optional): Dictionary to construct
            and config norm layer. Defaults to None.
        loss_point (:obj:`ConfigDict` or dict): Dictionary to construct and
            config loss layer of point head. Defaults to
            dict(type='CrossEntropyLoss', use_mask=True, loss_weight=1.0).
        init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \
            dict], optional): Initialization config dict.
    """

    def __init__(
        self,
        num_classes: int,
        num_fcs: int = 3,
        in_channels: int = 256,
        fc_channels: int = 256,
        class_agnostic: bool = False,
        coarse_pred_each_layer: bool = True,
        conv_cfg: ConfigType = dict(type='Conv1d'),
        norm_cfg: OptConfigType = None,
        act_cfg: ConfigType = dict(type='ReLU'),
        loss_point: ConfigType = dict(
            type='CrossEntropyLoss', use_mask=True, loss_weight=1.0),
        init_cfg: MultiConfig = dict(
            type='Normal', std=0.001, override=dict(name='fc_logits'))
    ) -> None:
        super().__init__(init_cfg=init_cfg)
        self.num_fcs = num_fcs
        self.in_channels = in_channels
        self.fc_channels = fc_channels
        self.num_classes = num_classes
        self.class_agnostic = class_agnostic
        self.coarse_pred_each_layer = coarse_pred_each_layer
        self.conv_cfg = conv_cfg
        self.norm_cfg = norm_cfg
        self.loss_point = MODELS.build(loss_point)

        fc_in_channels = in_channels + num_classes
        self.fcs = nn.ModuleList()
        for _ in range(num_fcs):
            fc = ConvModule(
                fc_in_channels,
                fc_channels,
                kernel_size=1,
                stride=1,
                padding=0,
                conv_cfg=conv_cfg,
                norm_cfg=norm_cfg,
                act_cfg=act_cfg)
            self.fcs.append(fc)
            fc_in_channels = fc_channels
            fc_in_channels += num_classes if self.coarse_pred_each_layer else 0

        out_channels = 1 if self.class_agnostic else self.num_classes
        self.fc_logits = nn.Conv1d(
            fc_in_channels, out_channels, kernel_size=1, stride=1, padding=0)

    def forward(self, fine_grained_feats: Tensor,
                coarse_feats: Tensor) -> Tensor:
        """Classify each point base on fine grained and coarse feats.

        Args:
            fine_grained_feats (Tensor): Fine grained feature sampled from FPN,
                shape (num_rois, in_channels, num_points).
            coarse_feats (Tensor): Coarse feature sampled from CoarseMaskHead,
                shape (num_rois, num_classes, num_points).

        Returns:
            Tensor: Point classification results,
            shape (num_rois, num_class, num_points).
        """

        x = torch.cat([fine_grained_feats, coarse_feats], dim=1)
        for fc in self.fcs:
            x = fc(x)
            if self.coarse_pred_each_layer:
                x = torch.cat((x, coarse_feats), dim=1)
        return self.fc_logits(x)

    def get_targets(self, rois: Tensor, rel_roi_points: Tensor,
                    sampling_results: List[SamplingResult],
                    batch_gt_instances: InstanceList,
                    cfg: ConfigType) -> Tensor:
        """Get training targets of MaskPointHead for all images.

        Args:
            rois (Tensor): Region of Interest, shape (num_rois, 5).
            rel_roi_points (Tensor): Points coordinates relative to RoI, shape
                (num_rois, num_points, 2).
            sampling_results (:obj:`SamplingResult`): Sampling result after
                sampling and assignment.
            batch_gt_instances (list[:obj:`InstanceData`]): Batch of
                gt_instance. It usually includes ``bboxes``, ``labels``, and
                ``masks`` attributes.
            cfg (obj:`ConfigDict` or dict): Training cfg.

        Returns:
            Tensor: Point target, shape (num_rois, num_points).
        """

        num_imgs = len(sampling_results)
        rois_list = []
        rel_roi_points_list = []
        for batch_ind in range(num_imgs):
            inds = (rois[:, 0] == batch_ind)
            rois_list.append(rois[inds])
            rel_roi_points_list.append(rel_roi_points[inds])
        pos_assigned_gt_inds_list = [
            res.pos_assigned_gt_inds for res in sampling_results
        ]
        cfg_list = [cfg for _ in range(num_imgs)]

        point_targets = map(self._get_targets_single, rois_list,
                            rel_roi_points_list, pos_assigned_gt_inds_list,
                            batch_gt_instances, cfg_list)
        point_targets = list(point_targets)

        if len(point_targets) > 0:
            point_targets = torch.cat(point_targets)

        return point_targets

    def _get_targets_single(self, rois: Tensor, rel_roi_points: Tensor,
                            pos_assigned_gt_inds: Tensor,
                            gt_instances: InstanceData,
                            cfg: ConfigType) -> Tensor:
        """Get training target of MaskPointHead for each image."""
        num_pos = rois.size(0)
        num_points = cfg.num_points
        if num_pos > 0:
            gt_masks_th = (
                gt_instances.masks.to_tensor(rois.dtype,
                                             rois.device).index_select(
                                                 0, pos_assigned_gt_inds))
            gt_masks_th = gt_masks_th.unsqueeze(1)
            rel_img_points = rel_roi_point_to_rel_img_point(
                rois, rel_roi_points, gt_masks_th)
            point_targets = point_sample(gt_masks_th,
                                         rel_img_points).squeeze(1)
        else:
            point_targets = rois.new_zeros((0, num_points))
        return point_targets

    def loss_and_target(self, point_pred: Tensor, rel_roi_points: Tensor,
                        sampling_results: List[SamplingResult],
                        batch_gt_instances: InstanceList,
                        cfg: ConfigType) -> dict:
        """Calculate loss for MaskPointHead.

        Args:
            point_pred (Tensor): Point predication result, shape
                (num_rois, num_classes, num_points).
            rel_roi_points (Tensor): Points coordinates relative to RoI, shape
                (num_rois, num_points, 2).
             sampling_results (:obj:`SamplingResult`): Sampling result after
                sampling and assignment.
            batch_gt_instances (list[:obj:`InstanceData`]): Batch of
                gt_instance. It usually includes ``bboxes``, ``labels``, and
                ``masks`` attributes.
            cfg (obj:`ConfigDict` or dict): Training cfg.

        Returns:
            dict: a dictionary of point loss and point target.
        """
        rois = bbox2roi([res.pos_bboxes for res in sampling_results])
        pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results])

        point_target = self.get_targets(rois, rel_roi_points, sampling_results,
                                        batch_gt_instances, cfg)
        if self.class_agnostic:
            loss_point = self.loss_point(point_pred, point_target,
                                         torch.zeros_like(pos_labels))
        else:
            loss_point = self.loss_point(point_pred, point_target, pos_labels)

        return dict(loss_point=loss_point, point_target=point_target)

    def get_roi_rel_points_train(self, mask_preds: Tensor, labels: Tensor,
                                 cfg: ConfigType) -> Tensor:
        """Get ``num_points`` most uncertain points with random points during
        train.

        Sample points in [0, 1] x [0, 1] coordinate space based on their
        uncertainty. The uncertainties are calculated for each point using
        '_get_uncertainty()' function that takes point's logit prediction as
        input.

        Args:
            mask_preds (Tensor): A tensor of shape (num_rois, num_classes,
                mask_height, mask_width) for class-specific or class-agnostic
                prediction.
            labels (Tensor): The ground truth class for each instance.
            cfg (:obj:`ConfigDict` or dict): Training config of point head.

        Returns:
            point_coords (Tensor): A tensor of shape (num_rois, num_points, 2)
            that contains the coordinates sampled points.
        """
        point_coords = get_uncertain_point_coords_with_randomness(
            mask_preds, labels, cfg.num_points, cfg.oversample_ratio,
            cfg.importance_sample_ratio)
        return point_coords

    def get_roi_rel_points_test(self, mask_preds: Tensor, label_preds: Tensor,
                                cfg: ConfigType) -> Tuple[Tensor, Tensor]:
        """Get ``num_points`` most uncertain points during test.

        Args:
            mask_preds (Tensor): A tensor of shape (num_rois, num_classes,
                mask_height, mask_width) for class-specific or class-agnostic
                prediction.
            label_preds (Tensor): The predication class for each instance.
            cfg (:obj:`ConfigDict` or dict): Testing config of point head.

        Returns:
            tuple:

            - point_indices (Tensor): A tensor of shape (num_rois, num_points)
              that contains indices from [0, mask_height x mask_width) of the
              most uncertain points.
            - point_coords (Tensor): A tensor of shape (num_rois, num_points,
              2) that contains [0, 1] x [0, 1] normalized coordinates of the
              most uncertain points from the [mask_height, mask_width] grid.
        """
        num_points = cfg.subdivision_num_points
        uncertainty_map = get_uncertainty(mask_preds, label_preds)
        num_rois, _, mask_height, mask_width = uncertainty_map.shape

        # During ONNX exporting, the type of each elements of 'shape' is
        # `Tensor(float)`, while it is `float` during PyTorch inference.
        if isinstance(mask_height, torch.Tensor):
            h_step = 1.0 / mask_height.float()
            w_step = 1.0 / mask_width.float()
        else:
            h_step = 1.0 / mask_height
            w_step = 1.0 / mask_width
        # cast to int to avoid dynamic K for TopK op in ONNX
        mask_size = int(mask_height * mask_width)
        uncertainty_map = uncertainty_map.view(num_rois, mask_size)
        num_points = min(mask_size, num_points)
        point_indices = uncertainty_map.topk(num_points, dim=1)[1]
        xs = w_step / 2.0 + (point_indices % mask_width).float() * w_step
        ys = h_step / 2.0 + (point_indices // mask_width).float() * h_step
        point_coords = torch.stack([xs, ys], dim=2)
        return point_indices, point_coords