Spaces:
Build error
Build error
import os | |
# Set memory optimization environment variables | |
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' | |
os.environ['ANEMOI_INFERENCE_NUM_CHUNKS'] = '16' | |
import gradio as gr | |
import datetime | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import cartopy.crs as ccrs | |
import cartopy.feature as cfeature | |
import matplotlib.tri as tri | |
from anemoi.inference.runners.simple import SimpleRunner | |
from ecmwf.opendata import Client as OpendataClient | |
import earthkit.data as ekd | |
import earthkit.regrid as ekr | |
# Define parameters (updating to match notebook.py) | |
PARAM_SFC = ["10u", "10v", "2d", "2t", "msl", "skt", "sp", "tcw", "lsm", "z", "slor", "sdor"] | |
PARAM_SOIL = ["vsw", "sot"] | |
PARAM_PL = ["gh", "t", "u", "v", "w", "q"] | |
LEVELS = [1000, 925, 850, 700, 600, 500, 400, 300, 250, 200, 150, 100, 50] | |
SOIL_LEVELS = [1, 2] | |
DEFAULT_DATE = OpendataClient().latest() | |
# First organize variables into categories | |
VARIABLE_GROUPS = { | |
"Surface Variables": { | |
"10u": "10m U Wind Component", | |
"10v": "10m V Wind Component", | |
"2d": "2m Dewpoint Temperature", | |
"2t": "2m Temperature", | |
"msl": "Mean Sea Level Pressure", | |
"skt": "Skin Temperature", | |
"sp": "Surface Pressure", | |
"tcw": "Total Column Water", | |
"lsm": "Land-Sea Mask", | |
"z": "Surface Geopotential", | |
"slor": "Slope of Sub-gridscale Orography", | |
"sdor": "Standard Deviation of Orography", | |
}, | |
"Soil Variables": { | |
"stl1": "Soil Temperature Level 1", | |
"stl2": "Soil Temperature Level 2", | |
"swvl1": "Soil Water Volume Level 1", | |
"swvl2": "Soil Water Volume Level 2", | |
}, | |
"Pressure Level Variables": {} # Will fill this dynamically | |
} | |
# Add pressure level variables dynamically | |
for var in ["t", "u", "v", "w", "q", "z"]: | |
var_name = { | |
"t": "Temperature", | |
"u": "U Wind Component", | |
"v": "V Wind Component", | |
"w": "Vertical Velocity", | |
"q": "Specific Humidity", | |
"z": "Geopotential" | |
}[var] | |
for level in LEVELS: | |
var_id = f"{var}_{level}" | |
VARIABLE_GROUPS["Pressure Level Variables"][var_id] = f"{var_name} at {level}hPa" | |
def get_open_data(param, levelist=[]): | |
fields = {} | |
# Get the data for the current date and the previous date | |
for date in [DEFAULT_DATE - datetime.timedelta(hours=6), DEFAULT_DATE]: | |
data = ekd.from_source("ecmwf-open-data", date=date, param=param, levelist=levelist) | |
for f in data: | |
assert f.to_numpy().shape == (721, 1440) | |
values = np.roll(f.to_numpy(), -f.shape[1] // 2, axis=1) | |
values = ekr.interpolate(values, {"grid": (0.25, 0.25)}, {"grid": "N320"}) | |
name = f"{f.metadata('param')}_{f.metadata('levelist')}" if levelist else f.metadata("param") | |
if name not in fields: | |
fields[name] = [] | |
fields[name].append(values) | |
# Create a single matrix for each parameter | |
for param, values in fields.items(): | |
fields[param] = np.stack(values) | |
return fields | |
def run_forecast(date, lead_time, device): | |
# Get all required fields | |
fields = {} | |
# Get surface fields | |
fields.update(get_open_data(param=PARAM_SFC)) | |
# Get soil fields and rename them | |
soil = get_open_data(param=PARAM_SOIL, levelist=SOIL_LEVELS) | |
mapping = { | |
'sot_1': 'stl1', 'sot_2': 'stl2', | |
'vsw_1': 'swvl1', 'vsw_2': 'swvl2' | |
} | |
for k, v in soil.items(): | |
fields[mapping[k]] = v | |
# Get pressure level fields | |
fields.update(get_open_data(param=PARAM_PL, levelist=LEVELS)) | |
# Convert geopotential height to geopotential | |
for level in LEVELS: | |
gh = fields.pop(f"gh_{level}") | |
fields[f"z_{level}"] = gh * 9.80665 | |
input_state = dict(date=date, fields=fields) | |
runner = SimpleRunner("aifs-single-mse-1.0.ckpt", device=device) | |
results = [] | |
for state in runner.run(input_state=input_state, lead_time=lead_time): | |
results.append(state) | |
return results[-1] | |
def plot_forecast(state, selected_variable): | |
latitudes, longitudes = state["latitudes"], state["longitudes"] | |
values = state["fields"][selected_variable] | |
fig, ax = plt.subplots(figsize=(11, 6), subplot_kw={"projection": ccrs.PlateCarree()}) | |
ax.coastlines() | |
ax.add_feature(cfeature.BORDERS, linestyle=":") | |
triangulation = tri.Triangulation(longitudes, latitudes) | |
# Use 'RdBu_r' instead of 'RdBu' to reverse the color scheme | |
contour = ax.tricontourf(triangulation, values, levels=20, | |
transform=ccrs.PlateCarree(), | |
cmap='RdBu_r') | |
plt.title(f"{selected_variable} at {state['date']}") | |
plt.colorbar(contour) | |
return fig | |
# Create dropdown choices with groups | |
DROPDOWN_CHOICES = [] | |
for group_name, variables in VARIABLE_GROUPS.items(): | |
# Add group separator | |
DROPDOWN_CHOICES.append((f"ββ {group_name} ββ", None)) | |
# Add variables in this group | |
for var_id, desc in sorted(variables.items()): | |
DROPDOWN_CHOICES.append((f"{desc} ({var_id})", var_id)) | |
def gradio_interface(date_str, lead_time, device, selected_variable): | |
try: | |
date = datetime.datetime.strptime(date_str, "%Y-%m-%d") | |
except ValueError: | |
raise gr.Error("Please enter a valid date in YYYY-MM-DD format") | |
state = run_forecast(date, lead_time, device) | |
return plot_forecast(state, selected_variable) | |
demo = gr.Interface( | |
fn=gradio_interface, | |
inputs=[ | |
gr.Textbox(value=DEFAULT_DATE.strftime("%Y-%m-%d"), label="Forecast Date (YYYY-MM-DD)"), | |
gr.Slider(minimum=6, maximum=48, step=6, value=12, label="Lead Time (Hours)"), | |
gr.Radio(choices=["cuda", "cpu"], value="cuda", label="Compute Device"), | |
gr.Dropdown( | |
choices=DROPDOWN_CHOICES, | |
value="2t", # Default to 2m temperature | |
label="Select Variable to Plot", | |
info="Choose a meteorological variable to visualize" | |
) | |
], | |
outputs=gr.Plot(), | |
title="AIFS Weather Forecast", | |
description="Interactive visualization of ECMWF AIFS weather forecasts. Select a date, forecast lead time, and meteorological variable to plot." | |
) | |
demo.launch() | |