manjunathainti commited on
Commit
2f82680
·
verified ·
1 Parent(s): f8d0f44

update app.py

Browse files
Files changed (1) hide show
  1. app.py +168 -85
app.py CHANGED
@@ -2,108 +2,191 @@ import gradio as gr
2
  import matplotlib.pyplot as plt
3
  import pandas as pd
4
  import joblib
 
 
5
 
6
- # Load the dataset
 
7
  data_file = "webtraffic.csv"
8
- webtraffic_data = pd.read_csv(data_file)
9
 
10
- # Verify if 'Datetime' exists, or create it
 
 
 
 
 
 
 
11
  if "Datetime" not in webtraffic_data.columns:
12
- print("Datetime column missing. Attempting to create from 'Hour Index'.")
13
  start_date = pd.Timestamp("2024-01-01 00:00:00")
14
- webtraffic_data["Datetime"] = start_date + pd.to_timedelta(
15
- webtraffic_data["Hour Index"], unit="h"
16
- )
17
  else:
18
  webtraffic_data["Datetime"] = pd.to_datetime(webtraffic_data["Datetime"])
19
 
20
- # Ensure 'Datetime' column is sorted
21
  webtraffic_data.sort_values("Datetime", inplace=True)
22
 
23
- # Load the SARIMA model
24
- sarima_model = joblib.load("sarima_model.pkl")
25
-
26
- # Define future periods for evaluation
27
- future_periods = 48
28
-
29
- # Dummy values for metrics (if needed)
30
- mae_sarima_future = 100
31
- rmse_sarima_future = 150
32
-
33
-
34
- # Function to generate plot based on SARIMA model
35
- def generate_plot():
36
- future_dates = pd.date_range(
37
- start=webtraffic_data["Datetime"].iloc[-1], periods=future_periods + 1, freq="H"
38
- )[1:]
39
-
40
- sarima_predictions = sarima_model.predict(n_periods=future_periods)
41
- future_predictions = pd.DataFrame(
42
- {"Datetime": future_dates, "SARIMA_Predicted": sarima_predictions}
43
- )
44
- plt.figure(figsize=(15, 6))
45
- plt.plot(
46
- webtraffic_data["Datetime"],
47
- webtraffic_data["Sessions"],
48
- label="Actual Traffic",
49
- color="black",
50
- linestyle="dotted",
51
- linewidth=2,
52
- )
53
- plt.plot(
54
- future_predictions["Datetime"],
55
- future_predictions["SARIMA_Predicted"],
56
- label="SARIMA Predicted",
57
- color="blue",
58
- linewidth=2,
59
- )
60
-
61
- plt.title("SARIMA Predictions vs Actual Traffic", fontsize=16)
62
- plt.xlabel("Datetime", fontsize=12)
63
- plt.ylabel("Sessions", fontsize=12)
64
- plt.legend(loc="upper left")
65
- plt.grid(True)
66
- plt.tight_layout()
67
-
68
- plot_path = "sarima_prediction_plot.png"
69
- plt.savefig(plot_path)
70
- plt.close()
71
- return plot_path
72
-
73
-
74
- # Function to display SARIMA metrics
75
- def display_metrics():
76
- metrics = {
77
- "Model": ["SARIMA"],
78
- "Mean Absolute Error (MAE)": [mae_sarima_future],
79
- "Root Mean Squared Error (RMSE)": [rmse_sarima_future],
80
- }
81
- return pd.DataFrame(metrics)
82
-
83
-
84
- # Gradio interface function
85
- def dashboard_interface():
86
- plot_path = generate_plot()
87
- metrics_df = display_metrics()
88
- return plot_path, metrics_df.to_string()
89
-
90
-
91
- # Build the Gradio interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  with gr.Blocks() as dashboard:
93
- gr.Markdown("## Interactive SARIMA Web Traffic Prediction Dashboard")
94
- gr.Markdown(
95
- "This dashboard shows SARIMA model predictions vs actual traffic along with performance metrics."
96
- )
 
 
97
 
98
- plot_output = gr.Image(label="Prediction Plot")
99
- metrics_output = gr.Textbox(label="Metrics", lines=15)
100
 
 
101
  gr.Button("Generate Predictions").click(
102
- fn=dashboard_interface,
103
  inputs=[],
104
  outputs=[plot_output, metrics_output],
105
  )
106
 
107
- # Launch the Gradio dashboard
 
 
 
 
 
 
108
  if __name__ == "__main__":
 
109
  dashboard.launch()
 
2
  import matplotlib.pyplot as plt
3
  import pandas as pd
4
  import joblib
5
+ from sklearn.metrics import mean_absolute_error, mean_squared_error
6
+ from math import sqrt
7
 
8
+ # Step 1: Load the Dataset
9
+ print("Loading Dataset...")
10
  data_file = "webtraffic.csv"
 
11
 
12
+ try:
13
+ webtraffic_data = pd.read_csv(data_file)
14
+ print("Dataset loaded successfully!")
15
+ except Exception as e:
16
+ print(f"Error loading dataset: {e}")
17
+ exit()
18
+
19
+ # Step 2: Ensure 'Datetime' column exists or create it
20
  if "Datetime" not in webtraffic_data.columns:
21
+ print("Datetime column missing. Creating from 'Hour Index'.")
22
  start_date = pd.Timestamp("2024-01-01 00:00:00")
23
+ webtraffic_data["Datetime"] = start_date + pd.to_timedelta(webtraffic_data["Hour Index"], unit="h")
 
 
24
  else:
25
  webtraffic_data["Datetime"] = pd.to_datetime(webtraffic_data["Datetime"])
26
 
 
27
  webtraffic_data.sort_values("Datetime", inplace=True)
28
 
29
+ # Step 3: Load SARIMA Model
30
+ print("Loading SARIMA Model...")
31
+ try:
32
+ sarima_model = joblib.load("sarima_model.pkl")
33
+ print("SARIMA model loaded successfully!")
34
+ except Exception as e:
35
+ print(f"Error loading SARIMA model: {e}")
36
+ exit()
37
+
38
+ # Step 4: Define Functions for Gradio Dashboard
39
+ future_periods = 48 # Number of hours to predict
40
+
41
+ def generate_sarima_plot():
42
+ """Generate SARIMA predictions and return a detailed plot with metrics."""
43
+ try:
44
+ # Generate future dates for predictions
45
+ future_dates = pd.date_range(
46
+ start=webtraffic_data["Datetime"].iloc[-1],
47
+ periods=future_periods + 1,
48
+ freq="H"
49
+ )[1:]
50
+
51
+ # Generate SARIMA predictions
52
+ sarima_predictions = sarima_model.predict(n_periods=future_periods)
53
+
54
+ # Extract actual data for the last 'future_periods' hours
55
+ actual_sessions = webtraffic_data["Sessions"].iloc[-future_periods:].values
56
+
57
+ # Calculate metrics
58
+ mae_sarima = mean_absolute_error(actual_sessions, sarima_predictions[:len(actual_sessions)])
59
+ rmse_sarima = sqrt(mean_squared_error(actual_sessions, sarima_predictions[:len(actual_sessions)]))
60
+
61
+ # Combine predictions into a DataFrame for plotting
62
+ future_predictions = pd.DataFrame({
63
+ "Datetime": future_dates,
64
+ "SARIMA_Predicted": sarima_predictions
65
+ })
66
+
67
+ # Plot Actual Traffic vs SARIMA Predictions
68
+ plt.figure(figsize=(15, 6))
69
+ plt.plot(
70
+ webtraffic_data["Datetime"],
71
+ webtraffic_data["Sessions"],
72
+ label="Actual Traffic",
73
+ color="black",
74
+ linestyle="dotted",
75
+ linewidth=2,
76
+ )
77
+ plt.plot(
78
+ future_predictions["Datetime"],
79
+ future_predictions["SARIMA_Predicted"],
80
+ label="SARIMA Predicted",
81
+ color="blue",
82
+ linewidth=2,
83
+ )
84
+
85
+ plt.title("SARIMA Predictions vs Actual Traffic", fontsize=16)
86
+ plt.xlabel("Datetime", fontsize=12)
87
+ plt.ylabel("Sessions", fontsize=12)
88
+ plt.legend(loc="upper left")
89
+ plt.grid(True)
90
+ plt.tight_layout()
91
+
92
+ # Save the plot
93
+ plot_path = "sarima_prediction_plot.png"
94
+ plt.savefig(plot_path)
95
+ plt.close()
96
+
97
+ # Return plot path and metrics
98
+ metrics = f"""
99
+ SARIMA Model Metrics:
100
+ - Mean Absolute Error (MAE): {mae_sarima:.2f}
101
+ - Root Mean Squared Error (RMSE): {rmse_sarima:.2f}
102
+ """
103
+ return plot_path, metrics
104
+
105
+ except Exception as e:
106
+ print(f"Error generating SARIMA plot: {e}")
107
+ return None, "Error in generating output. Please check the data and model."
108
+
109
+ def generate_zoomed_plot():
110
+ """Generate a zoomed-in SARIMA prediction plot."""
111
+ try:
112
+ # Generate future dates for predictions
113
+ future_dates = pd.date_range(
114
+ start=webtraffic_data["Datetime"].iloc[-1],
115
+ periods=future_periods + 1,
116
+ freq="H"
117
+ )[1:]
118
+
119
+ # Generate SARIMA predictions
120
+ sarima_predictions = sarima_model.predict(n_periods=future_periods)
121
+
122
+ # Combine predictions into a DataFrame for plotting
123
+ future_predictions = pd.DataFrame({
124
+ "Datetime": future_dates,
125
+ "SARIMA_Predicted": sarima_predictions
126
+ })
127
+
128
+ # Zoomed-in view of the plot (recent data only)
129
+ plt.figure(figsize=(15, 6))
130
+ plt.plot(
131
+ webtraffic_data["Datetime"].iloc[-future_periods:],
132
+ webtraffic_data["Sessions"].iloc[-future_periods:],
133
+ label="Actual Traffic (Zoomed)",
134
+ color="black",
135
+ linestyle="dotted",
136
+ linewidth=2,
137
+ )
138
+ plt.plot(
139
+ future_predictions["Datetime"],
140
+ future_predictions["SARIMA_Predicted"],
141
+ label="SARIMA Predicted (Zoomed)",
142
+ color="green",
143
+ linewidth=2,
144
+ )
145
+
146
+ plt.title("Zoomed-In SARIMA Predictions vs Actual Traffic", fontsize=16)
147
+ plt.xlabel("Datetime", fontsize=12)
148
+ plt.ylabel("Sessions", fontsize=12)
149
+ plt.legend(loc="upper left")
150
+ plt.grid(True)
151
+ plt.tight_layout()
152
+
153
+ # Save the zoomed plot
154
+ zoomed_plot_path = "sarima_zoomed_plot.png"
155
+ plt.savefig(zoomed_plot_path)
156
+ plt.close()
157
+
158
+ return zoomed_plot_path
159
+
160
+ except Exception as e:
161
+ print(f"Error generating zoomed plot: {e}")
162
+ return None
163
+
164
+ # Step 5: Gradio Dashboard with Two Tiles and Metrics
165
  with gr.Blocks() as dashboard:
166
+ gr.Markdown("## Enhanced SARIMA Web Traffic Prediction Dashboard")
167
+ gr.Markdown("This dashboard includes SARIMA predictions, performance metrics, and a zoomed-in view of recent data.")
168
+
169
+ # Outputs: Main Plot and Metrics
170
+ plot_output = gr.Image(label="SARIMA Prediction Plot")
171
+ metrics_output = gr.Textbox(label="Model Metrics", lines=6)
172
 
173
+ # Outputs: Zoomed Plot
174
+ zoomed_plot_output = gr.Image(label="Zoomed-In Prediction Plot")
175
 
176
+ # Button to Generate Results
177
  gr.Button("Generate Predictions").click(
178
+ fn=generate_sarima_plot,
179
  inputs=[],
180
  outputs=[plot_output, metrics_output],
181
  )
182
 
183
+ gr.Button("Generate Zoomed-In Plot").click(
184
+ fn=generate_zoomed_plot,
185
+ inputs=[],
186
+ outputs=[zoomed_plot_output],
187
+ )
188
+
189
+ # Launch the Gradio Dashboard
190
  if __name__ == "__main__":
191
+ print("\nLaunching Enhanced Gradio Dashboard...")
192
  dashboard.launch()