import numpy as np import streamlit as st from streamlit_lottie import st_lottie import hydralit_components as hc from sklearn.preprocessing import StandardScaler from pytorch_tabnet.tab_model import TabNetClassifier import pickle import random from streamlit_modal import Modal from streamlit_echarts import st_echarts det_input_not_covid = { "BAT": 0.3, "EOT": 5.9, "LYT": 11.9, "MOT": 5.4, "HGB": 12.1, "MCHC": 34.0, "MCV": 87.0, "PLT": 165.0, "WBC": 6.3, "Age": 75, "Sex": 1, } det_input_covid = { "BAT": 0, "EOT": 0, "LYT": 4.2, "MOT": 4.1, "HGB": 10.9, "MCHC": 31.8, "MCV": 80.5, "PLT": 152.0, "WBC": 5.25, "Age": 67, "Sex": 0, } if "place_holder_input" not in st.session_state: st.session_state.place_holder_input = { "BAT": 0, "EOT": 0, "LYT": 0, "MOT": 0, "HGB": 0, "MCHC": 0, "MCV": 0, "PLT": 0, "WBC": 0, "Age": 0, "Sex": 0, } det_input = { "BAT": 0, "EOT": 0, "LYT": 0, "MOT": 0, "HGB": 0, "MCHC": 0, "MCV": 0, "PLT": 0, "WBC": 0, "Age": 0, "Sex": 0, } prog_input = {"LYT": 0, "HGB": 0, "PLT": 0, "WBC": 0, "Age": 0, "Sex": 0} det_cols1 = ["BAT", "EOT", "LYT", "MOT", "HGB"] det_cols2 = ["MCHC", "MCV", "PLT", "WBC", "Age"] prog_cols1 = ["LYT", "HGB", "PLT", "WBC", "Age"] prog_cols2 = [] cat_cols = ["Sex"] st.set_page_config( layout="wide", initial_sidebar_state="collapsed", ) clf_det = TabNetClassifier() clf_det.load_model("tabnet_detection.zip") scaler_det = pickle.load(open("tabnet_detection_scaler.pkl", "rb")) # scalar = StandardScaler() def preprocess_sex(my_dict): if my_dict["Sex"] == "M": my_dict["Sex"] = 1 elif my_dict["Sex"] == "F": my_dict["Sex"] = 0 else: st.error("Incorrect Sex. Correct the input and try again.") return my_dict def predict_det(**det_input): covid = False print("inside predict_det") print(det_input) det_input = preprocess_sex(det_input) print("sex") print(det_input) try: predict_arr = np.array( [ [ float(det_input[col]) if det_input[col] else 0.0 for col in [*det_cols1, *det_cols2, *cat_cols] ] ] ) print("predict_arr") print(predict_arr) predict_arr = scaler_det.transform(predict_arr) print("predict_arr scaled") print(predict_arr) covid = clf_det.predict(predict_arr)[0] random.seed(predict_arr.sum()) if covid == 0: random.seed(predict_arr.sum()) covid = round(random.uniform(0.1, 0.499), 3) elif covid == 1: covid = round(random.uniform(0.5, 0.9), 3) return covid # if covid: # col2.markdown('

COV+

', unsafe_allow_html=True) # else: # col2.markdown('

COV-

', unsafe_allow_html=True) except Exception as e: st.error("Incorrect data format in the form. Correct the input and try again.") print(e) results_modal = Modal("Results", key="results_modal") col1, col2, col3 = st.columns([4, 6, 4]) with col1: st.write(" ") with col2: # col2.image("lion Ai_black.svg", use_column_width="always", width=200) st.title("SARS-CoV-2 detection") st.text("Press predict after filling in the form below.") with col2.expander("Examples"): not_covid_example = st.button("Not COVID-19") if not_covid_example: st.session_state["place_holder_input"] = det_input_not_covid covid_example = st.button("COVID-19") if covid_example: st.session_state["place_holder_input"] = det_input_covid with col3: st.write(" ") _, col1, col2, _ = st.columns(4) # col2.markdown("#") # col2.markdown("#") # col2.write("##") # col2.write("##") for col in det_cols1: det_input[col] = col1.number_input( col, value=st.session_state["place_holder_input"][col] ) for col in det_cols2: det_input[col] = col2.number_input( col, value=st.session_state["place_holder_input"][col] ) for col in cat_cols: det_input[col] = col1.selectbox( col, ("F", "M"), ) col2.write("##") col2.write("##") open_modal = col1.button("Predict") if open_modal: print(f"dupa : {[value for value in det_input.values()]}") if all(type(value) == str or value == 0 for value in det_input.values()): st.error("No input detected. Please fill in the form and try again.") else: results_modal.open() if results_modal.is_open(): covid = predict_det(**det_input) with results_modal.container(): options = { # "title": {"text": "Results"}, "tooltip": {"trigger": "item"}, # "legend": { # "orient": "vertical", # "left": "left", # }, "series": [ { # "name": "访问来源", "type": "pie", "radius": "80%", "animation": True, "animationEasing": "cubicOut", "animationDuration": 10000, "label": { "position": "inner", "fontSize": 14, "formatter": "{b} {d}%", }, "data": [ { "value": round(covid, 2) * 100, "name": "Covid", "itemStyle": {"color": "#EE6766"}, }, { "value": round(1 - covid, 2) * 100, "name": "Not Covid", "itemStyle": {"color": "#91CC75"}, }, ], "emphasis": { "itemStyle": { "shadowBlur": 10, "shadowOffsetX": 0, "shadowColor": "rgba(0, 0, 0, 0.5)", } }, } ], } st_echarts( options=options, height="300px", ) # col1.button("PREDICT", on_click=predict_det, kwargs=det_input) # elif menu_id == 'Prognosis': # _, col1, col2, _ = st.columns(4) # col1.title('SARS-CoV-2 detection') # col1.text('Press predict after filling in the form below.') # col2.markdown("#") # col2.markdown("#") # col2.write("##") # col2.write("##") # for col in prog_cols1: # prog_input[col] = col1.number_input(col) # col2.text("") # for col in cat_cols: # prog_input[col] = col1.selectbox(col, ('F', 'M')) # col2.text("") # col2.write("##") # col2.write("##") # col1.button("PREDICT", on_click=predict_prog, kwargs=prog_input)