File size: 21,522 Bytes
a952d46
e9a1c0f
 
a952d46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e9a1c0f
efc88b8
 
 
 
 
 
9ebc452
 
efc88b8
 
 
 
a952d46
 
 
 
 
 
 
 
 
62a6171
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93f7649
 
62a6171
93f7649
 
 
 
 
 
 
 
 
9ebc452
93f7649
62a6171
 
93f7649
45b15ae
 
 
e9a1c0f
 
 
 
 
efc88b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9ebc452
efc88b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9ebc452
efc88b8
 
 
9ebc452
efc88b8
 
 
 
9ebc452
efc88b8
 
 
9ebc452
efc88b8
 
9ebc452
efc88b8
 
 
 
a952d46
efc88b8
 
9ebc452
efc88b8
 
 
a952d46
 
 
 
 
 
 
 
9ebc452
a952d46
 
 
9ebc452
a952d46
 
efc88b8
 
9ebc452
e9a1c0f
 
 
9ebc452
efc88b8
 
e9a1c0f
 
9ebc452
efc88b8
 
 
9ebc452
efc88b8
 
 
 
 
 
9ebc452
efc88b8
 
 
 
9ebc452
efc88b8
 
9ebc452
efc88b8
 
 
 
 
9ebc452
efc88b8
 
 
 
 
 
 
9ebc452
efc88b8
9ebc452
efc88b8
 
 
e9a1c0f
9ebc452
e9a1c0f
 
efc88b8
a952d46
 
efc88b8
9ebc452
a952d46
efc88b8
a952d46
9ebc452
a952d46
efc88b8
a952d46
 
 
 
 
 
 
9ebc452
a952d46
efc88b8
a952d46
9ebc452
a952d46
 
 
 
9ebc452
a952d46
9ebc452
45b15ae
 
 
 
9ebc452
efc88b8
 
45b15ae
efc88b8
 
9ebc452
efc88b8
 
 
 
 
 
 
9ebc452
efc88b8
9ebc452
efc88b8
 
 
 
 
 
9ebc452
efc88b8
 
 
9ebc452
 
efc88b8
9ebc452
efc88b8
 
 
9ebc452
efc88b8
 
9ebc452
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
efc88b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9ebc452
 
 
 
 
 
 
 
 
 
 
 
efc88b8
 
 
 
 
 
 
 
 
 
9ebc452
efc88b8
9ebc452
 
efc88b8
 
 
 
 
 
9ebc452
 
efc88b8
 
 
 
 
 
9ebc452
 
 
efc88b8
9ebc452
efc88b8
9ebc452
efc88b8
 
9ebc452
efc88b8
 
 
 
 
 
9ebc452
efc88b8
9ebc452
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
efc88b8
 
 
 
 
9ebc452
efc88b8
 
 
 
9ebc452
efc88b8
 
 
 
 
 
9ebc452
efc88b8
 
 
 
9ebc452
 
efc88b8
 
 
a952d46
efc88b8
 
a952d46
9ebc452
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
import os
import tempfile
from pathlib import Path
# Set memory optimization environment variables
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
os.environ['ANEMOI_INFERENCE_NUM_CHUNKS'] = '16'

import gradio as gr
import datetime
import numpy as np
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.tri as tri
from anemoi.inference.runners.simple import SimpleRunner
from ecmwf.opendata import Client as OpendataClient
import earthkit.data as ekd
import earthkit.regrid as ekr
import matplotlib.animation as animation
from functools import lru_cache
import hashlib
import pickle
import json
from typing import List, Dict, Any
import logging
import xarray as xr
import pandas as pd

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Define parameters (updating to match notebook.py)
PARAM_SFC = ["10u", "10v", "2d", "2t", "msl", "skt", "sp", "tcw", "lsm", "z", "slor", "sdor"]
PARAM_SOIL = ["vsw", "sot"]
PARAM_PL = ["gh", "t", "u", "v", "w", "q"]
LEVELS = [1000, 925, 850, 700, 600, 500, 400, 300, 250, 200, 150, 100, 50]
SOIL_LEVELS = [1, 2]
DEFAULT_DATE = OpendataClient().latest()

# First organize variables into categories
VARIABLE_GROUPS = {
    "Surface Variables": {
        "10u": "10m U Wind Component",
        "10v": "10m V Wind Component",
        "2d": "2m Dewpoint Temperature",
        "2t": "2m Temperature",
        "msl": "Mean Sea Level Pressure",
        "skt": "Skin Temperature",
        "sp": "Surface Pressure",
        "tcw": "Total Column Water",
        "lsm": "Land-Sea Mask",
        "z": "Surface Geopotential",
        "slor": "Slope of Sub-gridscale Orography",
        "sdor": "Standard Deviation of Orography",
    },
    "Soil Variables": {
        "stl1": "Soil Temperature Level 1",
        "stl2": "Soil Temperature Level 2",
        "swvl1": "Soil Water Volume Level 1",
        "swvl2": "Soil Water Volume Level 2",
    },
    "Pressure Level Variables": {}  # Will fill this dynamically
}

# Add pressure level variables dynamically
for var in ["t", "u", "v", "w", "q", "z"]:
    var_name = {
        "t": "Temperature",
        "u": "U Wind Component",
        "v": "V Wind Component",
        "w": "Vertical Velocity",
        "q": "Specific Humidity",
        "z": "Geopotential"
    }[var]

    for level in LEVELS:
        var_id = f"{var}_{level}"
        VARIABLE_GROUPS["Pressure Level Variables"][var_id] = f"{var_name} at {level}hPa"

# Load the model once at startup
MODEL = SimpleRunner("aifs-single-mse-1.0.ckpt", device="cuda")  # Default to CUDA

# Create and set custom temp directory
TEMP_DIR = Path("./gradio_temp")
TEMP_DIR.mkdir(exist_ok=True)
os.environ['GRADIO_TEMP_DIR'] = str(TEMP_DIR)

# Add these cache-related functions after the MODEL initialization
def get_cache_key(date: datetime.datetime, params: List[str], levellist: List[int]) -> str:
    """Create a unique cache key based on the request parameters"""
    key_parts = [
        date.isoformat(),
        ",".join(sorted(params)),
        ",".join(str(x) for x in sorted(levellist)) if levellist else "no_levels"
    ]
    key_string = "_".join(key_parts)
    cache_key = hashlib.md5(key_string.encode()).hexdigest()
    logger.info(f"Generated cache key: {cache_key} for {key_string}")
    return cache_key

def get_cache_path(cache_key: str) -> Path:
    """Get the path to the cache file"""
    return TEMP_DIR / "data_cache" / f"{cache_key}.pkl"

def save_to_cache(cache_key: str, data: Dict[str, Any]) -> None:
    """Save data to disk cache"""
    cache_file = get_cache_path(cache_key)
    try:
        with open(cache_file, 'wb') as f:
            pickle.dump(data, f)
        logger.info(f"Successfully saved data to cache: {cache_file}")
    except Exception as e:
        logger.error(f"Failed to save to cache: {e}")

def load_from_cache(cache_key: str) -> Dict[str, Any]:
    """Load data from disk cache"""
    cache_file = get_cache_path(cache_key)
    if cache_file.exists():
        try:
            with open(cache_file, 'rb') as f:
                data = pickle.load(f)
            logger.info(f"Successfully loaded data from cache: {cache_file}")
            return data
        except Exception as e:
            logger.error(f"Failed to load from cache: {e}")
            cache_file.unlink(missing_ok=True)
    logger.info(f"No cache file found: {cache_file}")
    return None

# Modify the get_open_data function to use caching
@lru_cache(maxsize=32)
def get_cached_data(date_str: str, param_tuple: tuple, levelist_tuple: tuple) -> Dict[str, Any]:
    """Memory cache wrapper for get_open_data"""
    return get_open_data_impl(
        datetime.datetime.fromisoformat(date_str),
        list(param_tuple),
        list(levelist_tuple) if levelist_tuple else []
    )

def get_open_data(param: List[str], levelist: List[int] = None) -> Dict[str, Any]:
    """Main function to get data with caching"""
    if levelist is None:
        levelist = []

    # Try disk cache first (more persistent than memory cache)
    cache_key = get_cache_key(DEFAULT_DATE, param, levelist)
    logger.info(f"Checking cache for key: {cache_key}")

    cached_data = load_from_cache(cache_key)
    if cached_data is not None:
        logger.info(f"Cache hit for {cache_key}")
        return cached_data

    # If not in cache, download and process the data
    logger.info(f"Cache miss for {cache_key}, downloading fresh data")
    fields = get_open_data_impl(DEFAULT_DATE, param, levelist)

    # Save to disk cache
    save_to_cache(cache_key, fields)

    return fields

def get_open_data_impl(date: datetime.datetime, param: List[str], levelist: List[int]) -> Dict[str, Any]:
    """Implementation of data download and processing"""
    fields = {}
    myiterable = [date - datetime.timedelta(hours=6), date]
    logger.info(f"Downloading data for dates: {myiterable}")

    for current_date in myiterable:
        logger.info(f"Fetching data for {current_date}")
        data = ekd.from_source("ecmwf-open-data", date=current_date, param=param, levelist=levelist)
        for f in data:
            assert f.to_numpy().shape == (721, 1440)
            values = np.roll(f.to_numpy(), -f.shape[1] // 2, axis=1)
            values = ekr.interpolate(values, {"grid": (0.25, 0.25)}, {"grid": "N320"})
            name = f"{f.metadata('param')}_{f.metadata('levelist')}" if levelist else f.metadata("param")
            if name not in fields:
                fields[name] = []
            fields[name].append(values)

    # Create a single matrix for each parameter
    for param, values in fields.items():
        fields[param] = np.stack(values)

    return fields

def plot_forecast(state, selected_variable):
    logger.info(f"Plotting forecast for {selected_variable} at time {state['date']}")

    # Setup the figure and axis
    fig = plt.figure(figsize=(15, 8))
    ax = plt.axes(projection=ccrs.PlateCarree(central_longitude=0))

    # Get the coordinates
    latitudes, longitudes = state["latitudes"], state["longitudes"]
    fixed_lons = np.where(longitudes > 180, longitudes - 360, longitudes)
    triangulation = tri.Triangulation(fixed_lons, latitudes)

    # Get the values
    values = state["fields"][selected_variable]
    logger.info(f"Value range: min={np.min(values):.2f}, max={np.max(values):.2f}")

    # Set map features
    ax.set_global()
    ax.set_extent([-180, 180, -85, 85], crs=ccrs.PlateCarree())
    ax.coastlines(resolution='50m')
    ax.add_feature(cfeature.BORDERS, linestyle=":", alpha=0.5)
    ax.gridlines(draw_labels=True)

    # Create contour plot
    contour = ax.tricontourf(triangulation, values,
                            levels=20, transform=ccrs.PlateCarree(),
                            cmap='RdBu_r')

    # Add colorbar
    plt.colorbar(contour, ax=ax, orientation='horizontal', pad=0.05)

    # Format the date string
    forecast_time = state["date"]
    if isinstance(forecast_time, str):
        forecast_time = datetime.datetime.fromisoformat(forecast_time)
    time_str = forecast_time.strftime("%Y-%m-%d %H:%M UTC")

    # Get variable description
    var_desc = None
    for group in VARIABLE_GROUPS.values():
        if selected_variable in group:
            var_desc = group[selected_variable]
            break
    var_name = var_desc if var_desc else selected_variable

    ax.set_title(f"{var_name} - {time_str}")

    # Save as PNG
    temp_file = str(TEMP_DIR / f"forecast_{datetime.datetime.now().timestamp()}.png")
    plt.savefig(temp_file, bbox_inches='tight', dpi=100)
    plt.close()

    return temp_file

def run_forecast(date: datetime.datetime, lead_time: int, device: str) -> Dict[str, Any]:
    # Get all required fields
    fields = {}
    logger.info(f"Starting forecast for lead_time: {lead_time} hours")

    # Get surface fields
    logger.info("Getting surface fields...")
    fields.update(get_open_data(param=PARAM_SFC))

    # Get soil fields and rename them
    logger.info("Getting soil fields...")
    soil = get_open_data(param=PARAM_SOIL, levelist=SOIL_LEVELS)
    mapping = {
        'sot_1': 'stl1', 'sot_2': 'stl2',
        'vsw_1': 'swvl1', 'vsw_2': 'swvl2'
    }
    for k, v in soil.items():
        fields[mapping[k]] = v

    # Get pressure level fields
    logger.info("Getting pressure level fields...")
    fields.update(get_open_data(param=PARAM_PL, levelist=LEVELS))

    # Convert geopotential height to geopotential
    for level in LEVELS:
        gh = fields.pop(f"gh_{level}")
        fields[f"z_{level}"] = gh * 9.80665

    input_state = dict(date=date, fields=fields)

    # Use the global model instance
    global MODEL
    if device != MODEL.device:
        MODEL = SimpleRunner("aifs-single-mse-1.0.ckpt", device=device)

    # Run the model and get the final state
    final_state = None
    for state in MODEL.run(input_state=input_state, lead_time=lead_time):
        logger.info(f"\nπŸ˜€ date={state['date']} latitudes={state['latitudes'].shape} "
                   f"longitudes={state['longitudes'].shape} fields={len(state['fields'])}")

        # Log a few example variables to show we have all fields
        for var in ['2t', 'msl', 't_1000', 'z_850']:
            if var in state['fields']:
                values = state['fields'][var]
                logger.info(f"    {var:<6} shape={values.shape} "
                          f"min={np.min(values):.6f} "
                          f"max={np.max(values):.6f}")

        final_state = state

    logger.info(f"Final state contains {len(final_state['fields'])} variables")
    return final_state

def get_available_variables(state):
    """Get available variables from the state and organize them into groups"""
    available_vars = set(state['fields'].keys())

    # Create dropdown choices only for available variables
    choices = []
    for group_name, variables in VARIABLE_GROUPS.items():
        group_vars = [(f"{desc} ({var_id})", var_id)
                     for var_id, desc in variables.items()
                     if var_id in available_vars]

        if group_vars:  # Only add group if it has available variables
            choices.append((f"── {group_name} ──", None))
            choices.extend(group_vars)

    return choices

def save_forecast_data(state, format='json'):
    """Save forecast data in specified format"""
    if state is None:
        raise ValueError("No forecast data available. Please run a forecast first.")

    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    forecast_time = state['date'].strftime("%Y%m%d_%H") if isinstance(state['date'], datetime.datetime) else state['date']

    # Use forecasts directory for all outputs
    output_dir = TEMP_DIR / "forecasts"

    if format == 'json':
        # Create a JSON-serializable dictionary
        data = {
            'metadata': {
                'forecast_date': forecast_time,
                'export_date': datetime.datetime.now().isoformat(),
                'total_points': len(state['latitudes']),
                'total_variables': len(state['fields'])
            },
            'coordinates': {
                'latitudes': state['latitudes'].tolist(),
                'longitudes': state['longitudes'].tolist()
            },
            'fields': {
                var_name: {
                    'values': values.tolist(),
                    'statistics': {
                        'min': float(np.min(values)),
                        'max': float(np.max(values)),
                        'mean': float(np.mean(values)),
                        'std': float(np.std(values))
                    }
                }
                for var_name, values in state['fields'].items()
            }
        }

        output_file = output_dir / f"forecast_{forecast_time}_{timestamp}.json"
        with open(output_file, 'w') as f:
            json.dump(data, f, indent=2)

        return str(output_file)

    elif format == 'netcdf':
        # Create an xarray Dataset
        data_vars = {}
        coords = {
            'point': np.arange(len(state['latitudes'])),
            'latitude': ('point', state['latitudes']),
            'longitude': ('point', state['longitudes']),
        }

        # Add each field as a variable
        for var_name, values in state['fields'].items():
            data_vars[var_name] = (['point'], values)

        # Create the dataset
        ds = xr.Dataset(
            data_vars=data_vars,
            coords=coords,
            attrs={
                'forecast_date': forecast_time,
                'export_date': datetime.datetime.now().isoformat(),
                'description': 'AIFS Weather Forecast Data'
            }
        )

        output_file = output_dir / f"forecast_{forecast_time}_{timestamp}.nc"
        ds.to_netcdf(output_file)

        return str(output_file)

    elif format == 'csv':
        # Create a DataFrame with lat/lon and all variables
        df = pd.DataFrame({
            'latitude': state['latitudes'],
            'longitude': state['longitudes']
        })

        # Add each field as a column
        for var_name, values in state['fields'].items():
            df[var_name] = values

        output_file = output_dir / f"forecast_{forecast_time}_{timestamp}.csv"
        df.to_csv(output_file, index=False)

        return str(output_file)

    else:
        raise ValueError(f"Unsupported format: {format}")

# Create dropdown choices with groups
DROPDOWN_CHOICES = []
for group_name, variables in VARIABLE_GROUPS.items():
    # Add group separator
    DROPDOWN_CHOICES.append((f"── {group_name} ──", None))
    # Add variables in this group
    for var_id, desc in sorted(variables.items()):
        DROPDOWN_CHOICES.append((f"{desc} ({var_id})", var_id))

def update_interface():
    with gr.Blocks(css="""
        .centered-header {
            text-align: center;
            margin-bottom: 20px;
        }
        .subtitle {
            font-size: 1.2em;
            line-height: 1.5;
            margin: 20px 0;
        }
        .footer {
            text-align: center;
            padding: 20px;
            margin-top: 20px;
            border-top: 1px solid #eee;
        }
    """) as demo:
        forecast_state = gr.State(None)
        
        # Header section
        gr.Markdown(f"""
        # AIFS Weather Forecast
        
        <div class="subtitle">
        Interactive visualization of ECMWF AIFS weather forecasts.<br>
        Starting from the latest available data ({DEFAULT_DATE.strftime('%Y-%m-%d %H:%M UTC')}),<br>
        select how many hours ahead you want to forecast and which meteorological variable to visualize.
        </div>
        """)
        
        with gr.Row():
            with gr.Column(scale=1):
                lead_time = gr.Slider(
                    minimum=6, 
                    maximum=48, 
                    step=6, 
                    value=12, 
                    label="Forecast Hours Ahead"
                )
                # Start with the original DROPDOWN_CHOICES
                variable = gr.Dropdown(
                    choices=DROPDOWN_CHOICES,  # Use original choices at startup
                    value="2t",
                    label="Select Variable to Plot"
                )
                with gr.Row():
                    clear_btn = gr.Button("Clear")
                    run_btn = gr.Button("Run Forecast", variant="primary")
                
                download_nc = gr.Button("Download Forecast (NetCDF)")
                download_output = gr.File(label="Download Output")
            
            with gr.Column(scale=2):
                forecast_output = gr.Image()
        
        def run_and_store(lead_time):
            """Run forecast and store state"""
            forecast_state = run_forecast(DEFAULT_DATE, lead_time, "cuda")
            plot = plot_forecast(forecast_state, "2t")  # Default to 2t
            return forecast_state, plot
        
        def update_plot_from_state(forecast_state, variable):
            """Update plot using stored state"""
            if forecast_state is None or variable is None:
                return None
            try:
                return plot_forecast(forecast_state, variable)
            except KeyError as e:
                logger.error(f"Variable {variable} not found in state: {e}")
                return None
        
        def clear():
            """Clear everything"""
            return [None, None, 12, "2t"]
        
        def save_netcdf(forecast_state):
            """Save forecast data as NetCDF"""
            if forecast_state is None:
                raise ValueError("No forecast data available. Please run a forecast first.")
                
            timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
            forecast_time = forecast_state['date'].strftime("%Y%m%d_%H") if isinstance(forecast_state['date'], datetime.datetime) else forecast_state['date']
            
            # Create an xarray Dataset
            data_vars = {}
            coords = {
                'point': np.arange(len(forecast_state['latitudes'])),
                'latitude': ('point', forecast_state['latitudes']),
                'longitude': ('point', forecast_state['longitudes']),
            }
            
            # Add each field as a variable
            for var_name, values in forecast_state['fields'].items():
                data_vars[var_name] = (['point'], values)
            
            # Create the dataset
            ds = xr.Dataset(
                data_vars=data_vars,
                coords=coords,
                attrs={
                    'forecast_date': forecast_time,
                    'export_date': datetime.datetime.now().isoformat(),
                    'description': 'AIFS Weather Forecast Data'
                }
            )
            
            output_file = TEMP_DIR / "forecasts" / f"forecast_{forecast_time}_{timestamp}.nc"
            ds.to_netcdf(output_file)
            
            return str(output_file)
        
        # Connect the components
        run_btn.click(
            fn=run_and_store,
            inputs=[lead_time],
            outputs=[forecast_state, forecast_output]
        )
        
        variable.change(
            fn=update_plot_from_state,
            inputs=[forecast_state, variable],
            outputs=forecast_output
        )
        
        clear_btn.click(
            fn=clear,
            inputs=[],
            outputs=[forecast_state, forecast_output, lead_time, variable]
        )
        
        download_nc.click(
            fn=save_netcdf,
            inputs=[forecast_state],
            outputs=[download_output]
        )
        
        return demo

# Create and launch the interface
demo = update_interface()
demo.launch()

def setup_directories():
    """Create necessary directories with .keep files"""
    # Define all required directories
    directories = {
        TEMP_DIR / "data_cache": "Cache directory for downloaded weather data",
        TEMP_DIR / "forecasts": "Directory for forecast outputs (plots and data files)",
    }

    # Create directories and .keep files
    for directory, description in directories.items():
        directory.mkdir(parents=True, exist_ok=True)
        keep_file = directory / ".keep"
        if not keep_file.exists():
            keep_file.write_text(f"# {description}\n# This file ensures the directory is tracked in git\n")
            logger.info(f"Created directory and .keep file: {directory}")

# Call it during initialization
setup_directories()

def cleanup_old_files():
    """Remove old temporary and cache files"""
    current_time = datetime.datetime.now().timestamp()

    # Clean up forecast files (1 hour old)
    forecast_dir = TEMP_DIR / "forecasts"
    for file in forecast_dir.glob("*.*"):
        if file.name == ".keep":
            continue
        if current_time - file.stat().st_mtime > 3600:
            logger.info(f"Removing old forecast file: {file}")
            file.unlink(missing_ok=True)

    # Clean up cache files (24 hours old)
    cache_dir = TEMP_DIR / "data_cache"
    for file in cache_dir.glob("*.pkl"):
        if file.name == ".keep":
            continue
        if current_time - file.stat().st_mtime > 86400:
            logger.info(f"Removing old cache file: {file}")
            file.unlink(missing_ok=True)