Spaces:
Running
on
L40S
Running
on
L40S
import torch, math | |
def ciou(bboxes1, bboxes2): | |
bboxes1 = torch.sigmoid(bboxes1) | |
bboxes2 = torch.sigmoid(bboxes2) | |
rows = bboxes1.shape[0] | |
cols = bboxes2.shape[0] | |
cious = torch.zeros((rows, cols)) | |
if rows * cols == 0: | |
return cious | |
exchange = False | |
if bboxes1.shape[0] > bboxes2.shape[0]: | |
bboxes1, bboxes2 = bboxes2, bboxes1 | |
cious = torch.zeros((cols, rows)) | |
exchange = True | |
w1 = torch.exp(bboxes1[:, 2]) | |
h1 = torch.exp(bboxes1[:, 3]) | |
w2 = torch.exp(bboxes2[:, 2]) | |
h2 = torch.exp(bboxes2[:, 3]) | |
area1 = w1 * h1 | |
area2 = w2 * h2 | |
center_x1 = bboxes1[:, 0] | |
center_y1 = bboxes1[:, 1] | |
center_x2 = bboxes2[:, 0] | |
center_y2 = bboxes2[:, 1] | |
inter_l = torch.max(center_x1 - w1 / 2, center_x2 - w2 / 2) | |
inter_r = torch.min(center_x1 + w1 / 2, center_x2 + w2 / 2) | |
inter_t = torch.max(center_y1 - h1 / 2, center_y2 - h2 / 2) | |
inter_b = torch.min(center_y1 + h1 / 2, center_y2 + h2 / 2) | |
inter_area = torch.clamp((inter_r - inter_l), min=0) * torch.clamp( | |
(inter_b - inter_t), min=0) | |
c_l = torch.min(center_x1 - w1 / 2, center_x2 - w2 / 2) | |
c_r = torch.max(center_x1 + w1 / 2, center_x2 + w2 / 2) | |
c_t = torch.min(center_y1 - h1 / 2, center_y2 - h2 / 2) | |
c_b = torch.max(center_y1 + h1 / 2, center_y2 + h2 / 2) | |
inter_diag = (center_x2 - center_x1)**2 + (center_y2 - center_y1)**2 | |
c_diag = torch.clamp((c_r - c_l), min=0)**2 + torch.clamp( | |
(c_b - c_t), min=0)**2 | |
union = area1 + area2 - inter_area | |
u = (inter_diag) / c_diag | |
iou = inter_area / union | |
v = (4 / (math.pi**2)) * torch.pow( | |
(torch.atan(w2 / h2) - torch.atan(w1 / h1)), 2) | |
with torch.no_grad(): | |
S = (iou > 0.5).float() | |
alpha = S * v / (1 - iou + v) | |
cious = iou - u - alpha * v | |
cious = torch.clamp(cious, min=-1.0, max=1.0) | |
if exchange: | |
cious = cious.T | |
return 1 - cious | |
def diou(bboxes1, bboxes2): | |
bboxes1 = torch.sigmoid(bboxes1) | |
bboxes2 = torch.sigmoid(bboxes2) | |
rows = bboxes1.shape[0] | |
cols = bboxes2.shape[0] | |
cious = torch.zeros((rows, cols)) | |
if rows * cols == 0: | |
return cious | |
exchange = False | |
if bboxes1.shape[0] > bboxes2.shape[0]: | |
bboxes1, bboxes2 = bboxes2, bboxes1 | |
cious = torch.zeros((cols, rows)) | |
exchange = True | |
w1 = torch.exp(bboxes1[:, 2]) | |
h1 = torch.exp(bboxes1[:, 3]) | |
w2 = torch.exp(bboxes2[:, 2]) | |
h2 = torch.exp(bboxes2[:, 3]) | |
area1 = w1 * h1 | |
area2 = w2 * h2 | |
center_x1 = bboxes1[:, 0] | |
center_y1 = bboxes1[:, 1] | |
center_x2 = bboxes2[:, 0] | |
center_y2 = bboxes2[:, 1] | |
inter_l = torch.max(center_x1 - w1 / 2, center_x2 - w2 / 2) | |
inter_r = torch.min(center_x1 + w1 / 2, center_x2 + w2 / 2) | |
inter_t = torch.max(center_y1 - h1 / 2, center_y2 - h2 / 2) | |
inter_b = torch.min(center_y1 + h1 / 2, center_y2 + h2 / 2) | |
inter_area = torch.clamp((inter_r - inter_l), min=0) * torch.clamp( | |
(inter_b - inter_t), min=0) | |
c_l = torch.min(center_x1 - w1 / 2, center_x2 - w2 / 2) | |
c_r = torch.max(center_x1 + w1 / 2, center_x2 + w2 / 2) | |
c_t = torch.min(center_y1 - h1 / 2, center_y2 - h2 / 2) | |
c_b = torch.max(center_y1 + h1 / 2, center_y2 + h2 / 2) | |
inter_diag = (center_x2 - center_x1)**2 + (center_y2 - center_y1)**2 | |
c_diag = torch.clamp((c_r - c_l), min=0)**2 + torch.clamp( | |
(c_b - c_t), min=0)**2 | |
union = area1 + area2 - inter_area | |
u = (inter_diag) / c_diag | |
iou = inter_area / union | |
dious = iou - u | |
dious = torch.clamp(dious, min=-1.0, max=1.0) | |
if exchange: | |
dious = dious.T | |
return 1 - dious | |
if __name__ == '__main__': | |
x = torch.rand(10, 4) | |
y = torch.rand(10, 4) | |
import pdb | |
pdb.set_trace() | |
cxy = ciou(x, y) | |
dxy = diou(x, y) | |
print(cxy.shape, dxy.shape) | |
import pdb | |
pdb.set_trace() | |