saburq commited on
Commit
45b15ae
·
1 Parent(s): 62a6171
Files changed (1) hide show
  1. app.py +32 -12
app.py CHANGED
@@ -63,10 +63,17 @@ for var in ["t", "u", "v", "w", "q", "z"]:
63
  var_id = f"{var}_{level}"
64
  VARIABLE_GROUPS["Pressure Level Variables"][var_id] = f"{var_name} at {level}hPa"
65
 
 
 
 
66
  def get_open_data(param, levelist=[]):
67
  fields = {}
68
  # Get the data for the current date and the previous date
 
 
69
  for date in [DEFAULT_DATE - datetime.timedelta(hours=6), DEFAULT_DATE]:
 
 
70
  data = ekd.from_source("ecmwf-open-data", date=date, param=param, levelist=levelist)
71
  for f in data:
72
  assert f.to_numpy().shape == (721, 1440)
@@ -108,9 +115,15 @@ def run_forecast(date, lead_time, device):
108
  fields[f"z_{level}"] = gh * 9.80665
109
 
110
  input_state = dict(date=date, fields=fields)
111
- runner = SimpleRunner("aifs-single-mse-1.0.ckpt", device=device)
 
 
 
 
 
 
112
  results = []
113
- for state in runner.run(input_state=input_state, lead_time=lead_time):
114
  results.append(state)
115
  return results[-1]
116
 
@@ -139,20 +152,23 @@ for group_name, variables in VARIABLE_GROUPS.items():
139
  for var_id, desc in sorted(variables.items()):
140
  DROPDOWN_CHOICES.append((f"{desc} ({var_id})", var_id))
141
 
142
- def gradio_interface(date_str, lead_time, device, selected_variable):
143
- try:
144
- date = datetime.datetime.strptime(date_str, "%Y-%m-%d")
145
- except ValueError:
146
- raise gr.Error("Please enter a valid date in YYYY-MM-DD format")
147
- state = run_forecast(date, lead_time, device)
148
  return plot_forecast(state, selected_variable)
149
 
150
  demo = gr.Interface(
151
  fn=gradio_interface,
152
  inputs=[
153
- gr.Textbox(value=DEFAULT_DATE.strftime("%Y-%m-%d"), label="Forecast Date (YYYY-MM-DD)"),
154
- gr.Slider(minimum=6, maximum=48, step=6, value=12, label="Lead Time (Hours)"),
155
- gr.Radio(choices=["cuda", "cpu"], value="cuda", label="Compute Device"),
 
 
 
 
 
156
  gr.Dropdown(
157
  choices=DROPDOWN_CHOICES,
158
  value="2t", # Default to 2m temperature
@@ -162,7 +178,11 @@ demo = gr.Interface(
162
  ],
163
  outputs=gr.Plot(),
164
  title="AIFS Weather Forecast",
165
- description="Interactive visualization of ECMWF AIFS weather forecasts. Select a date, forecast lead time, and meteorological variable to plot."
 
 
 
 
166
  )
167
 
168
  demo.launch()
 
63
  var_id = f"{var}_{level}"
64
  VARIABLE_GROUPS["Pressure Level Variables"][var_id] = f"{var_name} at {level}hPa"
65
 
66
+ # Load the model once at startup
67
+ MODEL = SimpleRunner("aifs-single-mse-1.0.ckpt", device="cuda") # Default to CUDA
68
+
69
  def get_open_data(param, levelist=[]):
70
  fields = {}
71
  # Get the data for the current date and the previous date
72
+ myiterable = [DEFAULT_DATE - datetime.timedelta(hours=6), DEFAULT_DATE]
73
+ print(myiterable)
74
  for date in [DEFAULT_DATE - datetime.timedelta(hours=6), DEFAULT_DATE]:
75
+ print(f"Fetching data for {date}")
76
+ # sources can be seen https://earthkit-data.readthedocs.io/en/latest/guide/sources.html#id57
77
  data = ekd.from_source("ecmwf-open-data", date=date, param=param, levelist=levelist)
78
  for f in data:
79
  assert f.to_numpy().shape == (721, 1440)
 
115
  fields[f"z_{level}"] = gh * 9.80665
116
 
117
  input_state = dict(date=date, fields=fields)
118
+
119
+ # Use the global model instance
120
+ global MODEL
121
+ # If device preference changed, move model to new device
122
+ if device != MODEL.device:
123
+ MODEL = SimpleRunner("aifs-single-mse-1.0.ckpt", device=device)
124
+
125
  results = []
126
+ for state in MODEL.run(input_state=input_state, lead_time=lead_time):
127
  results.append(state)
128
  return results[-1]
129
 
 
152
  for var_id, desc in sorted(variables.items()):
153
  DROPDOWN_CHOICES.append((f"{desc} ({var_id})", var_id))
154
 
155
+ def gradio_interface(lead_time, selected_variable):
156
+ # Use the global latest date
157
+ global DEFAULT_DATE
158
+ state = run_forecast(DEFAULT_DATE, lead_time, "cuda") # Always use CUDA
 
 
159
  return plot_forecast(state, selected_variable)
160
 
161
  demo = gr.Interface(
162
  fn=gradio_interface,
163
  inputs=[
164
+ gr.Slider(
165
+ minimum=6,
166
+ maximum=48,
167
+ step=6,
168
+ value=12,
169
+ label="Forecast Hours Ahead",
170
+ info=f"Latest data available from: {DEFAULT_DATE.strftime('%Y-%m-%d %H:%M UTC')}"
171
+ ),
172
  gr.Dropdown(
173
  choices=DROPDOWN_CHOICES,
174
  value="2t", # Default to 2m temperature
 
178
  ],
179
  outputs=gr.Plot(),
180
  title="AIFS Weather Forecast",
181
+ description=f"""
182
+ Interactive visualization of ECMWF AIFS weather forecasts.
183
+ Starting from the latest available data ({DEFAULT_DATE.strftime('%Y-%m-%d %H:%M UTC')}),
184
+ select how many hours ahead you want to forecast and which meteorological variable to visualize.
185
+ """
186
  )
187
 
188
  demo.launch()