File size: 20,849 Bytes
f549064
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmengine.config import ConfigDict
from mmengine.model import BaseModule
from mmengine.structures import InstanceData
from torch import Tensor

from mmdet.models.task_modules.samplers import SamplingResult
from mmdet.registry import MODELS
from mmdet.utils import ConfigType, InstanceList, MultiConfig, OptConfigType


@MODELS.register_module()
class GridHead(BaseModule):
    """Implementation of `Grid Head <https://arxiv.org/abs/1811.12030>`_

    Args:
        grid_points (int): The number of grid points. Defaults to 9.
        num_convs (int): The number of convolution layers. Defaults to 8.
        roi_feat_size (int): RoI feature size. Default to 14.
        in_channels (int): The channel number of inputs features.
            Defaults to 256.
        conv_kernel_size (int): The kernel size of convolution layers.
            Defaults to 3.
        point_feat_channels (int): The number of channels of each point
            features. Defaults to 64.
        class_agnostic (bool): Whether use class agnostic classification.
            If so, the output channels of logits will be 1. Defaults to False.
        loss_grid (:obj:`ConfigDict` or dict): Config of grid loss.
        conv_cfg (:obj:`ConfigDict` or dict, optional) dictionary to
            construct and config conv layer.
        norm_cfg (:obj:`ConfigDict` or dict): dictionary to construct and
            config norm layer.
        init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \
            dict]): Initialization config dict.
    """

    def __init__(
        self,
        grid_points: int = 9,
        num_convs: int = 8,
        roi_feat_size: int = 14,
        in_channels: int = 256,
        conv_kernel_size: int = 3,
        point_feat_channels: int = 64,
        deconv_kernel_size: int = 4,
        class_agnostic: bool = False,
        loss_grid: ConfigType = dict(
            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=15),
        conv_cfg: OptConfigType = None,
        norm_cfg: ConfigType = dict(type='GN', num_groups=36),
        init_cfg: MultiConfig = [
            dict(type='Kaiming', layer=['Conv2d', 'Linear']),
            dict(
                type='Normal',
                layer='ConvTranspose2d',
                std=0.001,
                override=dict(
                    type='Normal',
                    name='deconv2',
                    std=0.001,
                    bias=-np.log(0.99 / 0.01)))
        ]
    ) -> None:
        super().__init__(init_cfg=init_cfg)
        self.grid_points = grid_points
        self.num_convs = num_convs
        self.roi_feat_size = roi_feat_size
        self.in_channels = in_channels
        self.conv_kernel_size = conv_kernel_size
        self.point_feat_channels = point_feat_channels
        self.conv_out_channels = self.point_feat_channels * self.grid_points
        self.class_agnostic = class_agnostic
        self.conv_cfg = conv_cfg
        self.norm_cfg = norm_cfg
        if isinstance(norm_cfg, dict) and norm_cfg['type'] == 'GN':
            assert self.conv_out_channels % norm_cfg['num_groups'] == 0

        assert self.grid_points >= 4
        self.grid_size = int(np.sqrt(self.grid_points))
        if self.grid_size * self.grid_size != self.grid_points:
            raise ValueError('grid_points must be a square number')

        # the predicted heatmap is half of whole_map_size
        if not isinstance(self.roi_feat_size, int):
            raise ValueError('Only square RoIs are supporeted in Grid R-CNN')
        self.whole_map_size = self.roi_feat_size * 4

        # compute point-wise sub-regions
        self.sub_regions = self.calc_sub_regions()

        self.convs = []
        for i in range(self.num_convs):
            in_channels = (
                self.in_channels if i == 0 else self.conv_out_channels)
            stride = 2 if i == 0 else 1
            padding = (self.conv_kernel_size - 1) // 2
            self.convs.append(
                ConvModule(
                    in_channels,
                    self.conv_out_channels,
                    self.conv_kernel_size,
                    stride=stride,
                    padding=padding,
                    conv_cfg=self.conv_cfg,
                    norm_cfg=self.norm_cfg,
                    bias=True))
        self.convs = nn.Sequential(*self.convs)

        self.deconv1 = nn.ConvTranspose2d(
            self.conv_out_channels,
            self.conv_out_channels,
            kernel_size=deconv_kernel_size,
            stride=2,
            padding=(deconv_kernel_size - 2) // 2,
            groups=grid_points)
        self.norm1 = nn.GroupNorm(grid_points, self.conv_out_channels)
        self.deconv2 = nn.ConvTranspose2d(
            self.conv_out_channels,
            grid_points,
            kernel_size=deconv_kernel_size,
            stride=2,
            padding=(deconv_kernel_size - 2) // 2,
            groups=grid_points)

        # find the 4-neighbor of each grid point
        self.neighbor_points = []
        grid_size = self.grid_size
        for i in range(grid_size):  # i-th column
            for j in range(grid_size):  # j-th row
                neighbors = []
                if i > 0:  # left: (i - 1, j)
                    neighbors.append((i - 1) * grid_size + j)
                if j > 0:  # up: (i, j - 1)
                    neighbors.append(i * grid_size + j - 1)
                if j < grid_size - 1:  # down: (i, j + 1)
                    neighbors.append(i * grid_size + j + 1)
                if i < grid_size - 1:  # right: (i + 1, j)
                    neighbors.append((i + 1) * grid_size + j)
                self.neighbor_points.append(tuple(neighbors))
        # total edges in the grid
        self.num_edges = sum([len(p) for p in self.neighbor_points])

        self.forder_trans = nn.ModuleList()  # first-order feature transition
        self.sorder_trans = nn.ModuleList()  # second-order feature transition
        for neighbors in self.neighbor_points:
            fo_trans = nn.ModuleList()
            so_trans = nn.ModuleList()
            for _ in range(len(neighbors)):
                # each transition module consists of a 5x5 depth-wise conv and
                # 1x1 conv.
                fo_trans.append(
                    nn.Sequential(
                        nn.Conv2d(
                            self.point_feat_channels,
                            self.point_feat_channels,
                            5,
                            stride=1,
                            padding=2,
                            groups=self.point_feat_channels),
                        nn.Conv2d(self.point_feat_channels,
                                  self.point_feat_channels, 1)))
                so_trans.append(
                    nn.Sequential(
                        nn.Conv2d(
                            self.point_feat_channels,
                            self.point_feat_channels,
                            5,
                            1,
                            2,
                            groups=self.point_feat_channels),
                        nn.Conv2d(self.point_feat_channels,
                                  self.point_feat_channels, 1)))
            self.forder_trans.append(fo_trans)
            self.sorder_trans.append(so_trans)

        self.loss_grid = MODELS.build(loss_grid)

    def forward(self, x: Tensor) -> Dict[str, Tensor]:
        """forward function of ``GridHead``.

        Args:
            x (Tensor): RoI features, has shape
                (num_rois, num_channels, roi_feat_size, roi_feat_size).

        Returns:
            Dict[str, Tensor]: Return a dict including fused and unfused
            heatmap.
        """
        assert x.shape[-1] == x.shape[-2] == self.roi_feat_size
        # RoI feature transformation, downsample 2x
        x = self.convs(x)

        c = self.point_feat_channels
        # first-order fusion
        x_fo = [None for _ in range(self.grid_points)]
        for i, points in enumerate(self.neighbor_points):
            x_fo[i] = x[:, i * c:(i + 1) * c]
            for j, point_idx in enumerate(points):
                x_fo[i] = x_fo[i] + self.forder_trans[i][j](
                    x[:, point_idx * c:(point_idx + 1) * c])

        # second-order fusion
        x_so = [None for _ in range(self.grid_points)]
        for i, points in enumerate(self.neighbor_points):
            x_so[i] = x[:, i * c:(i + 1) * c]
            for j, point_idx in enumerate(points):
                x_so[i] = x_so[i] + self.sorder_trans[i][j](x_fo[point_idx])

        # predicted heatmap with fused features
        x2 = torch.cat(x_so, dim=1)
        x2 = self.deconv1(x2)
        x2 = F.relu(self.norm1(x2), inplace=True)
        heatmap = self.deconv2(x2)

        # predicted heatmap with original features (applicable during training)
        if self.training:
            x1 = x
            x1 = self.deconv1(x1)
            x1 = F.relu(self.norm1(x1), inplace=True)
            heatmap_unfused = self.deconv2(x1)
        else:
            heatmap_unfused = heatmap

        return dict(fused=heatmap, unfused=heatmap_unfused)

    def calc_sub_regions(self) -> List[Tuple[float]]:
        """Compute point specific representation regions.

        See `Grid R-CNN Plus <https://arxiv.org/abs/1906.05688>`_ for details.
        """
        # to make it consistent with the original implementation, half_size
        # is computed as 2 * quarter_size, which is smaller
        half_size = self.whole_map_size // 4 * 2
        sub_regions = []
        for i in range(self.grid_points):
            x_idx = i // self.grid_size
            y_idx = i % self.grid_size
            if x_idx == 0:
                sub_x1 = 0
            elif x_idx == self.grid_size - 1:
                sub_x1 = half_size
            else:
                ratio = x_idx / (self.grid_size - 1) - 0.25
                sub_x1 = max(int(ratio * self.whole_map_size), 0)

            if y_idx == 0:
                sub_y1 = 0
            elif y_idx == self.grid_size - 1:
                sub_y1 = half_size
            else:
                ratio = y_idx / (self.grid_size - 1) - 0.25
                sub_y1 = max(int(ratio * self.whole_map_size), 0)
            sub_regions.append(
                (sub_x1, sub_y1, sub_x1 + half_size, sub_y1 + half_size))
        return sub_regions

    def get_targets(self, sampling_results: List[SamplingResult],
                    rcnn_train_cfg: ConfigDict) -> Tensor:
        """Calculate the ground truth for all samples in a batch according to
        the sampling_results.".

        Args:
            sampling_results (List[:obj:`SamplingResult`]): Assign results of
                all images in a batch after sampling.
            rcnn_train_cfg (:obj:`ConfigDict`): `train_cfg` of RCNN.

        Returns:
            Tensor: Grid heatmap targets.
        """
        # mix all samples (across images) together.
        pos_bboxes = torch.cat([res.pos_bboxes for res in sampling_results],
                               dim=0).cpu()
        pos_gt_bboxes = torch.cat(
            [res.pos_gt_bboxes for res in sampling_results], dim=0).cpu()
        assert pos_bboxes.shape == pos_gt_bboxes.shape

        # expand pos_bboxes to 2x of original size
        x1 = pos_bboxes[:, 0] - (pos_bboxes[:, 2] - pos_bboxes[:, 0]) / 2
        y1 = pos_bboxes[:, 1] - (pos_bboxes[:, 3] - pos_bboxes[:, 1]) / 2
        x2 = pos_bboxes[:, 2] + (pos_bboxes[:, 2] - pos_bboxes[:, 0]) / 2
        y2 = pos_bboxes[:, 3] + (pos_bboxes[:, 3] - pos_bboxes[:, 1]) / 2
        pos_bboxes = torch.stack([x1, y1, x2, y2], dim=-1)
        pos_bbox_ws = (pos_bboxes[:, 2] - pos_bboxes[:, 0]).unsqueeze(-1)
        pos_bbox_hs = (pos_bboxes[:, 3] - pos_bboxes[:, 1]).unsqueeze(-1)

        num_rois = pos_bboxes.shape[0]
        map_size = self.whole_map_size
        # this is not the final target shape
        targets = torch.zeros((num_rois, self.grid_points, map_size, map_size),
                              dtype=torch.float)

        # pre-compute interpolation factors for all grid points.
        # the first item is the factor of x-dim, and the second is y-dim.
        # for a 9-point grid, factors are like (1, 0), (0.5, 0.5), (0, 1)
        factors = []
        for j in range(self.grid_points):
            x_idx = j // self.grid_size
            y_idx = j % self.grid_size
            factors.append((1 - x_idx / (self.grid_size - 1),
                            1 - y_idx / (self.grid_size - 1)))

        radius = rcnn_train_cfg.pos_radius
        radius2 = radius**2
        for i in range(num_rois):
            # ignore small bboxes
            if (pos_bbox_ws[i] <= self.grid_size
                    or pos_bbox_hs[i] <= self.grid_size):
                continue
            # for each grid point, mark a small circle as positive
            for j in range(self.grid_points):
                factor_x, factor_y = factors[j]
                gridpoint_x = factor_x * pos_gt_bboxes[i, 0] + (
                    1 - factor_x) * pos_gt_bboxes[i, 2]
                gridpoint_y = factor_y * pos_gt_bboxes[i, 1] + (
                    1 - factor_y) * pos_gt_bboxes[i, 3]

                cx = int((gridpoint_x - pos_bboxes[i, 0]) / pos_bbox_ws[i] *
                         map_size)
                cy = int((gridpoint_y - pos_bboxes[i, 1]) / pos_bbox_hs[i] *
                         map_size)

                for x in range(cx - radius, cx + radius + 1):
                    for y in range(cy - radius, cy + radius + 1):
                        if x >= 0 and x < map_size and y >= 0 and y < map_size:
                            if (x - cx)**2 + (y - cy)**2 <= radius2:
                                targets[i, j, y, x] = 1
        # reduce the target heatmap size by a half
        # proposed in Grid R-CNN Plus (https://arxiv.org/abs/1906.05688).
        sub_targets = []
        for i in range(self.grid_points):
            sub_x1, sub_y1, sub_x2, sub_y2 = self.sub_regions[i]
            sub_targets.append(targets[:, [i], sub_y1:sub_y2, sub_x1:sub_x2])
        sub_targets = torch.cat(sub_targets, dim=1)
        sub_targets = sub_targets.to(sampling_results[0].pos_bboxes.device)
        return sub_targets

    def loss(self, grid_pred: Tensor, sample_idx: Tensor,
             sampling_results: List[SamplingResult],
             rcnn_train_cfg: ConfigDict) -> dict:
        """Calculate the loss based on the features extracted by the grid head.

        Args:
            grid_pred (dict[str, Tensor]): Outputs of grid_head forward.
            sample_idx (Tensor): The sampling index of ``grid_pred``.
            sampling_results (List[obj:SamplingResult]): Assign results of
                all images in a batch after sampling.
            rcnn_train_cfg (obj:`ConfigDict`): `train_cfg` of RCNN.

        Returns:
            dict: A dictionary of loss and targets components.
        """
        grid_targets = self.get_targets(sampling_results, rcnn_train_cfg)
        grid_targets = grid_targets[sample_idx]

        loss_fused = self.loss_grid(grid_pred['fused'], grid_targets)
        loss_unfused = self.loss_grid(grid_pred['unfused'], grid_targets)
        loss_grid = loss_fused + loss_unfused
        return dict(loss_grid=loss_grid)

    def predict_by_feat(self,
                        grid_preds: Dict[str, Tensor],
                        results_list: List[InstanceData],
                        batch_img_metas: List[dict],
                        rescale: bool = False) -> InstanceList:
        """Adjust the predicted bboxes from bbox head.

        Args:
            grid_preds (dict[str, Tensor]): dictionary outputted by forward
                function.
            results_list (list[:obj:`InstanceData`]): Detection results of
                each image.
            batch_img_metas (list[dict]): List of image information.
            rescale (bool): If True, return boxes in original image space.
                Defaults to False.

        Returns:
            list[:obj:`InstanceData`]: Detection results of each image
            after the post process. Each item usually contains following keys.

            - scores (Tensor): Classification scores, has a shape \
            (num_instance, )
            - labels (Tensor): Labels of bboxes, has a shape (num_instances, ).
            - bboxes (Tensor): Has a shape (num_instances, 4), the last \
            dimension 4 arrange as (x1, y1, x2, y2).
        """
        num_roi_per_img = tuple(res.bboxes.size(0) for res in results_list)
        grid_preds = {
            k: v.split(num_roi_per_img, 0)
            for k, v in grid_preds.items()
        }

        for i, results in enumerate(results_list):
            if len(results) != 0:
                bboxes = self._predict_by_feat_single(
                    grid_pred=grid_preds['fused'][i],
                    bboxes=results.bboxes,
                    img_meta=batch_img_metas[i],
                    rescale=rescale)
                results.bboxes = bboxes
        return results_list

    def _predict_by_feat_single(self,
                                grid_pred: Tensor,
                                bboxes: Tensor,
                                img_meta: dict,
                                rescale: bool = False) -> Tensor:
        """Adjust ``bboxes`` according to ``grid_pred``.

        Args:
            grid_pred (Tensor): Grid fused heatmap.
            bboxes (Tensor): Predicted bboxes, has shape (n, 4)
            img_meta (dict): image information.
            rescale (bool): If True, return boxes in original image space.
                Defaults to False.

        Returns:
            Tensor: adjusted bboxes.
        """
        assert bboxes.size(0) == grid_pred.size(0)
        grid_pred = grid_pred.sigmoid()

        R, c, h, w = grid_pred.shape
        half_size = self.whole_map_size // 4 * 2
        assert h == w == half_size
        assert c == self.grid_points

        # find the point with max scores in the half-sized heatmap
        grid_pred = grid_pred.view(R * c, h * w)
        pred_scores, pred_position = grid_pred.max(dim=1)
        xs = pred_position % w
        ys = pred_position // w

        # get the position in the whole heatmap instead of half-sized heatmap
        for i in range(self.grid_points):
            xs[i::self.grid_points] += self.sub_regions[i][0]
            ys[i::self.grid_points] += self.sub_regions[i][1]

        # reshape to (num_rois, grid_points)
        pred_scores, xs, ys = tuple(
            map(lambda x: x.view(R, c), [pred_scores, xs, ys]))

        # get expanded pos_bboxes
        widths = (bboxes[:, 2] - bboxes[:, 0]).unsqueeze(-1)
        heights = (bboxes[:, 3] - bboxes[:, 1]).unsqueeze(-1)
        x1 = (bboxes[:, 0, None] - widths / 2)
        y1 = (bboxes[:, 1, None] - heights / 2)
        # map the grid point to the absolute coordinates
        abs_xs = (xs.float() + 0.5) / w * widths + x1
        abs_ys = (ys.float() + 0.5) / h * heights + y1

        # get the grid points indices that fall on the bbox boundaries
        x1_inds = [i for i in range(self.grid_size)]
        y1_inds = [i * self.grid_size for i in range(self.grid_size)]
        x2_inds = [
            self.grid_points - self.grid_size + i
            for i in range(self.grid_size)
        ]
        y2_inds = [(i + 1) * self.grid_size - 1 for i in range(self.grid_size)]

        # voting of all grid points on some boundary
        bboxes_x1 = (abs_xs[:, x1_inds] * pred_scores[:, x1_inds]).sum(
            dim=1, keepdim=True) / (
                pred_scores[:, x1_inds].sum(dim=1, keepdim=True))
        bboxes_y1 = (abs_ys[:, y1_inds] * pred_scores[:, y1_inds]).sum(
            dim=1, keepdim=True) / (
                pred_scores[:, y1_inds].sum(dim=1, keepdim=True))
        bboxes_x2 = (abs_xs[:, x2_inds] * pred_scores[:, x2_inds]).sum(
            dim=1, keepdim=True) / (
                pred_scores[:, x2_inds].sum(dim=1, keepdim=True))
        bboxes_y2 = (abs_ys[:, y2_inds] * pred_scores[:, y2_inds]).sum(
            dim=1, keepdim=True) / (
                pred_scores[:, y2_inds].sum(dim=1, keepdim=True))

        bboxes = torch.cat([bboxes_x1, bboxes_y1, bboxes_x2, bboxes_y2], dim=1)
        bboxes[:, [0, 2]].clamp_(min=0, max=img_meta['img_shape'][1])
        bboxes[:, [1, 3]].clamp_(min=0, max=img_meta['img_shape'][0])

        if rescale:
            assert img_meta.get('scale_factor') is not None
            bboxes /= bboxes.new_tensor(img_meta['scale_factor']).repeat(
                (1, 2))

        return bboxes