File size: 7,035 Bytes
2301fd0
 
 
 
 
 
475a89a
 
 
 
 
 
 
 
 
 
 
 
 
6e48224
3efc156
914085c
 
2301fd0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a18a12
475a89a
2301fd0
 
 
 
3a18a12
9f12a38
 
 
 
3a18a12
 
2301fd0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e517d0e
2301fd0
 
 
 
 
 
 
 
 
 
 
475a89a
 
 
 
 
 
 
 
 
 
 
2301fd0
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import gradio as gr
import requests
import json
import plotly

def predict_fraud(selected_model, selected_interpretability_method, step, transaction_type, amount, oldbalanceOrg):
    # Validation checks
    if not selected_model:
        return "Model Selection is required.", None, None, None, None, None, None, None, None, None, None
    if not selected_interpretability_method:
        return "Interpretability Technique is required.", None, None, None, None, None, None, None, None, None, None
    if step == 0:  # Assuming step is a numerical value, check for None explicitly
        return "Step (Transaction Time) is required.", None, None, None, None, None, None, None, None, None, None
    if not transaction_type:
        return "Transaction Type is required.", None, None, None, None, None, None, None, None, None, None
    if amount == 0:  # Assuming amount is a numerical value, check for None explicitly
        return "Transaction Amount is required.", None, None, None, None, None, None, None, None, None, None
    if oldbalanceOrg is None:  # Assuming oldbalanceOrg is a numerical value, check for None explicitly
         return "Old Balance Org is required.", None, None, None, None, None, None, None, None, None, None

    #url = "https://fraud-sense-16a8ed5f96b5.herokuapp.com/predict_and_explain"
    #url = "https://fraudsense-02168c9829aa.herokuapp.com/predict_and_explain"
    url = "https://xaifraudsense-48ebac2f952e.herokuapp.com/predict_and_explain"
    data = {
        'selected_model': selected_model,
        'selected_interpretability_method': selected_interpretability_method,
        'step': step,
        'transaction_type': transaction_type,
        'amount': amount,
        'oldbalanceOrg': oldbalanceOrg
    }

    response = requests.post(url, json=data)
    if response.status_code == 200:
        result = response.json()

        # Directly use the base64-encoded image string for the network graph
        network_graph = result['network_graph']

        # Ensure other data is handled correctly
        prediction_text = result['prediction_text']
        model_explanation = result['model_explanation']
       


        mod_plot_json = result['mod_plot']        
        # Parse the JSON strings back into Plotly figures
        mod_plot = plotly.graph_objs.Figure(json.loads(mod_plot_json))        
        features_influence = result['features_influence']
        
        network_graph_json = result['network_graph']   #graph_objects 
        # Parse the JSON strings back into Plotly figures        
        network_graph = plotly.graph_objs.Figure(json.loads(network_graph_json))        
        network_explainer = result['network_explainer']
        top_main_effect = result['top_main_effect']
        top_interaction = result['top_interaction']        
        
        # Parse the JSON strings back into Plotly figures        
        radial_plot_json = result['radial_plot']
        bar_chart_json = result['bar_chart']        
        radial_plot = plotly.graph_objs.Figure(json.loads(radial_plot_json))
        bar_chart = plotly.graph_objs.Figure(json.loads(bar_chart_json))
                        
        narrative = result.get('narrative', "")

      # Return the results
        return prediction_text, model_explanation, mod_plot, features_influence, network_graph, network_explainer, top_main_effect, top_interaction,radial_plot, bar_chart, narrative
    else:
        # Handle error scenario by returning placeholders for each expected output
        return "Error: " + response.text, None, None, None, None,None, None, None, None, None, None
   
    
# Define your Gradio interface here
with gr.Blocks() as app:
    gr.Markdown("<h2 style='text-align: center; font-weight: bold;'>FraudSenseXAI - Advanced Fraud Detection</h2>")
    gr.Markdown("<p style='text-align: center;'>Predict and analyze fraudulent transactions.</p>", elem_id="description")
    gr.Markdown("<p style='text-align: center;'>This app utilizes financial synthetic dataset from kaggle, and it is structured to adapt to the nature of the dataset .</p>", elem_id="description")
    error_message = gr.Textbox(label="Error Message", visible=False, lines=2, interactive=False)
    
    
    with gr.Row():
        with gr.Column():
            gr.Markdown("#### INPUT PARAMETERS: All fields are required")
            model_selection = gr.Dropdown(choices=['Random Forest', 'Gradient Boost', 'Neural Network'], label="Model Selection")
            interpretability_selection = gr.Dropdown(choices=['LIME', 'SHAP'], label="Interpretability Technique")
            step = gr.Number(label="Step(Transaction Time)")
            transaction_type = gr.Dropdown(choices=['Transfer', 'Payment', 'Cash Out', 'Cash In'], label="Transaction Type")
            transaction_amount = gr.Number(label="Transaction Amount:")
            old_balance_org = gr.Number(label="Old Balance: total account balance prior to transaction initiation ")
            submit_btn = gr.Button("Analyze")
    
        # Define outputs
            gr.Markdown("#### PREDICTION RESULT")
            prediction_text = gr.Textbox(label="Prediction", lines=7)
        
        with gr.Column():
            gr.Markdown("#### MODEL INTERPRETATIONS")
            model_explanation = gr.Textbox(label="Model Explanation", lines=7)
            
            mod_plot = gr.Plot(label="Model Plot")
            features_influence = gr.Textbox(label="Features Influence", lines=7)
            
            
    with gr.Row():
        with gr.Column():        
            gr.Markdown("#### FEATURE INTERACTIONS: Note that this function only supports SHAP. LIME & Neural Network are not supported")     
            network_graph = gr.Plot(label="Network Graph")
            network_explainer = gr.Text(label="Network Graph Explanation")
            top_main_effect = gr.Text(label="Top Main Effect", lines=7)
            top_interaction = gr.Text(label="Top Interaction", lines=7)  
        
        with gr.Column():                  
            gr.Markdown("#### COUNTERFACTUAL EXPLANATIONS")
            radial_plot = gr.Plot(label="Radial Plot")
            bar_chart = gr.Plot(label="Bar Chart")
            narrative = gr.Textbox(label="Narrative")

            
    def update_error_message(error_text, *rest):
        if error_text and not error_text.startswith("Error: "):
            error_message.update(value=error_text, visible=True)
            return (None,) * len(rest)  # Update to match the number of outputs minus the error message
        else:
            error_message.update(visible=False)
            return (error_text,) + rest


    
    submit_btn.click(
        predict_fraud,
        inputs=[model_selection, interpretability_selection, step, transaction_type, transaction_amount, old_balance_org],
        outputs=[prediction_text, model_explanation, mod_plot, features_influence, network_graph, network_explainer, top_main_effect, top_interaction,radial_plot, bar_chart, narrative]
    )

app.launch(share=True)