Spaces:
Han-123
/
Running on Zero

IDM-VTON
update IDM-VTON Demo
938e515
raw
history blame
6.61 kB
# -*- coding: utf-8 -*-
# Copyright (c) Facebook, Inc. and its affiliates.
import torch
from torchvision.ops import boxes as box_ops
from torchvision.ops import nms # noqa . for compatibility
def batched_nms(
boxes: torch.Tensor, scores: torch.Tensor, idxs: torch.Tensor, iou_threshold: float
):
"""
Same as torchvision.ops.boxes.batched_nms, but with float().
"""
assert boxes.shape[-1] == 4
# Note: Torchvision already has a strategy (https://github.com/pytorch/vision/issues/1311)
# to decide whether to use coordinate trick or for loop to implement batched_nms. So we
# just call it directly.
# Fp16 does not have enough range for batched NMS, so adding float().
return box_ops.batched_nms(boxes.float(), scores, idxs, iou_threshold)
# Note: this function (nms_rotated) might be moved into
# torchvision/ops/boxes.py in the future
def nms_rotated(boxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float):
"""
Performs non-maximum suppression (NMS) on the rotated boxes according
to their intersection-over-union (IoU).
Rotated NMS iteratively removes lower scoring rotated boxes which have an
IoU greater than iou_threshold with another (higher scoring) rotated box.
Note that RotatedBox (5, 3, 4, 2, -90) covers exactly the same region as
RotatedBox (5, 3, 4, 2, 90) does, and their IoU will be 1. However, they
can be representing completely different objects in certain tasks, e.g., OCR.
As for the question of whether rotated-NMS should treat them as faraway boxes
even though their IOU is 1, it depends on the application and/or ground truth annotation.
As an extreme example, consider a single character v and the square box around it.
If the angle is 0 degree, the object (text) would be read as 'v';
If the angle is 90 degrees, the object (text) would become '>';
If the angle is 180 degrees, the object (text) would become '^';
If the angle is 270/-90 degrees, the object (text) would become '<'
All of these cases have IoU of 1 to each other, and rotated NMS that only
uses IoU as criterion would only keep one of them with the highest score -
which, practically, still makes sense in most cases because typically
only one of theses orientations is the correct one. Also, it does not matter
as much if the box is only used to classify the object (instead of transcribing
them with a sequential OCR recognition model) later.
On the other hand, when we use IoU to filter proposals that are close to the
ground truth during training, we should definitely take the angle into account if
we know the ground truth is labeled with the strictly correct orientation (as in,
upside-down words are annotated with -180 degrees even though they can be covered
with a 0/90/-90 degree box, etc.)
The way the original dataset is annotated also matters. For example, if the dataset
is a 4-point polygon dataset that does not enforce ordering of vertices/orientation,
we can estimate a minimum rotated bounding box to this polygon, but there's no way
we can tell the correct angle with 100% confidence (as shown above, there could be 4 different
rotated boxes, with angles differed by 90 degrees to each other, covering the exactly
same region). In that case we have to just use IoU to determine the box
proximity (as many detection benchmarks (even for text) do) unless there're other
assumptions we can make (like width is always larger than height, or the object is not
rotated by more than 90 degrees CCW/CW, etc.)
In summary, not considering angles in rotated NMS seems to be a good option for now,
but we should be aware of its implications.
Args:
boxes (Tensor[N, 5]): Rotated boxes to perform NMS on. They are expected to be in
(x_center, y_center, width, height, angle_degrees) format.
scores (Tensor[N]): Scores for each one of the rotated boxes
iou_threshold (float): Discards all overlapping rotated boxes with IoU < iou_threshold
Returns:
keep (Tensor): int64 tensor with the indices of the elements that have been kept
by Rotated NMS, sorted in decreasing order of scores
"""
return torch.ops.detectron2.nms_rotated(boxes, scores, iou_threshold)
# Note: this function (batched_nms_rotated) might be moved into
# torchvision/ops/boxes.py in the future
@torch.jit.script_if_tracing
def batched_nms_rotated(
boxes: torch.Tensor, scores: torch.Tensor, idxs: torch.Tensor, iou_threshold: float
):
"""
Performs non-maximum suppression in a batched fashion.
Each index value correspond to a category, and NMS
will not be applied between elements of different categories.
Args:
boxes (Tensor[N, 5]):
boxes where NMS will be performed. They
are expected to be in (x_ctr, y_ctr, width, height, angle_degrees) format
scores (Tensor[N]):
scores for each one of the boxes
idxs (Tensor[N]):
indices of the categories for each one of the boxes.
iou_threshold (float):
discards all overlapping boxes
with IoU < iou_threshold
Returns:
Tensor:
int64 tensor with the indices of the elements that have been kept
by NMS, sorted in decreasing order of scores
"""
assert boxes.shape[-1] == 5
if boxes.numel() == 0:
return torch.empty((0,), dtype=torch.int64, device=boxes.device)
boxes = boxes.float() # fp16 does not have enough range for batched NMS
# Strategy: in order to perform NMS independently per class,
# we add an offset to all the boxes. The offset is dependent
# only on the class idx, and is large enough so that boxes
# from different classes do not overlap
# Note that batched_nms in torchvision/ops/boxes.py only uses max_coordinate,
# which won't handle negative coordinates correctly.
# Here by using min_coordinate we can make sure the negative coordinates are
# correctly handled.
max_coordinate = (
torch.max(boxes[:, 0], boxes[:, 1]) + torch.max(boxes[:, 2], boxes[:, 3]) / 2
).max()
min_coordinate = (
torch.min(boxes[:, 0], boxes[:, 1]) - torch.max(boxes[:, 2], boxes[:, 3]) / 2
).min()
offsets = idxs.to(boxes) * (max_coordinate - min_coordinate + 1)
boxes_for_nms = boxes.clone() # avoid modifying the original values in boxes
boxes_for_nms[:, :2] += offsets[:, None]
keep = nms_rotated(boxes_for_nms, scores, iou_threshold)
return keep