fmegahed commited on
Commit
d88cede
·
verified ·
1 Parent(s): 4caae33

trying to add tabs to separate the validation and forecasts (also adding the actual forecasts)

Browse files
Files changed (1) hide show
  1. app.py +87 -38
app.py CHANGED
@@ -34,7 +34,7 @@ def load_data(file):
34
  return None, f"Error loading data: {str(e)}"
35
 
36
  # Function to generate and return a plot
37
- def create_forecast_plot(forecast_df, original_df):
38
  plt.figure(figsize=(10, 6))
39
  unique_ids = forecast_df['unique_id'].unique()
40
  forecast_cols = [col for col in forecast_df.columns if col not in ['unique_id', 'ds', 'cutoff']]
@@ -47,7 +47,32 @@ def create_forecast_plot(forecast_df, original_df):
47
  if col in forecast_data.columns:
48
  plt.plot(forecast_data['ds'], forecast_data[col], label=col)
49
 
50
- plt.title('Results')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  plt.xlabel('Date')
52
  plt.ylabel('Value')
53
  plt.legend()
@@ -72,11 +97,12 @@ def run_forecast(
72
  use_seasonal_window_avg,
73
  seasonal_window_size,
74
  use_autoets,
75
- use_autoarima
 
76
  ):
77
  df, message = load_data(file)
78
  if df is None:
79
- return None, None, None, message
80
 
81
  models = []
82
  model_aliases = []
@@ -104,27 +130,33 @@ def run_forecast(
104
  model_aliases.append('autoarima')
105
 
106
  if not models:
107
- return None, None, None, "Please select at least one forecasting model"
108
 
109
  sf = StatsForecast(models=models, freq=frequency, n_jobs=-1)
110
 
111
  try:
 
112
  if eval_strategy == "Cross Validation":
113
  cv_results = sf.cross_validation(df=df, h=horizon, step_size=step_size, n_windows=num_windows)
114
  evaluation = evaluate(df=cv_results, metrics=[bias, mae, rmse, mape], models=model_aliases)
115
  eval_df = pd.DataFrame(evaluation).reset_index()
116
- fig_forecast = create_forecast_plot(cv_results, df)
117
- return eval_df, cv_results, fig_forecast, "Cross validation completed successfully!"
118
-
119
  else: # Fixed window
120
- cv_results = sf.cross_validation(df=df, h=horizon, step_size=10, n_windows=1) # any step size will do since it is only 1 window
121
  evaluation = evaluate(df=cv_results, metrics=[bias, mae, rmse, mape], models=model_aliases)
122
  eval_df = pd.DataFrame(evaluation).reset_index()
123
- fig_forecast = create_forecast_plot(cv_results, df)
124
- return eval_df, cv_results, fig_forecast, "Fixed window evaluation completed successfully!"
 
 
 
 
 
 
 
125
 
126
  except Exception as e:
127
- return None, None, None, f"Error during forecasting: {str(e)}"
128
 
129
  # Sample CSV file generation
130
  def download_sample():
@@ -163,32 +195,49 @@ with gr.Blocks(title="StatsForecast Demo") as app:
163
  download_output = gr.File(label="Click to download", visible=True)
164
  download_btn.click(fn=download_sample, outputs=download_output)
165
 
166
- frequency = gr.Dropdown(choices=["H", "D", "WS", "MS", "QS", "YS"], label="Frequency", value="D")
167
- eval_strategy = gr.Radio(choices=["Fixed Window", "Cross Validation"], label="Evaluation Strategy", value="Cross Validation")
168
- horizon = gr.Slider(1, 100, value=10, step=1, label="Horizon")
169
- step_size = gr.Slider(1, 50, value=10, step=1, label="Step Size")
170
- num_windows = gr.Slider(1, 20, value=3, step=1, label="Number of Windows")
171
-
172
-
173
- gr.Markdown("### Model Configuration")
174
- use_historical_avg = gr.Checkbox(label="Use Historical Average", value=True)
175
- use_naive = gr.Checkbox(label="Use Naive", value=True)
176
- use_seasonal_naive = gr.Checkbox(label="Use Seasonal Naive")
177
- seasonality = gr.Number(label="Seasonality", value=10)
178
- use_window_avg = gr.Checkbox(label="Use Window Average")
179
- window_size = gr.Number(label="Window Size", value=3)
180
- use_seasonal_window_avg = gr.Checkbox(label="Use Seasonal Window Average")
181
- seasonal_window_size = gr.Number(label="Seasonal Window Size", value=2)
182
- use_autoets = gr.Checkbox(label="Use AutoETS")
183
- use_autoarima = gr.Checkbox(label="Use AutoARIMA")
184
-
185
- submit_btn = gr.Button("Run Forecast")
 
 
 
 
 
 
 
 
 
 
186
 
187
  with gr.Column(scale=3):
188
- eval_output = gr.Dataframe(label="Evaluation Results")
189
- forecast_output = gr.Dataframe(label="Detailed Evaluation Results")
190
- plot_output = gr.Plot(label="Plotting the Actual and the Evaluation Results")
191
- message_output = gr.Textbox(label="Message")
 
 
 
 
 
 
 
192
 
193
  submit_btn.click(
194
  fn=run_forecast,
@@ -196,9 +245,9 @@ with gr.Blocks(title="StatsForecast Demo") as app:
196
  file_input, frequency, eval_strategy, horizon, step_size, num_windows,
197
  use_historical_avg, use_naive, use_seasonal_naive, seasonality,
198
  use_window_avg, window_size, use_seasonal_window_avg, seasonal_window_size,
199
- use_autoets, use_autoarima
200
  ],
201
- outputs=[eval_output, forecast_output, plot_output, message_output]
202
  )
203
 
204
  if __name__ == "__main__":
 
34
  return None, f"Error loading data: {str(e)}"
35
 
36
  # Function to generate and return a plot
37
+ def create_forecast_plot(forecast_df, original_df, title="Forecasting Results"):
38
  plt.figure(figsize=(10, 6))
39
  unique_ids = forecast_df['unique_id'].unique()
40
  forecast_cols = [col for col in forecast_df.columns if col not in ['unique_id', 'ds', 'cutoff']]
 
47
  if col in forecast_data.columns:
48
  plt.plot(forecast_data['ds'], forecast_data[col], label=col)
49
 
50
+ plt.title(title)
51
+ plt.xlabel('Date')
52
+ plt.ylabel('Value')
53
+ plt.legend()
54
+ plt.grid(True)
55
+ fig = plt.gcf()
56
+ return fig
57
+
58
+ # Function to create a plot for future forecasts
59
+ def create_future_forecast_plot(forecast_df, original_df):
60
+ plt.figure(figsize=(10, 6))
61
+ unique_ids = forecast_df['unique_id'].unique()
62
+ forecast_cols = [col for col in forecast_df.columns if col not in ['unique_id', 'ds']]
63
+
64
+ for unique_id in unique_ids:
65
+ # Plot historical data
66
+ original_data = original_df[original_df['unique_id'] == unique_id]
67
+ plt.plot(original_data['ds'], original_data['y'], 'k-', label='Historical')
68
+
69
+ # Plot forecast data
70
+ forecast_data = forecast_df[forecast_df['unique_id'] == unique_id]
71
+ for col in forecast_cols:
72
+ if col in forecast_data.columns:
73
+ plt.plot(forecast_data['ds'], forecast_data[col], label=col)
74
+
75
+ plt.title('Future Forecast')
76
  plt.xlabel('Date')
77
  plt.ylabel('Value')
78
  plt.legend()
 
97
  use_seasonal_window_avg,
98
  seasonal_window_size,
99
  use_autoets,
100
+ use_autoarima,
101
+ future_horizon
102
  ):
103
  df, message = load_data(file)
104
  if df is None:
105
+ return None, None, None, None, None, message
106
 
107
  models = []
108
  model_aliases = []
 
130
  model_aliases.append('autoarima')
131
 
132
  if not models:
133
+ return None, None, None, None, None, "Please select at least one forecasting model"
134
 
135
  sf = StatsForecast(models=models, freq=frequency, n_jobs=-1)
136
 
137
  try:
138
+ # Run cross-validation
139
  if eval_strategy == "Cross Validation":
140
  cv_results = sf.cross_validation(df=df, h=horizon, step_size=step_size, n_windows=num_windows)
141
  evaluation = evaluate(df=cv_results, metrics=[bias, mae, rmse, mape], models=model_aliases)
142
  eval_df = pd.DataFrame(evaluation).reset_index()
143
+ fig_validation = create_forecast_plot(cv_results, df, "Cross Validation Results")
 
 
144
  else: # Fixed window
145
+ cv_results = sf.cross_validation(df=df, h=horizon, step_size=10, n_windows=1) # any step size for 1 window
146
  evaluation = evaluate(df=cv_results, metrics=[bias, mae, rmse, mape], models=model_aliases)
147
  eval_df = pd.DataFrame(evaluation).reset_index()
148
+ fig_validation = create_forecast_plot(cv_results, df, "Fixed Window Validation Results")
149
+
150
+ # Generate future forecasts
151
+ fitted_sf = StatsForecast(models=models, freq=frequency, n_jobs=-1)
152
+ fitted_sf.fit(df)
153
+ future_forecasts = fitted_sf.forecast(h=future_horizon)
154
+ fig_future = create_future_forecast_plot(future_forecasts, df)
155
+
156
+ return eval_df, cv_results, fig_validation, future_forecasts, fig_future, "Analysis completed successfully!"
157
 
158
  except Exception as e:
159
+ return None, None, None, None, None, f"Error during forecasting: {str(e)}"
160
 
161
  # Sample CSV file generation
162
  def download_sample():
 
195
  download_output = gr.File(label="Click to download", visible=True)
196
  download_btn.click(fn=download_sample, outputs=download_output)
197
 
198
+ with gr.Accordion("Data & Validation Settings", open=True):
199
+ frequency = gr.Dropdown(choices=["H", "D", "WS", "MS", "QS", "YS"], label="Frequency", value="D")
200
+ eval_strategy = gr.Radio(choices=["Fixed Window", "Cross Validation"], label="Evaluation Strategy", value="Cross Validation")
201
+ horizon = gr.Slider(1, 100, value=10, step=1, label="Validation Horizon")
202
+ step_size = gr.Slider(1, 50, value=10, step=1, label="Step Size")
203
+ num_windows = gr.Slider(1, 20, value=3, step=1, label="Number of Windows")
204
+
205
+ with gr.Accordion("Forecast Settings", open=True):
206
+ future_horizon = gr.Slider(1, 100, value=20, step=1, label="Future Forecast Horizon")
207
+
208
+ with gr.Accordion("Model Configuration", open=True):
209
+ use_historical_avg = gr.Checkbox(label="Use Historical Average", value=True)
210
+ use_naive = gr.Checkbox(label="Use Naive", value=True)
211
+
212
+ with gr.Row():
213
+ use_seasonal_naive = gr.Checkbox(label="Use Seasonal Naive")
214
+ seasonality = gr.Number(label="Seasonality", value=10)
215
+
216
+ with gr.Row():
217
+ use_window_avg = gr.Checkbox(label="Use Window Average")
218
+ window_size = gr.Number(label="Window Size", value=3)
219
+
220
+ with gr.Row():
221
+ use_seasonal_window_avg = gr.Checkbox(label="Use Seasonal Window Average")
222
+ seasonal_window_size = gr.Number(label="Seasonal Window Size", value=2)
223
+
224
+ use_autoets = gr.Checkbox(label="Use AutoETS")
225
+ use_autoarima = gr.Checkbox(label="Use AutoARIMA")
226
+
227
+ submit_btn = gr.Button("Run Forecast", variant="primary")
228
 
229
  with gr.Column(scale=3):
230
+ message_output = gr.Textbox(label="Status Message")
231
+
232
+ with gr.Tabs() as tabs:
233
+ with gr.TabItem("Validation Results"):
234
+ eval_output = gr.Dataframe(label="Evaluation Metrics")
235
+ validation_output = gr.Dataframe(label="Validation Data")
236
+ validation_plot = gr.Plot(label="Validation Plot")
237
+
238
+ with gr.TabItem("Future Forecast"):
239
+ forecast_output = gr.Dataframe(label="Future Forecast Data")
240
+ forecast_plot = gr.Plot(label="Future Forecast Plot")
241
 
242
  submit_btn.click(
243
  fn=run_forecast,
 
245
  file_input, frequency, eval_strategy, horizon, step_size, num_windows,
246
  use_historical_avg, use_naive, use_seasonal_naive, seasonality,
247
  use_window_avg, window_size, use_seasonal_window_avg, seasonal_window_size,
248
+ use_autoets, use_autoarima, future_horizon
249
  ],
250
+ outputs=[eval_output, validation_output, validation_plot, forecast_output, forecast_plot, message_output]
251
  )
252
 
253
  if __name__ == "__main__":