File size: 8,469 Bytes
f549064
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import warnings
from typing import List, Optional, Union

import numpy as np
import torch
from mmcv.ops import nms
from mmengine.config import ConfigDict
from torch import Tensor

from mmdet.structures.bbox import bbox_mapping_back


# TODO remove this, never be used in mmdet
def merge_aug_proposals(aug_proposals, img_metas, cfg):
    """Merge augmented proposals (multiscale, flip, etc.)

    Args:
        aug_proposals (list[Tensor]): proposals from different testing
            schemes, shape (n, 5). Note that they are not rescaled to the
            original image size.

        img_metas (list[dict]): list of image info dict where each dict has:
            'img_shape', 'scale_factor', 'flip', and may also contain
            'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
            For details on the values of these keys see
            `mmdet/datasets/pipelines/formatting.py:Collect`.

        cfg (dict): rpn test config.

    Returns:
        Tensor: shape (n, 4), proposals corresponding to original image scale.
    """

    cfg = copy.deepcopy(cfg)

    # deprecate arguments warning
    if 'nms' not in cfg or 'max_num' in cfg or 'nms_thr' in cfg:
        warnings.warn(
            'In rpn_proposal or test_cfg, '
            'nms_thr has been moved to a dict named nms as '
            'iou_threshold, max_num has been renamed as max_per_img, '
            'name of original arguments and the way to specify '
            'iou_threshold of NMS will be deprecated.')
    if 'nms' not in cfg:
        cfg.nms = ConfigDict(dict(type='nms', iou_threshold=cfg.nms_thr))
    if 'max_num' in cfg:
        if 'max_per_img' in cfg:
            assert cfg.max_num == cfg.max_per_img, f'You set max_num and ' \
                f'max_per_img at the same time, but get {cfg.max_num} ' \
                f'and {cfg.max_per_img} respectively' \
                f'Please delete max_num which will be deprecated.'
        else:
            cfg.max_per_img = cfg.max_num
    if 'nms_thr' in cfg:
        assert cfg.nms.iou_threshold == cfg.nms_thr, f'You set ' \
            f'iou_threshold in nms and ' \
            f'nms_thr at the same time, but get ' \
            f'{cfg.nms.iou_threshold} and {cfg.nms_thr}' \
            f' respectively. Please delete the nms_thr ' \
            f'which will be deprecated.'

    recovered_proposals = []
    for proposals, img_info in zip(aug_proposals, img_metas):
        img_shape = img_info['img_shape']
        scale_factor = img_info['scale_factor']
        flip = img_info['flip']
        flip_direction = img_info['flip_direction']
        _proposals = proposals.clone()
        _proposals[:, :4] = bbox_mapping_back(_proposals[:, :4], img_shape,
                                              scale_factor, flip,
                                              flip_direction)
        recovered_proposals.append(_proposals)
    aug_proposals = torch.cat(recovered_proposals, dim=0)
    merged_proposals, _ = nms(aug_proposals[:, :4].contiguous(),
                              aug_proposals[:, -1].contiguous(),
                              cfg.nms.iou_threshold)
    scores = merged_proposals[:, 4]
    _, order = scores.sort(0, descending=True)
    num = min(cfg.max_per_img, merged_proposals.shape[0])
    order = order[:num]
    merged_proposals = merged_proposals[order, :]
    return merged_proposals


# TODO remove this, never be used in mmdet
def merge_aug_bboxes(aug_bboxes, aug_scores, img_metas, rcnn_test_cfg):
    """Merge augmented detection bboxes and scores.

    Args:
        aug_bboxes (list[Tensor]): shape (n, 4*#class)
        aug_scores (list[Tensor] or None): shape (n, #class)
        img_shapes (list[Tensor]): shape (3, ).
        rcnn_test_cfg (dict): rcnn test config.

    Returns:
        tuple: (bboxes, scores)
    """
    recovered_bboxes = []
    for bboxes, img_info in zip(aug_bboxes, img_metas):
        img_shape = img_info[0]['img_shape']
        scale_factor = img_info[0]['scale_factor']
        flip = img_info[0]['flip']
        flip_direction = img_info[0]['flip_direction']
        bboxes = bbox_mapping_back(bboxes, img_shape, scale_factor, flip,
                                   flip_direction)
        recovered_bboxes.append(bboxes)
    bboxes = torch.stack(recovered_bboxes).mean(dim=0)
    if aug_scores is None:
        return bboxes
    else:
        scores = torch.stack(aug_scores).mean(dim=0)
        return bboxes, scores


def merge_aug_results(aug_batch_results, aug_batch_img_metas):
    """Merge augmented detection results, only bboxes corresponding score under
    flipping and multi-scale resizing can be processed now.

    Args:
        aug_batch_results (list[list[[obj:`InstanceData`]]):
            Detection results of multiple images with
            different augmentations.
            The outer list indicate the augmentation . The inter
            list indicate the batch dimension.
            Each item usually contains the following keys.

            - scores (Tensor): Classification scores, in shape
              (num_instance,)
            - labels (Tensor): Labels of bboxes, in shape
              (num_instances,).
            - bboxes (Tensor): In shape (num_instances, 4),
              the last dimension 4 arrange as (x1, y1, x2, y2).
        aug_batch_img_metas (list[list[dict]]): The outer list
            indicates test-time augs (multiscale, flip, etc.)
            and the inner list indicates
            images in a batch. Each dict in the list contains
            information of an image in the batch.

    Returns:
        batch_results (list[obj:`InstanceData`]): Same with
        the input `aug_results` except that all bboxes have
        been mapped to the original scale.
    """
    num_augs = len(aug_batch_results)
    num_imgs = len(aug_batch_results[0])

    batch_results = []
    aug_batch_results = copy.deepcopy(aug_batch_results)
    for img_id in range(num_imgs):
        aug_results = []
        for aug_id in range(num_augs):
            img_metas = aug_batch_img_metas[aug_id][img_id]
            results = aug_batch_results[aug_id][img_id]

            img_shape = img_metas['img_shape']
            scale_factor = img_metas['scale_factor']
            flip = img_metas['flip']
            flip_direction = img_metas['flip_direction']
            bboxes = bbox_mapping_back(results.bboxes, img_shape, scale_factor,
                                       flip, flip_direction)
            results.bboxes = bboxes
            aug_results.append(results)
        merged_aug_results = results.cat(aug_results)
        batch_results.append(merged_aug_results)

    return batch_results


def merge_aug_scores(aug_scores):
    """Merge augmented bbox scores."""
    if isinstance(aug_scores[0], torch.Tensor):
        return torch.mean(torch.stack(aug_scores), dim=0)
    else:
        return np.mean(aug_scores, axis=0)


def merge_aug_masks(aug_masks: List[Tensor],
                    img_metas: dict,
                    weights: Optional[Union[list, Tensor]] = None) -> Tensor:
    """Merge augmented mask prediction.

    Args:
        aug_masks (list[Tensor]): each has shape
            (n, c, h, w).
        img_metas (dict): Image information.
        weights (list or Tensor): Weight of each aug_masks,
            the length should be n.

    Returns:
        Tensor: has shape (n, c, h, w)
    """
    recovered_masks = []
    for i, mask in enumerate(aug_masks):
        if weights is not None:
            assert len(weights) == len(aug_masks)
            weight = weights[i]
        else:
            weight = 1
        flip = img_metas.get('filp', False)
        if flip:
            flip_direction = img_metas['flip_direction']
            if flip_direction == 'horizontal':
                mask = mask[:, :, :, ::-1]
            elif flip_direction == 'vertical':
                mask = mask[:, :, ::-1, :]
            elif flip_direction == 'diagonal':
                mask = mask[:, :, :, ::-1]
                mask = mask[:, :, ::-1, :]
            else:
                raise ValueError(
                    f"Invalid flipping direction '{flip_direction}'")
        recovered_masks.append(mask[None, :] * weight)

    merged_masks = torch.cat(recovered_masks, 0).mean(dim=0)
    if weights is not None:
        merged_masks = merged_masks * len(weights) / sum(weights)
    return merged_masks