File size: 3,722 Bytes
a5e59c4
 
 
 
 
 
0261be6
a5e59c4
 
0261be6
a5e59c4
 
 
0261be6
a5e59c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd022fa
a5e59c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import gradio as gr
import pandas as pd
import joblib
import numpy as np

# Load the pre-trained model
xgboost_model = joblib.load('saved_models/xgboost_model.joblib')

# Load the unique aircraft data
aircraft_data = pd.read_csv('datasets/aircraft_data.csv').drop_duplicates(subset='model')
aircraft_dict = aircraft_data.set_index('model').to_dict(orient='index')

# Load the airport distances data
airport_data = pd.read_csv('datasets/airport_distances.csv')
airport_dict = airport_data.set_index(['Origin_Airport', 'Destination_Airport']).to_dict(orient='index')


def predict_fuel_burn(model_name, origin, destination, seats, distance):
    # Validate the distance against seats
    max_seats = aircraft_dict[model_name]['seats']
    if seats > max_seats:
        return f"The {model_name} aircraft has a maximum of {max_seats} seats."
    if seats <= 0:
        return "The number of seats must be greater than 0."
    if distance <= 0:
        return "The distance must be greater than 0."

    # Prepare the input data for the model
    data = {
        'model': [model_name],
        'Origin_Airport': [origin],
        'Destination_Airport': [destination],
        'seats': [seats],
        'distance': [distance],
        'J/T': [aircraft_dict[model_name]['J/T']],
        'CAT': [aircraft_dict[model_name]['CAT']],
        '_Manufacturer': [aircraft_dict[model_name]['_Manufacturer']],
        'dist': [distance]
    }

    df = pd.DataFrame(data)

    # Make the prediction
    fuel_burn_prediction_xgboost = xgboost_model.predict(df)
    return f"{fuel_burn_prediction_xgboost[0]:.2f} kg"


def update_fields(model_name):
    return {
        jt: gr.update(value=aircraft_dict[model_name]['J/T']),
        cat: gr.update(value=aircraft_dict[model_name]['CAT']),
        manufacturer: gr.update(value=aircraft_dict[model_name]['_Manufacturer'])
    }


def update_destination_options(origin):
    destinations = airport_data[airport_data['Origin_Airport'] == origin]['Destination_Airport'].unique()
    return gr.update(choices=list(destinations))


def update_distance(origin, destination):
    distance_value = airport_dict.get((origin, destination), {}).get('distance', 'Distance not found')
    if distance_value == 'Distance not found':
        return gr.update(value=0)  # Return 0 if distance is not found
    return gr.update(value=distance_value)


with gr.Blocks() as demo:
    gr.Markdown("## Fuel Burn Prediction")

    with gr.Row():
        model_name = gr.Dropdown(
            label="Aircraft Model",
            choices=list(aircraft_dict.keys()),
            value=list(aircraft_dict.keys())[0],
        )
        origin = gr.Dropdown(
            label="Origin Airport",
            choices=sorted(airport_data['Origin_Airport'].unique())
        )
        destination = gr.Dropdown(
            label="Destination Airport",
            choices=[]
        )

    with gr.Row():
        jt = gr.Textbox(label="J/T", interactive=False)
        cat = gr.Textbox(label="CAT", interactive=False)
        manufacturer = gr.Textbox(label="Manufacturer", interactive=False)
        seats = gr.Number(label="Seats")

    distance = gr.Number(label="Distance", interactive=False)

    model_name.change(fn=update_fields, inputs=model_name, outputs=[jt, cat, manufacturer])
    origin.change(fn=update_destination_options, inputs=origin, outputs=destination)
    destination.change(fn=update_distance, inputs=[origin, destination], outputs=distance)

    submit_btn = gr.Button("Predict Fuel Burn")
    result = gr.Textbox(label="Fuel Burn Prediction", interactive=False)

    submit_btn.click(predict_fuel_burn, inputs=[model_name, origin, destination, seats, distance], outputs=result)

demo.launch()