freddyaboulton HF staff commited on
Commit
c8120da
·
verified ·
1 Parent(s): 8cacdf4

Upload folder using huggingface_hub

Browse files
Files changed (6) hide show
  1. README.md +7 -4
  2. app.py +64 -0
  3. index.html +262 -0
  4. inference.py +153 -0
  5. requirements.txt +2 -0
  6. utils.py +237 -0
README.md CHANGED
@@ -1,12 +1,15 @@
1
  ---
2
  title: Object Detection
3
- emoji: 🔥
4
- colorFrom: yellow
5
- colorTo: purple
6
  sdk: gradio
7
  sdk_version: 5.16.0
8
  app_file: app.py
9
  pinned: false
 
 
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: Object Detection
3
+ emoji: 📸
4
+ colorFrom: purple
5
+ colorTo: red
6
  sdk: gradio
7
  sdk_version: 5.16.0
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
+ short_description: Use YOLOv10 to detect objects in real-time
12
+ tags: [webrtc, websocket, gradio, secret|TWILIO_ACCOUNT_SID, secret|TWILIO_AUTH_TOKEN]
13
  ---
14
 
15
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ from fastrtc import Stream, get_twilio_turn_credentials
3
+ from huggingface_hub import hf_hub_download
4
+ from fastapi.responses import HTMLResponse
5
+ from pathlib import Path
6
+ import gradio as gr
7
+ from gradio.utils import get_space
8
+ import json
9
+ from pydantic import BaseModel, Field
10
+
11
+
12
+ try:
13
+ from demo.object_detection.inference import YOLOv10
14
+ except ImportError:
15
+ from .inference import YOLOv10
16
+
17
+
18
+ cur_dir = Path(__file__).parent
19
+
20
+ model_file = hf_hub_download(
21
+ repo_id="onnx-community/yolov10n", filename="onnx/model.onnx"
22
+ )
23
+
24
+ model = YOLOv10(model_file)
25
+
26
+
27
+ def detection(image, conf_threshold=0.3):
28
+ image = cv2.resize(image, (model.input_width, model.input_height))
29
+ print("conf_threshold", conf_threshold)
30
+ new_image = model.detect_objects(image, conf_threshold)
31
+ return cv2.resize(new_image, (500, 500))
32
+
33
+
34
+ stream = Stream(
35
+ handler=detection,
36
+ modality="video",
37
+ mode="send-receive",
38
+ additional_inputs=[gr.Slider(minimum=0, maximum=1, step=0.01, value=0.3)],
39
+ rtc_configuration=get_twilio_turn_credentials() if get_space() else None,
40
+ )
41
+
42
+
43
+ @stream.get("/")
44
+ async def _():
45
+ rtc_config = get_twilio_turn_credentials() if get_space() else None
46
+ html_content = open(cur_dir / "index.html").read()
47
+ html_content = html_content.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config))
48
+ return HTMLResponse(content=html_content)
49
+
50
+
51
+ class InputData(BaseModel):
52
+ webrtc_id: str
53
+ conf_threshold: float = Field(ge=0, le=1)
54
+
55
+
56
+ @stream.post("/input_hook")
57
+ async def _(data: InputData):
58
+ stream.set_input(data.webrtc_id, data.conf_threshold)
59
+
60
+
61
+ if __name__ == "__main__":
62
+ import uvicorn
63
+
64
+ uvicorn.run(stream, host="0.0.0.0", port=7860)
index.html ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+
4
+ <head>
5
+ <meta charset="UTF-8">
6
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
7
+ <title>Object Detection</title>
8
+ <style>
9
+ body {
10
+ font-family: system-ui, -apple-system, sans-serif;
11
+ background: linear-gradient(135deg, #2d2b52 0%, #191731 100%);
12
+ color: white;
13
+ margin: 0;
14
+ padding: 20px;
15
+ height: 100vh;
16
+ box-sizing: border-box;
17
+ display: flex;
18
+ flex-direction: column;
19
+ align-items: center;
20
+ justify-content: center;
21
+ }
22
+
23
+ .container {
24
+ width: 100%;
25
+ max-width: 800px;
26
+ text-align: center;
27
+ }
28
+
29
+ .video-container {
30
+ width: 100%;
31
+ aspect-ratio: 16/9;
32
+ background: rgba(255, 255, 255, 0.1);
33
+ border-radius: 12px;
34
+ overflow: hidden;
35
+ box-shadow: 0 8px 32px rgba(0, 0, 0, 0.2);
36
+ margin: 20px 0;
37
+ }
38
+
39
+ #video-output {
40
+ width: 100%;
41
+ height: 100%;
42
+ object-fit: cover;
43
+ }
44
+
45
+ button {
46
+ background: white;
47
+ color: #2d2b52;
48
+ border: none;
49
+ padding: 12px 32px;
50
+ border-radius: 24px;
51
+ font-size: 16px;
52
+ font-weight: 600;
53
+ cursor: pointer;
54
+ transition: all 0.3s ease;
55
+ box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
56
+ }
57
+
58
+ button:hover {
59
+ transform: translateY(-2px);
60
+ box-shadow: 0 6px 16px rgba(0, 0, 0, 0.2);
61
+ }
62
+
63
+ h1 {
64
+ font-size: 2.5em;
65
+ margin-bottom: 0.5em;
66
+ }
67
+
68
+ p {
69
+ color: rgba(255, 255, 255, 0.8);
70
+ margin-bottom: 2em;
71
+ }
72
+
73
+ .controls {
74
+ display: flex;
75
+ flex-direction: column;
76
+ gap: 20px;
77
+ align-items: center;
78
+ margin-top: 20px;
79
+ }
80
+
81
+ .slider-container {
82
+ width: 100%;
83
+ max-width: 300px;
84
+ display: flex;
85
+ flex-direction: column;
86
+ gap: 8px;
87
+ }
88
+
89
+ .slider-container label {
90
+ color: rgba(255, 255, 255, 0.8);
91
+ font-size: 14px;
92
+ }
93
+
94
+ input[type="range"] {
95
+ width: 100%;
96
+ height: 6px;
97
+ -webkit-appearance: none;
98
+ background: rgba(255, 255, 255, 0.1);
99
+ border-radius: 3px;
100
+ outline: none;
101
+ }
102
+
103
+ input[type="range"]::-webkit-slider-thumb {
104
+ -webkit-appearance: none;
105
+ width: 18px;
106
+ height: 18px;
107
+ background: white;
108
+ border-radius: 50%;
109
+ cursor: pointer;
110
+ }
111
+ </style>
112
+ </head>
113
+
114
+ <body>
115
+ <div class="container">
116
+ <h1>Real-time Object Detection</h1>
117
+ <p>Using YOLOv10 to detect objects in your webcam feed</p>
118
+ <div class="video-container">
119
+ <video id="video-output" autoplay playsinline></video>
120
+ </div>
121
+ <div class="controls">
122
+ <div class="slider-container">
123
+ <label>Confidence Threshold: <span id="conf-value">0.3</span></label>
124
+ <input type="range" id="conf-threshold" min="0" max="1" step="0.01" value="0.3">
125
+ </div>
126
+ <button id="start-button">Start</button>
127
+ </div>
128
+ </div>
129
+
130
+ <script>
131
+ let peerConnection;
132
+ let webrtc_id;
133
+ const startButton = document.getElementById('start-button');
134
+ const videoOutput = document.getElementById('video-output');
135
+ const confThreshold = document.getElementById('conf-threshold');
136
+ const confValue = document.getElementById('conf-value');
137
+
138
+ // Update confidence value display
139
+ confThreshold.addEventListener('input', (e) => {
140
+ confValue.textContent = e.target.value;
141
+ if (peerConnection) {
142
+ updateConfThreshold(e.target.value);
143
+ }
144
+ });
145
+
146
+ function updateConfThreshold(value) {
147
+ fetch('/input_hook', {
148
+ method: 'POST',
149
+ headers: {
150
+ 'Content-Type': 'application/json',
151
+ },
152
+ body: JSON.stringify({
153
+ webrtc_id: webrtc_id,
154
+ conf_threshold: parseFloat(value)
155
+ })
156
+ });
157
+ }
158
+
159
+ async function setupWebRTC() {
160
+ const config = __RTC_CONFIGURATION__;
161
+ peerConnection = new RTCPeerConnection(config);
162
+
163
+ try {
164
+ const stream = await navigator.mediaDevices.getUserMedia({
165
+ video: true
166
+ });
167
+
168
+ stream.getTracks().forEach(track => {
169
+ peerConnection.addTrack(track, stream);
170
+ });
171
+
172
+ peerConnection.addEventListener('track', (evt) => {
173
+ if (videoOutput && videoOutput.srcObject !== evt.streams[0]) {
174
+ videoOutput.srcObject = evt.streams[0];
175
+ }
176
+ });
177
+
178
+ const dataChannel = peerConnection.createDataChannel('text');
179
+ dataChannel.onmessage = (event) => {
180
+ const eventJson = JSON.parse(event.data);
181
+ if (eventJson.type === "send_input") {
182
+ updateConfThreshold(confThreshold.value);
183
+ }
184
+ };
185
+
186
+ const offer = await peerConnection.createOffer();
187
+ await peerConnection.setLocalDescription(offer);
188
+
189
+ await new Promise((resolve) => {
190
+ if (peerConnection.iceGatheringState === "complete") {
191
+ resolve();
192
+ } else {
193
+ const checkState = () => {
194
+ if (peerConnection.iceGatheringState === "complete") {
195
+ peerConnection.removeEventListener("icegatheringstatechange", checkState);
196
+ resolve();
197
+ }
198
+ };
199
+ peerConnection.addEventListener("icegatheringstatechange", checkState);
200
+ }
201
+ });
202
+
203
+ webrtc_id = Math.random().toString(36).substring(7);
204
+
205
+ const response = await fetch('/webrtc/offer', {
206
+ method: 'POST',
207
+ headers: { 'Content-Type': 'application/json' },
208
+ body: JSON.stringify({
209
+ sdp: peerConnection.localDescription.sdp,
210
+ type: peerConnection.localDescription.type,
211
+ webrtc_id: webrtc_id
212
+ })
213
+ });
214
+
215
+ const serverResponse = await response.json();
216
+ await peerConnection.setRemoteDescription(serverResponse);
217
+
218
+ // Send initial confidence threshold
219
+ updateConfThreshold(confThreshold.value);
220
+
221
+ } catch (err) {
222
+ console.error('Error setting up WebRTC:', err);
223
+ }
224
+ }
225
+
226
+ function stop() {
227
+ if (peerConnection) {
228
+ if (peerConnection.getTransceivers) {
229
+ peerConnection.getTransceivers().forEach(transceiver => {
230
+ if (transceiver.stop) {
231
+ transceiver.stop();
232
+ }
233
+ });
234
+ }
235
+
236
+ if (peerConnection.getSenders) {
237
+ peerConnection.getSenders().forEach(sender => {
238
+ if (sender.track && sender.track.stop) sender.track.stop();
239
+ });
240
+ }
241
+
242
+ setTimeout(() => {
243
+ peerConnection.close();
244
+ }, 500);
245
+ }
246
+
247
+ videoOutput.srcObject = null;
248
+ }
249
+
250
+ startButton.addEventListener('click', () => {
251
+ if (startButton.textContent === 'Start') {
252
+ setupWebRTC();
253
+ startButton.textContent = 'Stop';
254
+ } else {
255
+ stop();
256
+ startButton.textContent = 'Start';
257
+ }
258
+ });
259
+ </script>
260
+ </body>
261
+
262
+ </html>
inference.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import onnxruntime
6
+
7
+ try:
8
+ from demo.object_detection.utils import draw_detections
9
+ except ImportError:
10
+ from .utils import draw_detections
11
+
12
+
13
+ class YOLOv10:
14
+ def __init__(self, path):
15
+ # Initialize model
16
+ self.initialize_model(path)
17
+
18
+ def __call__(self, image):
19
+ return self.detect_objects(image)
20
+
21
+ def initialize_model(self, path):
22
+ self.session = onnxruntime.InferenceSession(
23
+ path, providers=onnxruntime.get_available_providers()
24
+ )
25
+ # Get model info
26
+ self.get_input_details()
27
+ self.get_output_details()
28
+
29
+ def detect_objects(self, image, conf_threshold=0.3):
30
+ input_tensor = self.prepare_input(image)
31
+
32
+ # Perform inference on the image
33
+ new_image = self.inference(image, input_tensor, conf_threshold)
34
+
35
+ return new_image
36
+
37
+ def prepare_input(self, image):
38
+ self.img_height, self.img_width = image.shape[:2]
39
+
40
+ input_img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
41
+
42
+ # Resize input image
43
+ input_img = cv2.resize(input_img, (self.input_width, self.input_height))
44
+
45
+ # Scale input pixel values to 0 to 1
46
+ input_img = input_img / 255.0
47
+ input_img = input_img.transpose(2, 0, 1)
48
+ input_tensor = input_img[np.newaxis, :, :, :].astype(np.float32)
49
+
50
+ return input_tensor
51
+
52
+ def inference(self, image, input_tensor, conf_threshold=0.3):
53
+ start = time.perf_counter()
54
+ outputs = self.session.run(
55
+ self.output_names, {self.input_names[0]: input_tensor}
56
+ )
57
+
58
+ print(f"Inference time: {(time.perf_counter() - start) * 1000:.2f} ms")
59
+ (
60
+ boxes,
61
+ scores,
62
+ class_ids,
63
+ ) = self.process_output(outputs, conf_threshold)
64
+ return self.draw_detections(image, boxes, scores, class_ids)
65
+
66
+ def process_output(self, output, conf_threshold=0.3):
67
+ predictions = np.squeeze(output[0])
68
+
69
+ # Filter out object confidence scores below threshold
70
+ scores = predictions[:, 4]
71
+ predictions = predictions[scores > conf_threshold, :]
72
+ scores = scores[scores > conf_threshold]
73
+
74
+ if len(scores) == 0:
75
+ return [], [], []
76
+
77
+ # Get the class with the highest confidence
78
+ class_ids = predictions[:, 5].astype(int)
79
+
80
+ # Get bounding boxes for each object
81
+ boxes = self.extract_boxes(predictions)
82
+
83
+ return boxes, scores, class_ids
84
+
85
+ def extract_boxes(self, predictions):
86
+ # Extract boxes from predictions
87
+ boxes = predictions[:, :4]
88
+
89
+ # Scale boxes to original image dimensions
90
+ boxes = self.rescale_boxes(boxes)
91
+
92
+ # Convert boxes to xyxy format
93
+ # boxes = xywh2xyxy(boxes)
94
+
95
+ return boxes
96
+
97
+ def rescale_boxes(self, boxes):
98
+ # Rescale boxes to original image dimensions
99
+ input_shape = np.array(
100
+ [self.input_width, self.input_height, self.input_width, self.input_height]
101
+ )
102
+ boxes = np.divide(boxes, input_shape, dtype=np.float32)
103
+ boxes *= np.array(
104
+ [self.img_width, self.img_height, self.img_width, self.img_height]
105
+ )
106
+ return boxes
107
+
108
+ def draw_detections(
109
+ self, image, boxes, scores, class_ids, draw_scores=True, mask_alpha=0.4
110
+ ):
111
+ return draw_detections(image, boxes, scores, class_ids, mask_alpha)
112
+
113
+ def get_input_details(self):
114
+ model_inputs = self.session.get_inputs()
115
+ self.input_names = [model_inputs[i].name for i in range(len(model_inputs))]
116
+
117
+ self.input_shape = model_inputs[0].shape
118
+ self.input_height = self.input_shape[2]
119
+ self.input_width = self.input_shape[3]
120
+
121
+ def get_output_details(self):
122
+ model_outputs = self.session.get_outputs()
123
+ self.output_names = [model_outputs[i].name for i in range(len(model_outputs))]
124
+
125
+
126
+ if __name__ == "__main__":
127
+ import tempfile
128
+
129
+ import requests
130
+ from huggingface_hub import hf_hub_download
131
+
132
+ model_file = hf_hub_download(
133
+ repo_id="onnx-community/yolov10s", filename="onnx/model.onnx"
134
+ )
135
+
136
+ yolov8_detector = YOLOv10(model_file)
137
+
138
+ with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f:
139
+ f.write(
140
+ requests.get(
141
+ "https://live.staticflickr.com/13/19041780_d6fd803de0_3k.jpg"
142
+ ).content
143
+ )
144
+ f.seek(0)
145
+ img = cv2.imread(f.name)
146
+
147
+ # # Detect Objects
148
+ combined_image = yolov8_detector.detect_objects(img)
149
+
150
+ # Draw detections
151
+ cv2.namedWindow("Output", cv2.WINDOW_NORMAL)
152
+ cv2.imshow("Output", combined_image)
153
+ cv2.waitKey(0)
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ fastrtc[vad]==0.0.32rc1
2
+ opencv-python
utils.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+ class_names = [
5
+ "person",
6
+ "bicycle",
7
+ "car",
8
+ "motorcycle",
9
+ "airplane",
10
+ "bus",
11
+ "train",
12
+ "truck",
13
+ "boat",
14
+ "traffic light",
15
+ "fire hydrant",
16
+ "stop sign",
17
+ "parking meter",
18
+ "bench",
19
+ "bird",
20
+ "cat",
21
+ "dog",
22
+ "horse",
23
+ "sheep",
24
+ "cow",
25
+ "elephant",
26
+ "bear",
27
+ "zebra",
28
+ "giraffe",
29
+ "backpack",
30
+ "umbrella",
31
+ "handbag",
32
+ "tie",
33
+ "suitcase",
34
+ "frisbee",
35
+ "skis",
36
+ "snowboard",
37
+ "sports ball",
38
+ "kite",
39
+ "baseball bat",
40
+ "baseball glove",
41
+ "skateboard",
42
+ "surfboard",
43
+ "tennis racket",
44
+ "bottle",
45
+ "wine glass",
46
+ "cup",
47
+ "fork",
48
+ "knife",
49
+ "spoon",
50
+ "bowl",
51
+ "banana",
52
+ "apple",
53
+ "sandwich",
54
+ "orange",
55
+ "broccoli",
56
+ "carrot",
57
+ "hot dog",
58
+ "pizza",
59
+ "donut",
60
+ "cake",
61
+ "chair",
62
+ "couch",
63
+ "potted plant",
64
+ "bed",
65
+ "dining table",
66
+ "toilet",
67
+ "tv",
68
+ "laptop",
69
+ "mouse",
70
+ "remote",
71
+ "keyboard",
72
+ "cell phone",
73
+ "microwave",
74
+ "oven",
75
+ "toaster",
76
+ "sink",
77
+ "refrigerator",
78
+ "book",
79
+ "clock",
80
+ "vase",
81
+ "scissors",
82
+ "teddy bear",
83
+ "hair drier",
84
+ "toothbrush",
85
+ ]
86
+
87
+ # Create a list of colors for each class where each color is a tuple of 3 integer values
88
+ rng = np.random.default_rng(3)
89
+ colors = rng.uniform(0, 255, size=(len(class_names), 3))
90
+
91
+
92
+ def nms(boxes, scores, iou_threshold):
93
+ # Sort by score
94
+ sorted_indices = np.argsort(scores)[::-1]
95
+
96
+ keep_boxes = []
97
+ while sorted_indices.size > 0:
98
+ # Pick the last box
99
+ box_id = sorted_indices[0]
100
+ keep_boxes.append(box_id)
101
+
102
+ # Compute IoU of the picked box with the rest
103
+ ious = compute_iou(boxes[box_id, :], boxes[sorted_indices[1:], :])
104
+
105
+ # Remove boxes with IoU over the threshold
106
+ keep_indices = np.where(ious < iou_threshold)[0]
107
+
108
+ # print(keep_indices.shape, sorted_indices.shape)
109
+ sorted_indices = sorted_indices[keep_indices + 1]
110
+
111
+ return keep_boxes
112
+
113
+
114
+ def multiclass_nms(boxes, scores, class_ids, iou_threshold):
115
+ unique_class_ids = np.unique(class_ids)
116
+
117
+ keep_boxes = []
118
+ for class_id in unique_class_ids:
119
+ class_indices = np.where(class_ids == class_id)[0]
120
+ class_boxes = boxes[class_indices, :]
121
+ class_scores = scores[class_indices]
122
+
123
+ class_keep_boxes = nms(class_boxes, class_scores, iou_threshold)
124
+ keep_boxes.extend(class_indices[class_keep_boxes])
125
+
126
+ return keep_boxes
127
+
128
+
129
+ def compute_iou(box, boxes):
130
+ # Compute xmin, ymin, xmax, ymax for both boxes
131
+ xmin = np.maximum(box[0], boxes[:, 0])
132
+ ymin = np.maximum(box[1], boxes[:, 1])
133
+ xmax = np.minimum(box[2], boxes[:, 2])
134
+ ymax = np.minimum(box[3], boxes[:, 3])
135
+
136
+ # Compute intersection area
137
+ intersection_area = np.maximum(0, xmax - xmin) * np.maximum(0, ymax - ymin)
138
+
139
+ # Compute union area
140
+ box_area = (box[2] - box[0]) * (box[3] - box[1])
141
+ boxes_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
142
+ union_area = box_area + boxes_area - intersection_area
143
+
144
+ # Compute IoU
145
+ iou = intersection_area / union_area
146
+
147
+ return iou
148
+
149
+
150
+ def xywh2xyxy(x):
151
+ # Convert bounding box (x, y, w, h) to bounding box (x1, y1, x2, y2)
152
+ y = np.copy(x)
153
+ y[..., 0] = x[..., 0] - x[..., 2] / 2
154
+ y[..., 1] = x[..., 1] - x[..., 3] / 2
155
+ y[..., 2] = x[..., 0] + x[..., 2] / 2
156
+ y[..., 3] = x[..., 1] + x[..., 3] / 2
157
+ return y
158
+
159
+
160
+ def draw_detections(image, boxes, scores, class_ids, mask_alpha=0.3):
161
+ det_img = image.copy()
162
+
163
+ img_height, img_width = image.shape[:2]
164
+ font_size = min([img_height, img_width]) * 0.0006
165
+ text_thickness = int(min([img_height, img_width]) * 0.001)
166
+
167
+ # det_img = draw_masks(det_img, boxes, class_ids, mask_alpha)
168
+
169
+ # Draw bounding boxes and labels of detections
170
+ for class_id, box, score in zip(class_ids, boxes, scores):
171
+ color = colors[class_id]
172
+
173
+ draw_box(det_img, box, color) # type: ignore
174
+
175
+ label = class_names[class_id]
176
+ caption = f"{label} {int(score * 100)}%"
177
+ draw_text(det_img, caption, box, color, font_size, text_thickness) # type: ignore
178
+
179
+ return det_img
180
+
181
+
182
+ def draw_box(
183
+ image: np.ndarray,
184
+ box: np.ndarray,
185
+ color: tuple[int, int, int] = (0, 0, 255),
186
+ thickness: int = 2,
187
+ ) -> np.ndarray:
188
+ x1, y1, x2, y2 = box.astype(int)
189
+ return cv2.rectangle(image, (x1, y1), (x2, y2), color, thickness)
190
+
191
+
192
+ def draw_text(
193
+ image: np.ndarray,
194
+ text: str,
195
+ box: np.ndarray,
196
+ color: tuple[int, int, int] = (0, 0, 255),
197
+ font_size: float = 0.001,
198
+ text_thickness: int = 2,
199
+ ) -> np.ndarray:
200
+ x1, y1, x2, y2 = box.astype(int)
201
+ (tw, th), _ = cv2.getTextSize(
202
+ text=text,
203
+ fontFace=cv2.FONT_HERSHEY_SIMPLEX,
204
+ fontScale=font_size,
205
+ thickness=text_thickness,
206
+ )
207
+ th = int(th * 1.2)
208
+
209
+ cv2.rectangle(image, (x1, y1), (x1 + tw, y1 - th), color, -1)
210
+
211
+ return cv2.putText(
212
+ image,
213
+ text,
214
+ (x1, y1),
215
+ cv2.FONT_HERSHEY_SIMPLEX,
216
+ font_size,
217
+ (255, 255, 255),
218
+ text_thickness,
219
+ cv2.LINE_AA,
220
+ )
221
+
222
+
223
+ def draw_masks(
224
+ image: np.ndarray, boxes: np.ndarray, classes: np.ndarray, mask_alpha: float = 0.3
225
+ ) -> np.ndarray:
226
+ mask_img = image.copy()
227
+
228
+ # Draw bounding boxes and labels of detections
229
+ for box, class_id in zip(boxes, classes):
230
+ color = colors[class_id]
231
+
232
+ x1, y1, x2, y2 = box.astype(int)
233
+
234
+ # Draw fill rectangle in mask image
235
+ cv2.rectangle(mask_img, (x1, y1), (x2, y2), color, -1) # type: ignore
236
+
237
+ return cv2.addWeighted(mask_img, mask_alpha, image, 1 - mask_alpha, 0)