kbrodt commited on
Commit
df84fee
·
1 Parent(s): f73c7b1

Upload utils.py

Browse files
Files changed (1) hide show
  1. src/spin/utils.py +146 -0
src/spin/utils.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+ from skimage.transform import resize, rotate
7
+ from torchvision.transforms import Normalize
8
+
9
+ from .constants import IMG_NORM_MEAN, IMG_NORM_STD, IMG_RES
10
+
11
+
12
+ def get_transform(center, scale, res, rot=0):
13
+ """Generate transformation matrix."""
14
+ h = 200 * scale
15
+ t = np.zeros((3, 3))
16
+ t[0, 0] = float(res[1]) / h
17
+ t[1, 1] = float(res[0]) / h
18
+ t[0, 2] = res[1] * (-float(center[0]) / h + 0.5)
19
+ t[1, 2] = res[0] * (-float(center[1]) / h + 0.5)
20
+ t[2, 2] = 1
21
+ if not rot == 0:
22
+ rot = -rot # To match direction of rotation from cropping
23
+ rot_mat = np.zeros((3, 3))
24
+ rot_rad = rot * np.pi / 180
25
+ sn, cs = np.sin(rot_rad), np.cos(rot_rad)
26
+ rot_mat[0, :2] = [cs, -sn]
27
+ rot_mat[1, :2] = [sn, cs]
28
+ rot_mat[2, 2] = 1
29
+ # Need to rotate around center
30
+ t_mat = np.eye(3)
31
+ t_mat[0, 2] = -res[1] / 2
32
+ t_mat[1, 2] = -res[0] / 2
33
+ t_inv = t_mat.copy()
34
+ t_inv[:2, 2] *= -1
35
+ t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t)))
36
+
37
+ return t
38
+
39
+
40
+ def transform(pt, center, scale, res, invert=0, rot=0):
41
+ """Transform pixel location to different reference."""
42
+ t = get_transform(center, scale, res, rot=rot)
43
+ if invert:
44
+ t = np.linalg.inv(t)
45
+ new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.0]).T
46
+ new_pt = np.dot(t, new_pt)
47
+
48
+ return new_pt[:2].astype(int) + 1
49
+
50
+
51
+ def crop(img, center, scale, res, rot=0):
52
+ """Crop image according to the supplied bounding box."""
53
+ # Upper left point
54
+ ul = np.array(transform([1, 1], center, scale, res, invert=1)) - 1
55
+ # Bottom right point
56
+ br = np.array(transform([res[0] + 1, res[1] + 1], center, scale, res, invert=1)) - 1
57
+
58
+ # Padding so that when rotated proper amount of context is included
59
+ pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2)
60
+ if not rot == 0:
61
+ ul -= pad
62
+ br += pad
63
+
64
+ new_shape = [br[1] - ul[1], br[0] - ul[0]]
65
+ if len(img.shape) > 2:
66
+ new_shape += [img.shape[2]]
67
+ new_img = np.zeros(new_shape)
68
+
69
+ # Range to fill new array
70
+ new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0]
71
+ new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1]
72
+ # Range to sample from original image
73
+ old_x = max(0, ul[0]), min(len(img[0]), br[0])
74
+ old_y = max(0, ul[1]), min(len(img), br[1])
75
+ new_img[new_y[0] : new_y[1], new_x[0] : new_x[1]] = img[
76
+ old_y[0] : old_y[1], old_x[0] : old_x[1]
77
+ ]
78
+
79
+ if not rot == 0:
80
+ # Remove padding
81
+ new_img = rotate(new_img, rot)
82
+ new_img = new_img[pad:-pad, pad:-pad]
83
+
84
+ new_img = resize(new_img, res)
85
+
86
+ return new_img
87
+
88
+
89
+ def bbox_from_openpose(openpose_file, rescale=1.2, detection_thresh=0.2):
90
+ """Get center and scale for bounding box from openpose detections."""
91
+ with open(openpose_file, "r") as f:
92
+ keypoints = json.load(f)["people"][0]["pose_keypoints_2d"]
93
+ keypoints = np.reshape(np.array(keypoints), (-1, 3))
94
+ valid = keypoints[:, -1] > detection_thresh
95
+ valid_keypoints = keypoints[valid][:, :-1]
96
+ center = valid_keypoints.mean(axis=0)
97
+ bbox_size = (valid_keypoints.max(axis=0) - valid_keypoints.min(axis=0)).max()
98
+ # adjust bounding box tightness
99
+ scale = bbox_size / 200.0
100
+ scale *= rescale
101
+
102
+ return center, scale
103
+
104
+
105
+ def bbox_from_json(bbox_file):
106
+ """Get center and scale of bounding box from bounding box annotations.
107
+ The expected format is [top_left(x), top_left(y), width, height].
108
+ """
109
+ with open(bbox_file, "r") as f:
110
+ bbox = np.array(json.load(f)["bbox"]).astype(np.float32)
111
+ ul_corner = bbox[:2]
112
+ center = ul_corner + 0.5 * bbox[2:]
113
+ width = max(bbox[2], bbox[3])
114
+ scale = width / 200.0
115
+ # make sure the bounding box is rectangular
116
+ return center, scale
117
+
118
+
119
+ def process_image(img_file, bbox_file=None, openpose_file=None, input_res=IMG_RES):
120
+ """Read image, do preprocessing and possibly crop it according to the bounding box.
121
+ If there are bounding box annotations, use them to crop the image.
122
+ If no bounding box is specified but openpose detections are available, use them to get the bounding box.
123
+ """
124
+ img_file = str(img_file)
125
+ normalize_img = Normalize(mean=IMG_NORM_MEAN, std=IMG_NORM_STD)
126
+ img = cv2.imread(img_file)[
127
+ :, :, ::-1
128
+ ].copy() # PyTorch does not support negative stride at the moment
129
+ if bbox_file is None and openpose_file is None:
130
+ # Assume that the person is centerered in the image
131
+ height = img.shape[0]
132
+ width = img.shape[1]
133
+ center = np.array([width // 2, height // 2])
134
+ scale = max(height, width) / 200
135
+ else:
136
+ if bbox_file is not None:
137
+ center, scale = bbox_from_json(bbox_file)
138
+ elif openpose_file is not None:
139
+ center, scale = bbox_from_openpose(openpose_file)
140
+
141
+ img = crop(img, center, scale, (input_res, input_res))
142
+ img = img.astype(np.float32) / 255.0
143
+ img = torch.from_numpy(img).permute(2, 0, 1)
144
+ norm_img = normalize_img(img.clone())
145
+
146
+ return img, norm_img