Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch | |
from mmdet.registry import TASK_UTILS | |
from mmdet.structures.bbox import bbox_overlaps, get_box_tensor | |
def cast_tensor_type(x, scale=1., dtype=None): | |
if dtype == 'fp16': | |
# scale is for preventing overflows | |
x = (x / scale).half() | |
return x | |
class BboxOverlaps2D: | |
"""2D Overlaps (e.g. IoUs, GIoUs) Calculator.""" | |
def __init__(self, scale=1., dtype=None): | |
self.scale = scale | |
self.dtype = dtype | |
def __call__(self, bboxes1, bboxes2, mode='iou', is_aligned=False): | |
"""Calculate IoU between 2D bboxes. | |
Args: | |
bboxes1 (Tensor or :obj:`BaseBoxes`): bboxes have shape (m, 4) | |
in <x1, y1, x2, y2> format, or shape (m, 5) in <x1, y1, x2, | |
y2, score> format. | |
bboxes2 (Tensor or :obj:`BaseBoxes`): bboxes have shape (m, 4) | |
in <x1, y1, x2, y2> format, shape (m, 5) in <x1, y1, x2, y2, | |
score> format, or be empty. If ``is_aligned `` is ``True``, | |
then m and n must be equal. | |
mode (str): "iou" (intersection over union), "iof" (intersection | |
over foreground), or "giou" (generalized intersection over | |
union). | |
is_aligned (bool, optional): If True, then m and n must be equal. | |
Default False. | |
Returns: | |
Tensor: shape (m, n) if ``is_aligned `` is False else shape (m,) | |
""" | |
bboxes1 = get_box_tensor(bboxes1) | |
bboxes2 = get_box_tensor(bboxes2) | |
assert bboxes1.size(-1) in [0, 4, 5] | |
assert bboxes2.size(-1) in [0, 4, 5] | |
if bboxes2.size(-1) == 5: | |
bboxes2 = bboxes2[..., :4] | |
if bboxes1.size(-1) == 5: | |
bboxes1 = bboxes1[..., :4] | |
if self.dtype == 'fp16': | |
# change tensor type to save cpu and cuda memory and keep speed | |
bboxes1 = cast_tensor_type(bboxes1, self.scale, self.dtype) | |
bboxes2 = cast_tensor_type(bboxes2, self.scale, self.dtype) | |
overlaps = bbox_overlaps(bboxes1, bboxes2, mode, is_aligned) | |
if not overlaps.is_cuda and overlaps.dtype == torch.float16: | |
# resume cpu float32 | |
overlaps = overlaps.float() | |
return overlaps | |
return bbox_overlaps(bboxes1, bboxes2, mode, is_aligned) | |
def __repr__(self): | |
"""str: a string describing the module""" | |
repr_str = self.__class__.__name__ + f'(' \ | |
f'scale={self.scale}, dtype={self.dtype})' | |
return repr_str | |