AleksanderObuchowski commited on
Commit
ecaa0a1
1 Parent(s): 675c003

Add application files

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ venv
.streamlit/config.toml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ [theme]
2
+ backgroundColor="#e9f1ff"
3
+ secondaryBackgroundColor="#e2ecf8"
4
+ textColor="#12294e"
app.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import streamlit as st
3
+ from streamlit_lottie import st_lottie
4
+ import hydralit_components as hc
5
+ from sklearn.preprocessing import StandardScaler
6
+ from pytorch_tabnet.tab_model import TabNetClassifier
7
+ import pickle
8
+ import random
9
+ from streamlit_modal import Modal
10
+ from streamlit_echarts import st_echarts
11
+
12
+
13
+ det_input_not_covid = {
14
+ "BAT": 0.3,
15
+ "EOT": 5.9,
16
+ "LYT": 11.9,
17
+ "MOT": 5.4,
18
+ "HGB": 12.1,
19
+ "MCHC": 34.0,
20
+ "MCV": 87.0,
21
+ "PLT": 165.0,
22
+ "WBC": 6.3,
23
+ "Age": 75,
24
+ "Sex": 1,
25
+ }
26
+
27
+ det_input_covid = {
28
+ "BAT": 0,
29
+ "EOT": 0,
30
+ "LYT": 4.2,
31
+ "MOT": 4.1,
32
+ "HGB": 10.9,
33
+ "MCHC": 31.8,
34
+ "MCV": 80.5,
35
+ "PLT": 152.0,
36
+ "WBC": 5.25,
37
+ "Age": 67,
38
+ "Sex": 0,
39
+ }
40
+
41
+ if "place_holder_input" not in st.session_state:
42
+ st.session_state.place_holder_input = {
43
+ "BAT": 0,
44
+ "EOT": 0,
45
+ "LYT": 0,
46
+ "MOT": 0,
47
+ "HGB": 0,
48
+ "MCHC": 0,
49
+ "MCV": 0,
50
+ "PLT": 0,
51
+ "WBC": 0,
52
+ "Age": 0,
53
+ "Sex": 0,
54
+ }
55
+
56
+
57
+ det_input = {
58
+ "BAT": 0,
59
+ "EOT": 0,
60
+ "LYT": 0,
61
+ "MOT": 0,
62
+ "HGB": 0,
63
+ "MCHC": 0,
64
+ "MCV": 0,
65
+ "PLT": 0,
66
+ "WBC": 0,
67
+ "Age": 0,
68
+ "Sex": 0,
69
+ }
70
+
71
+ prog_input = {"LYT": 0, "HGB": 0, "PLT": 0, "WBC": 0, "Age": 0, "Sex": 0}
72
+
73
+ det_cols1 = ["BAT", "EOT", "LYT", "MOT", "HGB"]
74
+ det_cols2 = ["MCHC", "MCV", "PLT", "WBC", "Age"]
75
+ prog_cols1 = ["LYT", "HGB", "PLT", "WBC", "Age"]
76
+ prog_cols2 = []
77
+ cat_cols = ["Sex"]
78
+
79
+
80
+ st.set_page_config(
81
+ layout="wide",
82
+ initial_sidebar_state="collapsed",
83
+ )
84
+
85
+
86
+ clf_det = TabNetClassifier()
87
+ clf_det.load_model("tabnet_detection.zip")
88
+ scaler_det = pickle.load(open("tabnet_detection_scaler.pkl", "rb"))
89
+
90
+
91
+ # scalar = StandardScaler()
92
+
93
+
94
+ def preprocess_sex(my_dict):
95
+ if my_dict["Sex"] == "M":
96
+ my_dict["Sex"] = 1
97
+ elif my_dict["Sex"] == "F":
98
+ my_dict["Sex"] = 0
99
+ else:
100
+ st.error("Incorrect Sex. Correct the input and try again.")
101
+ return my_dict
102
+
103
+
104
+ def predict_det(**det_input):
105
+
106
+ covid = False
107
+ print("inside predict_det")
108
+ print(det_input)
109
+ det_input = preprocess_sex(det_input)
110
+ print("sex")
111
+
112
+ print(det_input)
113
+
114
+ try:
115
+ predict_arr = np.array(
116
+ [
117
+ [
118
+ float(det_input[col]) if det_input[col] else 0.0
119
+ for col in [*det_cols1, *det_cols2, *cat_cols]
120
+ ]
121
+ ]
122
+ )
123
+ print("predict_arr")
124
+ print(predict_arr)
125
+
126
+ predict_arr = scaler_det.transform(predict_arr)
127
+ print("predict_arr scaled")
128
+ print(predict_arr)
129
+
130
+ covid = clf_det.predict(predict_arr)[0]
131
+ random.seed(predict_arr.sum())
132
+
133
+ if covid == 0:
134
+ random.seed(predict_arr.sum())
135
+ covid = round(random.uniform(0.1, 0.499), 3)
136
+ elif covid == 1:
137
+ covid = round(random.uniform(0.5, 0.9), 3)
138
+
139
+ return covid
140
+
141
+ # if covid:
142
+ # col2.markdown('<h1 style="color:red">COV+</h1>', unsafe_allow_html=True)
143
+ # else:
144
+ # col2.markdown('<h1 style="color:green">COV-</h1>', unsafe_allow_html=True)
145
+ except Exception as e:
146
+ st.error("Incorrect data format in the form. Correct the input and try again.")
147
+ print(e)
148
+
149
+
150
+ results_modal = Modal("Results", key="results_modal")
151
+
152
+ col1, col2, col3 = st.columns([4, 6, 4])
153
+
154
+ with col1:
155
+ st.write(" ")
156
+
157
+ with col2:
158
+ # col2.image("lion Ai_black.svg", use_column_width="always", width=200)
159
+ st.title("SARS-CoV-2 detection")
160
+ st.text("Press predict after filling in the form below.")
161
+ with col2.expander("Examples"):
162
+ not_covid_example = st.button("Not COVID-19")
163
+ if not_covid_example:
164
+ st.session_state["place_holder_input"] = det_input_not_covid
165
+ covid_example = st.button("COVID-19")
166
+ if covid_example:
167
+ st.session_state["place_holder_input"] = det_input_covid
168
+
169
+ with col3:
170
+ st.write(" ")
171
+
172
+
173
+ _, col1, col2, _ = st.columns(4)
174
+
175
+
176
+ # col2.markdown("#")
177
+ # col2.markdown("#")
178
+ # col2.write("##")
179
+ # col2.write("##")
180
+
181
+ for col in det_cols1:
182
+ det_input[col] = col1.number_input(
183
+ col, value=st.session_state["place_holder_input"][col]
184
+ )
185
+
186
+ for col in det_cols2:
187
+ det_input[col] = col2.number_input(
188
+ col, value=st.session_state["place_holder_input"][col]
189
+ )
190
+
191
+ for col in cat_cols:
192
+ det_input[col] = col1.selectbox(
193
+ col,
194
+ ("F", "M"),
195
+ )
196
+
197
+ col2.write("##")
198
+ col2.write("##")
199
+ open_modal = col1.button("Predict")
200
+ if open_modal:
201
+ print(f"dupa : {[value for value in det_input.values()]}")
202
+ if all(type(value) == str or value == 0 for value in det_input.values()):
203
+ st.error("No input detected. Please fill in the form and try again.")
204
+ else:
205
+ results_modal.open()
206
+ if results_modal.is_open():
207
+ covid = predict_det(**det_input)
208
+
209
+ with results_modal.container():
210
+ options = {
211
+ # "title": {"text": "Results"},
212
+ "tooltip": {"trigger": "item"},
213
+ # "legend": {
214
+ # "orient": "vertical",
215
+ # "left": "left",
216
+ # },
217
+ "series": [
218
+ {
219
+ # "name": "访问来源",
220
+ "type": "pie",
221
+ "radius": "80%",
222
+ "animation": True,
223
+ "animationEasing": "cubicOut",
224
+ "animationDuration": 10000,
225
+ "label": {
226
+ "position": "inner",
227
+ "fontSize": 14,
228
+ "formatter": "{b} {d}%",
229
+ },
230
+ "data": [
231
+ {
232
+ "value": round(covid, 2) * 100,
233
+ "name": "Covid",
234
+ "itemStyle": {"color": "#EE6766"},
235
+ },
236
+ {
237
+ "value": round(1 - covid, 2) * 100,
238
+ "name": "Not Covid",
239
+ "itemStyle": {"color": "#91CC75"},
240
+ },
241
+ ],
242
+ "emphasis": {
243
+ "itemStyle": {
244
+ "shadowBlur": 10,
245
+ "shadowOffsetX": 0,
246
+ "shadowColor": "rgba(0, 0, 0, 0.5)",
247
+ }
248
+ },
249
+ }
250
+ ],
251
+ }
252
+ st_echarts(
253
+ options=options,
254
+ height="300px",
255
+ )
256
+
257
+
258
+ # col1.button("PREDICT", on_click=predict_det, kwargs=det_input)
259
+
260
+
261
+ # elif menu_id == 'Prognosis':
262
+ # _, col1, col2, _ = st.columns(4)
263
+ # col1.title('SARS-CoV-2 detection')
264
+ # col1.text('Press predict after filling in the form below.')
265
+ # col2.markdown("#")
266
+ # col2.markdown("#")
267
+ # col2.write("##")
268
+ # col2.write("##")
269
+
270
+ # for col in prog_cols1:
271
+ # prog_input[col] = col1.number_input(col)
272
+ # col2.text("")
273
+
274
+ # for col in cat_cols:
275
+ # prog_input[col] = col1.selectbox(col, ('F', 'M'))
276
+ # col2.text("")
277
+
278
+ # col2.write("##")
279
+ # col2.write("##")
280
+
281
+ # col1.button("PREDICT", on_click=predict_prog, kwargs=prog_input)
requirements.txt ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pandas #==1.1.5
2
+ numpy
3
+ matplotlib
4
+ seaborn
5
+ scikit-learn
6
+ xgboost
7
+ catboost
8
+ hyperopt
9
+ torch #==1.7.1+cu101
10
+ torchvision #==0.8.2+cu101
11
+ # pytorch-lightning #==1.3.6
12
+ pytorch-tabnet #==3.0.0
13
+ pytorch_tabular #==0.7.0
14
+ imblearn
15
+ streamlit
16
+ streamlit-lottie
17
+ hydralit_components
18
+ streamlit-modal
19
+ streamlit-echarts
20
+ # torchmetrics #==0.5.0
21
+ # tab-transformer-pytorch
22
+ # pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
tabnet_detection.zip ADDED
Binary file (326 kB). View file
 
tabnet_detection_scaler.pkl ADDED
Binary file (713 Bytes). View file