Spaces:
Running
Running
Upload utils.py
Browse files- src/spin/utils.py +146 -0
src/spin/utils.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from skimage.transform import resize, rotate
|
7 |
+
from torchvision.transforms import Normalize
|
8 |
+
|
9 |
+
from .constants import IMG_NORM_MEAN, IMG_NORM_STD, IMG_RES
|
10 |
+
|
11 |
+
|
12 |
+
def get_transform(center, scale, res, rot=0):
|
13 |
+
"""Generate transformation matrix."""
|
14 |
+
h = 200 * scale
|
15 |
+
t = np.zeros((3, 3))
|
16 |
+
t[0, 0] = float(res[1]) / h
|
17 |
+
t[1, 1] = float(res[0]) / h
|
18 |
+
t[0, 2] = res[1] * (-float(center[0]) / h + 0.5)
|
19 |
+
t[1, 2] = res[0] * (-float(center[1]) / h + 0.5)
|
20 |
+
t[2, 2] = 1
|
21 |
+
if not rot == 0:
|
22 |
+
rot = -rot # To match direction of rotation from cropping
|
23 |
+
rot_mat = np.zeros((3, 3))
|
24 |
+
rot_rad = rot * np.pi / 180
|
25 |
+
sn, cs = np.sin(rot_rad), np.cos(rot_rad)
|
26 |
+
rot_mat[0, :2] = [cs, -sn]
|
27 |
+
rot_mat[1, :2] = [sn, cs]
|
28 |
+
rot_mat[2, 2] = 1
|
29 |
+
# Need to rotate around center
|
30 |
+
t_mat = np.eye(3)
|
31 |
+
t_mat[0, 2] = -res[1] / 2
|
32 |
+
t_mat[1, 2] = -res[0] / 2
|
33 |
+
t_inv = t_mat.copy()
|
34 |
+
t_inv[:2, 2] *= -1
|
35 |
+
t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t)))
|
36 |
+
|
37 |
+
return t
|
38 |
+
|
39 |
+
|
40 |
+
def transform(pt, center, scale, res, invert=0, rot=0):
|
41 |
+
"""Transform pixel location to different reference."""
|
42 |
+
t = get_transform(center, scale, res, rot=rot)
|
43 |
+
if invert:
|
44 |
+
t = np.linalg.inv(t)
|
45 |
+
new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.0]).T
|
46 |
+
new_pt = np.dot(t, new_pt)
|
47 |
+
|
48 |
+
return new_pt[:2].astype(int) + 1
|
49 |
+
|
50 |
+
|
51 |
+
def crop(img, center, scale, res, rot=0):
|
52 |
+
"""Crop image according to the supplied bounding box."""
|
53 |
+
# Upper left point
|
54 |
+
ul = np.array(transform([1, 1], center, scale, res, invert=1)) - 1
|
55 |
+
# Bottom right point
|
56 |
+
br = np.array(transform([res[0] + 1, res[1] + 1], center, scale, res, invert=1)) - 1
|
57 |
+
|
58 |
+
# Padding so that when rotated proper amount of context is included
|
59 |
+
pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2)
|
60 |
+
if not rot == 0:
|
61 |
+
ul -= pad
|
62 |
+
br += pad
|
63 |
+
|
64 |
+
new_shape = [br[1] - ul[1], br[0] - ul[0]]
|
65 |
+
if len(img.shape) > 2:
|
66 |
+
new_shape += [img.shape[2]]
|
67 |
+
new_img = np.zeros(new_shape)
|
68 |
+
|
69 |
+
# Range to fill new array
|
70 |
+
new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0]
|
71 |
+
new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1]
|
72 |
+
# Range to sample from original image
|
73 |
+
old_x = max(0, ul[0]), min(len(img[0]), br[0])
|
74 |
+
old_y = max(0, ul[1]), min(len(img), br[1])
|
75 |
+
new_img[new_y[0] : new_y[1], new_x[0] : new_x[1]] = img[
|
76 |
+
old_y[0] : old_y[1], old_x[0] : old_x[1]
|
77 |
+
]
|
78 |
+
|
79 |
+
if not rot == 0:
|
80 |
+
# Remove padding
|
81 |
+
new_img = rotate(new_img, rot)
|
82 |
+
new_img = new_img[pad:-pad, pad:-pad]
|
83 |
+
|
84 |
+
new_img = resize(new_img, res)
|
85 |
+
|
86 |
+
return new_img
|
87 |
+
|
88 |
+
|
89 |
+
def bbox_from_openpose(openpose_file, rescale=1.2, detection_thresh=0.2):
|
90 |
+
"""Get center and scale for bounding box from openpose detections."""
|
91 |
+
with open(openpose_file, "r") as f:
|
92 |
+
keypoints = json.load(f)["people"][0]["pose_keypoints_2d"]
|
93 |
+
keypoints = np.reshape(np.array(keypoints), (-1, 3))
|
94 |
+
valid = keypoints[:, -1] > detection_thresh
|
95 |
+
valid_keypoints = keypoints[valid][:, :-1]
|
96 |
+
center = valid_keypoints.mean(axis=0)
|
97 |
+
bbox_size = (valid_keypoints.max(axis=0) - valid_keypoints.min(axis=0)).max()
|
98 |
+
# adjust bounding box tightness
|
99 |
+
scale = bbox_size / 200.0
|
100 |
+
scale *= rescale
|
101 |
+
|
102 |
+
return center, scale
|
103 |
+
|
104 |
+
|
105 |
+
def bbox_from_json(bbox_file):
|
106 |
+
"""Get center and scale of bounding box from bounding box annotations.
|
107 |
+
The expected format is [top_left(x), top_left(y), width, height].
|
108 |
+
"""
|
109 |
+
with open(bbox_file, "r") as f:
|
110 |
+
bbox = np.array(json.load(f)["bbox"]).astype(np.float32)
|
111 |
+
ul_corner = bbox[:2]
|
112 |
+
center = ul_corner + 0.5 * bbox[2:]
|
113 |
+
width = max(bbox[2], bbox[3])
|
114 |
+
scale = width / 200.0
|
115 |
+
# make sure the bounding box is rectangular
|
116 |
+
return center, scale
|
117 |
+
|
118 |
+
|
119 |
+
def process_image(img_file, bbox_file=None, openpose_file=None, input_res=IMG_RES):
|
120 |
+
"""Read image, do preprocessing and possibly crop it according to the bounding box.
|
121 |
+
If there are bounding box annotations, use them to crop the image.
|
122 |
+
If no bounding box is specified but openpose detections are available, use them to get the bounding box.
|
123 |
+
"""
|
124 |
+
img_file = str(img_file)
|
125 |
+
normalize_img = Normalize(mean=IMG_NORM_MEAN, std=IMG_NORM_STD)
|
126 |
+
img = cv2.imread(img_file)[
|
127 |
+
:, :, ::-1
|
128 |
+
].copy() # PyTorch does not support negative stride at the moment
|
129 |
+
if bbox_file is None and openpose_file is None:
|
130 |
+
# Assume that the person is centerered in the image
|
131 |
+
height = img.shape[0]
|
132 |
+
width = img.shape[1]
|
133 |
+
center = np.array([width // 2, height // 2])
|
134 |
+
scale = max(height, width) / 200
|
135 |
+
else:
|
136 |
+
if bbox_file is not None:
|
137 |
+
center, scale = bbox_from_json(bbox_file)
|
138 |
+
elif openpose_file is not None:
|
139 |
+
center, scale = bbox_from_openpose(openpose_file)
|
140 |
+
|
141 |
+
img = crop(img, center, scale, (input_res, input_res))
|
142 |
+
img = img.astype(np.float32) / 255.0
|
143 |
+
img = torch.from_numpy(img).permute(2, 0, 1)
|
144 |
+
norm_img = normalize_img(img.clone())
|
145 |
+
|
146 |
+
return img, norm_img
|