poudel's picture
Upload app.py
0261be6 verified
raw
history blame
No virus
4.12 kB
import gradio as gr
import pandas as pd
import joblib
import numpy as np
from models.neural_network.inference import load_model_and_preprocessor
# Load the pre-trained model
nn_model, nn_preprocessor = load_model_and_preprocessor('saved_models/nn_model.keras',
'saved_models/nn_preprocessor.pkl')
xgboost_model = joblib.load('saved_models/xgboost_model.joblib')
# Load the unique aircraft data
aircraft_data = pd.read_csv('datasets/aircraft_data.csv').drop_duplicates(subset='model')
aircraft_dict = aircraft_data.set_index('model').to_dict(orient='index')
# Load the airport distances data
airport_data = pd.read_csv('datasets/airport_distances.csv')
airport_dict = airport_data.set_index(['Origin_Airport', 'Destination_Airport']).to_dict(orient='index')
def predict_fuel_burn(model_name, origin, destination, seats, distance):
# Validate the distance against seats
max_seats = aircraft_dict[model_name]['seats']
if seats > max_seats:
return f"The {model_name} aircraft has a maximum of {max_seats} seats."
if seats <= 0:
return "The number of seats must be greater than 0."
if distance <= 0:
return "The distance must be greater than 0."
# Prepare the input data for the model
data = {
'model': [model_name],
'Origin_Airport': [origin],
'Destination_Airport': [destination],
'seats': [seats],
'distance': [distance],
'J/T': [aircraft_dict[model_name]['J/T']],
'CAT': [aircraft_dict[model_name]['CAT']],
'_Manufacturer': [aircraft_dict[model_name]['_Manufacturer']],
'dist': [distance]
}
df = pd.DataFrame(data)
# Make the prediction
fuel_burn_prediction_nn = nn_model.predict(nn_preprocessor.transform(df))[0]
fuel_burn_prediction_xgboost = xgboost_model.predict(df)
return f"Neural Network: {fuel_burn_prediction_nn[0]:.2f} kg, XGBoost: {fuel_burn_prediction_xgboost[0]:.2f} kg"
def update_fields(model_name):
return {
jt: gr.update(value=aircraft_dict[model_name]['J/T']),
cat: gr.update(value=aircraft_dict[model_name]['CAT']),
manufacturer: gr.update(value=aircraft_dict[model_name]['_Manufacturer'])
}
def update_destination_options(origin):
destinations = airport_data[airport_data['Origin_Airport'] == origin]['Destination_Airport'].unique()
return gr.update(choices=list(destinations))
def update_distance(origin, destination):
distance_value = airport_dict.get((origin, destination), {}).get('distance', 'Distance not found')
if distance_value == 'Distance not found':
return gr.update(value=0) # Return 0 if distance is not found
return gr.update(value=distance_value)
with gr.Blocks() as demo:
gr.Markdown("## Fuel Burn Prediction")
with gr.Row():
model_name = gr.Dropdown(
label="Aircraft Model",
choices=list(aircraft_dict.keys()),
value=list(aircraft_dict.keys())[0],
)
origin = gr.Dropdown(
label="Origin Airport",
choices=sorted(airport_data['Origin_Airport'].unique())
)
destination = gr.Dropdown(
label="Destination Airport",
choices=[]
)
with gr.Row():
jt = gr.Textbox(label="J/T", interactive=False)
cat = gr.Textbox(label="CAT", interactive=False)
manufacturer = gr.Textbox(label="Manufacturer", interactive=False)
seats = gr.Number(label="Seats")
distance = gr.Number(label="Distance", interactive=False)
model_name.change(fn=update_fields, inputs=model_name, outputs=[jt, cat, manufacturer])
origin.change(fn=update_destination_options, inputs=origin, outputs=destination)
destination.change(fn=update_distance, inputs=[origin, destination], outputs=distance)
submit_btn = gr.Button("Predict Fuel Burn")
result = gr.Textbox(label="Fuel Burn Prediction", interactive=False)
submit_btn.click(predict_fuel_burn, inputs=[model_name, origin, destination, seats, distance], outputs=result)
demo.launch()