saburq commited on
Commit
efc88b8
Β·
1 Parent(s): b9b6166

remove animation

Browse files
Files changed (2) hide show
  1. .gitignore +1 -0
  2. app.py +288 -191
.gitignore CHANGED
@@ -1,2 +1,3 @@
1
  aifs-single-mse-1.0.ckpt
2
  flagged/
 
 
1
  aifs-single-mse-1.0.ckpt
2
  flagged/
3
+ gradio_temp/*
app.py CHANGED
@@ -17,6 +17,16 @@ from ecmwf.opendata import Client as OpendataClient
17
  import earthkit.data as ekd
18
  import earthkit.regrid as ekr
19
  import matplotlib.animation as animation
 
 
 
 
 
 
 
 
 
 
20
 
21
  # Define parameters (updating to match notebook.py)
22
  PARAM_SFC = ["10u", "10v", "2d", "2t", "msl", "skt", "sp", "tcw", "lsm", "z", "slor", "sdor"]
@@ -74,15 +84,92 @@ TEMP_DIR = Path("./gradio_temp")
74
  TEMP_DIR.mkdir(exist_ok=True)
75
  os.environ['GRADIO_TEMP_DIR'] = str(TEMP_DIR)
76
 
77
- def get_open_data(param, levelist=[]):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  fields = {}
79
- # Get the data for the current date and the previous date
80
- myiterable = [DEFAULT_DATE - datetime.timedelta(hours=6), DEFAULT_DATE]
81
- print(myiterable)
82
- for date in [DEFAULT_DATE - datetime.timedelta(hours=6), DEFAULT_DATE]:
83
- print(f"Fetching data for {date}")
84
- # sources can be seen https://earthkit-data.readthedocs.io/en/latest/guide/sources.html#id57
85
- data = ekd.from_source("ecmwf-open-data", date=date, param=param, levelist=levelist)
86
  for f in data:
87
  assert f.to_numpy().shape == (721, 1440)
88
  values = np.roll(f.to_numpy(), -f.shape[1] // 2, axis=1)
@@ -98,99 +185,71 @@ def get_open_data(param, levelist=[]):
98
 
99
  return fields
100
 
101
- def plot_forecast_animation(states, selected_variable):
 
 
102
  # Setup the figure and axis
103
  fig = plt.figure(figsize=(15, 8))
104
  ax = plt.axes(projection=ccrs.PlateCarree(central_longitude=0))
105
 
106
- # Get the first state to setup the plot
107
- first_state = states[0]
108
- latitudes, longitudes = first_state["latitudes"], first_state["longitudes"]
109
  fixed_lons = np.where(longitudes > 180, longitudes - 360, longitudes)
110
  triangulation = tri.Triangulation(fixed_lons, latitudes)
111
 
112
- # Find global min/max for consistent colorbar
113
- all_values = [state["fields"][selected_variable] for state in states]
114
- vmin, vmax = np.min(all_values), np.max(all_values)
115
 
116
- # Create a single colorbar that will be reused
117
- contour = None
118
- cbar_ax = None
 
 
 
119
 
120
- def update(frame):
121
- nonlocal contour, cbar_ax
122
- ax.clear()
123
-
124
- # Set map features
125
- ax.set_global()
126
- ax.set_extent([-180, 180, -85, 85], crs=ccrs.PlateCarree())
127
- ax.coastlines(resolution='50m')
128
- ax.add_feature(cfeature.BORDERS, linestyle=":", alpha=0.5)
129
- ax.gridlines(draw_labels=True)
130
-
131
- state = states[frame]
132
- values = state["fields"][selected_variable]
133
-
134
- # Clear the previous colorbar axis if it exists
135
- if cbar_ax:
136
- cbar_ax.remove()
137
-
138
- # Create new contour plot
139
- contour = ax.tricontourf(triangulation, values,
140
- levels=20, transform=ccrs.PlateCarree(),
141
- cmap='RdBu_r', vmin=vmin, vmax=vmax)
142
-
143
- # Create new colorbar
144
- cbar_ax = fig.add_axes([0.1, 0.05, 0.8, 0.03]) # [left, bottom, width, height]
145
- plt.colorbar(contour, cax=cbar_ax, orientation='horizontal')
146
-
147
- # Format the date string properly
148
- forecast_time = state["date"]
149
- if isinstance(forecast_time, str):
150
- try:
151
- forecast_time = datetime.datetime.strptime(forecast_time, "%Y-%m-%d %H:%M:%S")
152
- except ValueError:
153
- try:
154
- forecast_time = datetime.datetime.strptime(forecast_time, "%Y-%m-%d %H:%M:%S.%f")
155
- except ValueError:
156
- forecast_time = DEFAULT_DATE
157
-
158
- time_str = forecast_time.strftime("%Y-%m-%d %H:%M UTC")
159
-
160
- # Get variable description from VARIABLE_GROUPS
161
- var_desc = None
162
- for group in VARIABLE_GROUPS.values():
163
- if selected_variable in group:
164
- var_desc = group[selected_variable]
165
- break
166
- var_name = var_desc if var_desc else selected_variable
167
-
168
- ax.set_title(f"{var_name} - {time_str}")
169
 
170
- # Create animation
171
- anim = animation.FuncAnimation(
172
- fig, update,
173
- frames=len(states),
174
- interval=1000, # 1 second between frames
175
- repeat=True,
176
- blit=False # Must be False to update the colorbar
177
- )
178
 
179
- # Save as MP4
180
- temp_file = str(TEMP_DIR / f"forecast_{datetime.datetime.now().timestamp()}.mp4")
181
- anim.save(temp_file, writer='ffmpeg', fps=1)
 
 
 
 
 
 
 
 
 
 
182
  plt.close()
183
 
184
  return temp_file
185
 
186
- def run_forecast(date, lead_time, device):
187
  # Get all required fields
188
  fields = {}
 
189
 
190
  # Get surface fields
 
191
  fields.update(get_open_data(param=PARAM_SFC))
192
 
193
  # Get soil fields and rename them
 
194
  soil = get_open_data(param=PARAM_SOIL, levelist=SOIL_LEVELS)
195
  mapping = {
196
  'sot_1': 'stl1', 'sot_2': 'stl2',
@@ -200,6 +259,7 @@ def run_forecast(date, lead_time, device):
200
  fields[mapping[k]] = v
201
 
202
  # Get pressure level fields
 
203
  fields.update(get_open_data(param=PARAM_PL, levelist=LEVELS))
204
 
205
  # Convert geopotential height to geopotential
@@ -211,125 +271,162 @@ def run_forecast(date, lead_time, device):
211
 
212
  # Use the global model instance
213
  global MODEL
214
- # If device preference changed, move model to new device
215
  if device != MODEL.device:
216
  MODEL = SimpleRunner("aifs-single-mse-1.0.ckpt", device=device)
217
 
218
- # Collect all states instead of just the last one
219
- states = []
220
  for state in MODEL.run(input_state=input_state, lead_time=lead_time):
221
- states.append(state)
222
- return states
223
-
224
- def update_plot(lead_time, variable):
225
- cleanup_old_files() # Clean up old files before creating new ones
226
- states = run_forecast(DEFAULT_DATE, lead_time, "cuda")
227
- return plot_forecast_animation(states, variable)
228
-
229
- # Add cleanup function for old files
230
- def cleanup_old_files():
231
- # Remove files older than 1 hour
232
- current_time = datetime.datetime.now().timestamp()
233
- for file in TEMP_DIR.glob("*.mp4"): # Changed from *.gif to *.mp4
234
- if current_time - file.stat().st_mtime > 3600: # 1 hour in seconds
235
- file.unlink(missing_ok=True)
236
-
237
- # Create dropdown choices with groups
238
- DROPDOWN_CHOICES = []
239
- for group_name, variables in VARIABLE_GROUPS.items():
240
- # Add group separator
241
- DROPDOWN_CHOICES.append((f"── {group_name} ──", None))
242
- # Add variables in this group
243
- for var_id, desc in sorted(variables.items()):
244
- DROPDOWN_CHOICES.append((f"{desc} ({var_id})", var_id))
245
-
246
- with gr.Blocks(css="""
247
- .centered-header {
248
- text-align: center;
249
- margin-bottom: 20px;
250
- }
251
- .subtitle {
252
- font-size: 1.2em;
253
- line-height: 1.5;
254
- margin: 20px 0;
255
- }
256
- .footer {
257
- text-align: center;
258
- padding: 20px;
259
- margin-top: 20px;
260
- border-top: 1px solid #eee;
261
- }
262
- """) as demo:
263
- # Header section
264
- gr.Markdown(f"""
265
- # AIFS Weather Forecast
266
-
267
- <div class="subtitle">
268
- Interactive visualization of ECMWF AIFS weather forecasts.<br>
269
- Starting from the latest available data ({DEFAULT_DATE.strftime('%Y-%m-%d %H:%M UTC')}),<br>
270
- select how many hours ahead you want to forecast and which meteorological variable to visualize.
271
- </div>
272
- """)
273
-
274
- # Main content
275
- with gr.Row():
276
- with gr.Column(scale=1):
277
- lead_time = gr.Slider(
278
- minimum=6,
279
- maximum=48,
280
- step=6,
281
- value=12,
282
- label="Forecast Hours Ahead"
283
- )
284
- variable = gr.Dropdown(
285
- choices=DROPDOWN_CHOICES,
286
- value="2t",
287
- label="Select Variable to Plot"
288
- )
289
- with gr.Row():
290
- clear_btn = gr.Button("Clear")
291
- submit_btn = gr.Button("Submit", variant="primary")
292
 
293
- with gr.Column(scale=2):
294
- animation_output = gr.Video()
295
-
296
- # Footer with fork instructions and model reference
297
- gr.Markdown("""
298
- <div class="footer">
299
- <h3>Want to run this on your own?</h3>
300
- You can fork this space and run it yourself:
301
-
302
- 1. Visit <a href="https://huggingface.co/spaces/geobase/aifs-forecast" target="_blank">https://huggingface.co/spaces/geobase/aifs-forecast</a>\n
303
- 2. Click the "Duplicate this Space" button in the top right\n
304
- 3. Select your hardware requirements (GPU recommended)\n
305
- 4. Wait for your copy to deploy
306
-
307
- <h3>Model Information</h3>
308
- This demo uses the <a href="https://huggingface.co/ecmwf/aifs-single-1.0" target="_blank">AIFS Single 1.0</a> model from ECMWF,
309
- which is their first operationally supported Artificial Intelligence Forecasting System. The model produces highly skilled forecasts
310
- for upper-air variables, surface weather parameters, and tropical cyclone tracks.
311
 
312
- Note: If you encounter any issues with this demo, trying your own fork might work better!
313
- </div>
314
- """)
 
 
 
315
 
316
- def clear():
317
- return [
318
- 12,
319
- "2t",
320
- None
321
- ]
 
 
 
 
322
 
323
- # Connect the inputs to the forecast function
324
- submit_btn.click(
325
- fn=update_plot,
326
- inputs=[lead_time, variable],
327
- outputs=animation_output
328
- )
329
- clear_btn.click(
330
- fn=clear,
331
- inputs=[],
332
- outputs=[lead_time, variable, animation_output]
333
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
 
 
 
335
  demo.launch()
 
17
  import earthkit.data as ekd
18
  import earthkit.regrid as ekr
19
  import matplotlib.animation as animation
20
+ from functools import lru_cache
21
+ import hashlib
22
+ import pickle
23
+ import json
24
+ from typing import List, Dict, Any
25
+ import logging
26
+
27
+ # Configure logging
28
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
29
+ logger = logging.getLogger(__name__)
30
 
31
  # Define parameters (updating to match notebook.py)
32
  PARAM_SFC = ["10u", "10v", "2d", "2t", "msl", "skt", "sp", "tcw", "lsm", "z", "slor", "sdor"]
 
84
  TEMP_DIR.mkdir(exist_ok=True)
85
  os.environ['GRADIO_TEMP_DIR'] = str(TEMP_DIR)
86
 
87
+ # Add these cache-related functions after the MODEL initialization
88
+ def get_cache_key(date: datetime.datetime, params: List[str], levellist: List[int]) -> str:
89
+ """Create a unique cache key based on the request parameters"""
90
+ key_parts = [
91
+ date.isoformat(),
92
+ ",".join(sorted(params)),
93
+ ",".join(str(x) for x in sorted(levellist)) if levellist else "no_levels"
94
+ ]
95
+ key_string = "_".join(key_parts)
96
+ cache_key = hashlib.md5(key_string.encode()).hexdigest()
97
+ logger.info(f"Generated cache key: {cache_key} for {key_string}")
98
+ return cache_key
99
+
100
+ def get_cache_path(cache_key: str) -> Path:
101
+ """Get the path to the cache file"""
102
+ cache_dir = TEMP_DIR / "data_cache"
103
+ cache_dir.mkdir(exist_ok=True)
104
+ return cache_dir / f"{cache_key}.pkl"
105
+
106
+ def save_to_cache(cache_key: str, data: Dict[str, Any]) -> None:
107
+ """Save data to disk cache"""
108
+ cache_file = get_cache_path(cache_key)
109
+ try:
110
+ with open(cache_file, 'wb') as f:
111
+ pickle.dump(data, f)
112
+ logger.info(f"Successfully saved data to cache: {cache_file}")
113
+ except Exception as e:
114
+ logger.error(f"Failed to save to cache: {e}")
115
+
116
+ def load_from_cache(cache_key: str) -> Dict[str, Any]:
117
+ """Load data from disk cache"""
118
+ cache_file = get_cache_path(cache_key)
119
+ if cache_file.exists():
120
+ try:
121
+ with open(cache_file, 'rb') as f:
122
+ data = pickle.load(f)
123
+ logger.info(f"Successfully loaded data from cache: {cache_file}")
124
+ return data
125
+ except Exception as e:
126
+ logger.error(f"Failed to load from cache: {e}")
127
+ cache_file.unlink(missing_ok=True)
128
+ logger.info(f"No cache file found: {cache_file}")
129
+ return None
130
+
131
+ # Modify the get_open_data function to use caching
132
+ @lru_cache(maxsize=32)
133
+ def get_cached_data(date_str: str, param_tuple: tuple, levelist_tuple: tuple) -> Dict[str, Any]:
134
+ """Memory cache wrapper for get_open_data"""
135
+ return get_open_data_impl(
136
+ datetime.datetime.fromisoformat(date_str),
137
+ list(param_tuple),
138
+ list(levelist_tuple) if levelist_tuple else []
139
+ )
140
+
141
+ def get_open_data(param: List[str], levelist: List[int] = None) -> Dict[str, Any]:
142
+ """Main function to get data with caching"""
143
+ if levelist is None:
144
+ levelist = []
145
+
146
+ # Try disk cache first (more persistent than memory cache)
147
+ cache_key = get_cache_key(DEFAULT_DATE, param, levelist)
148
+ logger.info(f"Checking cache for key: {cache_key}")
149
+
150
+ cached_data = load_from_cache(cache_key)
151
+ if cached_data is not None:
152
+ logger.info(f"Cache hit for {cache_key}")
153
+ return cached_data
154
+
155
+ # If not in cache, download and process the data
156
+ logger.info(f"Cache miss for {cache_key}, downloading fresh data")
157
+ fields = get_open_data_impl(DEFAULT_DATE, param, levelist)
158
+
159
+ # Save to disk cache
160
+ save_to_cache(cache_key, fields)
161
+
162
+ return fields
163
+
164
+ def get_open_data_impl(date: datetime.datetime, param: List[str], levelist: List[int]) -> Dict[str, Any]:
165
+ """Implementation of data download and processing"""
166
  fields = {}
167
+ myiterable = [date - datetime.timedelta(hours=6), date]
168
+ logger.info(f"Downloading data for dates: {myiterable}")
169
+
170
+ for current_date in myiterable:
171
+ logger.info(f"Fetching data for {current_date}")
172
+ data = ekd.from_source("ecmwf-open-data", date=current_date, param=param, levelist=levelist)
 
173
  for f in data:
174
  assert f.to_numpy().shape == (721, 1440)
175
  values = np.roll(f.to_numpy(), -f.shape[1] // 2, axis=1)
 
185
 
186
  return fields
187
 
188
+ def plot_forecast(state, selected_variable):
189
+ logger.info(f"Plotting forecast for {selected_variable} at time {state['date']}")
190
+
191
  # Setup the figure and axis
192
  fig = plt.figure(figsize=(15, 8))
193
  ax = plt.axes(projection=ccrs.PlateCarree(central_longitude=0))
194
 
195
+ # Get the coordinates
196
+ latitudes, longitudes = state["latitudes"], state["longitudes"]
 
197
  fixed_lons = np.where(longitudes > 180, longitudes - 360, longitudes)
198
  triangulation = tri.Triangulation(fixed_lons, latitudes)
199
 
200
+ # Get the values
201
+ values = state["fields"][selected_variable]
202
+ logger.info(f"Value range: min={np.min(values):.2f}, max={np.max(values):.2f}")
203
 
204
+ # Set map features
205
+ ax.set_global()
206
+ ax.set_extent([-180, 180, -85, 85], crs=ccrs.PlateCarree())
207
+ ax.coastlines(resolution='50m')
208
+ ax.add_feature(cfeature.BORDERS, linestyle=":", alpha=0.5)
209
+ ax.gridlines(draw_labels=True)
210
 
211
+ # Create contour plot
212
+ contour = ax.tricontourf(triangulation, values,
213
+ levels=20, transform=ccrs.PlateCarree(),
214
+ cmap='RdBu_r')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
 
216
+ # Add colorbar
217
+ plt.colorbar(contour, ax=ax, orientation='horizontal', pad=0.05)
218
+
219
+ # Format the date string
220
+ forecast_time = state["date"]
221
+ if isinstance(forecast_time, str):
222
+ forecast_time = datetime.datetime.fromisoformat(forecast_time)
223
+ time_str = forecast_time.strftime("%Y-%m-%d %H:%M UTC")
224
 
225
+ # Get variable description
226
+ var_desc = None
227
+ for group in VARIABLE_GROUPS.values():
228
+ if selected_variable in group:
229
+ var_desc = group[selected_variable]
230
+ break
231
+ var_name = var_desc if var_desc else selected_variable
232
+
233
+ ax.set_title(f"{var_name} - {time_str}")
234
+
235
+ # Save as PNG
236
+ temp_file = str(TEMP_DIR / f"forecast_{datetime.datetime.now().timestamp()}.png")
237
+ plt.savefig(temp_file, bbox_inches='tight', dpi=100)
238
  plt.close()
239
 
240
  return temp_file
241
 
242
+ def run_forecast(date: datetime.datetime, lead_time: int, device: str) -> Dict[str, Any]:
243
  # Get all required fields
244
  fields = {}
245
+ logger.info(f"Starting forecast for lead_time: {lead_time} hours")
246
 
247
  # Get surface fields
248
+ logger.info("Getting surface fields...")
249
  fields.update(get_open_data(param=PARAM_SFC))
250
 
251
  # Get soil fields and rename them
252
+ logger.info("Getting soil fields...")
253
  soil = get_open_data(param=PARAM_SOIL, levelist=SOIL_LEVELS)
254
  mapping = {
255
  'sot_1': 'stl1', 'sot_2': 'stl2',
 
259
  fields[mapping[k]] = v
260
 
261
  # Get pressure level fields
262
+ logger.info("Getting pressure level fields...")
263
  fields.update(get_open_data(param=PARAM_PL, levelist=LEVELS))
264
 
265
  # Convert geopotential height to geopotential
 
271
 
272
  # Use the global model instance
273
  global MODEL
 
274
  if device != MODEL.device:
275
  MODEL = SimpleRunner("aifs-single-mse-1.0.ckpt", device=device)
276
 
277
+ # Run the model and get the final state
278
+ final_state = None
279
  for state in MODEL.run(input_state=input_state, lead_time=lead_time):
280
+ logger.info(f"\nπŸ˜€ date={state['date']} latitudes={state['latitudes'].shape} "
281
+ f"longitudes={state['longitudes'].shape} fields={len(state['fields'])}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
 
283
+ # Log a few example variables to show we have all fields
284
+ for var in ['2t', 'msl', 't_1000', 'z_850']:
285
+ if var in state['fields']:
286
+ values = state['fields'][var]
287
+ logger.info(f" {var:<6} shape={values.shape} "
288
+ f"min={np.min(values):.6f} "
289
+ f"max={np.max(values):.6f}")
290
+
291
+ final_state = state
 
 
 
 
 
 
 
 
 
292
 
293
+ logger.info(f"Final state contains {len(final_state['fields'])} variables")
294
+ return final_state
295
+
296
+ def get_available_variables(state):
297
+ """Get available variables from the state and organize them into groups"""
298
+ available_vars = set(state['fields'].keys())
299
 
300
+ # Create dropdown choices only for available variables
301
+ choices = []
302
+ for group_name, variables in VARIABLE_GROUPS.items():
303
+ group_vars = [(f"{desc} ({var_id})", var_id)
304
+ for var_id, desc in variables.items()
305
+ if var_id in available_vars]
306
+
307
+ if group_vars: # Only add group if it has available variables
308
+ choices.append((f"── {group_name} ──", None))
309
+ choices.extend(group_vars)
310
 
311
+ return choices
312
+
313
+ def update_interface():
314
+ with gr.Blocks(css="""
315
+ .centered-header {
316
+ text-align: center;
317
+ margin-bottom: 20px;
318
+ }
319
+ .subtitle {
320
+ font-size: 1.2em;
321
+ line-height: 1.5;
322
+ margin: 20px 0;
323
+ }
324
+ .footer {
325
+ text-align: center;
326
+ padding: 20px;
327
+ margin-top: 20px;
328
+ border-top: 1px solid #eee;
329
+ }
330
+ """) as demo:
331
+ state = gr.State(None)
332
+
333
+ with gr.Row():
334
+ with gr.Column(scale=1):
335
+ lead_time = gr.Slider(
336
+ minimum=6,
337
+ maximum=48,
338
+ step=6,
339
+ value=12,
340
+ label="Forecast Hours Ahead"
341
+ )
342
+ variable = gr.Dropdown(
343
+ choices=[], # Start empty
344
+ value=None,
345
+ label="Select Variable to Plot"
346
+ )
347
+ with gr.Row():
348
+ clear_btn = gr.Button("Clear")
349
+ run_btn = gr.Button("Run Forecast", variant="primary")
350
+
351
+ with gr.Row():
352
+ download_json = gr.Button("Download JSON")
353
+ download_nc = gr.Button("Download NetCDF")
354
+
355
+ with gr.Column(scale=2):
356
+ forecast_output = gr.Image()
357
+
358
+ def run_and_store(lead_time):
359
+ """Run forecast and store state"""
360
+ state = run_forecast(DEFAULT_DATE, lead_time, "cuda")
361
+
362
+ # Get available variables
363
+ choices = get_available_variables(state)
364
+
365
+ # Select first real variable as default
366
+ default_var = next((var_id for _, var_id in choices if var_id is not None), None)
367
+
368
+ # Generate initial plot
369
+ plot = plot_forecast(state, default_var) if default_var else None
370
+
371
+ return [state, gr.Dropdown(choices=choices), default_var, plot]
372
+
373
+ def update_plot_from_state(state, variable):
374
+ """Update plot using stored state"""
375
+ if state is None or variable is None:
376
+ return None
377
+ try:
378
+ return plot_forecast(state, variable)
379
+ except KeyError as e:
380
+ logger.error(f"Variable {variable} not found in state: {e}")
381
+ return None
382
+
383
+ def clear():
384
+ """Clear everything"""
385
+ return [None, None, gr.Dropdown(choices=[]), None]
386
+
387
+ def save_json(state):
388
+ if state is None:
389
+ return None
390
+ return save_forecast_data(state, 'json')
391
+
392
+ def save_netcdf(state):
393
+ if state is None:
394
+ return None
395
+ return save_forecast_data(state, 'netcdf')
396
+
397
+ # Connect the components
398
+ run_btn.click(
399
+ fn=run_and_store,
400
+ inputs=[lead_time],
401
+ outputs=[state, variable, variable, forecast_output]
402
+ )
403
+
404
+ variable.change(
405
+ fn=update_plot_from_state,
406
+ inputs=[state, variable],
407
+ outputs=forecast_output
408
+ )
409
+
410
+ clear_btn.click(
411
+ fn=clear,
412
+ inputs=[],
413
+ outputs=[state, forecast_output, variable, variable]
414
+ )
415
+
416
+ download_json.click(
417
+ fn=save_json,
418
+ inputs=[state],
419
+ outputs=gr.File()
420
+ )
421
+
422
+ download_nc.click(
423
+ fn=save_netcdf,
424
+ inputs=[state],
425
+ outputs=gr.File()
426
+ )
427
+
428
+ return demo
429
 
430
+ # Create and launch the interface
431
+ demo = update_interface()
432
  demo.launch()