Praveen76 commited on
Commit
f151c4b
1 Parent(s): e9767e9

Upload 2 files

Browse files
Files changed (2) hide show
  1. app/__init__.py +0 -0
  2. app/main.py +125 -0
app/__init__.py ADDED
File without changes
app/main.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from pathlib import Path
3
+ file = Path(__file__).resolve()
4
+ parent, root = file.parent, file.parents[1]
5
+ sys.path.append(str(root))
6
+
7
+ import gradio
8
+ from fastapi import FastAPI, Request, Response
9
+
10
+ import random
11
+ import numpy as np
12
+ import pandas as pd
13
+ from titanic_model.processing.data_manager import load_dataset, load_pipeline
14
+ from titanic_model import __version__ as _version
15
+ from titanic_model.config.core import config
16
+ from sklearn.model_selection import train_test_split
17
+ from titanic_model.predict import make_prediction
18
+
19
+ from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
20
+
21
+
22
+ # FastAPI object
23
+ app = FastAPI()
24
+
25
+
26
+ ################################# Prometheus related code START ######################################################
27
+ import prometheus_client as prom
28
+
29
+ acc_metric = prom.Gauge('titanic_accuracy_score', 'Accuracy score for few random 100 test samples')
30
+ f1_metric = prom.Gauge('titanic_f1_score', 'F1 score for few random 100 test samples')
31
+ precision_metric = prom.Gauge('titanic_precision_score', 'Precision score for few random 100 test samples')
32
+ recall_metric = prom.Gauge('titanic_recall_score', 'Recall score for few random 100 test samples')
33
+
34
+ # LOAD TEST DATA
35
+ pipeline_file_name = f"{config.app_config.pipeline_save_file}{_version}.pkl"
36
+ titanic_pipe= load_pipeline(file_name=pipeline_file_name)
37
+ data = load_dataset(file_name=config.app_config.training_data_file) # read complete data
38
+
39
+ X_train, X_test, y_train, y_test = train_test_split( # divide into train and test set
40
+ data[config.model_config.features],
41
+ data[config.model_config.target],
42
+ test_size=config.model_config.test_size,
43
+ random_state=config.model_config.random_state,
44
+ )
45
+ test_data = X_test.copy()
46
+ test_data['target'] = y_test.values
47
+
48
+
49
+ # Function for updating metrics
50
+ def update_metrics():
51
+ global test_data
52
+ # Performance on test set
53
+ size = random.randint(100, 130)
54
+ test = test_data.sample(size, random_state = random.randint(0, 1e6)) # sample few 100 rows randomly
55
+ y_pred = titanic_pipe.predict(test.iloc[:, :-1]) # prediction
56
+ acc = accuracy_score(test['target'], y_pred).round(3) # accuracy score
57
+ f1 = f1_score(test['target'], y_pred).round(3) # F1 score
58
+ precision = precision_score(test['target'], y_pred).round(3) # Precision score
59
+ recall = recall_score(test['target'], y_pred).round(3) # Recall score
60
+
61
+ acc_metric.set(acc)
62
+ f1_metric.set(f1)
63
+ precision_metric.set(precision)
64
+ recall_metric.set(recall)
65
+
66
+ @app.get("/metrics")
67
+ async def get_metrics():
68
+ update_metrics()
69
+ return Response(media_type="text/plain", content= prom.generate_latest())
70
+
71
+ ################################# Prometheus related code END ######################################################
72
+
73
+
74
+ # UI - Input components
75
+ in_Pid = gradio.Textbox(lines=1, placeholder=None, value="79", label='Passenger Id')
76
+ in_Pclass = gradio.Radio(['1', '2', '3'], type="value", label='Passenger class')
77
+ in_Pname = gradio.Textbox(lines=1, placeholder=None, value="Caldwell, Master. Alden Gates", label='Passenger Name')
78
+ in_sex = gradio.Radio(["Male", "Female"], type="value", label='Gender')
79
+ in_age = gradio.Textbox(lines=1, placeholder=None, value="14", label='Age of the passenger in yrs')
80
+ in_sibsp = gradio.Textbox(lines=1, placeholder=None, value="0", label='No. of siblings/spouse of the passenger aboard')
81
+ in_parch = gradio.Textbox(lines=1, placeholder=None, value="2", label='No. of parents/children of the passenger aboard')
82
+ in_ticket = gradio.Textbox(lines=1, placeholder=None, value="248738", label='Ticket number')
83
+ in_cabin = gradio.Textbox(lines=1, placeholder=None, value="A5", label='Cabin number')
84
+ in_embarked = gradio.Radio(["Southampton", "Cherbourg", "Queenstown"], type="value", label='Port of Embarkation')
85
+ in_fare = gradio.Textbox(lines=1, placeholder=None, value="29", label='Passenger fare')
86
+
87
+ # UI - Output component
88
+ out_label = gradio.Textbox(type="text", label='Prediction', elem_id="out_textbox")
89
+
90
+ # Label prediction function
91
+ def get_output_label(in_Pid, in_Pclass, in_Pname, in_sex, in_age, in_sibsp, in_parch, in_ticket, in_cabin, in_embarked, in_fare):
92
+
93
+ input_df = pd.DataFrame({"PassengerId": [in_Pid],
94
+ "Pclass": [int(in_Pclass)],
95
+ "Name": [in_Pname],
96
+ "Sex": [in_sex.lower()],
97
+ "Age": [float(in_age)],
98
+ "SibSp": [int(in_sibsp)],
99
+ "Parch": [int(in_parch)],
100
+ "Ticket": [in_ticket],
101
+ "Cabin": [in_cabin],
102
+ "Embarked": [in_embarked[0]],
103
+ "Fare": [float(in_fare)]})
104
+
105
+ result = make_prediction(input_data=input_df.replace({np.nan: None}))["predictions"]
106
+ label = "Survived" if result[0]==1 else "Not Survived"
107
+ return label
108
+
109
+
110
+ # Create Gradio interface object
111
+ iface = gradio.Interface(fn = get_output_label,
112
+ inputs = [in_Pid, in_Pclass, in_Pname, in_sex, in_age, in_sibsp, in_parch, in_ticket, in_cabin, in_embarked, in_fare],
113
+ outputs = [out_label],
114
+ title="Titanic Survival Prediction API ⛴",
115
+ description="Predictive model that answers the question: “What sort of people were more likely to survive?”",
116
+ allow_flagging='never',
117
+ )
118
+
119
+ # Mount gradio interface object on FastAPI app at endpoint = '/'
120
+ app = gradio.mount_gradio_app(app, iface, path="/")
121
+
122
+
123
+ if __name__ == "__main__":
124
+ import uvicorn
125
+ uvicorn.run(app, host="0.0.0.0", port=8001)