test / app.py
Yassine EL OUAHIDI
remove shap 2
9ab6fdd
import shap
import streamlit as st
import streamlit.components.v1 as components
#import xgboost
import numpy as np
import pickle
import pandas as pd
import time
def impute_missing_values(df_,dict_impute):
return df_.fillna(dict_impute)
to_rescale_ = ['diff_MS','MS_d','pr_pre_tavi',
'surface_systole','MS_s','age',
'ncc_calcif_n','calc_risque_n']
to_encode_ = ['syncope_Oui', 'lcc_calc_1.0', 'bloc_branche_pre_bbd']
st.title("Prédiction du risque d'IPM post TAVI, Pre Opération")
cols = ['diff_MS','MS_d','pr_pre_tavi',
'surface_systole','MS_s','age',
'ncc_calcif_n','calc_risque_n',
'syncope_Oui', 'lcc_calc_1.0', 'bloc_branche_pre_bbd']
def st_shap(plot, height=None):
shap_html = f"<head>{shap.getjs()}</head><body>{plot.html()}</body>"
components.html(shap_html, height=height)
with st.form(key='cols_form'):
c1,c2,c3 = st.columns(3)
c4,c5,c6 = st.columns(3)
c7,c8,c9 = st.columns(3)
c10,c11 = st.columns(2)
with c1:
diff_MS = st.number_input('Différence SM (mm)',value=1.218)
with c2:
MS_d = st.number_input('Mesure diastolique SM (mm)',value=5.3)
with c3:
pr_pre_tavi = st.number_input('PR Pre TAVI (ms)',value=240)
with c4:
surface_systole = st.number_input("Surface systolique de l'anneau (mm²) ",value=518)
with c5:
MS_s = st.number_input('Mesure systolique SM (mm)',value=4.082)
with c6:
age = st.number_input('Age (année)',value=85.775)
with c7:
ncc_calcif_n = st.number_input('Degré calcification CNC',value=2)
with c8:
calc_risque_n = st.number_input('Degré calcificaton zone à risque',value=4)
with c9:
syncope_Oui = st.number_input('Syncope',value=0)
with c10:
lcc_calc_1 = st.number_input('Calcification CCG',value=1)
with c11:
bloc_branche_pre_bbd = st.number_input('Présence bloc de branche droit pre tavi ',value=0)
submitButton = st.form_submit_button(label = 'Predict')
#load model, set cache to prevent reloading
@st.cache_resource()
def load_model():
with open(r"svm_pre.pkl", "rb") as input_file:
model = pickle.load(input_file)
return model
# @st.cache_resource()
# def load_explainer():
# with open(r"explainer_pre.pkl", "rb") as input_file:
# explainer = pickle.load(input_file)
# return explainer
#@st.cache_resource()
def load_scaler():
with open(r"scaler_pre.pkl", "rb") as input_file:
scaler = pickle.load(input_file)
return scaler
@st.cache_resource()
def load_impute():
with open(r"dict_impute_pre.pkl", "rb") as input_file:
dict_impute = pickle.load(input_file)
return dict_impute
with st.spinner("Loading Model...."):
model = load_model()
dict_impute = load_impute()
scaler = load_scaler()
#explainer = load_explainer()
#st.write("Predicting Class...")
with st.spinner("Prediction..."):
age = age*365
pred_arr = np.array([[diff_MS,MS_d,pr_pre_tavi,
surface_systole,MS_s,age,
ncc_calcif_n,calc_risque_n,
syncope_Oui, lcc_calc_1, bloc_branche_pre_bbd]])
pred_df = pd.DataFrame(pred_arr,columns=cols)
#pred_df = impute_missing_values(pred_df,dict_impute)
df_scaled_ = scaler.transform(pred_df[to_rescale_])
df_scaled = pd.DataFrame(columns=to_rescale_,data=df_scaled_)
pred_df = pd.concat([df_scaled,pred_df[to_encode_]],axis=1)
pred=round(model.predict_proba(pred_df)[0][1]*100,2)
print(model.feature_names_in_)
print(pred_df.columns)
print(pred_df.iloc[0])
print(pred)
#shap_values = explainer.shap_values(pred_df)
#print(shap_values)
st.write("Probabilité de risque d'IPM :",pred,' %')
#st_shap(shap.force_plot(explainer.expected_value[1], shap_values[1], pred_df.iloc[0,:], link="logit"))