Spaces:
Build error
Build error
remove animation
Browse files- .gitignore +1 -0
- app.py +288 -191
.gitignore
CHANGED
@@ -1,2 +1,3 @@
|
|
1 |
aifs-single-mse-1.0.ckpt
|
2 |
flagged/
|
|
|
|
1 |
aifs-single-mse-1.0.ckpt
|
2 |
flagged/
|
3 |
+
gradio_temp/*
|
app.py
CHANGED
@@ -17,6 +17,16 @@ from ecmwf.opendata import Client as OpendataClient
|
|
17 |
import earthkit.data as ekd
|
18 |
import earthkit.regrid as ekr
|
19 |
import matplotlib.animation as animation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
# Define parameters (updating to match notebook.py)
|
22 |
PARAM_SFC = ["10u", "10v", "2d", "2t", "msl", "skt", "sp", "tcw", "lsm", "z", "slor", "sdor"]
|
@@ -74,15 +84,92 @@ TEMP_DIR = Path("./gradio_temp")
|
|
74 |
TEMP_DIR.mkdir(exist_ok=True)
|
75 |
os.environ['GRADIO_TEMP_DIR'] = str(TEMP_DIR)
|
76 |
|
77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
fields = {}
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
for
|
83 |
-
|
84 |
-
|
85 |
-
data = ekd.from_source("ecmwf-open-data", date=date, param=param, levelist=levelist)
|
86 |
for f in data:
|
87 |
assert f.to_numpy().shape == (721, 1440)
|
88 |
values = np.roll(f.to_numpy(), -f.shape[1] // 2, axis=1)
|
@@ -98,99 +185,71 @@ def get_open_data(param, levelist=[]):
|
|
98 |
|
99 |
return fields
|
100 |
|
101 |
-
def
|
|
|
|
|
102 |
# Setup the figure and axis
|
103 |
fig = plt.figure(figsize=(15, 8))
|
104 |
ax = plt.axes(projection=ccrs.PlateCarree(central_longitude=0))
|
105 |
|
106 |
-
# Get the
|
107 |
-
|
108 |
-
latitudes, longitudes = first_state["latitudes"], first_state["longitudes"]
|
109 |
fixed_lons = np.where(longitudes > 180, longitudes - 360, longitudes)
|
110 |
triangulation = tri.Triangulation(fixed_lons, latitudes)
|
111 |
|
112 |
-
#
|
113 |
-
|
114 |
-
|
115 |
|
116 |
-
#
|
117 |
-
|
118 |
-
|
|
|
|
|
|
|
119 |
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
# Set map features
|
125 |
-
ax.set_global()
|
126 |
-
ax.set_extent([-180, 180, -85, 85], crs=ccrs.PlateCarree())
|
127 |
-
ax.coastlines(resolution='50m')
|
128 |
-
ax.add_feature(cfeature.BORDERS, linestyle=":", alpha=0.5)
|
129 |
-
ax.gridlines(draw_labels=True)
|
130 |
-
|
131 |
-
state = states[frame]
|
132 |
-
values = state["fields"][selected_variable]
|
133 |
-
|
134 |
-
# Clear the previous colorbar axis if it exists
|
135 |
-
if cbar_ax:
|
136 |
-
cbar_ax.remove()
|
137 |
-
|
138 |
-
# Create new contour plot
|
139 |
-
contour = ax.tricontourf(triangulation, values,
|
140 |
-
levels=20, transform=ccrs.PlateCarree(),
|
141 |
-
cmap='RdBu_r', vmin=vmin, vmax=vmax)
|
142 |
-
|
143 |
-
# Create new colorbar
|
144 |
-
cbar_ax = fig.add_axes([0.1, 0.05, 0.8, 0.03]) # [left, bottom, width, height]
|
145 |
-
plt.colorbar(contour, cax=cbar_ax, orientation='horizontal')
|
146 |
-
|
147 |
-
# Format the date string properly
|
148 |
-
forecast_time = state["date"]
|
149 |
-
if isinstance(forecast_time, str):
|
150 |
-
try:
|
151 |
-
forecast_time = datetime.datetime.strptime(forecast_time, "%Y-%m-%d %H:%M:%S")
|
152 |
-
except ValueError:
|
153 |
-
try:
|
154 |
-
forecast_time = datetime.datetime.strptime(forecast_time, "%Y-%m-%d %H:%M:%S.%f")
|
155 |
-
except ValueError:
|
156 |
-
forecast_time = DEFAULT_DATE
|
157 |
-
|
158 |
-
time_str = forecast_time.strftime("%Y-%m-%d %H:%M UTC")
|
159 |
-
|
160 |
-
# Get variable description from VARIABLE_GROUPS
|
161 |
-
var_desc = None
|
162 |
-
for group in VARIABLE_GROUPS.values():
|
163 |
-
if selected_variable in group:
|
164 |
-
var_desc = group[selected_variable]
|
165 |
-
break
|
166 |
-
var_name = var_desc if var_desc else selected_variable
|
167 |
-
|
168 |
-
ax.set_title(f"{var_name} - {time_str}")
|
169 |
|
170 |
-
#
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
)
|
178 |
|
179 |
-
#
|
180 |
-
|
181 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
plt.close()
|
183 |
|
184 |
return temp_file
|
185 |
|
186 |
-
def run_forecast(date, lead_time, device):
|
187 |
# Get all required fields
|
188 |
fields = {}
|
|
|
189 |
|
190 |
# Get surface fields
|
|
|
191 |
fields.update(get_open_data(param=PARAM_SFC))
|
192 |
|
193 |
# Get soil fields and rename them
|
|
|
194 |
soil = get_open_data(param=PARAM_SOIL, levelist=SOIL_LEVELS)
|
195 |
mapping = {
|
196 |
'sot_1': 'stl1', 'sot_2': 'stl2',
|
@@ -200,6 +259,7 @@ def run_forecast(date, lead_time, device):
|
|
200 |
fields[mapping[k]] = v
|
201 |
|
202 |
# Get pressure level fields
|
|
|
203 |
fields.update(get_open_data(param=PARAM_PL, levelist=LEVELS))
|
204 |
|
205 |
# Convert geopotential height to geopotential
|
@@ -211,125 +271,162 @@ def run_forecast(date, lead_time, device):
|
|
211 |
|
212 |
# Use the global model instance
|
213 |
global MODEL
|
214 |
-
# If device preference changed, move model to new device
|
215 |
if device != MODEL.device:
|
216 |
MODEL = SimpleRunner("aifs-single-mse-1.0.ckpt", device=device)
|
217 |
|
218 |
-
#
|
219 |
-
|
220 |
for state in MODEL.run(input_state=input_state, lead_time=lead_time):
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
def update_plot(lead_time, variable):
|
225 |
-
cleanup_old_files() # Clean up old files before creating new ones
|
226 |
-
states = run_forecast(DEFAULT_DATE, lead_time, "cuda")
|
227 |
-
return plot_forecast_animation(states, variable)
|
228 |
-
|
229 |
-
# Add cleanup function for old files
|
230 |
-
def cleanup_old_files():
|
231 |
-
# Remove files older than 1 hour
|
232 |
-
current_time = datetime.datetime.now().timestamp()
|
233 |
-
for file in TEMP_DIR.glob("*.mp4"): # Changed from *.gif to *.mp4
|
234 |
-
if current_time - file.stat().st_mtime > 3600: # 1 hour in seconds
|
235 |
-
file.unlink(missing_ok=True)
|
236 |
-
|
237 |
-
# Create dropdown choices with groups
|
238 |
-
DROPDOWN_CHOICES = []
|
239 |
-
for group_name, variables in VARIABLE_GROUPS.items():
|
240 |
-
# Add group separator
|
241 |
-
DROPDOWN_CHOICES.append((f"ββ {group_name} ββ", None))
|
242 |
-
# Add variables in this group
|
243 |
-
for var_id, desc in sorted(variables.items()):
|
244 |
-
DROPDOWN_CHOICES.append((f"{desc} ({var_id})", var_id))
|
245 |
-
|
246 |
-
with gr.Blocks(css="""
|
247 |
-
.centered-header {
|
248 |
-
text-align: center;
|
249 |
-
margin-bottom: 20px;
|
250 |
-
}
|
251 |
-
.subtitle {
|
252 |
-
font-size: 1.2em;
|
253 |
-
line-height: 1.5;
|
254 |
-
margin: 20px 0;
|
255 |
-
}
|
256 |
-
.footer {
|
257 |
-
text-align: center;
|
258 |
-
padding: 20px;
|
259 |
-
margin-top: 20px;
|
260 |
-
border-top: 1px solid #eee;
|
261 |
-
}
|
262 |
-
""") as demo:
|
263 |
-
# Header section
|
264 |
-
gr.Markdown(f"""
|
265 |
-
# AIFS Weather Forecast
|
266 |
-
|
267 |
-
<div class="subtitle">
|
268 |
-
Interactive visualization of ECMWF AIFS weather forecasts.<br>
|
269 |
-
Starting from the latest available data ({DEFAULT_DATE.strftime('%Y-%m-%d %H:%M UTC')}),<br>
|
270 |
-
select how many hours ahead you want to forecast and which meteorological variable to visualize.
|
271 |
-
</div>
|
272 |
-
""")
|
273 |
-
|
274 |
-
# Main content
|
275 |
-
with gr.Row():
|
276 |
-
with gr.Column(scale=1):
|
277 |
-
lead_time = gr.Slider(
|
278 |
-
minimum=6,
|
279 |
-
maximum=48,
|
280 |
-
step=6,
|
281 |
-
value=12,
|
282 |
-
label="Forecast Hours Ahead"
|
283 |
-
)
|
284 |
-
variable = gr.Dropdown(
|
285 |
-
choices=DROPDOWN_CHOICES,
|
286 |
-
value="2t",
|
287 |
-
label="Select Variable to Plot"
|
288 |
-
)
|
289 |
-
with gr.Row():
|
290 |
-
clear_btn = gr.Button("Clear")
|
291 |
-
submit_btn = gr.Button("Submit", variant="primary")
|
292 |
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
1. Visit <a href="https://huggingface.co/spaces/geobase/aifs-forecast" target="_blank">https://huggingface.co/spaces/geobase/aifs-forecast</a>\n
|
303 |
-
2. Click the "Duplicate this Space" button in the top right\n
|
304 |
-
3. Select your hardware requirements (GPU recommended)\n
|
305 |
-
4. Wait for your copy to deploy
|
306 |
-
|
307 |
-
<h3>Model Information</h3>
|
308 |
-
This demo uses the <a href="https://huggingface.co/ecmwf/aifs-single-1.0" target="_blank">AIFS Single 1.0</a> model from ECMWF,
|
309 |
-
which is their first operationally supported Artificial Intelligence Forecasting System. The model produces highly skilled forecasts
|
310 |
-
for upper-air variables, surface weather parameters, and tropical cyclone tracks.
|
311 |
|
312 |
-
|
313 |
-
|
314 |
-
|
|
|
|
|
|
|
315 |
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
|
|
|
|
|
|
|
|
322 |
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
334 |
|
|
|
|
|
335 |
demo.launch()
|
|
|
17 |
import earthkit.data as ekd
|
18 |
import earthkit.regrid as ekr
|
19 |
import matplotlib.animation as animation
|
20 |
+
from functools import lru_cache
|
21 |
+
import hashlib
|
22 |
+
import pickle
|
23 |
+
import json
|
24 |
+
from typing import List, Dict, Any
|
25 |
+
import logging
|
26 |
+
|
27 |
+
# Configure logging
|
28 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
29 |
+
logger = logging.getLogger(__name__)
|
30 |
|
31 |
# Define parameters (updating to match notebook.py)
|
32 |
PARAM_SFC = ["10u", "10v", "2d", "2t", "msl", "skt", "sp", "tcw", "lsm", "z", "slor", "sdor"]
|
|
|
84 |
TEMP_DIR.mkdir(exist_ok=True)
|
85 |
os.environ['GRADIO_TEMP_DIR'] = str(TEMP_DIR)
|
86 |
|
87 |
+
# Add these cache-related functions after the MODEL initialization
|
88 |
+
def get_cache_key(date: datetime.datetime, params: List[str], levellist: List[int]) -> str:
|
89 |
+
"""Create a unique cache key based on the request parameters"""
|
90 |
+
key_parts = [
|
91 |
+
date.isoformat(),
|
92 |
+
",".join(sorted(params)),
|
93 |
+
",".join(str(x) for x in sorted(levellist)) if levellist else "no_levels"
|
94 |
+
]
|
95 |
+
key_string = "_".join(key_parts)
|
96 |
+
cache_key = hashlib.md5(key_string.encode()).hexdigest()
|
97 |
+
logger.info(f"Generated cache key: {cache_key} for {key_string}")
|
98 |
+
return cache_key
|
99 |
+
|
100 |
+
def get_cache_path(cache_key: str) -> Path:
|
101 |
+
"""Get the path to the cache file"""
|
102 |
+
cache_dir = TEMP_DIR / "data_cache"
|
103 |
+
cache_dir.mkdir(exist_ok=True)
|
104 |
+
return cache_dir / f"{cache_key}.pkl"
|
105 |
+
|
106 |
+
def save_to_cache(cache_key: str, data: Dict[str, Any]) -> None:
|
107 |
+
"""Save data to disk cache"""
|
108 |
+
cache_file = get_cache_path(cache_key)
|
109 |
+
try:
|
110 |
+
with open(cache_file, 'wb') as f:
|
111 |
+
pickle.dump(data, f)
|
112 |
+
logger.info(f"Successfully saved data to cache: {cache_file}")
|
113 |
+
except Exception as e:
|
114 |
+
logger.error(f"Failed to save to cache: {e}")
|
115 |
+
|
116 |
+
def load_from_cache(cache_key: str) -> Dict[str, Any]:
|
117 |
+
"""Load data from disk cache"""
|
118 |
+
cache_file = get_cache_path(cache_key)
|
119 |
+
if cache_file.exists():
|
120 |
+
try:
|
121 |
+
with open(cache_file, 'rb') as f:
|
122 |
+
data = pickle.load(f)
|
123 |
+
logger.info(f"Successfully loaded data from cache: {cache_file}")
|
124 |
+
return data
|
125 |
+
except Exception as e:
|
126 |
+
logger.error(f"Failed to load from cache: {e}")
|
127 |
+
cache_file.unlink(missing_ok=True)
|
128 |
+
logger.info(f"No cache file found: {cache_file}")
|
129 |
+
return None
|
130 |
+
|
131 |
+
# Modify the get_open_data function to use caching
|
132 |
+
@lru_cache(maxsize=32)
|
133 |
+
def get_cached_data(date_str: str, param_tuple: tuple, levelist_tuple: tuple) -> Dict[str, Any]:
|
134 |
+
"""Memory cache wrapper for get_open_data"""
|
135 |
+
return get_open_data_impl(
|
136 |
+
datetime.datetime.fromisoformat(date_str),
|
137 |
+
list(param_tuple),
|
138 |
+
list(levelist_tuple) if levelist_tuple else []
|
139 |
+
)
|
140 |
+
|
141 |
+
def get_open_data(param: List[str], levelist: List[int] = None) -> Dict[str, Any]:
|
142 |
+
"""Main function to get data with caching"""
|
143 |
+
if levelist is None:
|
144 |
+
levelist = []
|
145 |
+
|
146 |
+
# Try disk cache first (more persistent than memory cache)
|
147 |
+
cache_key = get_cache_key(DEFAULT_DATE, param, levelist)
|
148 |
+
logger.info(f"Checking cache for key: {cache_key}")
|
149 |
+
|
150 |
+
cached_data = load_from_cache(cache_key)
|
151 |
+
if cached_data is not None:
|
152 |
+
logger.info(f"Cache hit for {cache_key}")
|
153 |
+
return cached_data
|
154 |
+
|
155 |
+
# If not in cache, download and process the data
|
156 |
+
logger.info(f"Cache miss for {cache_key}, downloading fresh data")
|
157 |
+
fields = get_open_data_impl(DEFAULT_DATE, param, levelist)
|
158 |
+
|
159 |
+
# Save to disk cache
|
160 |
+
save_to_cache(cache_key, fields)
|
161 |
+
|
162 |
+
return fields
|
163 |
+
|
164 |
+
def get_open_data_impl(date: datetime.datetime, param: List[str], levelist: List[int]) -> Dict[str, Any]:
|
165 |
+
"""Implementation of data download and processing"""
|
166 |
fields = {}
|
167 |
+
myiterable = [date - datetime.timedelta(hours=6), date]
|
168 |
+
logger.info(f"Downloading data for dates: {myiterable}")
|
169 |
+
|
170 |
+
for current_date in myiterable:
|
171 |
+
logger.info(f"Fetching data for {current_date}")
|
172 |
+
data = ekd.from_source("ecmwf-open-data", date=current_date, param=param, levelist=levelist)
|
|
|
173 |
for f in data:
|
174 |
assert f.to_numpy().shape == (721, 1440)
|
175 |
values = np.roll(f.to_numpy(), -f.shape[1] // 2, axis=1)
|
|
|
185 |
|
186 |
return fields
|
187 |
|
188 |
+
def plot_forecast(state, selected_variable):
|
189 |
+
logger.info(f"Plotting forecast for {selected_variable} at time {state['date']}")
|
190 |
+
|
191 |
# Setup the figure and axis
|
192 |
fig = plt.figure(figsize=(15, 8))
|
193 |
ax = plt.axes(projection=ccrs.PlateCarree(central_longitude=0))
|
194 |
|
195 |
+
# Get the coordinates
|
196 |
+
latitudes, longitudes = state["latitudes"], state["longitudes"]
|
|
|
197 |
fixed_lons = np.where(longitudes > 180, longitudes - 360, longitudes)
|
198 |
triangulation = tri.Triangulation(fixed_lons, latitudes)
|
199 |
|
200 |
+
# Get the values
|
201 |
+
values = state["fields"][selected_variable]
|
202 |
+
logger.info(f"Value range: min={np.min(values):.2f}, max={np.max(values):.2f}")
|
203 |
|
204 |
+
# Set map features
|
205 |
+
ax.set_global()
|
206 |
+
ax.set_extent([-180, 180, -85, 85], crs=ccrs.PlateCarree())
|
207 |
+
ax.coastlines(resolution='50m')
|
208 |
+
ax.add_feature(cfeature.BORDERS, linestyle=":", alpha=0.5)
|
209 |
+
ax.gridlines(draw_labels=True)
|
210 |
|
211 |
+
# Create contour plot
|
212 |
+
contour = ax.tricontourf(triangulation, values,
|
213 |
+
levels=20, transform=ccrs.PlateCarree(),
|
214 |
+
cmap='RdBu_r')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
|
216 |
+
# Add colorbar
|
217 |
+
plt.colorbar(contour, ax=ax, orientation='horizontal', pad=0.05)
|
218 |
+
|
219 |
+
# Format the date string
|
220 |
+
forecast_time = state["date"]
|
221 |
+
if isinstance(forecast_time, str):
|
222 |
+
forecast_time = datetime.datetime.fromisoformat(forecast_time)
|
223 |
+
time_str = forecast_time.strftime("%Y-%m-%d %H:%M UTC")
|
224 |
|
225 |
+
# Get variable description
|
226 |
+
var_desc = None
|
227 |
+
for group in VARIABLE_GROUPS.values():
|
228 |
+
if selected_variable in group:
|
229 |
+
var_desc = group[selected_variable]
|
230 |
+
break
|
231 |
+
var_name = var_desc if var_desc else selected_variable
|
232 |
+
|
233 |
+
ax.set_title(f"{var_name} - {time_str}")
|
234 |
+
|
235 |
+
# Save as PNG
|
236 |
+
temp_file = str(TEMP_DIR / f"forecast_{datetime.datetime.now().timestamp()}.png")
|
237 |
+
plt.savefig(temp_file, bbox_inches='tight', dpi=100)
|
238 |
plt.close()
|
239 |
|
240 |
return temp_file
|
241 |
|
242 |
+
def run_forecast(date: datetime.datetime, lead_time: int, device: str) -> Dict[str, Any]:
|
243 |
# Get all required fields
|
244 |
fields = {}
|
245 |
+
logger.info(f"Starting forecast for lead_time: {lead_time} hours")
|
246 |
|
247 |
# Get surface fields
|
248 |
+
logger.info("Getting surface fields...")
|
249 |
fields.update(get_open_data(param=PARAM_SFC))
|
250 |
|
251 |
# Get soil fields and rename them
|
252 |
+
logger.info("Getting soil fields...")
|
253 |
soil = get_open_data(param=PARAM_SOIL, levelist=SOIL_LEVELS)
|
254 |
mapping = {
|
255 |
'sot_1': 'stl1', 'sot_2': 'stl2',
|
|
|
259 |
fields[mapping[k]] = v
|
260 |
|
261 |
# Get pressure level fields
|
262 |
+
logger.info("Getting pressure level fields...")
|
263 |
fields.update(get_open_data(param=PARAM_PL, levelist=LEVELS))
|
264 |
|
265 |
# Convert geopotential height to geopotential
|
|
|
271 |
|
272 |
# Use the global model instance
|
273 |
global MODEL
|
|
|
274 |
if device != MODEL.device:
|
275 |
MODEL = SimpleRunner("aifs-single-mse-1.0.ckpt", device=device)
|
276 |
|
277 |
+
# Run the model and get the final state
|
278 |
+
final_state = None
|
279 |
for state in MODEL.run(input_state=input_state, lead_time=lead_time):
|
280 |
+
logger.info(f"\nπ date={state['date']} latitudes={state['latitudes'].shape} "
|
281 |
+
f"longitudes={state['longitudes'].shape} fields={len(state['fields'])}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
282 |
|
283 |
+
# Log a few example variables to show we have all fields
|
284 |
+
for var in ['2t', 'msl', 't_1000', 'z_850']:
|
285 |
+
if var in state['fields']:
|
286 |
+
values = state['fields'][var]
|
287 |
+
logger.info(f" {var:<6} shape={values.shape} "
|
288 |
+
f"min={np.min(values):.6f} "
|
289 |
+
f"max={np.max(values):.6f}")
|
290 |
+
|
291 |
+
final_state = state
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
292 |
|
293 |
+
logger.info(f"Final state contains {len(final_state['fields'])} variables")
|
294 |
+
return final_state
|
295 |
+
|
296 |
+
def get_available_variables(state):
|
297 |
+
"""Get available variables from the state and organize them into groups"""
|
298 |
+
available_vars = set(state['fields'].keys())
|
299 |
|
300 |
+
# Create dropdown choices only for available variables
|
301 |
+
choices = []
|
302 |
+
for group_name, variables in VARIABLE_GROUPS.items():
|
303 |
+
group_vars = [(f"{desc} ({var_id})", var_id)
|
304 |
+
for var_id, desc in variables.items()
|
305 |
+
if var_id in available_vars]
|
306 |
+
|
307 |
+
if group_vars: # Only add group if it has available variables
|
308 |
+
choices.append((f"ββ {group_name} ββ", None))
|
309 |
+
choices.extend(group_vars)
|
310 |
|
311 |
+
return choices
|
312 |
+
|
313 |
+
def update_interface():
|
314 |
+
with gr.Blocks(css="""
|
315 |
+
.centered-header {
|
316 |
+
text-align: center;
|
317 |
+
margin-bottom: 20px;
|
318 |
+
}
|
319 |
+
.subtitle {
|
320 |
+
font-size: 1.2em;
|
321 |
+
line-height: 1.5;
|
322 |
+
margin: 20px 0;
|
323 |
+
}
|
324 |
+
.footer {
|
325 |
+
text-align: center;
|
326 |
+
padding: 20px;
|
327 |
+
margin-top: 20px;
|
328 |
+
border-top: 1px solid #eee;
|
329 |
+
}
|
330 |
+
""") as demo:
|
331 |
+
state = gr.State(None)
|
332 |
+
|
333 |
+
with gr.Row():
|
334 |
+
with gr.Column(scale=1):
|
335 |
+
lead_time = gr.Slider(
|
336 |
+
minimum=6,
|
337 |
+
maximum=48,
|
338 |
+
step=6,
|
339 |
+
value=12,
|
340 |
+
label="Forecast Hours Ahead"
|
341 |
+
)
|
342 |
+
variable = gr.Dropdown(
|
343 |
+
choices=[], # Start empty
|
344 |
+
value=None,
|
345 |
+
label="Select Variable to Plot"
|
346 |
+
)
|
347 |
+
with gr.Row():
|
348 |
+
clear_btn = gr.Button("Clear")
|
349 |
+
run_btn = gr.Button("Run Forecast", variant="primary")
|
350 |
+
|
351 |
+
with gr.Row():
|
352 |
+
download_json = gr.Button("Download JSON")
|
353 |
+
download_nc = gr.Button("Download NetCDF")
|
354 |
+
|
355 |
+
with gr.Column(scale=2):
|
356 |
+
forecast_output = gr.Image()
|
357 |
+
|
358 |
+
def run_and_store(lead_time):
|
359 |
+
"""Run forecast and store state"""
|
360 |
+
state = run_forecast(DEFAULT_DATE, lead_time, "cuda")
|
361 |
+
|
362 |
+
# Get available variables
|
363 |
+
choices = get_available_variables(state)
|
364 |
+
|
365 |
+
# Select first real variable as default
|
366 |
+
default_var = next((var_id for _, var_id in choices if var_id is not None), None)
|
367 |
+
|
368 |
+
# Generate initial plot
|
369 |
+
plot = plot_forecast(state, default_var) if default_var else None
|
370 |
+
|
371 |
+
return [state, gr.Dropdown(choices=choices), default_var, plot]
|
372 |
+
|
373 |
+
def update_plot_from_state(state, variable):
|
374 |
+
"""Update plot using stored state"""
|
375 |
+
if state is None or variable is None:
|
376 |
+
return None
|
377 |
+
try:
|
378 |
+
return plot_forecast(state, variable)
|
379 |
+
except KeyError as e:
|
380 |
+
logger.error(f"Variable {variable} not found in state: {e}")
|
381 |
+
return None
|
382 |
+
|
383 |
+
def clear():
|
384 |
+
"""Clear everything"""
|
385 |
+
return [None, None, gr.Dropdown(choices=[]), None]
|
386 |
+
|
387 |
+
def save_json(state):
|
388 |
+
if state is None:
|
389 |
+
return None
|
390 |
+
return save_forecast_data(state, 'json')
|
391 |
+
|
392 |
+
def save_netcdf(state):
|
393 |
+
if state is None:
|
394 |
+
return None
|
395 |
+
return save_forecast_data(state, 'netcdf')
|
396 |
+
|
397 |
+
# Connect the components
|
398 |
+
run_btn.click(
|
399 |
+
fn=run_and_store,
|
400 |
+
inputs=[lead_time],
|
401 |
+
outputs=[state, variable, variable, forecast_output]
|
402 |
+
)
|
403 |
+
|
404 |
+
variable.change(
|
405 |
+
fn=update_plot_from_state,
|
406 |
+
inputs=[state, variable],
|
407 |
+
outputs=forecast_output
|
408 |
+
)
|
409 |
+
|
410 |
+
clear_btn.click(
|
411 |
+
fn=clear,
|
412 |
+
inputs=[],
|
413 |
+
outputs=[state, forecast_output, variable, variable]
|
414 |
+
)
|
415 |
+
|
416 |
+
download_json.click(
|
417 |
+
fn=save_json,
|
418 |
+
inputs=[state],
|
419 |
+
outputs=gr.File()
|
420 |
+
)
|
421 |
+
|
422 |
+
download_nc.click(
|
423 |
+
fn=save_netcdf,
|
424 |
+
inputs=[state],
|
425 |
+
outputs=gr.File()
|
426 |
+
)
|
427 |
+
|
428 |
+
return demo
|
429 |
|
430 |
+
# Create and launch the interface
|
431 |
+
demo = update_interface()
|
432 |
demo.launch()
|