aliciiavs commited on
Commit
e946a46
1 Parent(s): 831bc4d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +166 -0
app.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Object detection demo with MobileNet SSD.
2
+ This model and code are based on
3
+ https://github.com/robmarkcole/object-detection-app
4
+ """
5
+
6
+ import logging
7
+ import queue
8
+ from pathlib import Path
9
+ from typing import List, NamedTuple
10
+
11
+ import av
12
+ import cv2
13
+ import numpy as np
14
+ import streamlit as st
15
+ from streamlit_webrtc import WebRtcMode, webrtc_streamer
16
+
17
+ from sample_utils.download import download_file
18
+ from sample_utils.turn import get_ice_servers
19
+
20
+ HERE = Path(__file__).parent
21
+ ROOT = HERE.parent
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ MODEL_URL = "https://github.com/robmarkcole/object-detection-app/raw/master/model/MobileNetSSD_deploy.caffemodel" # noqa: E501
27
+ MODEL_LOCAL_PATH = ROOT / "./models/MobileNetSSD_deploy.caffemodel"
28
+ PROTOTXT_URL = "https://github.com/robmarkcole/object-detection-app/raw/master/model/MobileNetSSD_deploy.prototxt.txt" # noqa: E501
29
+ PROTOTXT_LOCAL_PATH = ROOT / "./models/MobileNetSSD_deploy.prototxt.txt"
30
+
31
+ CLASSES = [
32
+ "background",
33
+ "aeroplane",
34
+ "bicycle",
35
+ "bird",
36
+ "boat",
37
+ "bottle",
38
+ "bus",
39
+ "car",
40
+ "cat",
41
+ "chair",
42
+ "cow",
43
+ "diningtable",
44
+ "dog",
45
+ "horse",
46
+ "motorbike",
47
+ "person",
48
+ "pottedplant",
49
+ "sheep",
50
+ "sofa",
51
+ "train",
52
+ "tvmonitor",
53
+ ]
54
+
55
+
56
+ class Detection(NamedTuple):
57
+ class_id: int
58
+ label: str
59
+ score: float
60
+ box: np.ndarray
61
+
62
+
63
+ @st.cache_resource # type: ignore
64
+ def generate_label_colors():
65
+ return np.random.uniform(0, 255, size=(len(CLASSES), 3))
66
+
67
+
68
+ COLORS = generate_label_colors()
69
+
70
+ download_file(MODEL_URL, MODEL_LOCAL_PATH, expected_size=23147564)
71
+ download_file(PROTOTXT_URL, PROTOTXT_LOCAL_PATH, expected_size=29353)
72
+
73
+
74
+ # Session-specific caching
75
+ cache_key = "object_detection_dnn"
76
+ if cache_key in st.session_state:
77
+ net = st.session_state[cache_key]
78
+ else:
79
+ net = cv2.dnn.readNetFromCaffe(str(PROTOTXT_LOCAL_PATH), str(MODEL_LOCAL_PATH))
80
+ st.session_state[cache_key] = net
81
+
82
+ score_threshold = st.slider("Score threshold", 0.0, 1.0, 0.5, 0.05)
83
+
84
+ # NOTE: The callback will be called in another thread,
85
+ # so use a queue here for thread-safety to pass the data
86
+ # from inside to outside the callback.
87
+ # TODO: A general-purpose shared state object may be more useful.
88
+ result_queue: "queue.Queue[List[Detection]]" = queue.Queue()
89
+
90
+
91
+ def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame:
92
+ image = frame.to_ndarray(format="bgr24")
93
+
94
+ # Run inference
95
+ blob = cv2.dnn.blobFromImage(
96
+ cv2.resize(image, (300, 300)), 0.007843, (300, 300), 127.5
97
+ )
98
+ net.setInput(blob)
99
+ output = net.forward()
100
+
101
+ h, w = image.shape[:2]
102
+
103
+ # Convert the output array into a structured form.
104
+ output = output.squeeze() # (1, 1, N, 7) -> (N, 7)
105
+ output = output[output[:, 2] >= score_threshold]
106
+ detections = [
107
+ Detection(
108
+ class_id=int(detection[1]),
109
+ label=CLASSES[int(detection[1])],
110
+ score=float(detection[2]),
111
+ box=(detection[3:7] * np.array([w, h, w, h])),
112
+ )
113
+ for detection in output
114
+ ]
115
+
116
+ # Render bounding boxes and captions
117
+ for detection in detections:
118
+ caption = f"{detection.label}: {round(detection.score * 100, 2)}%"
119
+ color = COLORS[detection.class_id]
120
+ xmin, ymin, xmax, ymax = detection.box.astype("int")
121
+
122
+ cv2.rectangle(image, (xmin, ymin), (xmax, ymax), color, 2)
123
+ cv2.putText(
124
+ image,
125
+ caption,
126
+ (xmin, ymin - 15 if ymin - 15 > 15 else ymin + 15),
127
+ cv2.FONT_HERSHEY_SIMPLEX,
128
+ 0.5,
129
+ color,
130
+ 2,
131
+ )
132
+
133
+ result_queue.put(detections)
134
+
135
+ return av.VideoFrame.from_ndarray(image, format="bgr24")
136
+
137
+
138
+ webrtc_ctx = webrtc_streamer(
139
+ key="object-detection",
140
+ mode=WebRtcMode.SENDRECV,
141
+ rtc_configuration={
142
+ "iceServers": get_ice_servers(),
143
+ "iceTransportPolicy": "relay",
144
+ },
145
+ video_frame_callback=video_frame_callback,
146
+ media_stream_constraints={"video": True, "audio": False},
147
+ async_processing=True,
148
+ )
149
+
150
+ if st.checkbox("Show the detected labels", value=True):
151
+ if webrtc_ctx.state.playing:
152
+ labels_placeholder = st.empty()
153
+ # NOTE: The video transformation with object detection and
154
+ # this loop displaying the result labels are running
155
+ # in different threads asynchronously.
156
+ # Then the rendered video frames and the labels displayed here
157
+ # are not strictly synchronized.
158
+ while True:
159
+ result = result_queue.get()
160
+ labels_placeholder.table(result)
161
+
162
+ st.markdown(
163
+ "This demo uses a model and code from "
164
+ "https://github.com/robmarkcole/object-detection-app. "
165
+ "Many thanks to the project."
166
+ )