Spaces:
Sleeping
Sleeping
File size: 7,035 Bytes
2301fd0 475a89a 6e48224 3efc156 914085c 2301fd0 3a18a12 475a89a 2301fd0 3a18a12 9f12a38 3a18a12 2301fd0 e517d0e 2301fd0 475a89a 2301fd0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import gradio as gr
import requests
import json
import plotly
def predict_fraud(selected_model, selected_interpretability_method, step, transaction_type, amount, oldbalanceOrg):
# Validation checks
if not selected_model:
return "Model Selection is required.", None, None, None, None, None, None, None, None, None, None
if not selected_interpretability_method:
return "Interpretability Technique is required.", None, None, None, None, None, None, None, None, None, None
if step == 0: # Assuming step is a numerical value, check for None explicitly
return "Step (Transaction Time) is required.", None, None, None, None, None, None, None, None, None, None
if not transaction_type:
return "Transaction Type is required.", None, None, None, None, None, None, None, None, None, None
if amount == 0: # Assuming amount is a numerical value, check for None explicitly
return "Transaction Amount is required.", None, None, None, None, None, None, None, None, None, None
if oldbalanceOrg is None: # Assuming oldbalanceOrg is a numerical value, check for None explicitly
return "Old Balance Org is required.", None, None, None, None, None, None, None, None, None, None
#url = "https://fraud-sense-16a8ed5f96b5.herokuapp.com/predict_and_explain"
#url = "https://fraudsense-02168c9829aa.herokuapp.com/predict_and_explain"
url = "https://xaifraudsense-48ebac2f952e.herokuapp.com/predict_and_explain"
data = {
'selected_model': selected_model,
'selected_interpretability_method': selected_interpretability_method,
'step': step,
'transaction_type': transaction_type,
'amount': amount,
'oldbalanceOrg': oldbalanceOrg
}
response = requests.post(url, json=data)
if response.status_code == 200:
result = response.json()
# Directly use the base64-encoded image string for the network graph
network_graph = result['network_graph']
# Ensure other data is handled correctly
prediction_text = result['prediction_text']
model_explanation = result['model_explanation']
mod_plot_json = result['mod_plot']
# Parse the JSON strings back into Plotly figures
mod_plot = plotly.graph_objs.Figure(json.loads(mod_plot_json))
features_influence = result['features_influence']
network_graph_json = result['network_graph'] #graph_objects
# Parse the JSON strings back into Plotly figures
network_graph = plotly.graph_objs.Figure(json.loads(network_graph_json))
network_explainer = result['network_explainer']
top_main_effect = result['top_main_effect']
top_interaction = result['top_interaction']
# Parse the JSON strings back into Plotly figures
radial_plot_json = result['radial_plot']
bar_chart_json = result['bar_chart']
radial_plot = plotly.graph_objs.Figure(json.loads(radial_plot_json))
bar_chart = plotly.graph_objs.Figure(json.loads(bar_chart_json))
narrative = result.get('narrative', "")
# Return the results
return prediction_text, model_explanation, mod_plot, features_influence, network_graph, network_explainer, top_main_effect, top_interaction,radial_plot, bar_chart, narrative
else:
# Handle error scenario by returning placeholders for each expected output
return "Error: " + response.text, None, None, None, None,None, None, None, None, None, None
# Define your Gradio interface here
with gr.Blocks() as app:
gr.Markdown("<h2 style='text-align: center; font-weight: bold;'>FraudSenseXAI - Advanced Fraud Detection</h2>")
gr.Markdown("<p style='text-align: center;'>Predict and analyze fraudulent transactions.</p>", elem_id="description")
gr.Markdown("<p style='text-align: center;'>This app utilizes financial synthetic dataset from kaggle, and it is structured to adapt to the nature of the dataset .</p>", elem_id="description")
error_message = gr.Textbox(label="Error Message", visible=False, lines=2, interactive=False)
with gr.Row():
with gr.Column():
gr.Markdown("#### INPUT PARAMETERS: All fields are required")
model_selection = gr.Dropdown(choices=['Random Forest', 'Gradient Boost', 'Neural Network'], label="Model Selection")
interpretability_selection = gr.Dropdown(choices=['LIME', 'SHAP'], label="Interpretability Technique")
step = gr.Number(label="Step(Transaction Time)")
transaction_type = gr.Dropdown(choices=['Transfer', 'Payment', 'Cash Out', 'Cash In'], label="Transaction Type")
transaction_amount = gr.Number(label="Transaction Amount:")
old_balance_org = gr.Number(label="Old Balance: total account balance prior to transaction initiation ")
submit_btn = gr.Button("Analyze")
# Define outputs
gr.Markdown("#### PREDICTION RESULT")
prediction_text = gr.Textbox(label="Prediction", lines=7)
with gr.Column():
gr.Markdown("#### MODEL INTERPRETATIONS")
model_explanation = gr.Textbox(label="Model Explanation", lines=7)
mod_plot = gr.Plot(label="Model Plot")
features_influence = gr.Textbox(label="Features Influence", lines=7)
with gr.Row():
with gr.Column():
gr.Markdown("#### FEATURE INTERACTIONS: Note that this function only supports SHAP. LIME & Neural Network are not supported")
network_graph = gr.Plot(label="Network Graph")
network_explainer = gr.Text(label="Network Graph Explanation")
top_main_effect = gr.Text(label="Top Main Effect", lines=7)
top_interaction = gr.Text(label="Top Interaction", lines=7)
with gr.Column():
gr.Markdown("#### COUNTERFACTUAL EXPLANATIONS")
radial_plot = gr.Plot(label="Radial Plot")
bar_chart = gr.Plot(label="Bar Chart")
narrative = gr.Textbox(label="Narrative")
def update_error_message(error_text, *rest):
if error_text and not error_text.startswith("Error: "):
error_message.update(value=error_text, visible=True)
return (None,) * len(rest) # Update to match the number of outputs minus the error message
else:
error_message.update(visible=False)
return (error_text,) + rest
submit_btn.click(
predict_fraud,
inputs=[model_selection, interpretability_selection, step, transaction_type, transaction_amount, old_balance_org],
outputs=[prediction_text, model_explanation, mod_plot, features_influence, network_graph, network_explainer, top_main_effect, top_interaction,radial_plot, bar_chart, narrative]
)
app.launch(share=True)
|