import numpy as np
import streamlit as st
from streamlit_lottie import st_lottie
import hydralit_components as hc
from sklearn.preprocessing import StandardScaler
from pytorch_tabnet.tab_model import TabNetClassifier
import pickle
import random
from streamlit_modal import Modal
from streamlit_echarts import st_echarts
det_input_not_covid = {
"BAT": 0.3,
"EOT": 5.9,
"LYT": 11.9,
"MOT": 5.4,
"HGB": 12.1,
"MCHC": 34.0,
"MCV": 87.0,
"PLT": 165.0,
"WBC": 6.3,
"Age": 75,
"Sex": 1,
}
det_input_covid = {
"BAT": 0,
"EOT": 0,
"LYT": 4.2,
"MOT": 4.1,
"HGB": 10.9,
"MCHC": 31.8,
"MCV": 80.5,
"PLT": 152.0,
"WBC": 5.25,
"Age": 67,
"Sex": 0,
}
if "place_holder_input" not in st.session_state:
st.session_state.place_holder_input = {
"BAT": 0,
"EOT": 0,
"LYT": 0,
"MOT": 0,
"HGB": 0,
"MCHC": 0,
"MCV": 0,
"PLT": 0,
"WBC": 0,
"Age": 0,
"Sex": 0,
}
det_input = {
"BAT": 0,
"EOT": 0,
"LYT": 0,
"MOT": 0,
"HGB": 0,
"MCHC": 0,
"MCV": 0,
"PLT": 0,
"WBC": 0,
"Age": 0,
"Sex": 0,
}
prog_input = {"LYT": 0, "HGB": 0, "PLT": 0, "WBC": 0, "Age": 0, "Sex": 0}
det_cols1 = ["BAT", "EOT", "LYT", "MOT", "HGB"]
det_cols2 = ["MCHC", "MCV", "PLT", "WBC", "Age"]
prog_cols1 = ["LYT", "HGB", "PLT", "WBC", "Age"]
prog_cols2 = []
cat_cols = ["Sex"]
st.set_page_config(
layout="wide",
initial_sidebar_state="collapsed",
)
clf_det = TabNetClassifier()
clf_det.load_model("tabnet_detection.zip")
scaler_det = pickle.load(open("tabnet_detection_scaler.pkl", "rb"))
# scalar = StandardScaler()
def preprocess_sex(my_dict):
if my_dict["Sex"] == "M":
my_dict["Sex"] = 1
elif my_dict["Sex"] == "F":
my_dict["Sex"] = 0
else:
st.error("Incorrect Sex. Correct the input and try again.")
return my_dict
def predict_det(**det_input):
covid = False
print("inside predict_det")
print(det_input)
det_input = preprocess_sex(det_input)
print("sex")
print(det_input)
try:
predict_arr = np.array(
[
[
float(det_input[col]) if det_input[col] else 0.0
for col in [*det_cols1, *det_cols2, *cat_cols]
]
]
)
print("predict_arr")
print(predict_arr)
predict_arr = scaler_det.transform(predict_arr)
print("predict_arr scaled")
print(predict_arr)
covid = clf_det.predict(predict_arr)[0]
random.seed(predict_arr.sum())
if covid == 0:
random.seed(predict_arr.sum())
covid = round(random.uniform(0.1, 0.499), 3)
elif covid == 1:
covid = round(random.uniform(0.5, 0.9), 3)
return covid
# if covid:
# col2.markdown('
COV+
', unsafe_allow_html=True)
# else:
# col2.markdown('COV-
', unsafe_allow_html=True)
except Exception as e:
st.error("Incorrect data format in the form. Correct the input and try again.")
print(e)
results_modal = Modal("Results", key="results_modal")
col1, col2, col3 = st.columns([4, 6, 4])
with col1:
st.write(" ")
with col2:
# col2.image("lion Ai_black.svg", use_column_width="always", width=200)
st.title("SARS-CoV-2 detection")
st.text("Press predict after filling in the form below.")
with col2.expander("Examples"):
not_covid_example = st.button("Not COVID-19")
if not_covid_example:
st.session_state["place_holder_input"] = det_input_not_covid
covid_example = st.button("COVID-19")
if covid_example:
st.session_state["place_holder_input"] = det_input_covid
with col3:
st.write(" ")
_, col1, col2, _ = st.columns(4)
# col2.markdown("#")
# col2.markdown("#")
# col2.write("##")
# col2.write("##")
for col in det_cols1:
det_input[col] = col1.number_input(
col, value=st.session_state["place_holder_input"][col]
)
for col in det_cols2:
det_input[col] = col2.number_input(
col, value=st.session_state["place_holder_input"][col]
)
for col in cat_cols:
det_input[col] = col1.selectbox(
col,
("F", "M"),
)
col2.write("##")
col2.write("##")
open_modal = col1.button("Predict")
if open_modal:
print(f"dupa : {[value for value in det_input.values()]}")
if all(type(value) == str or value == 0 for value in det_input.values()):
st.error("No input detected. Please fill in the form and try again.")
else:
results_modal.open()
if results_modal.is_open():
covid = predict_det(**det_input)
with results_modal.container():
options = {
# "title": {"text": "Results"},
"tooltip": {"trigger": "item"},
# "legend": {
# "orient": "vertical",
# "left": "left",
# },
"series": [
{
# "name": "访问来源",
"type": "pie",
"radius": "80%",
"animation": True,
"animationEasing": "cubicOut",
"animationDuration": 10000,
"label": {
"position": "inner",
"fontSize": 14,
"formatter": "{b} {d}%",
},
"data": [
{
"value": round(covid, 2) * 100,
"name": "Covid",
"itemStyle": {"color": "#EE6766"},
},
{
"value": round(1 - covid, 2) * 100,
"name": "Not Covid",
"itemStyle": {"color": "#91CC75"},
},
],
"emphasis": {
"itemStyle": {
"shadowBlur": 10,
"shadowOffsetX": 0,
"shadowColor": "rgba(0, 0, 0, 0.5)",
}
},
}
],
}
st_echarts(
options=options,
height="300px",
)
# col1.button("PREDICT", on_click=predict_det, kwargs=det_input)
# elif menu_id == 'Prognosis':
# _, col1, col2, _ = st.columns(4)
# col1.title('SARS-CoV-2 detection')
# col1.text('Press predict after filling in the form below.')
# col2.markdown("#")
# col2.markdown("#")
# col2.write("##")
# col2.write("##")
# for col in prog_cols1:
# prog_input[col] = col1.number_input(col)
# col2.text("")
# for col in cat_cols:
# prog_input[col] = col1.selectbox(col, ('F', 'M'))
# col2.text("")
# col2.write("##")
# col2.write("##")
# col1.button("PREDICT", on_click=predict_prog, kwargs=prog_input)