aliciiavs commited on
Commit
06219bd
·
verified ·
1 Parent(s): 1567d03

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -165
app.py CHANGED
@@ -1,166 +1,26 @@
1
- """Object detection demo with MobileNet SSD.
2
- This model and code are based on
3
- https://github.com/robmarkcole/object-detection-app
4
- """
5
-
6
- import logging
7
- import queue
8
- from pathlib import Path
9
- from typing import List, NamedTuple
10
-
11
- import av
12
- import cv2
13
- import numpy as np
14
  import streamlit as st
15
- from streamlit_webrtc import WebRtcMode, webrtc_streamer
16
-
17
- from sample_utils.download import download_file
18
- from sample_utils.turn import get_ice_servers
19
-
20
- HERE = Path(__file__).parent
21
- ROOT = HERE.parent
22
-
23
- logger = logging.getLogger(__name__)
24
-
25
-
26
- MODEL_URL = "https://github.com/robmarkcole/object-detection-app/raw/master/model/MobileNetSSD_deploy.caffemodel" # noqa: E501
27
- MODEL_LOCAL_PATH = ROOT / "./models/MobileNetSSD_deploy.caffemodel"
28
- PROTOTXT_URL = "https://github.com/robmarkcole/object-detection-app/raw/master/model/MobileNetSSD_deploy.prototxt.txt" # noqa: E501
29
- PROTOTXT_LOCAL_PATH = ROOT / "./models/MobileNetSSD_deploy.prototxt.txt"
30
-
31
- CLASSES = [
32
- "background",
33
- "aeroplane",
34
- "bicycle",
35
- "bird",
36
- "boat",
37
- "bottle",
38
- "bus",
39
- "car",
40
- "cat",
41
- "chair",
42
- "cow",
43
- "diningtable",
44
- "dog",
45
- "horse",
46
- "motorbike",
47
- "person",
48
- "pottedplant",
49
- "sheep",
50
- "sofa",
51
- "train",
52
- "tvmonitor",
53
- ]
54
-
55
-
56
- class Detection(NamedTuple):
57
- class_id: int
58
- label: str
59
- score: float
60
- box: np.ndarray
61
-
62
-
63
- @st.cache_resource # type: ignore
64
- def generate_label_colors():
65
- return np.random.uniform(0, 255, size=(len(CLASSES), 3))
66
-
67
-
68
- COLORS = generate_label_colors()
69
-
70
- download_file(MODEL_URL, MODEL_LOCAL_PATH, expected_size=23147564)
71
- download_file(PROTOTXT_URL, PROTOTXT_LOCAL_PATH, expected_size=29353)
72
-
73
-
74
- # Session-specific caching
75
- cache_key = "object_detection_dnn"
76
- if cache_key in st.session_state:
77
- net = st.session_state[cache_key]
78
- else:
79
- net = cv2.dnn.readNetFromCaffe(str(PROTOTXT_LOCAL_PATH), str(MODEL_LOCAL_PATH))
80
- st.session_state[cache_key] = net
81
-
82
- score_threshold = st.slider("Score threshold", 0.0, 1.0, 0.5, 0.05)
83
-
84
- # NOTE: The callback will be called in another thread,
85
- # so use a queue here for thread-safety to pass the data
86
- # from inside to outside the callback.
87
- # TODO: A general-purpose shared state object may be more useful.
88
- result_queue: "queue.Queue[List[Detection]]" = queue.Queue()
89
-
90
-
91
- def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame:
92
- image = frame.to_ndarray(format="bgr24")
93
-
94
- # Run inference
95
- blob = cv2.dnn.blobFromImage(
96
- cv2.resize(image, (300, 300)), 0.007843, (300, 300), 127.5
97
- )
98
- net.setInput(blob)
99
- output = net.forward()
100
-
101
- h, w = image.shape[:2]
102
-
103
- # Convert the output array into a structured form.
104
- output = output.squeeze() # (1, 1, N, 7) -> (N, 7)
105
- output = output[output[:, 2] >= score_threshold]
106
- detections = [
107
- Detection(
108
- class_id=int(detection[1]),
109
- label=CLASSES[int(detection[1])],
110
- score=float(detection[2]),
111
- box=(detection[3:7] * np.array([w, h, w, h])),
112
- )
113
- for detection in output
114
- ]
115
-
116
- # Render bounding boxes and captions
117
- for detection in detections:
118
- caption = f"{detection.label}: {round(detection.score * 100, 2)}%"
119
- color = COLORS[detection.class_id]
120
- xmin, ymin, xmax, ymax = detection.box.astype("int")
121
-
122
- cv2.rectangle(image, (xmin, ymin), (xmax, ymax), color, 2)
123
- cv2.putText(
124
- image,
125
- caption,
126
- (xmin, ymin - 15 if ymin - 15 > 15 else ymin + 15),
127
- cv2.FONT_HERSHEY_SIMPLEX,
128
- 0.5,
129
- color,
130
- 2,
131
- )
132
-
133
- result_queue.put(detections)
134
-
135
- return av.VideoFrame.from_ndarray(image, format="bgr24")
136
-
137
-
138
- webrtc_ctx = webrtc_streamer(
139
- key="object-detection",
140
- mode=WebRtcMode.SENDRECV,
141
- rtc_configuration={
142
- "iceServers": get_ice_servers(),
143
- "iceTransportPolicy": "relay",
144
- },
145
- video_frame_callback=video_frame_callback,
146
- media_stream_constraints={"video": True, "audio": False},
147
- async_processing=True,
148
- )
149
-
150
- if st.checkbox("Show the detected labels", value=True):
151
- if webrtc_ctx.state.playing:
152
- labels_placeholder = st.empty()
153
- # NOTE: The video transformation with object detection and
154
- # this loop displaying the result labels are running
155
- # in different threads asynchronously.
156
- # Then the rendered video frames and the labels displayed here
157
- # are not strictly synchronized.
158
- while True:
159
- result = result_queue.get()
160
- labels_placeholder.table(result)
161
-
162
- st.markdown(
163
- "This demo uses a model and code from "
164
- "https://github.com/robmarkcole/object-detection-app. "
165
- "Many thanks to the project. hehehhe"
166
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import pandas as pd
3
+ import pickle
4
+ # Load Model
5
+ model = pickle.load(open('logreg_model.pkl', 'rb'))
6
+ st.title('Iris Variety Prediction')
7
+ # Form
8
+ with st.form(key='form_parameters'):
9
+ sepal_length = st.slider('Sepal Length', 4.0, 8.0, 4.0)
10
+ sepal_width = st.slider('Sepal Width', 2.0, 4.5, 2.0)
11
+ petal_length = st.slider('Petal Length', 1.0, 7.0, 1.0)
12
+ petal_width = st.slider('Petal Width', 0.1, 2.5, 0.1)
13
+ st.markdown('---')
14
+ submitted = st.form_submit_button('Predict')
15
+ # Data Inference
16
+ data_inf = {
17
+ 'sepal.length': sepal_length,
18
+ 'sepal.width': sepal_width,
19
+ 'petal.length': petal_length,
20
+ 'petal.width': petal_width
21
+ }
22
+ data_inf = pd.DataFrame([data_inf])
23
+ if submitted:
24
+ # Predict using Logistic Regression
25
+ y_pred_inf = model.predict(data_inf)
26
+ st.write('## Iris Variety = '+ str(y_pred_inf))