Spaces:
Runtime error
Runtime error
File size: 8,469 Bytes
f549064 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 |
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import warnings
from typing import List, Optional, Union
import numpy as np
import torch
from mmcv.ops import nms
from mmengine.config import ConfigDict
from torch import Tensor
from mmdet.structures.bbox import bbox_mapping_back
# TODO remove this, never be used in mmdet
def merge_aug_proposals(aug_proposals, img_metas, cfg):
"""Merge augmented proposals (multiscale, flip, etc.)
Args:
aug_proposals (list[Tensor]): proposals from different testing
schemes, shape (n, 5). Note that they are not rescaled to the
original image size.
img_metas (list[dict]): list of image info dict where each dict has:
'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`mmdet/datasets/pipelines/formatting.py:Collect`.
cfg (dict): rpn test config.
Returns:
Tensor: shape (n, 4), proposals corresponding to original image scale.
"""
cfg = copy.deepcopy(cfg)
# deprecate arguments warning
if 'nms' not in cfg or 'max_num' in cfg or 'nms_thr' in cfg:
warnings.warn(
'In rpn_proposal or test_cfg, '
'nms_thr has been moved to a dict named nms as '
'iou_threshold, max_num has been renamed as max_per_img, '
'name of original arguments and the way to specify '
'iou_threshold of NMS will be deprecated.')
if 'nms' not in cfg:
cfg.nms = ConfigDict(dict(type='nms', iou_threshold=cfg.nms_thr))
if 'max_num' in cfg:
if 'max_per_img' in cfg:
assert cfg.max_num == cfg.max_per_img, f'You set max_num and ' \
f'max_per_img at the same time, but get {cfg.max_num} ' \
f'and {cfg.max_per_img} respectively' \
f'Please delete max_num which will be deprecated.'
else:
cfg.max_per_img = cfg.max_num
if 'nms_thr' in cfg:
assert cfg.nms.iou_threshold == cfg.nms_thr, f'You set ' \
f'iou_threshold in nms and ' \
f'nms_thr at the same time, but get ' \
f'{cfg.nms.iou_threshold} and {cfg.nms_thr}' \
f' respectively. Please delete the nms_thr ' \
f'which will be deprecated.'
recovered_proposals = []
for proposals, img_info in zip(aug_proposals, img_metas):
img_shape = img_info['img_shape']
scale_factor = img_info['scale_factor']
flip = img_info['flip']
flip_direction = img_info['flip_direction']
_proposals = proposals.clone()
_proposals[:, :4] = bbox_mapping_back(_proposals[:, :4], img_shape,
scale_factor, flip,
flip_direction)
recovered_proposals.append(_proposals)
aug_proposals = torch.cat(recovered_proposals, dim=0)
merged_proposals, _ = nms(aug_proposals[:, :4].contiguous(),
aug_proposals[:, -1].contiguous(),
cfg.nms.iou_threshold)
scores = merged_proposals[:, 4]
_, order = scores.sort(0, descending=True)
num = min(cfg.max_per_img, merged_proposals.shape[0])
order = order[:num]
merged_proposals = merged_proposals[order, :]
return merged_proposals
# TODO remove this, never be used in mmdet
def merge_aug_bboxes(aug_bboxes, aug_scores, img_metas, rcnn_test_cfg):
"""Merge augmented detection bboxes and scores.
Args:
aug_bboxes (list[Tensor]): shape (n, 4*#class)
aug_scores (list[Tensor] or None): shape (n, #class)
img_shapes (list[Tensor]): shape (3, ).
rcnn_test_cfg (dict): rcnn test config.
Returns:
tuple: (bboxes, scores)
"""
recovered_bboxes = []
for bboxes, img_info in zip(aug_bboxes, img_metas):
img_shape = img_info[0]['img_shape']
scale_factor = img_info[0]['scale_factor']
flip = img_info[0]['flip']
flip_direction = img_info[0]['flip_direction']
bboxes = bbox_mapping_back(bboxes, img_shape, scale_factor, flip,
flip_direction)
recovered_bboxes.append(bboxes)
bboxes = torch.stack(recovered_bboxes).mean(dim=0)
if aug_scores is None:
return bboxes
else:
scores = torch.stack(aug_scores).mean(dim=0)
return bboxes, scores
def merge_aug_results(aug_batch_results, aug_batch_img_metas):
"""Merge augmented detection results, only bboxes corresponding score under
flipping and multi-scale resizing can be processed now.
Args:
aug_batch_results (list[list[[obj:`InstanceData`]]):
Detection results of multiple images with
different augmentations.
The outer list indicate the augmentation . The inter
list indicate the batch dimension.
Each item usually contains the following keys.
- scores (Tensor): Classification scores, in shape
(num_instance,)
- labels (Tensor): Labels of bboxes, in shape
(num_instances,).
- bboxes (Tensor): In shape (num_instances, 4),
the last dimension 4 arrange as (x1, y1, x2, y2).
aug_batch_img_metas (list[list[dict]]): The outer list
indicates test-time augs (multiscale, flip, etc.)
and the inner list indicates
images in a batch. Each dict in the list contains
information of an image in the batch.
Returns:
batch_results (list[obj:`InstanceData`]): Same with
the input `aug_results` except that all bboxes have
been mapped to the original scale.
"""
num_augs = len(aug_batch_results)
num_imgs = len(aug_batch_results[0])
batch_results = []
aug_batch_results = copy.deepcopy(aug_batch_results)
for img_id in range(num_imgs):
aug_results = []
for aug_id in range(num_augs):
img_metas = aug_batch_img_metas[aug_id][img_id]
results = aug_batch_results[aug_id][img_id]
img_shape = img_metas['img_shape']
scale_factor = img_metas['scale_factor']
flip = img_metas['flip']
flip_direction = img_metas['flip_direction']
bboxes = bbox_mapping_back(results.bboxes, img_shape, scale_factor,
flip, flip_direction)
results.bboxes = bboxes
aug_results.append(results)
merged_aug_results = results.cat(aug_results)
batch_results.append(merged_aug_results)
return batch_results
def merge_aug_scores(aug_scores):
"""Merge augmented bbox scores."""
if isinstance(aug_scores[0], torch.Tensor):
return torch.mean(torch.stack(aug_scores), dim=0)
else:
return np.mean(aug_scores, axis=0)
def merge_aug_masks(aug_masks: List[Tensor],
img_metas: dict,
weights: Optional[Union[list, Tensor]] = None) -> Tensor:
"""Merge augmented mask prediction.
Args:
aug_masks (list[Tensor]): each has shape
(n, c, h, w).
img_metas (dict): Image information.
weights (list or Tensor): Weight of each aug_masks,
the length should be n.
Returns:
Tensor: has shape (n, c, h, w)
"""
recovered_masks = []
for i, mask in enumerate(aug_masks):
if weights is not None:
assert len(weights) == len(aug_masks)
weight = weights[i]
else:
weight = 1
flip = img_metas.get('filp', False)
if flip:
flip_direction = img_metas['flip_direction']
if flip_direction == 'horizontal':
mask = mask[:, :, :, ::-1]
elif flip_direction == 'vertical':
mask = mask[:, :, ::-1, :]
elif flip_direction == 'diagonal':
mask = mask[:, :, :, ::-1]
mask = mask[:, :, ::-1, :]
else:
raise ValueError(
f"Invalid flipping direction '{flip_direction}'")
recovered_masks.append(mask[None, :] * weight)
merged_masks = torch.cat(recovered_masks, 0).mean(dim=0)
if weights is not None:
merged_masks = merged_masks * len(weights) / sum(weights)
return merged_masks
|