timeki commited on
Commit
86e24a2
·
2 Parent(s): e209431 c06b6ff

Merge branch 'dev' of https://bitbucket.org/ekimetrics/climate_qa into dev

Browse files
app.py CHANGED
@@ -9,14 +9,13 @@ from climateqa.engine.embeddings import get_embeddings_function
9
  from climateqa.engine.llm import get_llm
10
  from climateqa.engine.vectorstore import get_pinecone_vectorstore
11
  from climateqa.engine.reranker import get_reranker
12
- from climateqa.engine.graph import make_graph_agent,make_graph_agent_poc
13
  from climateqa.engine.chains.retrieve_papers import find_papers
14
  from climateqa.chat import start_chat, chat_stream, finish_chat
15
- from climateqa.engine.talk_to_data.main import ask_vanna
16
- from climateqa.engine.talk_to_data.myVanna import MyVanna
17
 
18
- from front.tabs import (create_config_modal, cqa_tab, create_about_tab)
19
- from front.tabs import (MainTabPanel, ConfigPanel)
 
20
  from front.utils import process_figures
21
  from gradio_modal import Modal
22
 
@@ -25,14 +24,14 @@ from utils import create_user_id
25
  import logging
26
 
27
  logging.basicConfig(level=logging.WARNING)
28
- os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Suppresses INFO and WARNING logs
29
  logging.getLogger().setLevel(logging.WARNING)
30
 
31
 
32
-
33
  # Load environment variables in local mode
34
  try:
35
  from dotenv import load_dotenv
 
36
  load_dotenv()
37
  except Exception as e:
38
  pass
@@ -63,39 +62,103 @@ share_client = service.get_share_client(file_share_name)
63
  user_id = create_user_id()
64
 
65
 
66
-
67
  # Create vectorstore and retriever
68
  embeddings_function = get_embeddings_function()
69
- vectorstore = get_pinecone_vectorstore(embeddings_function, index_name=os.getenv("PINECONE_API_INDEX"))
70
- vectorstore_graphs = get_pinecone_vectorstore(embeddings_function, index_name=os.getenv("PINECONE_API_INDEX_OWID"), text_key="description")
71
- vectorstore_region = get_pinecone_vectorstore(embeddings_function, index_name=os.getenv("PINECONE_API_INDEX_LOCAL_V2"))
 
 
 
 
 
 
 
 
72
 
73
- llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
74
  if os.environ["GRADIO_ENV"] == "local":
75
  reranker = get_reranker("nano")
76
- else :
77
  reranker = get_reranker("large")
78
 
79
- agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, vectorstore_region = vectorstore_region, reranker=reranker, threshold_docs=0.2)
80
- agent_poc = make_graph_agent_poc(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, vectorstore_region = vectorstore_region, reranker=reranker, threshold_docs=0, version="v4")#TODO put back default 0.2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
- #Vanna object
 
83
 
84
- vn = MyVanna(config = {"temperature": 0, "api_key": os.getenv('THEO_API_KEY'), 'model': os.getenv('VANNA_MODEL'), 'pc_api_key': os.getenv('VANNA_PINECONE_API_KEY'), 'index_name': os.getenv('VANNA_INDEX_NAME'), "top_k" : 4})
85
- db_vanna_path = os.path.join(os.getcwd(), "data/drias/drias.db")
86
- vn.connect_to_sqlite(db_vanna_path)
87
 
88
- def ask_vanna_query(query):
89
- return ask_vanna(vn, db_vanna_path, query)
90
 
91
- async def chat(query, history, audience, sources, reports, relevant_content_sources_selection, search_only):
 
 
 
 
 
 
 
 
 
92
  print("chat cqa - message received")
93
- async for event in chat_stream(agent, query, history, audience, sources, reports, relevant_content_sources_selection, search_only, share_client, user_id):
 
 
 
 
 
 
 
 
 
 
 
94
  yield event
95
-
96
- async def chat_poc(query, history, audience, sources, reports, relevant_content_sources_selection, search_only):
 
 
 
 
 
 
 
 
 
97
  print("chat poc - message received")
98
- async for event in chat_stream(agent_poc, query, history, audience, sources, reports, relevant_content_sources_selection, search_only, share_client, user_id):
 
 
 
 
 
 
 
 
 
 
 
99
  yield event
100
 
101
 
@@ -103,14 +166,17 @@ async def chat_poc(query, history, audience, sources, reports, relevant_content_
103
  # Gradio
104
  # --------------------------------------------------------------------
105
 
 
106
  # Function to update modal visibility
107
  def update_config_modal_visibility(config_open):
108
  print(config_open)
109
  new_config_visibility_status = not config_open
110
  return Modal(visible=new_config_visibility_status), new_config_visibility_status
111
-
112
 
113
- def update_sources_number_display(sources_textbox, figures_cards, current_graphs, papers_html):
 
 
 
114
  sources_number = sources_textbox.count("<h2>")
115
  figures_number = figures_cards.count("<h2>")
116
  graphs_number = current_graphs.count("<iframe")
@@ -119,42 +185,40 @@ def update_sources_number_display(sources_textbox, figures_cards, current_graphs
119
  figures_notif_label = f"Figures ({figures_number})"
120
  graphs_notif_label = f"Graphs ({graphs_number})"
121
  papers_notif_label = f"Papers ({papers_number})"
122
- recommended_content_notif_label = f"Recommended content ({figures_number + graphs_number + papers_number})"
123
-
124
- return gr.update(label=recommended_content_notif_label), gr.update(label=sources_notif_label), gr.update(label=figures_notif_label), gr.update(label=graphs_notif_label), gr.update(label=papers_notif_label)
125
-
126
- def create_drias_tab():
127
- with gr.Tab("Beta - Talk to DRIAS", elem_id="tab-vanna", id=6) as tab_vanna:
128
- vanna_direct_question = gr.Textbox(label="Direct Question", placeholder="You can write direct question here",elem_id="direct-question", interactive=True)
129
- with gr.Accordion("Details",elem_id = 'vanna-details', open=False) as vanna_details :
130
- vanna_sql_query = gr.Textbox(label="SQL Query Used", elem_id="sql-query", interactive=False)
131
- show_vanna_table = gr.Button("Show Table", elem_id="show-table")
132
- with Modal(visible=False) as vanna_table_modal:
133
- vanna_table = gr.DataFrame([], elem_id="vanna-table")
134
- close_vanna_modal = gr.Button("Close", elem_id="close-vanna-modal")
135
- close_vanna_modal.click(lambda: Modal(visible=False),None, [vanna_table_modal])
136
- show_vanna_table.click(lambda: Modal(visible=True),None ,[vanna_table_modal])
137
-
138
- vanna_display = gr.Plot()
139
- vanna_direct_question.submit(ask_vanna_query, [vanna_direct_question], [vanna_sql_query ,vanna_table, vanna_display])
140
-
141
-
142
- def config_event_handling(main_tabs_components : list[MainTabPanel], config_componenets : ConfigPanel):
143
  config_open = config_componenets.config_open
144
  config_modal = config_componenets.config_modal
145
  close_config_modal = config_componenets.close_config_modal_button
146
-
147
- for button in [close_config_modal] + [main_tab_component.config_button for main_tab_component in main_tabs_components]:
 
 
148
  button.click(
149
  fn=update_config_modal_visibility,
150
  inputs=[config_open],
151
- outputs=[config_modal, config_open]
152
- )
153
-
 
154
  def event_handling(
155
- main_tab_components : MainTabPanel,
156
- config_components : ConfigPanel,
157
- tab_name="ClimateQ&A"
158
  ):
159
  chatbot = main_tab_components.chatbot
160
  textbox = main_tab_components.textbox
@@ -178,7 +242,7 @@ def event_handling(
178
  graphs_container = main_tab_components.graph_container
179
  follow_up_examples = main_tab_components.follow_up_examples
180
  follow_up_examples_hidden = main_tab_components.follow_up_examples_hidden
181
-
182
  dropdown_sources = config_components.dropdown_sources
183
  dropdown_reports = config_components.dropdown_reports
184
  dropdown_external_sources = config_components.dropdown_external_sources
@@ -187,91 +251,302 @@ def event_handling(
187
  after = config_components.after
188
  output_query = config_components.output_query
189
  output_language = config_components.output_language
190
-
191
  new_sources_hmtl = gr.State([])
192
  ttd_data = gr.State([])
193
 
194
-
195
  if tab_name == "ClimateQ&A":
196
  print("chat cqa - message sent")
197
 
198
  # Event for textbox
199
- (textbox
200
- .submit(start_chat, [textbox, chatbot, search_only], [textbox, tabs, chatbot, sources_raw], queue=False, api_name=f"start_chat_{textbox.elem_id}")
201
- .then(chat, [textbox, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only], [chatbot, new_sources_hmtl, output_query, output_language, new_figures, current_graphs, follow_up_examples.dataset], concurrency_limit=8, api_name=f"chat_{textbox.elem_id}")
202
- .then(finish_chat, None, [textbox], api_name=f"finish_chat_{textbox.elem_id}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  )
204
  # Event for examples_hidden
205
- (examples_hidden
206
- .change(start_chat, [examples_hidden, chatbot, search_only], [examples_hidden, tabs, chatbot, sources_raw], queue=False, api_name=f"start_chat_{examples_hidden.elem_id}")
207
- .then(chat, [examples_hidden, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only], [chatbot, new_sources_hmtl, output_query, output_language, new_figures, current_graphs,follow_up_examples.dataset], concurrency_limit=8, api_name=f"chat_{examples_hidden.elem_id}")
208
- .then(finish_chat, None, [textbox], api_name=f"finish_chat_{examples_hidden.elem_id}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  )
210
- (follow_up_examples_hidden
211
- .change(start_chat, [follow_up_examples_hidden, chatbot, search_only], [follow_up_examples_hidden, tabs, chatbot, sources_raw], queue=False, api_name=f"start_chat_{examples_hidden.elem_id}")
212
- .then(chat, [follow_up_examples_hidden, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only], [chatbot, new_sources_hmtl, output_query, output_language, new_figures, current_graphs,follow_up_examples.dataset], concurrency_limit=8, api_name=f"chat_{examples_hidden.elem_id}")
213
- .then(finish_chat, None, [textbox], api_name=f"finish_chat_{follow_up_examples_hidden.elem_id}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
  )
215
-
216
  elif tab_name == "Beta - POC Adapt'Action":
217
  print("chat poc - message sent")
218
  # Event for textbox
219
- (textbox
220
- .submit(start_chat, [textbox, chatbot, search_only], [textbox, tabs, chatbot, sources_raw], queue=False, api_name=f"start_chat_{textbox.elem_id}")
221
- .then(chat_poc, [textbox, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only], [chatbot, new_sources_hmtl, output_query, output_language, new_figures, current_graphs], concurrency_limit=8, api_name=f"chat_{textbox.elem_id}")
222
- .then(finish_chat, None, [textbox], api_name=f"finish_chat_{textbox.elem_id}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  )
224
  # Event for examples_hidden
225
- (examples_hidden
226
- .change(start_chat, [examples_hidden, chatbot, search_only], [examples_hidden, tabs, chatbot, sources_raw], queue=False, api_name=f"start_chat_{examples_hidden.elem_id}")
227
- .then(chat_poc, [examples_hidden, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only], [chatbot, new_sources_hmtl, output_query, output_language, new_figures, current_graphs], concurrency_limit=8, api_name=f"chat_{examples_hidden.elem_id}")
228
- .then(finish_chat, None, [textbox], api_name=f"finish_chat_{examples_hidden.elem_id}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  )
230
- (follow_up_examples_hidden
231
- .change(start_chat, [follow_up_examples_hidden, chatbot, search_only], [follow_up_examples_hidden, tabs, chatbot, sources_raw], queue=False, api_name=f"start_chat_{examples_hidden.elem_id}")
232
- .then(chat, [follow_up_examples_hidden, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only], [chatbot, new_sources_hmtl, output_query, output_language, new_figures, current_graphs,follow_up_examples.dataset], concurrency_limit=8, api_name=f"chat_{examples_hidden.elem_id}")
233
- .then(finish_chat, None, [textbox], api_name=f"finish_chat_{follow_up_examples_hidden.elem_id}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  )
235
-
236
-
237
- new_sources_hmtl.change(lambda x : x, inputs = [new_sources_hmtl], outputs = [sources_textbox])
238
- current_graphs.change(lambda x: x, inputs=[current_graphs], outputs=[graphs_container])
239
- new_figures.change(process_figures, inputs=[sources_raw, new_figures], outputs=[sources_raw, figures_cards, gallery_component])
 
 
 
 
 
 
 
240
 
241
  # Update sources numbers
242
  for component in [sources_textbox, figures_cards, current_graphs, papers_html]:
243
- component.change(update_sources_number_display, [sources_textbox, figures_cards, current_graphs, papers_html], [tab_recommended_content, tab_sources, tab_figures, tab_graphs, tab_papers])
244
-
 
 
 
 
245
  # Search for papers
246
  for component in [textbox, examples_hidden, papers_direct_search]:
247
- component.submit(find_papers, [component, after, dropdown_external_sources], [papers_html, citations_network, papers_summary])
248
-
 
 
 
249
 
250
  # if tab_name == "Beta - POC Adapt'Action": # Not untill results are good enough
251
  # # Drias search
252
  # textbox.submit(ask_vanna, [textbox], [vanna_sql_query ,vanna_table, vanna_display])
253
 
 
254
  def main_ui():
255
  # config_open = gr.State(True)
256
- with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=theme, elem_id="main-component") as demo:
257
- config_components = create_config_modal()
258
-
 
 
 
 
 
259
  with gr.Tabs():
260
- cqa_components = cqa_tab(tab_name = "ClimateQ&A")
261
- local_cqa_components = cqa_tab(tab_name = "Beta - POC Adapt'Action")
262
  create_drias_tab()
263
-
264
  create_about_tab()
265
-
266
- event_handling(cqa_components, config_components, tab_name = 'ClimateQ&A')
267
- event_handling(local_cqa_components, config_components, tab_name = "Beta - POC Adapt'Action")
268
-
269
- config_event_handling([cqa_components,local_cqa_components] ,config_components)
270
-
 
 
271
  demo.queue()
272
-
273
  return demo
274
 
275
-
276
  demo = main_ui()
277
  demo.launch(ssr_mode=False)
 
9
  from climateqa.engine.llm import get_llm
10
  from climateqa.engine.vectorstore import get_pinecone_vectorstore
11
  from climateqa.engine.reranker import get_reranker
12
+ from climateqa.engine.graph import make_graph_agent, make_graph_agent_poc
13
  from climateqa.engine.chains.retrieve_papers import find_papers
14
  from climateqa.chat import start_chat, chat_stream, finish_chat
 
 
15
 
16
+ from front.tabs import create_config_modal, cqa_tab, create_about_tab
17
+ from front.tabs import MainTabPanel, ConfigPanel
18
+ from front.tabs.tab_drias import create_drias_tab
19
  from front.utils import process_figures
20
  from gradio_modal import Modal
21
 
 
24
  import logging
25
 
26
  logging.basicConfig(level=logging.WARNING)
27
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Suppresses INFO and WARNING logs
28
  logging.getLogger().setLevel(logging.WARNING)
29
 
30
 
 
31
  # Load environment variables in local mode
32
  try:
33
  from dotenv import load_dotenv
34
+
35
  load_dotenv()
36
  except Exception as e:
37
  pass
 
62
  user_id = create_user_id()
63
 
64
 
 
65
  # Create vectorstore and retriever
66
  embeddings_function = get_embeddings_function()
67
+ vectorstore = get_pinecone_vectorstore(
68
+ embeddings_function, index_name=os.getenv("PINECONE_API_INDEX")
69
+ )
70
+ vectorstore_graphs = get_pinecone_vectorstore(
71
+ embeddings_function,
72
+ index_name=os.getenv("PINECONE_API_INDEX_OWID"),
73
+ text_key="description",
74
+ )
75
+ vectorstore_region = get_pinecone_vectorstore(
76
+ embeddings_function, index_name=os.getenv("PINECONE_API_INDEX_LOCAL_V2")
77
+ )
78
 
79
+ llm = get_llm(provider="openai", max_tokens=1024, temperature=0.0)
80
  if os.environ["GRADIO_ENV"] == "local":
81
  reranker = get_reranker("nano")
82
+ else:
83
  reranker = get_reranker("large")
84
 
85
+ agent = make_graph_agent(
86
+ llm=llm,
87
+ vectorstore_ipcc=vectorstore,
88
+ vectorstore_graphs=vectorstore_graphs,
89
+ vectorstore_region=vectorstore_region,
90
+ reranker=reranker,
91
+ threshold_docs=0.2,
92
+ )
93
+ agent_poc = make_graph_agent_poc(
94
+ llm=llm,
95
+ vectorstore_ipcc=vectorstore,
96
+ vectorstore_graphs=vectorstore_graphs,
97
+ vectorstore_region=vectorstore_region,
98
+ reranker=reranker,
99
+ threshold_docs=0,
100
+ version="v4",
101
+ ) # TODO put back default 0.2
102
+
103
+ # Vanna object
104
+
105
+ # vn = MyVanna(config = {"temperature": 0, "api_key": os.getenv('THEO_API_KEY'), 'model': os.getenv('VANNA_MODEL'), 'pc_api_key': os.getenv('VANNA_PINECONE_API_KEY'), 'index_name': os.getenv('VANNA_INDEX_NAME'), "top_k" : 4})
106
+ # db_vanna_path = os.path.join(os.getcwd(), "data/drias/drias.db")
107
+ # vn.connect_to_sqlite(db_vanna_path)
108
 
109
+ # def ask_vanna_query(query):
110
+ # return ask_vanna(vn, db_vanna_path, query)
111
 
 
 
 
112
 
 
 
113
 
114
+
115
+ async def chat(
116
+ query,
117
+ history,
118
+ audience,
119
+ sources,
120
+ reports,
121
+ relevant_content_sources_selection,
122
+ search_only,
123
+ ):
124
  print("chat cqa - message received")
125
+ async for event in chat_stream(
126
+ agent,
127
+ query,
128
+ history,
129
+ audience,
130
+ sources,
131
+ reports,
132
+ relevant_content_sources_selection,
133
+ search_only,
134
+ share_client,
135
+ user_id,
136
+ ):
137
  yield event
138
+
139
+
140
+ async def chat_poc(
141
+ query,
142
+ history,
143
+ audience,
144
+ sources,
145
+ reports,
146
+ relevant_content_sources_selection,
147
+ search_only,
148
+ ):
149
  print("chat poc - message received")
150
+ async for event in chat_stream(
151
+ agent_poc,
152
+ query,
153
+ history,
154
+ audience,
155
+ sources,
156
+ reports,
157
+ relevant_content_sources_selection,
158
+ search_only,
159
+ share_client,
160
+ user_id,
161
+ ):
162
  yield event
163
 
164
 
 
166
  # Gradio
167
  # --------------------------------------------------------------------
168
 
169
+
170
  # Function to update modal visibility
171
  def update_config_modal_visibility(config_open):
172
  print(config_open)
173
  new_config_visibility_status = not config_open
174
  return Modal(visible=new_config_visibility_status), new_config_visibility_status
 
175
 
176
+
177
+ def update_sources_number_display(
178
+ sources_textbox, figures_cards, current_graphs, papers_html
179
+ ):
180
  sources_number = sources_textbox.count("<h2>")
181
  figures_number = figures_cards.count("<h2>")
182
  graphs_number = current_graphs.count("<iframe")
 
185
  figures_notif_label = f"Figures ({figures_number})"
186
  graphs_notif_label = f"Graphs ({graphs_number})"
187
  papers_notif_label = f"Papers ({papers_number})"
188
+ recommended_content_notif_label = (
189
+ f"Recommended content ({figures_number + graphs_number + papers_number})"
190
+ )
191
+
192
+ return (
193
+ gr.update(label=recommended_content_notif_label),
194
+ gr.update(label=sources_notif_label),
195
+ gr.update(label=figures_notif_label),
196
+ gr.update(label=graphs_notif_label),
197
+ gr.update(label=papers_notif_label),
198
+ )
199
+
200
+
201
+ def config_event_handling(
202
+ main_tabs_components: list[MainTabPanel], config_componenets: ConfigPanel
203
+ ):
 
 
 
 
 
204
  config_open = config_componenets.config_open
205
  config_modal = config_componenets.config_modal
206
  close_config_modal = config_componenets.close_config_modal_button
207
+
208
+ for button in [close_config_modal] + [
209
+ main_tab_component.config_button for main_tab_component in main_tabs_components
210
+ ]:
211
  button.click(
212
  fn=update_config_modal_visibility,
213
  inputs=[config_open],
214
+ outputs=[config_modal, config_open],
215
+ )
216
+
217
+
218
  def event_handling(
219
+ main_tab_components: MainTabPanel,
220
+ config_components: ConfigPanel,
221
+ tab_name="ClimateQ&A",
222
  ):
223
  chatbot = main_tab_components.chatbot
224
  textbox = main_tab_components.textbox
 
242
  graphs_container = main_tab_components.graph_container
243
  follow_up_examples = main_tab_components.follow_up_examples
244
  follow_up_examples_hidden = main_tab_components.follow_up_examples_hidden
245
+
246
  dropdown_sources = config_components.dropdown_sources
247
  dropdown_reports = config_components.dropdown_reports
248
  dropdown_external_sources = config_components.dropdown_external_sources
 
251
  after = config_components.after
252
  output_query = config_components.output_query
253
  output_language = config_components.output_language
254
+
255
  new_sources_hmtl = gr.State([])
256
  ttd_data = gr.State([])
257
 
 
258
  if tab_name == "ClimateQ&A":
259
  print("chat cqa - message sent")
260
 
261
  # Event for textbox
262
+ (
263
+ textbox.submit(
264
+ start_chat,
265
+ [textbox, chatbot, search_only],
266
+ [textbox, tabs, chatbot, sources_raw],
267
+ queue=False,
268
+ api_name=f"start_chat_{textbox.elem_id}",
269
+ )
270
+ .then(
271
+ chat,
272
+ [
273
+ textbox,
274
+ chatbot,
275
+ dropdown_audience,
276
+ dropdown_sources,
277
+ dropdown_reports,
278
+ dropdown_external_sources,
279
+ search_only,
280
+ ],
281
+ [
282
+ chatbot,
283
+ new_sources_hmtl,
284
+ output_query,
285
+ output_language,
286
+ new_figures,
287
+ current_graphs,
288
+ follow_up_examples.dataset,
289
+ ],
290
+ concurrency_limit=8,
291
+ api_name=f"chat_{textbox.elem_id}",
292
+ )
293
+ .then(
294
+ finish_chat, None, [textbox], api_name=f"finish_chat_{textbox.elem_id}"
295
+ )
296
  )
297
  # Event for examples_hidden
298
+ (
299
+ examples_hidden.change(
300
+ start_chat,
301
+ [examples_hidden, chatbot, search_only],
302
+ [examples_hidden, tabs, chatbot, sources_raw],
303
+ queue=False,
304
+ api_name=f"start_chat_{examples_hidden.elem_id}",
305
+ )
306
+ .then(
307
+ chat,
308
+ [
309
+ examples_hidden,
310
+ chatbot,
311
+ dropdown_audience,
312
+ dropdown_sources,
313
+ dropdown_reports,
314
+ dropdown_external_sources,
315
+ search_only,
316
+ ],
317
+ [
318
+ chatbot,
319
+ new_sources_hmtl,
320
+ output_query,
321
+ output_language,
322
+ new_figures,
323
+ current_graphs,
324
+ follow_up_examples.dataset,
325
+ ],
326
+ concurrency_limit=8,
327
+ api_name=f"chat_{examples_hidden.elem_id}",
328
+ )
329
+ .then(
330
+ finish_chat,
331
+ None,
332
+ [textbox],
333
+ api_name=f"finish_chat_{examples_hidden.elem_id}",
334
+ )
335
  )
336
+ (
337
+ follow_up_examples_hidden.change(
338
+ start_chat,
339
+ [follow_up_examples_hidden, chatbot, search_only],
340
+ [follow_up_examples_hidden, tabs, chatbot, sources_raw],
341
+ queue=False,
342
+ api_name=f"start_chat_{examples_hidden.elem_id}",
343
+ )
344
+ .then(
345
+ chat,
346
+ [
347
+ follow_up_examples_hidden,
348
+ chatbot,
349
+ dropdown_audience,
350
+ dropdown_sources,
351
+ dropdown_reports,
352
+ dropdown_external_sources,
353
+ search_only,
354
+ ],
355
+ [
356
+ chatbot,
357
+ new_sources_hmtl,
358
+ output_query,
359
+ output_language,
360
+ new_figures,
361
+ current_graphs,
362
+ follow_up_examples.dataset,
363
+ ],
364
+ concurrency_limit=8,
365
+ api_name=f"chat_{examples_hidden.elem_id}",
366
+ )
367
+ .then(
368
+ finish_chat,
369
+ None,
370
+ [textbox],
371
+ api_name=f"finish_chat_{follow_up_examples_hidden.elem_id}",
372
+ )
373
  )
374
+
375
  elif tab_name == "Beta - POC Adapt'Action":
376
  print("chat poc - message sent")
377
  # Event for textbox
378
+ (
379
+ textbox.submit(
380
+ start_chat,
381
+ [textbox, chatbot, search_only],
382
+ [textbox, tabs, chatbot, sources_raw],
383
+ queue=False,
384
+ api_name=f"start_chat_{textbox.elem_id}",
385
+ )
386
+ .then(
387
+ chat_poc,
388
+ [
389
+ textbox,
390
+ chatbot,
391
+ dropdown_audience,
392
+ dropdown_sources,
393
+ dropdown_reports,
394
+ dropdown_external_sources,
395
+ search_only,
396
+ ],
397
+ [
398
+ chatbot,
399
+ new_sources_hmtl,
400
+ output_query,
401
+ output_language,
402
+ new_figures,
403
+ current_graphs,
404
+ ],
405
+ concurrency_limit=8,
406
+ api_name=f"chat_{textbox.elem_id}",
407
+ )
408
+ .then(
409
+ finish_chat, None, [textbox], api_name=f"finish_chat_{textbox.elem_id}"
410
+ )
411
  )
412
  # Event for examples_hidden
413
+ (
414
+ examples_hidden.change(
415
+ start_chat,
416
+ [examples_hidden, chatbot, search_only],
417
+ [examples_hidden, tabs, chatbot, sources_raw],
418
+ queue=False,
419
+ api_name=f"start_chat_{examples_hidden.elem_id}",
420
+ )
421
+ .then(
422
+ chat_poc,
423
+ [
424
+ examples_hidden,
425
+ chatbot,
426
+ dropdown_audience,
427
+ dropdown_sources,
428
+ dropdown_reports,
429
+ dropdown_external_sources,
430
+ search_only,
431
+ ],
432
+ [
433
+ chatbot,
434
+ new_sources_hmtl,
435
+ output_query,
436
+ output_language,
437
+ new_figures,
438
+ current_graphs,
439
+ ],
440
+ concurrency_limit=8,
441
+ api_name=f"chat_{examples_hidden.elem_id}",
442
+ )
443
+ .then(
444
+ finish_chat,
445
+ None,
446
+ [textbox],
447
+ api_name=f"finish_chat_{examples_hidden.elem_id}",
448
+ )
449
  )
450
+ (
451
+ follow_up_examples_hidden.change(
452
+ start_chat,
453
+ [follow_up_examples_hidden, chatbot, search_only],
454
+ [follow_up_examples_hidden, tabs, chatbot, sources_raw],
455
+ queue=False,
456
+ api_name=f"start_chat_{examples_hidden.elem_id}",
457
+ )
458
+ .then(
459
+ chat,
460
+ [
461
+ follow_up_examples_hidden,
462
+ chatbot,
463
+ dropdown_audience,
464
+ dropdown_sources,
465
+ dropdown_reports,
466
+ dropdown_external_sources,
467
+ search_only,
468
+ ],
469
+ [
470
+ chatbot,
471
+ new_sources_hmtl,
472
+ output_query,
473
+ output_language,
474
+ new_figures,
475
+ current_graphs,
476
+ follow_up_examples.dataset,
477
+ ],
478
+ concurrency_limit=8,
479
+ api_name=f"chat_{examples_hidden.elem_id}",
480
+ )
481
+ .then(
482
+ finish_chat,
483
+ None,
484
+ [textbox],
485
+ api_name=f"finish_chat_{follow_up_examples_hidden.elem_id}",
486
+ )
487
  )
488
+
489
+ new_sources_hmtl.change(
490
+ lambda x: x, inputs=[new_sources_hmtl], outputs=[sources_textbox]
491
+ )
492
+ current_graphs.change(
493
+ lambda x: x, inputs=[current_graphs], outputs=[graphs_container]
494
+ )
495
+ new_figures.change(
496
+ process_figures,
497
+ inputs=[sources_raw, new_figures],
498
+ outputs=[sources_raw, figures_cards, gallery_component],
499
+ )
500
 
501
  # Update sources numbers
502
  for component in [sources_textbox, figures_cards, current_graphs, papers_html]:
503
+ component.change(
504
+ update_sources_number_display,
505
+ [sources_textbox, figures_cards, current_graphs, papers_html],
506
+ [tab_recommended_content, tab_sources, tab_figures, tab_graphs, tab_papers],
507
+ )
508
+
509
  # Search for papers
510
  for component in [textbox, examples_hidden, papers_direct_search]:
511
+ component.submit(
512
+ find_papers,
513
+ [component, after, dropdown_external_sources],
514
+ [papers_html, citations_network, papers_summary],
515
+ )
516
 
517
  # if tab_name == "Beta - POC Adapt'Action": # Not untill results are good enough
518
  # # Drias search
519
  # textbox.submit(ask_vanna, [textbox], [vanna_sql_query ,vanna_table, vanna_display])
520
 
521
+
522
  def main_ui():
523
  # config_open = gr.State(True)
524
+ with gr.Blocks(
525
+ title="Climate Q&A",
526
+ css_paths=os.getcwd() + "/style.css",
527
+ theme=theme,
528
+ elem_id="main-component",
529
+ ) as demo:
530
+ config_components = create_config_modal()
531
+
532
  with gr.Tabs():
533
+ cqa_components = cqa_tab(tab_name="ClimateQ&A")
534
+ local_cqa_components = cqa_tab(tab_name="Beta - POC Adapt'Action")
535
  create_drias_tab()
536
+
537
  create_about_tab()
538
+
539
+ event_handling(cqa_components, config_components, tab_name="ClimateQ&A")
540
+ event_handling(
541
+ local_cqa_components, config_components, tab_name="Beta - POC Adapt'Action"
542
+ )
543
+
544
+ config_event_handling([cqa_components, local_cqa_components], config_components)
545
+
546
  demo.queue()
547
+
548
  return demo
549
 
550
+
551
  demo = main_ui()
552
  demo.launch(ssr_mode=False)
climateqa/engine/talk_to_data/config.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DRIAS_TABLES = [
2
+ "total_winter_precipitation",
3
+ "total_summer_precipiation",
4
+ "total_annual_precipitation",
5
+ "total_remarkable_daily_precipitation",
6
+ "frequency_of_remarkable_daily_precipitation",
7
+ "extreme_precipitation_intensity",
8
+ "mean_winter_temperature",
9
+ "mean_summer_temperature",
10
+ "mean_annual_temperature",
11
+ "number_of_tropical_nights",
12
+ "maximum_summer_temperature",
13
+ "number_of_days_with_tx_above_30",
14
+ "number_of_days_with_tx_above_35",
15
+ "number_of_days_with_a_dry_ground",
16
+ ]
17
+
18
+ INDICATOR_COLUMNS_PER_TABLE = {
19
+ "total_winter_precipitation": "total_winter_precipitation",
20
+ "total_summer_precipiation": "total_summer_precipitation",
21
+ "total_annual_precipitation": "total_annual_precipitation",
22
+ "total_remarkable_daily_precipitation": "total_remarkable_daily_precipitation",
23
+ "frequency_of_remarkable_daily_precipitation": "frequency_of_remarkable_daily_precipitation",
24
+ "extreme_precipitation_intensity": "extreme_precipitation_intensity",
25
+ "mean_winter_temperature": "mean_winter_temperature",
26
+ "mean_summer_temperature": "mean_summer_temperature",
27
+ "mean_annual_temperature": "mean_annual_temperature",
28
+ "number_of_tropical_nights": "number_tropical_nights",
29
+ "maximum_summer_temperature": "maximum_summer_temperature",
30
+ "number_of_days_with_tx_above_30": "number_of_days_with_tx_above_30",
31
+ "number_of_days_with_tx_above_35": "number_of_days_with_tx_above_35",
32
+ "number_of_days_with_a_dry_ground": "number_of_days_with_dry_ground"
33
+ }
34
+
35
+ DRIAS_MODELS = [
36
+ 'ALL',
37
+ 'RegCM4-6_MPI-ESM-LR',
38
+ 'RACMO22E_EC-EARTH',
39
+ 'RegCM4-6_HadGEM2-ES',
40
+ 'HadREM3-GA7_EC-EARTH',
41
+ 'HadREM3-GA7_CNRM-CM5',
42
+ 'REMO2015_NorESM1-M',
43
+ 'SMHI-RCA4_EC-EARTH',
44
+ 'WRF381P_NorESM1-M',
45
+ 'ALADIN63_CNRM-CM5',
46
+ 'CCLM4-8-17_MPI-ESM-LR',
47
+ 'HIRHAM5_IPSL-CM5A-MR',
48
+ 'HadREM3-GA7_HadGEM2-ES',
49
+ 'SMHI-RCA4_IPSL-CM5A-MR',
50
+ 'HIRHAM5_NorESM1-M',
51
+ 'REMO2009_MPI-ESM-LR',
52
+ 'CCLM4-8-17_HadGEM2-ES'
53
+ ]
54
+ # Mapping between indicator columns and their units
55
+ INDICATOR_TO_UNIT = {
56
+ "total_winter_precipitation": "mm",
57
+ "total_summer_precipitation": "mm",
58
+ "total_annual_precipitation": "mm",
59
+ "total_remarkable_daily_precipitation": "mm",
60
+ "frequency_of_remarkable_daily_precipitation": "days",
61
+ "extreme_precipitation_intensity": "mm",
62
+ "mean_winter_temperature": "°C",
63
+ "mean_summer_temperature": "°C",
64
+ "mean_annual_temperature": "°C",
65
+ "number_tropical_nights": "days",
66
+ "maximum_summer_temperature": "°C",
67
+ "number_of_days_with_tx_above_30": "days",
68
+ "number_of_days_with_tx_above_35": "days",
69
+ "number_of_days_with_dry_ground": "days"
70
+ }
71
+
72
+ DRIAS_UI_TEXT = """
73
+ Hi, I'm **Talk to Drias**, designed to answer your questions using [**DRIAS - TRACC 2023**](https://www.drias-climat.fr/accompagnement/sections/401) data.
74
+ I'll answer by displaying a list of SQL queries, graphs and data most relevant to your question.
75
+
76
+ ❓ **How to use?**
77
+ You can ask me anything about these climate indicators: **temperature**, **precipitation** or **drought**.
78
+ You can specify **location** and/or **year**.
79
+ You can choose from a list of climate models. By default, we take the **average of each model**.
80
+
81
+ For example, you can ask:
82
+ - What will the temperature be like in Paris?
83
+ - What will be the total rainfall in France in 2030?
84
+ - How frequent will extreme events be in Lyon?
85
+
86
+ **Example of indicators in the data**:
87
+ - Mean temperature (annual, winter, summer)
88
+ - Total precipitation (annual, winter, summer)
89
+ - Number of days with remarkable precipitations, with dry ground, with temperature above 30°C
90
+
91
+ ⚠️ **Limitations**:
92
+ - You can't ask anything that isn't related to **DRIAS - TRACC 2023** data.
93
+ - You can only ask about **locations in France**.
94
+ - If you specify a year, there may be **no data for that year for some models**.
95
+ - You **cannot compare two models**.
96
+
97
+ 🛈 **Information**
98
+ Please note that we **log your questions for meta-analysis purposes**, so avoid sharing any sensitive or personal information.
99
+ """
climateqa/engine/talk_to_data/main.py CHANGED
@@ -1,47 +1,115 @@
1
- from climateqa.engine.talk_to_data.myVanna import MyVanna
2
- from climateqa.engine.talk_to_data.utils import loc2coords, detect_location_with_openai, detectTable, nearestNeighbourSQL, detect_relevant_tables, replace_coordonates
3
- import sqlite3
4
- import os
5
- import pandas as pd
6
  from climateqa.engine.llm import get_llm
7
  import ast
8
 
9
-
10
-
11
  llm = get_llm(provider="openai")
12
 
13
- def ask_llm_to_add_table_names(sql_query, llm):
 
 
 
 
 
 
 
 
 
 
 
 
14
  sql_with_table_names = llm.invoke(f"Make the following sql query display the source table in the rows {sql_query}. Just answer the query. The answer should not include ```sql\n").content
15
  return sql_with_table_names
16
 
17
- def ask_llm_column_names(sql_query, llm):
 
 
 
 
 
 
 
 
 
 
 
 
18
  columns = llm.invoke(f"From the given sql query, list the columns that are being selected. The answer should only be a python list. Just answer the list. The SQL query : {sql_query}").content
19
  columns_list = ast.literal_eval(columns.strip("```python\n").strip())
20
  return columns_list
21
 
22
- def ask_vanna(vn,db_vanna_path, query):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- try :
25
- location = detect_location_with_openai(query)
26
- if location:
27
 
28
- coords = loc2coords(location)
29
- user_input = query.lower().replace(location.lower(), f"lat, long : {coords}")
 
 
 
 
 
 
30
 
31
- relevant_tables = detect_relevant_tables(user_input, llm)
32
- coords_tables = [nearestNeighbourSQL(db_vanna_path, coords, relevant_tables[i]) for i in range(len(relevant_tables))]
33
- user_input_with_coords = replace_coordonates(coords, user_input, coords_tables)
34
-
35
- sql_query, result_dataframe, figure = vn.ask(user_input_with_coords, print_results=False, allow_llm_to_see_data=True, auto_train=False)
36
-
37
- return sql_query, result_dataframe, figure
38
-
39
- else :
40
- empty_df = pd.DataFrame()
41
- empty_fig = None
42
- return "", empty_df, empty_fig
43
- except Exception as e:
44
- print(f"Error: {e}")
45
- empty_df = pd.DataFrame()
46
- empty_fig = None
47
- return "", empty_df, empty_fig
 
1
+ from climateqa.engine.talk_to_data.workflow import drias_workflow
 
 
 
 
2
  from climateqa.engine.llm import get_llm
3
  import ast
4
 
 
 
5
  llm = get_llm(provider="openai")
6
 
7
+ def ask_llm_to_add_table_names(sql_query: str, llm) -> str:
8
+ """Adds table names to the SQL query result rows using LLM.
9
+
10
+ This function modifies the SQL query to include the source table name in each row
11
+ of the result set, making it easier to track which data comes from which table.
12
+
13
+ Args:
14
+ sql_query (str): The original SQL query to modify
15
+ llm: The language model instance to use for generating the modified query
16
+
17
+ Returns:
18
+ str: The modified SQL query with table names included in the result rows
19
+ """
20
  sql_with_table_names = llm.invoke(f"Make the following sql query display the source table in the rows {sql_query}. Just answer the query. The answer should not include ```sql\n").content
21
  return sql_with_table_names
22
 
23
+ def ask_llm_column_names(sql_query: str, llm) -> list[str]:
24
+ """Extracts column names from a SQL query using LLM.
25
+
26
+ This function analyzes a SQL query to identify which columns are being selected
27
+ in the result set.
28
+
29
+ Args:
30
+ sql_query (str): The SQL query to analyze
31
+ llm: The language model instance to use for column extraction
32
+
33
+ Returns:
34
+ list[str]: A list of column names being selected in the query
35
+ """
36
  columns = llm.invoke(f"From the given sql query, list the columns that are being selected. The answer should only be a python list. Just answer the list. The SQL query : {sql_query}").content
37
  columns_list = ast.literal_eval(columns.strip("```python\n").strip())
38
  return columns_list
39
 
40
+ async def ask_drias(query: str, index_state: int = 0) -> tuple:
41
+ """Main function to process a DRIAS query and return results.
42
+
43
+ This function orchestrates the DRIAS workflow, processing a user query to generate
44
+ SQL queries, dataframes, and visualizations. It handles multiple results and allows
45
+ pagination through them.
46
+
47
+ Args:
48
+ query (str): The user's question about climate data
49
+ index_state (int, optional): The index of the result to return. Defaults to 0.
50
+
51
+ Returns:
52
+ tuple: A tuple containing:
53
+ - sql_query (str): The SQL query used
54
+ - dataframe (pd.DataFrame): The resulting data
55
+ - figure (Callable): Function to generate the visualization
56
+ - sql_queries (list): All generated SQL queries
57
+ - result_dataframes (list): All resulting dataframes
58
+ - figures (list): All figure generation functions
59
+ - index_state (int): Current result index
60
+ - table_list (list): List of table names used
61
+ - error (str): Error message if any
62
+ """
63
+ final_state = await drias_workflow(query)
64
+ sql_queries = []
65
+ result_dataframes = []
66
+ figures = []
67
+ table_list = []
68
+
69
+ for plot_state in final_state['plot_states'].values():
70
+ for table_state in plot_state['table_states'].values():
71
+ if table_state['status'] == 'OK':
72
+ if 'table_name' in table_state:
73
+ table_list.append(' '.join(table_state['table_name'].capitalize().split('_')))
74
+ if 'sql_query' in table_state and table_state['sql_query'] is not None:
75
+ sql_queries.append(table_state['sql_query'])
76
+
77
+ if 'dataframe' in table_state and table_state['dataframe'] is not None:
78
+ result_dataframes.append(table_state['dataframe'])
79
+ if 'figure' in table_state and table_state['figure'] is not None:
80
+ figures.append(table_state['figure'])
81
+
82
+ if "error" in final_state and final_state["error"] != "":
83
+ return None, None, None, [], [], [], 0, final_state["error"]
84
+
85
+ sql_query = sql_queries[index_state]
86
+ dataframe = result_dataframes[index_state]
87
+ figure = figures[index_state](dataframe)
88
 
89
+ return sql_query, dataframe, figure, sql_queries, result_dataframes, figures, index_state, table_list, ""
 
 
90
 
91
+ # def ask_vanna(vn,db_vanna_path, query):
92
+
93
+ # try :
94
+ # location = detect_location_with_openai(query)
95
+ # if location:
96
+
97
+ # coords = loc2coords(location)
98
+ # user_input = query.lower().replace(location.lower(), f"lat, long : {coords}")
99
 
100
+ # relevant_tables = detect_relevant_tables(db_vanna_path, user_input, llm)
101
+ # coords_tables = [nearestNeighbourSQL(db_vanna_path, coords, relevant_tables[i]) for i in range(len(relevant_tables))]
102
+ # user_input_with_coords = replace_coordonates(coords, user_input, coords_tables)
103
+
104
+ # sql_query, result_dataframe, figure = vn.ask(user_input_with_coords, print_results=False, allow_llm_to_see_data=True, auto_train=False)
105
+
106
+ # return sql_query, result_dataframe, figure
107
+ # else :
108
+ # empty_df = pd.DataFrame()
109
+ # empty_fig = None
110
+ # return "", empty_df, empty_fig
111
+ # except Exception as e:
112
+ # print(f"Error: {e}")
113
+ # empty_df = pd.DataFrame()
114
+ # empty_fig = None
115
+ # return "", empty_df, empty_fig
 
climateqa/engine/talk_to_data/plot.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, TypedDict
2
+ from matplotlib.figure import figaspect
3
+ import pandas as pd
4
+ from plotly.graph_objects import Figure
5
+ import plotly.graph_objects as go
6
+ import plotly.express as px
7
+
8
+ from climateqa.engine.talk_to_data.sql_query import (
9
+ indicator_for_given_year_query,
10
+ indicator_per_year_at_location_query,
11
+ )
12
+ from climateqa.engine.talk_to_data.config import INDICATOR_TO_UNIT
13
+
14
+
15
+
16
+
17
+ class Plot(TypedDict):
18
+ """Represents a plot configuration in the DRIAS system.
19
+
20
+ This class defines the structure for configuring different types of plots
21
+ that can be generated from climate data.
22
+
23
+ Attributes:
24
+ name (str): The name of the plot type
25
+ description (str): A description of what the plot shows
26
+ params (list[str]): List of required parameters for the plot
27
+ plot_function (Callable[..., Callable[..., Figure]]): Function to generate the plot
28
+ sql_query (Callable[..., str]): Function to generate the SQL query for the plot
29
+ """
30
+ name: str
31
+ description: str
32
+ params: list[str]
33
+ plot_function: Callable[..., Callable[..., Figure]]
34
+ sql_query: Callable[..., str]
35
+
36
+
37
+ def plot_indicator_evolution_at_location(params: dict) -> Callable[..., Figure]:
38
+ """Generates a function to plot indicator evolution over time at a location.
39
+
40
+ This function creates a line plot showing how a climate indicator changes
41
+ over time at a specific location. It handles temperature, precipitation,
42
+ and other climate indicators.
43
+
44
+ Args:
45
+ params (dict): Dictionary containing:
46
+ - indicator_column (str): The column name for the indicator
47
+ - location (str): The location to plot
48
+ - model (str): The climate model to use
49
+
50
+ Returns:
51
+ Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
52
+
53
+ Example:
54
+ >>> plot_func = plot_indicator_evolution_at_location({
55
+ ... 'indicator_column': 'mean_temperature',
56
+ ... 'location': 'Paris',
57
+ ... 'model': 'ALL'
58
+ ... })
59
+ >>> fig = plot_func(df)
60
+ """
61
+ indicator = params["indicator_column"]
62
+ location = params["location"]
63
+ indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
64
+ unit = INDICATOR_TO_UNIT.get(indicator, "")
65
+
66
+ def plot_data(df: pd.DataFrame) -> Figure:
67
+ """Generates the actual plot from the data.
68
+
69
+ Args:
70
+ df (pd.DataFrame): DataFrame containing the data to plot
71
+
72
+ Returns:
73
+ Figure: A plotly Figure object showing the indicator evolution
74
+ """
75
+ fig = go.Figure()
76
+ if df['model'].nunique() != 1:
77
+ df_avg = df.groupby("year", as_index=False)[indicator].mean()
78
+
79
+ # Transform to list to avoid pandas encoding
80
+ indicators = df_avg[indicator].astype(float).tolist()
81
+ years = df_avg["year"].astype(int).tolist()
82
+
83
+ # Compute the 10-year rolling average
84
+ sliding_averages = (
85
+ df_avg[indicator]
86
+ .rolling(window=10, min_periods=1)
87
+ .mean()
88
+ .astype(float)
89
+ .tolist()
90
+ )
91
+ model_label = "Model Average"
92
+
93
+ else:
94
+ df_model = df
95
+
96
+ # Transform to list to avoid pandas encoding
97
+ indicators = df_model[indicator].astype(float).tolist()
98
+ years = df_model["year"].astype(int).tolist()
99
+
100
+ # Compute the 10-year rolling average
101
+ sliding_averages = (
102
+ df_model[indicator]
103
+ .rolling(window=10, min_periods=1)
104
+ .mean()
105
+ .astype(float)
106
+ .tolist()
107
+ )
108
+ model_label = f"Model : {df['model'].unique()[0]}"
109
+
110
+
111
+ # Indicator per year plot
112
+ fig.add_scatter(
113
+ x=years,
114
+ y=indicators,
115
+ name=f"Yearly {indicator_label}",
116
+ mode="lines",
117
+ marker=dict(color="#1f77b4"),
118
+ hovertemplate=f"{indicator_label}: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
119
+ )
120
+
121
+ # Sliding average dashed line
122
+ fig.add_scatter(
123
+ x=years,
124
+ y=sliding_averages,
125
+ mode="lines",
126
+ name="10 years rolling average",
127
+ line=dict(dash="dash"),
128
+ marker=dict(color="#d62728"),
129
+ hovertemplate=f"10-year average: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
130
+ )
131
+ fig.update_layout(
132
+ title=f"Plot of {indicator_label} in {location} ({model_label})",
133
+ xaxis_title="Year",
134
+ yaxis_title=f"{indicator_label} ({unit})",
135
+ template="plotly_white",
136
+ )
137
+ return fig
138
+
139
+ return plot_data
140
+
141
+
142
+ indicator_evolution_at_location: Plot = {
143
+ "name": "Indicator evolution at location",
144
+ "description": "Plot an evolution of the indicator at a certain location",
145
+ "params": ["indicator_column", "location", "model"],
146
+ "plot_function": plot_indicator_evolution_at_location,
147
+ "sql_query": indicator_per_year_at_location_query,
148
+ }
149
+
150
+
151
+ def plot_indicator_number_of_days_per_year_at_location(
152
+ params: dict,
153
+ ) -> Callable[..., Figure]:
154
+ """Generates a function to plot the number of days per year for an indicator.
155
+
156
+ This function creates a bar chart showing the frequency of certain climate
157
+ events (like days above a temperature threshold) per year at a specific location.
158
+
159
+ Args:
160
+ params (dict): Dictionary containing:
161
+ - indicator_column (str): The column name for the indicator
162
+ - location (str): The location to plot
163
+ - model (str): The climate model to use
164
+
165
+ Returns:
166
+ Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
167
+ """
168
+ indicator = params["indicator_column"]
169
+ location = params["location"]
170
+ indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
171
+ unit = INDICATOR_TO_UNIT.get(indicator, "")
172
+
173
+ def plot_data(df: pd.DataFrame) -> Figure:
174
+ """Generate the figure thanks to the dataframe
175
+
176
+ Args:
177
+ df (pd.DataFrame): pandas dataframe with the required data
178
+
179
+ Returns:
180
+ Figure: Plotly figure
181
+ """
182
+ fig = go.Figure()
183
+ if df['model'].nunique() != 1:
184
+ df_avg = df.groupby("year", as_index=False)[indicator].mean()
185
+
186
+ # Transform to list to avoid pandas encoding
187
+ indicators = df_avg[indicator].astype(float).tolist()
188
+ years = df_avg["year"].astype(int).tolist()
189
+ model_label = "Model Average"
190
+
191
+ else:
192
+ df_model = df
193
+ # Transform to list to avoid pandas encoding
194
+ indicators = df_model[indicator].astype(float).tolist()
195
+ years = df_model["year"].astype(int).tolist()
196
+ model_label = f"Model : {df['model'].unique()[0]}"
197
+
198
+
199
+ # Bar plot
200
+ fig.add_trace(
201
+ go.Bar(
202
+ x=years,
203
+ y=indicators,
204
+ width=0.5,
205
+ marker=dict(color="#1f77b4"),
206
+ hovertemplate=f"{indicator_label}: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
207
+ )
208
+ )
209
+
210
+ fig.update_layout(
211
+ title=f"{indicator_label} in {location} ({model_label})",
212
+ xaxis_title="Year",
213
+ yaxis_title=f"{indicator_label} ({unit})",
214
+ yaxis=dict(range=[0, max(indicators)]),
215
+ bargap=0.5,
216
+ template="plotly_white",
217
+ )
218
+
219
+ return fig
220
+
221
+ return plot_data
222
+
223
+
224
+ indicator_number_of_days_per_year_at_location: Plot = {
225
+ "name": "Indicator number of days per year at location",
226
+ "description": "Plot a barchart of the number of days per year of a certain indicator at a certain location. It is appropriate for frequency indicator.",
227
+ "params": ["indicator_column", "location", "model"],
228
+ "plot_function": plot_indicator_number_of_days_per_year_at_location,
229
+ "sql_query": indicator_per_year_at_location_query,
230
+ }
231
+
232
+
233
+ def plot_distribution_of_indicator_for_given_year(
234
+ params: dict,
235
+ ) -> Callable[..., Figure]:
236
+ """Generates a function to plot the distribution of an indicator for a year.
237
+
238
+ This function creates a histogram showing the distribution of a climate
239
+ indicator across different locations for a specific year.
240
+
241
+ Args:
242
+ params (dict): Dictionary containing:
243
+ - indicator_column (str): The column name for the indicator
244
+ - year (str): The year to plot
245
+ - model (str): The climate model to use
246
+
247
+ Returns:
248
+ Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
249
+ """
250
+ indicator = params["indicator_column"]
251
+ year = params["year"]
252
+ indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
253
+ unit = INDICATOR_TO_UNIT.get(indicator, "")
254
+
255
+ def plot_data(df: pd.DataFrame) -> Figure:
256
+ """Generate the figure thanks to the dataframe
257
+
258
+ Args:
259
+ df (pd.DataFrame): pandas dataframe with the required data
260
+
261
+ Returns:
262
+ Figure: Plotly figure
263
+ """
264
+ fig = go.Figure()
265
+ if df['model'].nunique() != 1:
266
+ df_avg = df.groupby(["latitude", "longitude"], as_index=False)[
267
+ indicator
268
+ ].mean()
269
+
270
+ # Transform to list to avoid pandas encoding
271
+ indicators = df_avg[indicator].astype(float).tolist()
272
+ model_label = "Model Average"
273
+
274
+ else:
275
+ df_model = df
276
+
277
+ # Transform to list to avoid pandas encoding
278
+ indicators = df_model[indicator].astype(float).tolist()
279
+ model_label = f"Model : {df['model'].unique()[0]}"
280
+
281
+
282
+ fig.add_trace(
283
+ go.Histogram(
284
+ x=indicators,
285
+ opacity=0.8,
286
+ histnorm="percent",
287
+ marker=dict(color="#1f77b4"),
288
+ hovertemplate=f"{indicator_label}: %{{x:.2f}} {unit}<br>Frequency: %{{y:.2f}}%<extra></extra>"
289
+ )
290
+ )
291
+
292
+ fig.update_layout(
293
+ title=f"Distribution of {indicator_label} in {year} ({model_label})",
294
+ xaxis_title=f"{indicator_label} ({unit})",
295
+ yaxis_title="Frequency (%)",
296
+ plot_bgcolor="rgba(0, 0, 0, 0)",
297
+ showlegend=False,
298
+ )
299
+
300
+ return fig
301
+
302
+ return plot_data
303
+
304
+
305
+ distribution_of_indicator_for_given_year: Plot = {
306
+ "name": "Distribution of an indicator for a given year",
307
+ "description": "Plot an histogram of the distribution for a given year of the values of an indicator",
308
+ "params": ["indicator_column", "model", "year"],
309
+ "plot_function": plot_distribution_of_indicator_for_given_year,
310
+ "sql_query": indicator_for_given_year_query,
311
+ }
312
+
313
+
314
+ def plot_map_of_france_of_indicator_for_given_year(
315
+ params: dict,
316
+ ) -> Callable[..., Figure]:
317
+ """Generates a function to plot a map of France for an indicator.
318
+
319
+ This function creates a choropleth map of France showing the spatial
320
+ distribution of a climate indicator for a specific year.
321
+
322
+ Args:
323
+ params (dict): Dictionary containing:
324
+ - indicator_column (str): The column name for the indicator
325
+ - year (str): The year to plot
326
+ - model (str): The climate model to use
327
+
328
+ Returns:
329
+ Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
330
+ """
331
+ indicator = params["indicator_column"]
332
+ year = params["year"]
333
+ indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
334
+ unit = INDICATOR_TO_UNIT.get(indicator, "")
335
+
336
+ def plot_data(df: pd.DataFrame) -> Figure:
337
+ fig = go.Figure()
338
+ if df['model'].nunique() != 1:
339
+ df_avg = df.groupby(["latitude", "longitude"], as_index=False)[
340
+ indicator
341
+ ].mean()
342
+
343
+ indicators = df_avg[indicator].astype(float).tolist()
344
+ latitudes = df_avg["latitude"].astype(float).tolist()
345
+ longitudes = df_avg["longitude"].astype(float).tolist()
346
+ model_label = "Model Average"
347
+
348
+ else:
349
+ df_model = df
350
+
351
+ # Transform to list to avoid pandas encoding
352
+ indicators = df_model[indicator].astype(float).tolist()
353
+ latitudes = df_model["latitude"].astype(float).tolist()
354
+ longitudes = df_model["longitude"].astype(float).tolist()
355
+ model_label = f"Model : {df['model'].unique()[0]}"
356
+
357
+
358
+ fig.add_trace(
359
+ go.Scattermapbox(
360
+ lat=latitudes,
361
+ lon=longitudes,
362
+ mode="markers",
363
+ marker=dict(
364
+ size=10,
365
+ color=indicators, # Color mapped to values
366
+ colorscale="Turbo", # Color scale (can be 'Plasma', 'Jet', etc.)
367
+ cmin=min(indicators), # Minimum color range
368
+ cmax=max(indicators), # Maximum color range
369
+ showscale=True, # Show colorbar
370
+ ),
371
+ text=[f"{indicator_label}: {value:.2f} {unit}" for value in indicators], # Add hover text showing the indicator value
372
+ hoverinfo="text" # Only show the custom text on hover
373
+ )
374
+ )
375
+
376
+ fig.update_layout(
377
+ mapbox_style="open-street-map", # Use OpenStreetMap
378
+ mapbox_zoom=3,
379
+ mapbox_center={"lat": 46.6, "lon": 2.0},
380
+ coloraxis_colorbar=dict(title=f"{indicator_label} ({unit})"), # Add legend
381
+ title=f"{indicator_label} in {year} in France ({model_label}) " # Title
382
+ )
383
+ return fig
384
+
385
+ return plot_data
386
+
387
+
388
+ map_of_france_of_indicator_for_given_year: Plot = {
389
+ "name": "Map of France of an indicator for a given year",
390
+ "description": "Heatmap on the map of France of the values of an in indicator for a given year",
391
+ "params": ["indicator_column", "year", "model"],
392
+ "plot_function": plot_map_of_france_of_indicator_for_given_year,
393
+ "sql_query": indicator_for_given_year_query,
394
+ }
395
+
396
+
397
+ PLOTS = [
398
+ indicator_evolution_at_location,
399
+ indicator_number_of_days_per_year_at_location,
400
+ distribution_of_indicator_for_given_year,
401
+ map_of_france_of_indicator_for_given_year,
402
+ ]
climateqa/engine/talk_to_data/sql_query.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ from concurrent.futures import ThreadPoolExecutor
3
+ from typing import TypedDict
4
+ import duckdb
5
+ import pandas as pd
6
+
7
+ async def execute_sql_query(sql_query: str) -> pd.DataFrame:
8
+ """Executes a SQL query on the DRIAS database and returns the results.
9
+
10
+ This function connects to the DuckDB database containing DRIAS climate data
11
+ and executes the provided SQL query. It handles the database connection and
12
+ returns the results as a pandas DataFrame.
13
+
14
+ Args:
15
+ sql_query (str): The SQL query to execute
16
+
17
+ Returns:
18
+ pd.DataFrame: A DataFrame containing the query results
19
+
20
+ Raises:
21
+ duckdb.Error: If there is an error executing the SQL query
22
+ """
23
+ def _execute_query():
24
+ # Execute the query
25
+ results = duckdb.sql(sql_query)
26
+ # return fetched data
27
+ return results.fetchdf()
28
+
29
+ # Run the query in a thread pool to avoid blocking
30
+ loop = asyncio.get_event_loop()
31
+ with ThreadPoolExecutor() as executor:
32
+ return await loop.run_in_executor(executor, _execute_query)
33
+
34
+
35
+ class IndicatorPerYearAtLocationQueryParams(TypedDict, total=False):
36
+ """Parameters for querying an indicator's values over time at a location.
37
+
38
+ This class defines the parameters needed to query climate indicator data
39
+ for a specific location over multiple years.
40
+
41
+ Attributes:
42
+ indicator_column (str): The column name for the climate indicator
43
+ latitude (str): The latitude coordinate of the location
44
+ longitude (str): The longitude coordinate of the location
45
+ model (str): The climate model to use (optional)
46
+ """
47
+ indicator_column: str
48
+ latitude: str
49
+ longitude: str
50
+ model: str
51
+
52
+
53
+ def indicator_per_year_at_location_query(
54
+ table: str, params: IndicatorPerYearAtLocationQueryParams
55
+ ) -> str:
56
+ """SQL Query to get the evolution of an indicator per year at a certain location
57
+
58
+ Args:
59
+ table (str): sql table of the indicator
60
+ params (IndicatorPerYearAtLocationQueryParams) : dictionary with the required params for the query
61
+
62
+ Returns:
63
+ str: the sql query
64
+ """
65
+ indicator_column = params.get("indicator_column")
66
+ latitude = params.get("latitude")
67
+ longitude = params.get("longitude")
68
+
69
+ if indicator_column is None or latitude is None or longitude is None: # If one parameter is missing, returns an empty query
70
+ return ""
71
+
72
+ table = f"'hf://datasets/timeki/drias_db/{table.lower()}.parquet'"
73
+
74
+ sql_query = f"SELECT year, {indicator_column}, model\nFROM {table}\nWHERE latitude = {latitude} \nAnd longitude = {longitude} \nOrder by Year"
75
+
76
+ return sql_query
77
+
78
+ class IndicatorForGivenYearQueryParams(TypedDict, total=False):
79
+ """Parameters for querying an indicator's values across locations for a year.
80
+
81
+ This class defines the parameters needed to query climate indicator data
82
+ across different locations for a specific year.
83
+
84
+ Attributes:
85
+ indicator_column (str): The column name for the climate indicator
86
+ year (str): The year to query
87
+ model (str): The climate model to use (optional)
88
+ """
89
+ indicator_column: str
90
+ year: str
91
+ model: str
92
+
93
+ def indicator_for_given_year_query(
94
+ table:str, params: IndicatorForGivenYearQueryParams
95
+ ) -> str:
96
+ """SQL Query to get the values of an indicator with their latitudes, longitudes and models for a given year
97
+
98
+ Args:
99
+ table (str): sql table of the indicator
100
+ params (IndicatorForGivenYearQueryParams): dictionarry with the required params for the query
101
+
102
+ Returns:
103
+ str: the sql query
104
+ """
105
+ indicator_column = params.get("indicator_column")
106
+ year = params.get('year')
107
+ if year is None or indicator_column is None: # If one parameter is missing, returns an empty query
108
+ return ""
109
+
110
+ table = f"'hf://datasets/timeki/drias_db/{table.lower()}.parquet'"
111
+
112
+ sql_query = f"Select {indicator_column}, latitude, longitude, model\nFrom {table}\nWhere year = {year}"
113
+ return sql_query
climateqa/engine/talk_to_data/utils.py CHANGED
@@ -1,12 +1,15 @@
1
  import re
2
- import openai
3
- import pandas as pd
4
  from geopy.geocoders import Nominatim
5
- import sqlite3
6
  import ast
7
  from climateqa.engine.llm import get_llm
 
 
 
8
 
9
- def detect_location_with_openai(sentence):
 
10
  """
11
  Detects locations in a sentence using OpenAI's API via LangChain.
12
  """
@@ -19,74 +22,260 @@ def detect_location_with_openai(sentence):
19
  Sentence: "{sentence}"
20
  """
21
 
22
- response = llm.invoke(prompt)
23
  location_list = ast.literal_eval(response.content.strip("```python\n").strip())
24
  if location_list:
25
  return location_list[0]
26
  else:
27
  return ""
28
 
29
- def detectTable(sql_query):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  pattern = r'(?i)\bFROM\s+((?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+)(?:\.(?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+))*)'
31
  matches = re.findall(pattern, sql_query)
32
  return matches
33
 
34
 
35
-
36
- def loc2coords(location : str):
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  geolocator = Nominatim(user_agent="city_to_latlong")
38
- location = geolocator.geocode(location)
39
- return (location.latitude, location.longitude)
40
 
41
 
42
- def coords2loc(coords : tuple):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  geolocator = Nominatim(user_agent="coords_to_city")
44
  try:
45
  location = geolocator.reverse(coords)
46
  return location.address
47
  except Exception as e:
48
  print(f"Error: {e}")
49
- return "Unknown Location"
50
 
51
 
52
- def nearestNeighbourSQL(db: str, location: tuple, table : str):
53
- conn = sqlite3.connect(db)
54
  long = round(location[1], 3)
55
  lat = round(location[0], 3)
56
- cursor = conn.cursor()
57
- cursor.execute(f"SELECT lat, lon FROM {table} WHERE lat BETWEEN {lat - 0.3} AND {lat + 0.3} AND lon BETWEEN {long - 0.3} AND {long + 0.3}")
58
- results = cursor.fetchall()
59
- return results[0]
60
-
61
- def detect_relevant_tables(user_question, llm):
62
- table_names_list = [
63
- "Frequency_of_rainy_days_index",
64
- "Winter_precipitation_total",
65
- "Summer_precipitation_total",
66
- "Annual_precipitation_total",
67
- # "Remarkable_daily_precipitation_total_(Q99)",
68
- "Frequency_of_remarkable_daily_precipitation",
69
- "Extreme_precipitation_intensity",
70
- "Mean_winter_temperature",
71
- "Mean_summer_temperature",
72
- "Number_of_tropical_nights",
73
- "Maximum_summer_temperature",
74
- "Number_of_days_with_Tx_above_30C",
75
- "Number_of_days_with_Tx_above_35C",
76
- "Drought_index"
77
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  prompt = (
79
- f"You are helping to build a sql query to retrieve relevant data for a user question."
80
- f"The different tables are {table_names_list}."
81
- f"The user question is {user_question}. Write the relevant tables to use. Answer only a python list of table name."
 
 
 
 
 
 
 
 
82
  )
83
- table_names = ast.literal_eval(llm.invoke(prompt).content.strip("```python\n").strip())
84
  return table_names
85
 
 
86
  def replace_coordonates(coords, query, coords_tables):
87
  n = query.count(str(coords[0]))
88
 
89
  for i in range(n):
90
- query = query.replace(str(coords[0]), str(coords_tables[i][0]),1)
91
- query = query.replace(str(coords[1]), str(coords_tables[i][1]),1)
92
- return query
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import re
2
+ from typing import Annotated, TypedDict
3
+ import duckdb
4
  from geopy.geocoders import Nominatim
 
5
  import ast
6
  from climateqa.engine.llm import get_llm
7
+ from climateqa.engine.talk_to_data.config import DRIAS_TABLES
8
+ from climateqa.engine.talk_to_data.plot import PLOTS, Plot
9
+ from langchain_core.prompts import ChatPromptTemplate
10
 
11
+
12
+ async def detect_location_with_openai(sentence):
13
  """
14
  Detects locations in a sentence using OpenAI's API via LangChain.
15
  """
 
22
  Sentence: "{sentence}"
23
  """
24
 
25
+ response = await llm.ainvoke(prompt)
26
  location_list = ast.literal_eval(response.content.strip("```python\n").strip())
27
  if location_list:
28
  return location_list[0]
29
  else:
30
  return ""
31
 
32
+ class ArrayOutput(TypedDict):
33
+ """Represents the output of a function that returns an array.
34
+
35
+ This class is used to type-hint functions that return arrays,
36
+ ensuring consistent return types across the codebase.
37
+
38
+ Attributes:
39
+ array (str): A syntactically valid Python array string
40
+ """
41
+ array: Annotated[str, "Syntactically valid python array."]
42
+
43
+ async def detect_year_with_openai(sentence: str) -> str:
44
+ """
45
+ Detects years in a sentence using OpenAI's API via LangChain.
46
+ """
47
+ llm = get_llm()
48
+
49
+ prompt = """
50
+ Extract all years mentioned in the following sentence.
51
+ Return the result as a Python list. If no year are mentioned, return an empty list.
52
+
53
+ Sentence: "{sentence}"
54
+ """
55
+
56
+ prompt = ChatPromptTemplate.from_template(prompt)
57
+ structured_llm = llm.with_structured_output(ArrayOutput)
58
+ chain = prompt | structured_llm
59
+ response: ArrayOutput = await chain.ainvoke({"sentence": sentence})
60
+ years_list = eval(response['array'])
61
+ if len(years_list) > 0:
62
+ return years_list[0]
63
+ else:
64
+ return ""
65
+
66
+
67
+ def detectTable(sql_query: str) -> list[str]:
68
+ """Extracts table names from a SQL query.
69
+
70
+ This function uses regular expressions to find all table names
71
+ referenced in a SQL query's FROM clause.
72
+
73
+ Args:
74
+ sql_query (str): The SQL query to analyze
75
+
76
+ Returns:
77
+ list[str]: A list of table names found in the query
78
+
79
+ Example:
80
+ >>> detectTable("SELECT * FROM temperature_data WHERE year > 2000")
81
+ ['temperature_data']
82
+ """
83
  pattern = r'(?i)\bFROM\s+((?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+)(?:\.(?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+))*)'
84
  matches = re.findall(pattern, sql_query)
85
  return matches
86
 
87
 
88
+ def loc2coords(location: str) -> tuple[float, float]:
89
+ """Converts a location name to geographic coordinates.
90
+
91
+ This function uses the Nominatim geocoding service to convert
92
+ a location name (e.g., city name) to its latitude and longitude.
93
+
94
+ Args:
95
+ location (str): The name of the location to geocode
96
+
97
+ Returns:
98
+ tuple[float, float]: A tuple containing (latitude, longitude)
99
+
100
+ Raises:
101
+ AttributeError: If the location cannot be found
102
+ """
103
  geolocator = Nominatim(user_agent="city_to_latlong")
104
+ coords = geolocator.geocode(location)
105
+ return (coords.latitude, coords.longitude)
106
 
107
 
108
+ def coords2loc(coords: tuple[float, float]) -> str:
109
+ """Converts geographic coordinates to a location name.
110
+
111
+ This function uses the Nominatim reverse geocoding service to convert
112
+ latitude and longitude coordinates to a human-readable location name.
113
+
114
+ Args:
115
+ coords (tuple[float, float]): A tuple containing (latitude, longitude)
116
+
117
+ Returns:
118
+ str: The address of the location, or "Unknown Location" if not found
119
+
120
+ Example:
121
+ >>> coords2loc((48.8566, 2.3522))
122
+ 'Paris, France'
123
+ """
124
  geolocator = Nominatim(user_agent="coords_to_city")
125
  try:
126
  location = geolocator.reverse(coords)
127
  return location.address
128
  except Exception as e:
129
  print(f"Error: {e}")
130
+ return "Unknown Location"
131
 
132
 
133
+ def nearestNeighbourSQL(location: tuple, table: str) -> tuple[str, str]:
 
134
  long = round(location[1], 3)
135
  lat = round(location[0], 3)
136
+
137
+ table = f"'hf://datasets/timeki/drias_db/{table.lower()}.parquet'"
138
+
139
+ results = duckdb.sql(
140
+ f"SELECT latitude, longitude FROM {table} WHERE latitude BETWEEN {lat - 0.3} AND {lat + 0.3} AND longitude BETWEEN {long - 0.3} AND {long + 0.3}"
141
+ ).fetchdf()
142
+
143
+ if len(results) == 0:
144
+ return "", ""
145
+ # cursor.execute(f"SELECT latitude, longitude FROM {table} WHERE latitude BETWEEN {lat - 0.3} AND {lat + 0.3} AND longitude BETWEEN {long - 0.3} AND {long + 0.3}")
146
+ return results['latitude'].iloc[0], results['longitude'].iloc[0]
147
+
148
+
149
+ async def detect_relevant_tables(user_question: str, plot: Plot, llm) -> list[str]:
150
+ """Identifies relevant tables for a plot based on user input.
151
+
152
+ This function uses an LLM to analyze the user's question and the plot
153
+ description to determine which tables in the DRIAS database would be
154
+ most relevant for generating the requested visualization.
155
+
156
+ Args:
157
+ user_question (str): The user's question about climate data
158
+ plot (Plot): The plot configuration object
159
+ llm: The language model instance to use for analysis
160
+
161
+ Returns:
162
+ list[str]: A list of table names that are relevant for the plot
163
+
164
+ Example:
165
+ >>> detect_relevant_tables(
166
+ ... "What will the temperature be like in Paris?",
167
+ ... indicator_evolution_at_location,
168
+ ... llm
169
+ ... )
170
+ ['mean_annual_temperature', 'mean_summer_temperature']
171
+ """
172
+ # Get all table names
173
+ table_names_list = DRIAS_TABLES
174
+
175
  prompt = (
176
+ f"You are helping to build a plot following this description : {plot['description']}."
177
+ f"You are given a list of tables and a user question."
178
+ f"Based on the description of the plot, which table are appropriate for that kind of plot."
179
+ f"Write the 3 most relevant tables to use. Answer only a python list of table name."
180
+ f"### List of tables : {table_names_list}"
181
+ f"### User question : {user_question}"
182
+ f"### List of table name : "
183
+ )
184
+
185
+ table_names = ast.literal_eval(
186
+ (await llm.ainvoke(prompt)).content.strip("```python\n").strip()
187
  )
 
188
  return table_names
189
 
190
+
191
  def replace_coordonates(coords, query, coords_tables):
192
  n = query.count(str(coords[0]))
193
 
194
  for i in range(n):
195
+ query = query.replace(str(coords[0]), str(coords_tables[i][0]), 1)
196
+ query = query.replace(str(coords[1]), str(coords_tables[i][1]), 1)
197
+ return query
198
+
199
+
200
+ async def detect_relevant_plots(user_question: str, llm):
201
+ plots_description = ""
202
+ for plot in PLOTS:
203
+ plots_description += "Name: " + plot["name"]
204
+ plots_description += " - Description: " + plot["description"] + "\n"
205
+
206
+ prompt = (
207
+ f"You are helping to answer a quesiton with insightful visualizations."
208
+ f"You are given an user question and a list of plots with their name and description."
209
+ f"Based on the descriptions of the plots, which plot is appropriate to answer to this question."
210
+ f"Write the most relevant tables to use. Answer only a python list of plot name."
211
+ f"### Descriptions of the plots : {plots_description}"
212
+ f"### User question : {user_question}"
213
+ f"### Name of the plot : "
214
+ )
215
+ # prompt = (
216
+ # f"You are helping to answer a question with insightful visualizations. "
217
+ # f"Given a list of plots with their name and description: "
218
+ # f"{plots_description} "
219
+ # f"The user question is: {user_question}. "
220
+ # f"Choose the most relevant plots to answer the question. "
221
+ # f"The answer must be a Python list with the names of the relevant plots, and nothing else. "
222
+ # f"Ensure the response is in the exact format: ['PlotName1', 'PlotName2']."
223
+ # )
224
+
225
+ plot_names = ast.literal_eval(
226
+ (await llm.ainvoke(prompt)).content.strip("```python\n").strip()
227
+ )
228
+ return plot_names
229
+
230
+
231
+ # Next Version
232
+ # class QueryOutput(TypedDict):
233
+ # """Generated SQL query."""
234
+
235
+ # query: Annotated[str, ..., "Syntactically valid SQL query."]
236
+
237
+
238
+ # class PlotlyCodeOutput(TypedDict):
239
+ # """Generated Plotly code"""
240
+
241
+ # code: Annotated[str, ..., "Synatically valid Plotly python code."]
242
+ # def write_sql_query(user_input: str, db: SQLDatabase, relevant_tables: list[str], llm):
243
+ # """Generate SQL query to fetch information."""
244
+ # prompt_params = {
245
+ # "dialect": db.dialect,
246
+ # "table_info": db.get_table_info(),
247
+ # "input": user_input,
248
+ # "relevant_tables": relevant_tables,
249
+ # "model": "ALADIN63_CNRM-CM5",
250
+ # }
251
+
252
+ # prompt = ChatPromptTemplate.from_template(query_prompt_template)
253
+ # structured_llm = llm.with_structured_output(QueryOutput)
254
+ # chain = prompt | structured_llm
255
+ # result = chain.invoke(prompt_params)
256
+
257
+ # return result["query"]
258
+
259
+
260
+ # def fetch_data_from_sql_query(db: str, sql_query: str):
261
+ # conn = sqlite3.connect(db)
262
+ # cursor = conn.cursor()
263
+ # cursor.execute(sql_query)
264
+ # column_names = [desc[0] for desc in cursor.description]
265
+ # values = cursor.fetchall()
266
+ # return {"column_names": column_names, "data": values}
267
+
268
+
269
+ # def generate_chart_code(user_input: str, sql_query: list[str], llm):
270
+ # """ "Generate plotly python code for the chart based on the sql query and the user question"""
271
+
272
+ # class PlotlyCodeOutput(TypedDict):
273
+ # """Generated Plotly code"""
274
+
275
+ # code: Annotated[str, ..., "Synatically valid Plotly python code."]
276
+
277
+ # prompt = ChatPromptTemplate.from_template(plot_prompt_template)
278
+ # structured_llm = llm.with_structured_output(PlotlyCodeOutput)
279
+ # chain = prompt | structured_llm
280
+ # result = chain.invoke({"input": user_input, "sql_query": sql_query})
281
+ # return result["code"]
climateqa/engine/talk_to_data/workflow.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from typing import Any, Callable, NotRequired, TypedDict
4
+ import pandas as pd
5
+
6
+ from plotly.graph_objects import Figure
7
+ from climateqa.engine.llm import get_llm
8
+ from climateqa.engine.talk_to_data.config import INDICATOR_COLUMNS_PER_TABLE
9
+ from climateqa.engine.talk_to_data.plot import PLOTS, Plot
10
+ from climateqa.engine.talk_to_data.sql_query import execute_sql_query
11
+ from climateqa.engine.talk_to_data.utils import (
12
+ detect_relevant_plots,
13
+ detect_year_with_openai,
14
+ loc2coords,
15
+ detect_location_with_openai,
16
+ nearestNeighbourSQL,
17
+ detect_relevant_tables,
18
+ )
19
+
20
+ ROOT_PATH = os.path.dirname(os.path.dirname(os.getcwd()))
21
+
22
+ class TableState(TypedDict):
23
+ """Represents the state of a table in the DRIAS workflow.
24
+
25
+ This class defines the structure for tracking the state of a table during the
26
+ data processing workflow, including its name, parameters, SQL query, and results.
27
+
28
+ Attributes:
29
+ table_name (str): The name of the table in the database
30
+ params (dict[str, Any]): Parameters used for querying the table
31
+ sql_query (str, optional): The SQL query used to fetch data
32
+ dataframe (pd.DataFrame | None, optional): The resulting data
33
+ figure (Callable[..., Figure], optional): Function to generate visualization
34
+ status (str): The current status of the table processing ('OK' or 'ERROR')
35
+ """
36
+ table_name: str
37
+ params: dict[str, Any]
38
+ sql_query: NotRequired[str]
39
+ dataframe: NotRequired[pd.DataFrame | None]
40
+ figure: NotRequired[Callable[..., Figure]]
41
+ status: str
42
+
43
+ class PlotState(TypedDict):
44
+ """Represents the state of a plot in the DRIAS workflow.
45
+
46
+ This class defines the structure for tracking the state of a plot during the
47
+ data processing workflow, including its name and associated tables.
48
+
49
+ Attributes:
50
+ plot_name (str): The name of the plot
51
+ tables (list[str]): List of tables used in the plot
52
+ table_states (dict[str, TableState]): States of the tables used in the plot
53
+ """
54
+ plot_name: str
55
+ tables: list[str]
56
+ table_states: dict[str, TableState]
57
+
58
+ class State(TypedDict):
59
+ user_input: str
60
+ plots: list[str]
61
+ plot_states: dict[str, PlotState]
62
+ error: NotRequired[str]
63
+
64
+ async def drias_workflow(user_input: str) -> State:
65
+ """Performs the complete workflow of Talk To Drias : from user input to sql queries, dataframes and figures generated
66
+
67
+ Args:
68
+ user_input (str): initial user input
69
+
70
+ Returns:
71
+ State: Final state with all the results
72
+ """
73
+ state: State = {
74
+ 'user_input': user_input,
75
+ 'plots': [],
76
+ 'plot_states': {}
77
+ }
78
+
79
+ llm = get_llm(provider="openai")
80
+
81
+ plots = await find_relevant_plots(state, llm)
82
+ state['plots'] = plots
83
+
84
+ if not state['plots']:
85
+ state['error'] = 'There is no plot to answer to the question'
86
+ return state
87
+
88
+ have_relevant_table = False
89
+ have_sql_query = False
90
+ have_dataframe = False
91
+ for plot_name in state['plots']:
92
+
93
+ plot = next((p for p in PLOTS if p['name'] == plot_name), None) # Find the associated plot object
94
+ if plot is None:
95
+ continue
96
+
97
+ plot_state: PlotState = {
98
+ 'plot_name': plot_name,
99
+ 'tables': [],
100
+ 'table_states': {}
101
+ }
102
+
103
+ plot_state['plot_name'] = plot_name
104
+
105
+ relevant_tables = await find_relevant_tables_per_plot(state, plot, llm)
106
+ if len(relevant_tables) > 0 :
107
+ have_relevant_table = True
108
+
109
+ plot_state['tables'] = relevant_tables
110
+
111
+ params = {}
112
+ for param_name in plot['params']:
113
+ param = await find_param(state, param_name, relevant_tables[0])
114
+ if param:
115
+ params.update(param)
116
+
117
+ for n, table in enumerate(plot_state['tables']):
118
+ if n > 2:
119
+ break
120
+
121
+ table_state: TableState = {
122
+ 'table_name': table,
123
+ 'params': params,
124
+ 'status': 'OK'
125
+ }
126
+
127
+ table_state["params"]['indicator_column'] = find_indicator_column(table)
128
+
129
+ sql_query = plot['sql_query'](table, table_state['params'])
130
+
131
+ if sql_query == "":
132
+ table_state['status'] = 'ERROR'
133
+ continue
134
+ else :
135
+ have_sql_query = True
136
+
137
+ table_state['sql_query'] = sql_query
138
+ df = await execute_sql_query(sql_query)
139
+
140
+ if len(df) > 0:
141
+ have_dataframe = True
142
+
143
+ figure = plot['plot_function'](table_state['params'])
144
+ table_state['dataframe'] = df
145
+ table_state['figure'] = figure
146
+ plot_state['table_states'][table] = table_state
147
+
148
+ state['plot_states'][plot_name] = plot_state
149
+
150
+ if not have_relevant_table:
151
+ state['error'] = "There is no relevant table in the our database to answer your question"
152
+ elif not have_sql_query:
153
+ state['error'] = "There is no relevant sql query on our database that can help to answer your question"
154
+ elif not have_dataframe:
155
+ state['error'] = "There is no data in our table that can answer to your question"
156
+
157
+ return state
158
+
159
+ async def find_relevant_plots(state: State, llm) -> list[str]:
160
+ print("---- Find relevant plots ----")
161
+ relevant_plots = await detect_relevant_plots(state['user_input'], llm)
162
+ return relevant_plots
163
+
164
+ async def find_relevant_tables_per_plot(state: State, plot: Plot, llm) -> list[str]:
165
+ print(f"---- Find relevant tables for {plot['name']} ----")
166
+ relevant_tables = await detect_relevant_tables(state['user_input'], plot, llm)
167
+ return relevant_tables
168
+
169
+ async def find_param(state: State, param_name:str, table: str) -> dict[str, Any] | None:
170
+ """Perform the good method to retrieve the desired parameter
171
+
172
+ Args:
173
+ state (State): state of the workflow
174
+ param_name (str): name of the desired parameter
175
+ table (str): name of the table
176
+
177
+ Returns:
178
+ dict[str, Any] | None:
179
+ """
180
+ if param_name == 'location':
181
+ location = await find_location(state['user_input'], table)
182
+ return location
183
+ if param_name == 'year':
184
+ year = await find_year(state['user_input'])
185
+ return {'year': year}
186
+ return None
187
+
188
+ class Location(TypedDict):
189
+ location: str
190
+ latitude: NotRequired[str]
191
+ longitude: NotRequired[str]
192
+
193
+ async def find_location(user_input: str, table: str) -> Location:
194
+ print(f"---- Find location in table {table} ----")
195
+ location = await detect_location_with_openai(user_input)
196
+ output: Location = {'location' : location}
197
+ if location:
198
+ coords = loc2coords(location)
199
+ neighbour = nearestNeighbourSQL(coords, table)
200
+ output.update({
201
+ "latitude": neighbour[0],
202
+ "longitude": neighbour[1],
203
+ })
204
+ return output
205
+
206
+ async def find_year(user_input: str) -> str:
207
+ """Extracts year information from user input using LLM.
208
+
209
+ This function uses an LLM to identify and extract year information from the
210
+ user's query, which is used to filter data in subsequent queries.
211
+
212
+ Args:
213
+ user_input (str): The user's query text
214
+
215
+ Returns:
216
+ str: The extracted year, or empty string if no year found
217
+ """
218
+ print(f"---- Find year ---")
219
+ year = await detect_year_with_openai(user_input)
220
+ return year
221
+
222
+ def find_indicator_column(table: str) -> str:
223
+ """Retrieves the name of the indicator column within a table.
224
+
225
+ This function maps table names to their corresponding indicator columns
226
+ using the predefined mapping in INDICATOR_COLUMNS_PER_TABLE.
227
+
228
+ Args:
229
+ table (str): Name of the table in the database
230
+
231
+ Returns:
232
+ str: Name of the indicator column for the specified table
233
+
234
+ Raises:
235
+ KeyError: If the table name is not found in the mapping
236
+ """
237
+ print(f"---- Find indicator column in table {table} ----")
238
+ return INDICATOR_COLUMNS_PER_TABLE[table]
239
+
240
+
241
+ # def make_write_query_node():
242
+
243
+ # def write_query(state):
244
+ # print("---- Write query ----")
245
+ # for table in state["tables"]:
246
+ # sql_query = QUERIES[state[table]['query_type']](
247
+ # table=table,
248
+ # indicator_column=state[table]["columns"],
249
+ # longitude=state[table]["longitude"],
250
+ # latitude=state[table]["latitude"],
251
+ # )
252
+ # state[table].update({"sql_query": sql_query})
253
+
254
+ # return state
255
+
256
+ # return write_query
257
+
258
+ # def make_fetch_data_node(db_path):
259
+
260
+ # def fetch_data(state):
261
+ # print("---- Fetch data ----")
262
+ # for table in state["tables"]:
263
+ # results = execute_sql_query(db_path, state[table]['sql_query'])
264
+ # state[table].update(results)
265
+
266
+ # return state
267
+
268
+ # return fetch_data
269
+
270
+
271
+
272
+ ## V2
273
+
274
+
275
+ # def make_fetch_data_node(db_path: str, llm):
276
+ # def fetch_data(state):
277
+ # print("---- Fetch data ----")
278
+ # db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
279
+ # output = {}
280
+ # sql_query = write_sql_query(state["query"], db, state["tables"], llm)
281
+ # # TO DO : Add query checker
282
+ # print(f"SQL query : {sql_query}")
283
+ # output["sql_query"] = sql_query
284
+ # output.update(fetch_data_from_sql_query(db_path, sql_query))
285
+ # return output
286
+
287
+ # return fetch_data
front/tabs/tab_drias.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from typing import TypedDict, List, Optional
3
+
4
+ from climateqa.engine.talk_to_data.main import ask_drias
5
+ from climateqa.engine.talk_to_data.config import DRIAS_MODELS, DRIAS_UI_TEXT
6
+
7
+
8
+ class DriasUIElements(TypedDict):
9
+ tab: gr.Tab
10
+ details_accordion: gr.Accordion
11
+ examples_hidden: gr.Textbox
12
+ examples: gr.Examples
13
+ drias_direct_question: gr.Textbox
14
+ result_text: gr.Textbox
15
+ table_names_display: gr.DataFrame
16
+ query_accordion: gr.Accordion
17
+ drias_sql_query: gr.Textbox
18
+ chart_accordion: gr.Accordion
19
+ model_selection: gr.Dropdown
20
+ drias_display: gr.Plot
21
+ table_accordion: gr.Accordion
22
+ drias_table: gr.DataFrame
23
+ pagination_display: gr.Markdown
24
+ prev_button: gr.Button
25
+ next_button: gr.Button
26
+
27
+
28
+ async def ask_drias_query(query: str, index_state: int):
29
+ return await ask_drias(query, index_state)
30
+
31
+
32
+ def show_results(sql_queries_state, dataframes_state, plots_state):
33
+ if not sql_queries_state or not dataframes_state or not plots_state:
34
+ # If all results are empty, show "No result"
35
+ return (
36
+ gr.update(visible=True),
37
+ gr.update(visible=False),
38
+ gr.update(visible=False),
39
+ gr.update(visible=False),
40
+ gr.update(visible=False),
41
+ gr.update(visible=False),
42
+ gr.update(visible=False),
43
+ gr.update(visible=False),
44
+ )
45
+ else:
46
+ # Show the appropriate components with their data
47
+ return (
48
+ gr.update(visible=False),
49
+ gr.update(visible=True),
50
+ gr.update(visible=True),
51
+ gr.update(visible=True),
52
+ gr.update(visible=True),
53
+ gr.update(visible=True),
54
+ gr.update(visible=True),
55
+ gr.update(visible=True),
56
+ )
57
+
58
+
59
+ def filter_by_model(dataframes, figures, index_state, model_selection):
60
+ df = dataframes[index_state]
61
+ if df.empty:
62
+ return df, None
63
+ if "model" not in df.columns:
64
+ return df, figures[index_state](df)
65
+ if model_selection != "ALL":
66
+ df = df[df["model"] == model_selection]
67
+ if df.empty:
68
+ return df, None
69
+ figure = figures[index_state](df)
70
+ return df, figure
71
+
72
+
73
+ def update_pagination(index, sql_queries):
74
+ pagination = f"{index + 1}/{len(sql_queries)}" if sql_queries else ""
75
+ return pagination
76
+
77
+
78
+ def show_previous(index, sql_queries, dataframes, plots):
79
+ if index > 0:
80
+ index -= 1
81
+ return (
82
+ sql_queries[index],
83
+ dataframes[index],
84
+ plots[index](dataframes[index]),
85
+ index,
86
+ )
87
+
88
+
89
+ def show_next(index, sql_queries, dataframes, plots):
90
+ if index < len(sql_queries) - 1:
91
+ index += 1
92
+ return (
93
+ sql_queries[index],
94
+ dataframes[index],
95
+ plots[index](dataframes[index]),
96
+ index,
97
+ )
98
+
99
+
100
+ def display_table_names(table_names):
101
+ return [table_names]
102
+
103
+
104
+ def on_table_click(evt: gr.SelectData, table_names, sql_queries, dataframes, plots):
105
+ index = evt.index[1]
106
+ figure = plots[index](dataframes[index])
107
+ return (
108
+ sql_queries[index],
109
+ dataframes[index],
110
+ figure,
111
+ index,
112
+ )
113
+
114
+
115
+ def create_drias_ui() -> DriasUIElements:
116
+ """Create and return all UI elements for the DRIAS tab."""
117
+ with gr.Tab("Beta - Talk to DRIAS", elem_id="tab-vanna", id=6) as tab:
118
+ with gr.Accordion(label="Details") as details_accordion:
119
+ gr.Markdown(DRIAS_UI_TEXT)
120
+
121
+ # Add examples for common questions
122
+ examples_hidden = gr.Textbox(visible=False, elem_id="drias-examples-hidden")
123
+ examples = gr.Examples(
124
+ examples=[
125
+ ["What will the temperature be like in Paris?"],
126
+ ["What will be the total rainfall in France in 2030?"],
127
+ ["How frequent will extreme events be in Lyon?"],
128
+ ["Comment va évoluer la température en France entre 2030 et 2050 ?"]
129
+ ],
130
+ label="Example Questions",
131
+ inputs=[examples_hidden],
132
+ outputs=[examples_hidden],
133
+ )
134
+
135
+ with gr.Row():
136
+ drias_direct_question = gr.Textbox(
137
+ label="Direct Question",
138
+ placeholder="You can write direct question here",
139
+ elem_id="direct-question",
140
+ interactive=True,
141
+ )
142
+
143
+ result_text = gr.Textbox(
144
+ label="", elem_id="no-result-label", interactive=False, visible=True
145
+ )
146
+
147
+ table_names_display = gr.DataFrame(
148
+ [], label="List of relevant indicators", headers=None, interactive=False, elem_id="table-names", visible=False
149
+ )
150
+
151
+ with gr.Accordion(label="SQL Query Used", visible=False) as query_accordion:
152
+ drias_sql_query = gr.Textbox(
153
+ label="", elem_id="sql-query", interactive=False
154
+ )
155
+
156
+ with gr.Accordion(label="Chart", visible=False) as chart_accordion:
157
+ model_selection = gr.Dropdown(
158
+ label="Model", choices=DRIAS_MODELS, value="ALL", interactive=True
159
+ )
160
+ drias_display = gr.Plot(elem_id="vanna-plot")
161
+
162
+ with gr.Accordion(
163
+ label="Data used", open=False, visible=False
164
+ ) as table_accordion:
165
+ drias_table = gr.DataFrame([], elem_id="vanna-table")
166
+
167
+ pagination_display = gr.Markdown(
168
+ value="", visible=False, elem_id="pagination-display"
169
+ )
170
+
171
+ with gr.Row():
172
+ prev_button = gr.Button("Previous", visible=False)
173
+ next_button = gr.Button("Next", visible=False)
174
+
175
+ return DriasUIElements(
176
+ tab=tab,
177
+ details_accordion=details_accordion,
178
+ examples_hidden=examples_hidden,
179
+ examples=examples,
180
+ drias_direct_question=drias_direct_question,
181
+ result_text=result_text,
182
+ table_names_display=table_names_display,
183
+ query_accordion=query_accordion,
184
+ drias_sql_query=drias_sql_query,
185
+ chart_accordion=chart_accordion,
186
+ model_selection=model_selection,
187
+ drias_display=drias_display,
188
+ table_accordion=table_accordion,
189
+ drias_table=drias_table,
190
+ pagination_display=pagination_display,
191
+ prev_button=prev_button,
192
+ next_button=next_button
193
+ )
194
+
195
+ def setup_drias_events(ui_elements: DriasUIElements) -> None:
196
+ """Set up all event handlers for the DRIAS tab."""
197
+ # Create state variables
198
+ sql_queries_state = gr.State([])
199
+ dataframes_state = gr.State([])
200
+ plots_state = gr.State([])
201
+ index_state = gr.State(0)
202
+ table_names_list = gr.State([])
203
+
204
+ # Handle example selection
205
+ ui_elements["examples_hidden"].change(
206
+ lambda x: (gr.Accordion(open=False), gr.Textbox(value=x)),
207
+ inputs=[ui_elements["examples_hidden"]],
208
+ outputs=[ui_elements["details_accordion"], ui_elements["drias_direct_question"]]
209
+ ).then(
210
+ ask_drias_query,
211
+ inputs=[ui_elements["examples_hidden"], index_state],
212
+ outputs=[
213
+ ui_elements["drias_sql_query"],
214
+ ui_elements["drias_table"],
215
+ ui_elements["drias_display"],
216
+ sql_queries_state,
217
+ dataframes_state,
218
+ plots_state,
219
+ index_state,
220
+ table_names_list,
221
+ ui_elements["result_text"],
222
+ ],
223
+ ).then(
224
+ show_results,
225
+ inputs=[sql_queries_state, dataframes_state, plots_state],
226
+ outputs=[
227
+ ui_elements["result_text"],
228
+ ui_elements["query_accordion"],
229
+ ui_elements["table_accordion"],
230
+ ui_elements["chart_accordion"],
231
+ ui_elements["prev_button"],
232
+ ui_elements["next_button"],
233
+ ui_elements["pagination_display"],
234
+ ui_elements["table_names_display"],
235
+ ],
236
+ ).then(
237
+ update_pagination,
238
+ inputs=[index_state, sql_queries_state],
239
+ outputs=[ui_elements["pagination_display"]],
240
+ ).then(
241
+ display_table_names,
242
+ inputs=[table_names_list],
243
+ outputs=[ui_elements["table_names_display"]],
244
+ )
245
+
246
+ # Handle direct question submission
247
+ ui_elements["drias_direct_question"].submit(
248
+ lambda: gr.Accordion(open=False),
249
+ inputs=None,
250
+ outputs=[ui_elements["details_accordion"]]
251
+ ).then(
252
+ ask_drias_query,
253
+ inputs=[ui_elements["drias_direct_question"], index_state],
254
+ outputs=[
255
+ ui_elements["drias_sql_query"],
256
+ ui_elements["drias_table"],
257
+ ui_elements["drias_display"],
258
+ sql_queries_state,
259
+ dataframes_state,
260
+ plots_state,
261
+ index_state,
262
+ table_names_list,
263
+ ui_elements["result_text"],
264
+ ],
265
+ ).then(
266
+ show_results,
267
+ inputs=[sql_queries_state, dataframes_state, plots_state],
268
+ outputs=[
269
+ ui_elements["result_text"],
270
+ ui_elements["query_accordion"],
271
+ ui_elements["table_accordion"],
272
+ ui_elements["chart_accordion"],
273
+ ui_elements["prev_button"],
274
+ ui_elements["next_button"],
275
+ ui_elements["pagination_display"],
276
+ ui_elements["table_names_display"],
277
+ ],
278
+ ).then(
279
+ update_pagination,
280
+ inputs=[index_state, sql_queries_state],
281
+ outputs=[ui_elements["pagination_display"]],
282
+ ).then(
283
+ display_table_names,
284
+ inputs=[table_names_list],
285
+ outputs=[ui_elements["table_names_display"]],
286
+ )
287
+
288
+ # Handle model selection change
289
+ ui_elements["model_selection"].change(
290
+ filter_by_model,
291
+ inputs=[dataframes_state, plots_state, index_state, ui_elements["model_selection"]],
292
+ outputs=[ui_elements["drias_table"], ui_elements["drias_display"]],
293
+ )
294
+
295
+ # Handle pagination buttons
296
+ ui_elements["prev_button"].click(
297
+ show_previous,
298
+ inputs=[index_state, sql_queries_state, dataframes_state, plots_state],
299
+ outputs=[ui_elements["drias_sql_query"], ui_elements["drias_table"], ui_elements["drias_display"], index_state],
300
+ ).then(
301
+ update_pagination,
302
+ inputs=[index_state, sql_queries_state],
303
+ outputs=[ui_elements["pagination_display"]],
304
+ )
305
+
306
+ ui_elements["next_button"].click(
307
+ show_next,
308
+ inputs=[index_state, sql_queries_state, dataframes_state, plots_state],
309
+ outputs=[ui_elements["drias_sql_query"], ui_elements["drias_table"], ui_elements["drias_display"], index_state],
310
+ ).then(
311
+ update_pagination,
312
+ inputs=[index_state, sql_queries_state],
313
+ outputs=[ui_elements["pagination_display"]],
314
+ )
315
+
316
+ # Handle table selection
317
+ ui_elements["table_names_display"].select(
318
+ fn=on_table_click,
319
+ inputs=[table_names_list, sql_queries_state, dataframes_state, plots_state],
320
+ outputs=[ui_elements["drias_sql_query"], ui_elements["drias_table"], ui_elements["drias_display"], index_state],
321
+ ).then(
322
+ update_pagination,
323
+ inputs=[index_state, sql_queries_state],
324
+ outputs=[ui_elements["pagination_display"]],
325
+ )
326
+
327
+ def create_drias_tab():
328
+ """Main function to create the DRIAS tab with UI and event handling."""
329
+ ui_elements = create_drias_ui()
330
+ setup_drias_events(ui_elements)
331
+
332
+
style.css CHANGED
@@ -520,7 +520,6 @@ a {
520
  height: calc(100vh - 190px) !important;
521
  overflow-y: scroll !important;
522
  }
523
- div#tab-vanna,
524
  div#sources-figures,
525
  div#graphs-container,
526
  div#tab-citations {
@@ -653,14 +652,61 @@ a {
653
  }
654
 
655
  #vanna-display {
656
- max-height: 300px;
657
  /* overflow-y: scroll; */
658
  }
659
  #sql-query{
660
- max-height: 100px;
661
  overflow-y:scroll;
662
  }
663
- #vanna-details{
664
- max-height: 500px;
665
- overflow-y:scroll;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
666
  }
 
520
  height: calc(100vh - 190px) !important;
521
  overflow-y: scroll !important;
522
  }
 
523
  div#sources-figures,
524
  div#graphs-container,
525
  div#tab-citations {
 
652
  }
653
 
654
  #vanna-display {
655
+ max-height: 200px;
656
  /* overflow-y: scroll; */
657
  }
658
  #sql-query{
659
+ max-height: 300px;
660
  overflow-y:scroll;
661
  }
662
+
663
+ #sql-query textarea{
664
+ min-height: 100px !important;
665
+ }
666
+
667
+ #sql-query span{
668
+ display: none;
669
+ }
670
+ div#tab-vanna{
671
+ max-height: 100¨vh;
672
+ overflow-y: hidden;
673
+ }
674
+ #vanna-plot{
675
+ max-height:500px
676
+ }
677
+
678
+ #pagination-display{
679
+ text-align: center;
680
+ font-weight: bold;
681
+ font-size: 16px;
682
+ }
683
+
684
+ #table-names table{
685
+ overflow: hidden;
686
+ }
687
+ #table-names thead{
688
+ display: none;
689
+ }
690
+
691
+ /* DRIAS Data Table Styles */
692
+ #vanna-table {
693
+ height: 400px !important;
694
+ overflow-y: auto !important;
695
+ }
696
+
697
+ #vanna-table > div[class*="table"] {
698
+ height: 400px !important;
699
+ overflow-y: None !important;
700
+ }
701
+
702
+ #vanna-table .table-wrap {
703
+ height: 400px !important;
704
+ overflow-y: None !important;
705
+ }
706
+
707
+ #vanna-table thead {
708
+ position: sticky;
709
+ top: 0;
710
+ background: white;
711
+ z-index: 1;
712
  }