Spaces:
Running
on
L40S
Running
on
L40S
File size: 3,888 Bytes
d7e58f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import torch, math
def ciou(bboxes1, bboxes2):
bboxes1 = torch.sigmoid(bboxes1)
bboxes2 = torch.sigmoid(bboxes2)
rows = bboxes1.shape[0]
cols = bboxes2.shape[0]
cious = torch.zeros((rows, cols))
if rows * cols == 0:
return cious
exchange = False
if bboxes1.shape[0] > bboxes2.shape[0]:
bboxes1, bboxes2 = bboxes2, bboxes1
cious = torch.zeros((cols, rows))
exchange = True
w1 = torch.exp(bboxes1[:, 2])
h1 = torch.exp(bboxes1[:, 3])
w2 = torch.exp(bboxes2[:, 2])
h2 = torch.exp(bboxes2[:, 3])
area1 = w1 * h1
area2 = w2 * h2
center_x1 = bboxes1[:, 0]
center_y1 = bboxes1[:, 1]
center_x2 = bboxes2[:, 0]
center_y2 = bboxes2[:, 1]
inter_l = torch.max(center_x1 - w1 / 2, center_x2 - w2 / 2)
inter_r = torch.min(center_x1 + w1 / 2, center_x2 + w2 / 2)
inter_t = torch.max(center_y1 - h1 / 2, center_y2 - h2 / 2)
inter_b = torch.min(center_y1 + h1 / 2, center_y2 + h2 / 2)
inter_area = torch.clamp((inter_r - inter_l), min=0) * torch.clamp(
(inter_b - inter_t), min=0)
c_l = torch.min(center_x1 - w1 / 2, center_x2 - w2 / 2)
c_r = torch.max(center_x1 + w1 / 2, center_x2 + w2 / 2)
c_t = torch.min(center_y1 - h1 / 2, center_y2 - h2 / 2)
c_b = torch.max(center_y1 + h1 / 2, center_y2 + h2 / 2)
inter_diag = (center_x2 - center_x1)**2 + (center_y2 - center_y1)**2
c_diag = torch.clamp((c_r - c_l), min=0)**2 + torch.clamp(
(c_b - c_t), min=0)**2
union = area1 + area2 - inter_area
u = (inter_diag) / c_diag
iou = inter_area / union
v = (4 / (math.pi**2)) * torch.pow(
(torch.atan(w2 / h2) - torch.atan(w1 / h1)), 2)
with torch.no_grad():
S = (iou > 0.5).float()
alpha = S * v / (1 - iou + v)
cious = iou - u - alpha * v
cious = torch.clamp(cious, min=-1.0, max=1.0)
if exchange:
cious = cious.T
return 1 - cious
def diou(bboxes1, bboxes2):
bboxes1 = torch.sigmoid(bboxes1)
bboxes2 = torch.sigmoid(bboxes2)
rows = bboxes1.shape[0]
cols = bboxes2.shape[0]
cious = torch.zeros((rows, cols))
if rows * cols == 0:
return cious
exchange = False
if bboxes1.shape[0] > bboxes2.shape[0]:
bboxes1, bboxes2 = bboxes2, bboxes1
cious = torch.zeros((cols, rows))
exchange = True
w1 = torch.exp(bboxes1[:, 2])
h1 = torch.exp(bboxes1[:, 3])
w2 = torch.exp(bboxes2[:, 2])
h2 = torch.exp(bboxes2[:, 3])
area1 = w1 * h1
area2 = w2 * h2
center_x1 = bboxes1[:, 0]
center_y1 = bboxes1[:, 1]
center_x2 = bboxes2[:, 0]
center_y2 = bboxes2[:, 1]
inter_l = torch.max(center_x1 - w1 / 2, center_x2 - w2 / 2)
inter_r = torch.min(center_x1 + w1 / 2, center_x2 + w2 / 2)
inter_t = torch.max(center_y1 - h1 / 2, center_y2 - h2 / 2)
inter_b = torch.min(center_y1 + h1 / 2, center_y2 + h2 / 2)
inter_area = torch.clamp((inter_r - inter_l), min=0) * torch.clamp(
(inter_b - inter_t), min=0)
c_l = torch.min(center_x1 - w1 / 2, center_x2 - w2 / 2)
c_r = torch.max(center_x1 + w1 / 2, center_x2 + w2 / 2)
c_t = torch.min(center_y1 - h1 / 2, center_y2 - h2 / 2)
c_b = torch.max(center_y1 + h1 / 2, center_y2 + h2 / 2)
inter_diag = (center_x2 - center_x1)**2 + (center_y2 - center_y1)**2
c_diag = torch.clamp((c_r - c_l), min=0)**2 + torch.clamp(
(c_b - c_t), min=0)**2
union = area1 + area2 - inter_area
u = (inter_diag) / c_diag
iou = inter_area / union
dious = iou - u
dious = torch.clamp(dious, min=-1.0, max=1.0)
if exchange:
dious = dious.T
return 1 - dious
if __name__ == '__main__':
x = torch.rand(10, 4)
y = torch.rand(10, 4)
import pdb
pdb.set_trace()
cxy = ciou(x, y)
dxy = diou(x, y)
print(cxy.shape, dxy.shape)
import pdb
pdb.set_trace()
|