nolanzandi commited on
Commit
c101c53
·
verified ·
1 Parent(s): a928bec

feat/postgresql-integration (#25)

Browse files

- sql integration (5acf39cc95da88ec7b5ae96fd388b6b40f5a1606)

app.py CHANGED
@@ -1,6 +1,6 @@
1
  from utils import TEMP_DIR, message_dict
2
  import gradio as gr
3
- import data_file
4
 
5
  import os
6
  from getpass import getpass
@@ -18,7 +18,7 @@ def delete_db(req: gr.Request):
18
  if "OPENAI_API_KEY" not in os.environ:
19
  os.environ["OPENAI_API_KEY"] = getpass("Enter OpenAI API key:")
20
 
21
- css= ".file_marker .large{min-height:50px !important;} .example_btn{max-width:300px;} .padding{padding:0;} .description_component{overflow:visible !important;}"
22
  head = """<meta charset="UTF-8">
23
  <meta name="viewport" content="width=device-width, initial-scale=1.0">
24
  <title>Virtual Data Analyst</title>
@@ -72,9 +72,8 @@ with gr.Blocks(theme=theme, css=css, head=head, delete_cache=(3600,3600)) as dem
72
  </main>""")
73
  with gr.Tab("Data File"):
74
  data_file.demo.render()
75
- with gr.Tab("SQL Database - Coming Soon", interactive=False):
76
- gr.Text("COMING SOON")
77
- # sql_db.demo.render()
78
 
79
  footer = gr.HTML("""<!-- Footer -->
80
  <footer class="max-w-4xl mx-auto mt-12 text-center text-gray-500 text-sm">
 
1
  from utils import TEMP_DIR, message_dict
2
  import gradio as gr
3
+ import templates.data_file as data_file, templates.sql_db as sql_db
4
 
5
  import os
6
  from getpass import getpass
 
18
  if "OPENAI_API_KEY" not in os.environ:
19
  os.environ["OPENAI_API_KEY"] = getpass("Enter OpenAI API key:")
20
 
21
+ css= ".file_marker .large{min-height:50px !important;} .padding{padding:0;} .description_component{overflow:visible !important;}"
22
  head = """<meta charset="UTF-8">
23
  <meta name="viewport" content="width=device-width, initial-scale=1.0">
24
  <title>Virtual Data Analyst</title>
 
72
  </main>""")
73
  with gr.Tab("Data File"):
74
  data_file.demo.render()
75
+ with gr.Tab("SQL Database"):
76
+ sql_db.demo.render()
 
77
 
78
  footer = gr.HTML("""<!-- Footer -->
79
  <footer class="max-w-4xl mx-auto mt-12 text-center text-gray-500 text-sm">
data_sources/__init__.py CHANGED
@@ -1,3 +1,4 @@
1
  from .upload_file import process_data_upload
 
2
 
3
- __all__ = ["process_data_upload"]
 
1
  from .upload_file import process_data_upload
2
+ from .connect_sql_db import connect_sql_db
3
 
4
+ __all__ = ["process_data_upload","connect_sql_db"]
data_sources/connect_sql_db.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import psycopg2
2
+ import os
3
+ from utils import TEMP_DIR
4
+
5
+ def connect_sql_db(url, sql_user, sql_port, sql_pass, sql_db_name, session_hash):
6
+ try:
7
+ conn = psycopg2.connect(
8
+ database=sql_db_name,
9
+ user=sql_user,
10
+ password=sql_pass,
11
+ host=url, # e.g., "localhost" or an IP address
12
+ port=sql_port # default is 5432
13
+ )
14
+ print("Connected to PostgreSQL")
15
+
16
+ # Create a cursor object to execute SQL queries
17
+ cur = conn.cursor()
18
+ # Example: Execute a query
19
+ cur.execute("""SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'""")
20
+ table_tuples = cur.fetchall()
21
+ table_names = []
22
+ for table in table_tuples:
23
+ table_names.append(table[0])
24
+
25
+ print(table_names)
26
+
27
+ # Close the cursor and connection
28
+ cur.close()
29
+ conn.close()
30
+ print("Connection closed.")
31
+
32
+ session_path = 'sql'
33
+
34
+ dir_path = TEMP_DIR / str(session_hash) / str(session_path)
35
+ os.makedirs(dir_path, exist_ok=True)
36
+
37
+ return ["success","<p style='color:green;text-align:center;font-size:18px;'>SQL database connected successful</p>", table_names]
38
+ except Exception as e:
39
+ print("UPLOAD ERROR")
40
+ print(e)
41
+ return ["error",f"<p style='color:red;text-align:center;font-size:18px;font-weight:bold;'>ERROR: {e}</p>"]
42
+
data_sources/upload_file.py CHANGED
@@ -74,7 +74,9 @@ def process_data_upload(data_file, session_hash):
74
  if df[column].dtype == 'object' and isinstance(df[column].iloc[0], list):
75
  df[column] = df[column].explode()
76
 
77
- dir_path = TEMP_DIR / str(session_hash)
 
 
78
  os.makedirs(dir_path, exist_ok=True)
79
 
80
  connection = sqlite3.connect(f'{dir_path}/data_source.db')
 
74
  if df[column].dtype == 'object' and isinstance(df[column].iloc[0], list):
75
  df[column] = df[column].explode()
76
 
77
+ session_path = 'file_upload'
78
+
79
+ dir_path = TEMP_DIR / str(session_hash) / str(session_path)
80
  os.makedirs(dir_path, exist_ok=True)
81
 
82
  connection = sqlite3.connect(f'{dir_path}/data_source.db')
functions/__init__.py CHANGED
@@ -1,9 +1,9 @@
1
- from .sqlite_functions import SQLiteQuery, sqlite_query_func
2
  from .chart_functions import table_generation_func, scatter_chart_generation_func, \
3
  line_chart_generation_func, bar_chart_generation_func, pie_chart_generation_func, histogram_generation_func, scatter_chart_fig
4
- from .chat_functions import example_question_generator, chatbot_with_fc
5
  from .stat_functions import regression_func
6
 
7
- __all__ = ["SQLiteQuery","sqlite_query_func","table_generation_func","scatter_chart_generation_func",
8
  "line_chart_generation_func","bar_chart_generation_func","regression_func", "pie_chart_generation_func", "histogram_generation_func",
9
- "scatter_chart_fig","example_question_generator","chatbot_with_fc"]
 
1
+ from .query_functions import SQLiteQuery, sqlite_query_func, PostgreSQLQuery, sql_query_func
2
  from .chart_functions import table_generation_func, scatter_chart_generation_func, \
3
  line_chart_generation_func, bar_chart_generation_func, pie_chart_generation_func, histogram_generation_func, scatter_chart_fig
4
+ from .chat_functions import sql_example_question_generator, example_question_generator, chatbot_with_fc, sql_chatbot_with_fc
5
  from .stat_functions import regression_func
6
 
7
+ __all__ = ["SQLiteQuery","sqlite_query_func","sql_query_func","table_generation_func","scatter_chart_generation_func",
8
  "line_chart_generation_func","bar_chart_generation_func","regression_func", "pie_chart_generation_func", "histogram_generation_func",
9
+ "scatter_chart_fig","sql_example_question_generator","example_question_generator","chatbot_with_fc","sql_chatbot_with_fc"]
functions/chart_functions.py CHANGED
@@ -92,11 +92,11 @@ def scatter_chart_fig(df, x_column: List[str], y_column: str, category: str="",
92
 
93
  return fig
94
 
95
- def scatter_chart_generation_func(x_column: List[str], y_column: str, session_hash, data: List[dict]=[{}], layout: List[dict]=[{}],
96
  category: str="", trendline: str="", trendline_options: List[dict]=[{}], marginal_x: str="", marginal_y: str="",
97
- size: str=""):
98
  try:
99
- dir_path = TEMP_DIR / str(session_hash)
100
  chart_path = f'{dir_path}/chart.html'
101
  csv_query_path = f'{dir_path}/query.csv'
102
 
@@ -129,7 +129,7 @@ def scatter_chart_generation_func(x_column: List[str], y_column: str, session_ha
129
 
130
  pio.write_html(fig, chart_path, full_html=False)
131
 
132
- chart_url = f'{root_url}/gradio_api/file/temp/{session_hash}/chart.html'
133
 
134
  iframe = '<div style=overflow:auto;><iframe\n scrolling="yes"\n width="1000px"\n height="500px"\n src="' + chart_url + '"\n frameborder="0"\n allowfullscreen\n></iframe>\n</div>'
135
 
@@ -144,10 +144,10 @@ def scatter_chart_generation_func(x_column: List[str], y_column: str, session_ha
144
  """
145
  return {"reply": reply}
146
 
147
- def line_chart_generation_func(x_column: str, y_column: str, session_hash, data: List[dict]=[{}], layout: List[dict]=[{}],
148
- category: str=""):
149
  try:
150
- dir_path = TEMP_DIR / str(session_hash)
151
  chart_path = f'{dir_path}/chart.html'
152
  csv_query_path = f'{dir_path}/query.csv'
153
 
@@ -180,7 +180,7 @@ def line_chart_generation_func(x_column: str, y_column: str, session_hash, data:
180
 
181
  pio.write_html(fig, chart_path, full_html=False)
182
 
183
- chart_url = f'{root_url}/gradio_api/file/temp/{session_hash}/chart.html'
184
 
185
  iframe = '<div style=overflow:auto;><iframe\n scrolling="yes"\n width="1000px"\n height="500px"\n src="' + chart_url + '"\n frameborder="0"\n allowfullscreen\n></iframe>\n</div>'
186
 
@@ -195,10 +195,10 @@ def line_chart_generation_func(x_column: str, y_column: str, session_hash, data:
195
  """
196
  return {"reply": reply}
197
 
198
- def bar_chart_generation_func(x_column: str, y_column: str, session_hash, data: List[dict]=[{}], layout: List[dict]=[{}],
199
- category: str="", facet_row: str="", facet_col: str=""):
200
  try:
201
- dir_path = TEMP_DIR / str(session_hash)
202
  chart_path = f'{dir_path}/chart.html'
203
  csv_query_path = f'{dir_path}/query.csv'
204
 
@@ -235,7 +235,7 @@ def bar_chart_generation_func(x_column: str, y_column: str, session_hash, data:
235
 
236
  pio.write_html(fig, chart_path, full_html=False)
237
 
238
- chart_url = f'{root_url}/gradio_api/file/temp/{session_hash}/chart.html'
239
 
240
  iframe = '<div style=overflow:auto;><iframe\n scrolling="yes"\n width="1000px"\n height="500px"\n src="' + chart_url + '"\n frameborder="0"\n allowfullscreen\n></iframe>\n</div>'
241
 
@@ -250,9 +250,9 @@ def bar_chart_generation_func(x_column: str, y_column: str, session_hash, data:
250
  """
251
  return {"reply": reply}
252
 
253
- def pie_chart_generation_func(values: str, names: str, session_hash, data: List[dict]=[{}], layout: List[dict]=[{}]):
254
  try:
255
- dir_path = TEMP_DIR / str(session_hash)
256
  chart_path = f'{dir_path}/chart.html'
257
  csv_query_path = f'{dir_path}/query.csv'
258
 
@@ -282,7 +282,7 @@ def pie_chart_generation_func(values: str, names: str, session_hash, data: List[
282
 
283
  pio.write_html(fig, chart_path, full_html=False)
284
 
285
- chart_url = f'{root_url}/gradio_api/file/temp/{session_hash}/chart.html'
286
 
287
  iframe = '<div style=overflow:auto;><iframe\n scrolling="yes"\n width="1000px"\n height="500px"\n src="' + chart_url + '"\n frameborder="0"\n allowfullscreen\n></iframe>\n</div>'
288
 
@@ -297,16 +297,15 @@ def pie_chart_generation_func(values: str, names: str, session_hash, data: List[
297
  """
298
  return {"reply": reply}
299
 
300
- def histogram_generation_func(x_column: str, session_hash, y_column: str="", data: List[dict]=[{}], layout: List[dict]=[{}], histnorm: str="", category: str="",
301
- histfunc: str=""):
302
  try:
303
- dir_path = TEMP_DIR / str(session_hash)
304
  chart_path = f'{dir_path}/chart.html'
305
  csv_query_path = f'{dir_path}/query.csv'
306
 
307
  df = pd.read_csv(csv_query_path)
308
 
309
- print(df)
310
  print(x_column)
311
 
312
  function_args = {"data_frame":df, "x":x_column}
@@ -342,7 +341,7 @@ def histogram_generation_func(x_column: str, session_hash, y_column: str="", dat
342
 
343
  pio.write_html(fig, chart_path, full_html=False)
344
 
345
- chart_url = f'{root_url}/gradio_api/file/temp/{session_hash}/chart.html'
346
 
347
  iframe = '<div style=overflow:auto;><iframe\n scrolling="yes"\n width="1000px"\n height="500px"\n src="' + chart_url + '"\n frameborder="0"\n allowfullscreen\n></iframe>\n</div>'
348
 
@@ -357,15 +356,14 @@ def histogram_generation_func(x_column: str, session_hash, y_column: str="", dat
357
  """
358
  return {"reply": reply}
359
 
360
- def table_generation_func(session_hash):
361
  print("TABLE GENERATION")
362
  try:
363
- dir_path = TEMP_DIR / str(session_hash)
364
  csv_query_path = f'{dir_path}/query.csv'
365
  table_path = f'{dir_path}/table.html'
366
 
367
  df = pd.read_csv(csv_query_path)
368
- print(df)
369
 
370
  html_table = df.to_html()
371
  print(html_table)
@@ -373,7 +371,7 @@ def table_generation_func(session_hash):
373
  with open(table_path, "w") as file:
374
  file.write(html_table)
375
 
376
- table_url = f'{root_url}/gradio_api/file/temp/{session_hash}/table.html'
377
 
378
  iframe = '<div style=overflow:auto;><iframe\n scrolling="yes"\n width="1000px"\n height="500px"\n src="' + table_url + '"\n frameborder="0"\n allowfullscreen\n></iframe>\n</div>'
379
  print(iframe)
 
92
 
93
  return fig
94
 
95
+ def scatter_chart_generation_func(x_column: List[str], y_column: str, session_hash, session_folder, data: List[dict]=[{}], layout: List[dict]=[{}],
96
  category: str="", trendline: str="", trendline_options: List[dict]=[{}], marginal_x: str="", marginal_y: str="",
97
+ size: str="", **kwargs):
98
  try:
99
+ dir_path = TEMP_DIR / str(session_hash) / str(session_folder)
100
  chart_path = f'{dir_path}/chart.html'
101
  csv_query_path = f'{dir_path}/query.csv'
102
 
 
129
 
130
  pio.write_html(fig, chart_path, full_html=False)
131
 
132
+ chart_url = f'{root_url}/gradio_api/file/temp/{session_hash}/{session_folder}/chart.html'
133
 
134
  iframe = '<div style=overflow:auto;><iframe\n scrolling="yes"\n width="1000px"\n height="500px"\n src="' + chart_url + '"\n frameborder="0"\n allowfullscreen\n></iframe>\n</div>'
135
 
 
144
  """
145
  return {"reply": reply}
146
 
147
+ def line_chart_generation_func(x_column: str, y_column: str, session_hash, session_folder, data: List[dict]=[{}], layout: List[dict]=[{}],
148
+ category: str="", **kwargs):
149
  try:
150
+ dir_path = TEMP_DIR / str(session_hash) / str(session_folder)
151
  chart_path = f'{dir_path}/chart.html'
152
  csv_query_path = f'{dir_path}/query.csv'
153
 
 
180
 
181
  pio.write_html(fig, chart_path, full_html=False)
182
 
183
+ chart_url = f'{root_url}/gradio_api/file/temp/{session_hash}/{session_folder}/chart.html'
184
 
185
  iframe = '<div style=overflow:auto;><iframe\n scrolling="yes"\n width="1000px"\n height="500px"\n src="' + chart_url + '"\n frameborder="0"\n allowfullscreen\n></iframe>\n</div>'
186
 
 
195
  """
196
  return {"reply": reply}
197
 
198
+ def bar_chart_generation_func(x_column: str, y_column: str, session_hash, session_folder, data: List[dict]=[{}], layout: List[dict]=[{}],
199
+ category: str="", facet_row: str="", facet_col: str="", **kwargs):
200
  try:
201
+ dir_path = TEMP_DIR / str(session_hash) / str(session_folder)
202
  chart_path = f'{dir_path}/chart.html'
203
  csv_query_path = f'{dir_path}/query.csv'
204
 
 
235
 
236
  pio.write_html(fig, chart_path, full_html=False)
237
 
238
+ chart_url = f'{root_url}/gradio_api/file/temp/{session_hash}/{session_folder}/chart.html'
239
 
240
  iframe = '<div style=overflow:auto;><iframe\n scrolling="yes"\n width="1000px"\n height="500px"\n src="' + chart_url + '"\n frameborder="0"\n allowfullscreen\n></iframe>\n</div>'
241
 
 
250
  """
251
  return {"reply": reply}
252
 
253
+ def pie_chart_generation_func(values: str, names: str, session_hash, session_folder, data: List[dict]=[{}], layout: List[dict]=[{}], **kwargs):
254
  try:
255
+ dir_path = TEMP_DIR / str(session_hash) / str(session_folder)
256
  chart_path = f'{dir_path}/chart.html'
257
  csv_query_path = f'{dir_path}/query.csv'
258
 
 
282
 
283
  pio.write_html(fig, chart_path, full_html=False)
284
 
285
+ chart_url = f'{root_url}/gradio_api/file/temp/{session_hash}/{session_folder}/chart.html'
286
 
287
  iframe = '<div style=overflow:auto;><iframe\n scrolling="yes"\n width="1000px"\n height="500px"\n src="' + chart_url + '"\n frameborder="0"\n allowfullscreen\n></iframe>\n</div>'
288
 
 
297
  """
298
  return {"reply": reply}
299
 
300
+ def histogram_generation_func(x_column: str, session_hash, session_folder, y_column: str="", data: List[dict]=[{}], layout: List[dict]=[{}], histnorm: str="", category: str="",
301
+ histfunc: str="", **kwargs):
302
  try:
303
+ dir_path = TEMP_DIR / str(session_hash) / str(session_folder)
304
  chart_path = f'{dir_path}/chart.html'
305
  csv_query_path = f'{dir_path}/query.csv'
306
 
307
  df = pd.read_csv(csv_query_path)
308
 
 
309
  print(x_column)
310
 
311
  function_args = {"data_frame":df, "x":x_column}
 
341
 
342
  pio.write_html(fig, chart_path, full_html=False)
343
 
344
+ chart_url = f'{root_url}/gradio_api/file/temp/{session_hash}/{session_folder}/chart.html'
345
 
346
  iframe = '<div style=overflow:auto;><iframe\n scrolling="yes"\n width="1000px"\n height="500px"\n src="' + chart_url + '"\n frameborder="0"\n allowfullscreen\n></iframe>\n</div>'
347
 
 
356
  """
357
  return {"reply": reply}
358
 
359
+ def table_generation_func(session_hash, session_folder, **kwargs):
360
  print("TABLE GENERATION")
361
  try:
362
+ dir_path = TEMP_DIR / str(session_hash) / str(session_folder)
363
  csv_query_path = f'{dir_path}/query.csv'
364
  table_path = f'{dir_path}/table.html'
365
 
366
  df = pd.read_csv(csv_query_path)
 
367
 
368
  html_table = df.to_html()
369
  print(html_table)
 
371
  with open(table_path, "w") as file:
372
  file.write(html_table)
373
 
374
+ table_url = f'{root_url}/gradio_api/file/temp/{session_hash}/{session_folder}/table.html'
375
 
376
  iframe = '<div style=overflow:auto;><iframe\n scrolling="yes"\n width="1000px"\n height="500px"\n src="' + table_url + '"\n frameborder="0"\n allowfullscreen\n></iframe>\n</div>'
377
  print(iframe)
functions/chat_functions.py CHANGED
@@ -35,6 +35,25 @@ def example_question_generator(session_hash):
35
 
36
  return example_response["replies"][0].text
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  def chatbot_with_fc(message, history, session_hash):
39
  from functions import sqlite_query_func, table_generation_func, regression_func, scatter_chart_generation_func, \
40
  line_chart_generation_func,bar_chart_generation_func,pie_chart_generation_func,histogram_generation_func
@@ -46,8 +65,8 @@ def chatbot_with_fc(message, history, session_hash):
46
  "histogram_generation_func":histogram_generation_func,
47
  "regression_func":regression_func }
48
 
49
- if message_dict[session_hash] != None:
50
- message_dict[session_hash].append(ChatMessage.from_user(message))
51
  else:
52
  messages = [
53
  ChatMessage.from_system(
@@ -58,35 +77,94 @@ def chatbot_with_fc(message, history, session_hash):
58
  You also have access to a bar graph function, called line_chart_generation_func, that can take a query.csv file generated from our sql query and uses plotly dictionaries to generate a bar graph and returns an iframe that we can display in our chat window.
59
  You also have access to a pie chart function, called pie_chart_generation_func, that can take a query.csv file generated from our sql query and uses plotly dictionaries to generate a pie chart and returns an iframe that we can display in our chat window.
60
  You also have access to a histogram function, called histogram_generation_func, that can take a query.csv file generated from our sql query and uses plotly dictionaries to generate a histogram and returns an iframe that we can display in our chat window.
61
- You also have access to a linear regression function, called regression_func, that can take a query.csv file generated from our sql query and a list of column names for our independent and dependent variables and return a regression data string and a regression chart which is returned as an iframe."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  )
63
  ]
64
  messages.append(ChatMessage.from_user(message))
65
- message_dict[session_hash] = messages
66
 
67
- response = chat_generator.run(messages=message_dict[session_hash], generation_kwargs={"tools": tools.data_file_tools_call(session_hash)})
68
 
69
  while True:
70
  # if OpenAI response is a tool call
71
  if response and response["replies"][0].meta["finish_reason"] == "tool_calls" or response["replies"][0].tool_calls:
72
  function_calls = response["replies"][0].tool_calls
73
  for function_call in function_calls:
74
- message_dict[session_hash].append(ChatMessage.from_assistant(tool_calls=[function_call]))
75
  ## Parse function calling information
76
  function_name = function_call.tool_name
77
  function_args = function_call.arguments
78
 
79
  ## Find the corresponding function and call it with the given arguments
80
  function_to_call = available_functions[function_name]
81
- function_response = function_to_call(**function_args, session_hash=session_hash)
 
82
  print(function_name)
83
  ## Append function response to the messages list using `ChatMessage.from_tool`
84
- message_dict[session_hash].append(ChatMessage.from_tool(tool_result=function_response['reply'], origin=function_call))
85
- response = chat_generator.run(messages=message_dict[session_hash], generation_kwargs={"tools": tools.data_file_tools_call(session_hash)})
86
 
87
  # Regular Conversation
88
  else:
89
- message_dict[session_hash].append(response["replies"][0])
90
  break
91
 
92
  return response["replies"][0].text
 
35
 
36
  return example_response["replies"][0].text
37
 
38
+ def sql_example_question_generator(session_hash, db_tables, db_name):
39
+ example_response = None
40
+ example_messages = [
41
+ ChatMessage.from_system(
42
+ f"You are a helpful and knowledgeable agent who has access to an PostgreSQL database called {db_name}."
43
+ )
44
+ ]
45
+
46
+ example_messages.append(ChatMessage.from_user(text=f"""We have a PostgreSQL database with the following tables: {db_tables}.
47
+ We also have an AI agent with access to the same database that will be performing data analysis.
48
+ Please return an array of seven strings, each one being a question for our data analysis agent
49
+ that we can suggest that you believe will be insightful or helpful to a data analysis looking for
50
+ data insights. Return nothing more than the array of questions because I need that specific data structure
51
+ to process your response. No other response type or data structure will work."""))
52
+
53
+ example_response = chat_generator.run(messages=example_messages)
54
+
55
+ return example_response["replies"][0].text
56
+
57
  def chatbot_with_fc(message, history, session_hash):
58
  from functions import sqlite_query_func, table_generation_func, regression_func, scatter_chart_generation_func, \
59
  line_chart_generation_func,bar_chart_generation_func,pie_chart_generation_func,histogram_generation_func
 
65
  "histogram_generation_func":histogram_generation_func,
66
  "regression_func":regression_func }
67
 
68
+ if message_dict[session_hash]['file_upload'] != None:
69
+ message_dict[session_hash]['file_upload'].append(ChatMessage.from_user(message))
70
  else:
71
  messages = [
72
  ChatMessage.from_system(
 
77
  You also have access to a bar graph function, called line_chart_generation_func, that can take a query.csv file generated from our sql query and uses plotly dictionaries to generate a bar graph and returns an iframe that we can display in our chat window.
78
  You also have access to a pie chart function, called pie_chart_generation_func, that can take a query.csv file generated from our sql query and uses plotly dictionaries to generate a pie chart and returns an iframe that we can display in our chat window.
79
  You also have access to a histogram function, called histogram_generation_func, that can take a query.csv file generated from our sql query and uses plotly dictionaries to generate a histogram and returns an iframe that we can display in our chat window.
80
+ You also have access to a linear regression function, called regression_func, that can take a query.csv file generated from our sql query and a list of column names for our independent and dependent variables and return a regression data string and a regression chart which is returned as an iframe.
81
+ Charts, tables, and visualizations are a very important part of your output. If you generate a chart, table, or visualization as part of your answer, please display it always."""
82
+ )
83
+ ]
84
+ messages.append(ChatMessage.from_user(message))
85
+ message_dict[session_hash]['file_upload'] = messages
86
+
87
+ response = chat_generator.run(messages=message_dict[session_hash]['file_upload'], generation_kwargs={"tools": tools.data_file_tools_call(session_hash)})
88
+
89
+ while True:
90
+ # if OpenAI response is a tool call
91
+ if response and response["replies"][0].meta["finish_reason"] == "tool_calls" or response["replies"][0].tool_calls:
92
+ function_calls = response["replies"][0].tool_calls
93
+ for function_call in function_calls:
94
+ message_dict[session_hash]['file_upload'].append(ChatMessage.from_assistant(tool_calls=[function_call]))
95
+ ## Parse function calling information
96
+ function_name = function_call.tool_name
97
+ function_args = function_call.arguments
98
+
99
+ ## Find the corresponding function and call it with the given arguments
100
+ function_to_call = available_functions[function_name]
101
+ function_response = function_to_call(**function_args, session_hash=session_hash, session_folder='file_upload')
102
+ print(function_name)
103
+ ## Append function response to the messages list using `ChatMessage.from_tool`
104
+ message_dict[session_hash]['file_upload'].append(ChatMessage.from_tool(tool_result=function_response['reply'], origin=function_call))
105
+ response = chat_generator.run(messages=message_dict[session_hash]['file_upload'], generation_kwargs={"tools": tools.data_file_tools_call(session_hash)})
106
+
107
+ # Regular Conversation
108
+ else:
109
+ message_dict[session_hash]['file_upload'].append(response["replies"][0])
110
+ break
111
+
112
+ return response["replies"][0].text
113
+
114
+ def sql_chatbot_with_fc(message, history, session_hash, db_url, db_port, db_user, db_pass, db_name, db_tables):
115
+ from functions import sql_query_func, table_generation_func, regression_func, scatter_chart_generation_func, \
116
+ line_chart_generation_func,bar_chart_generation_func,pie_chart_generation_func,histogram_generation_func
117
+ import tools.tools as tools
118
+
119
+ available_functions = {"sql_query_func": sql_query_func,"table_generation_func":table_generation_func,
120
+ "line_chart_generation_func":line_chart_generation_func,"bar_chart_generation_func":bar_chart_generation_func,
121
+ "scatter_chart_generation_func":scatter_chart_generation_func, "pie_chart_generation_func":pie_chart_generation_func,
122
+ "histogram_generation_func":histogram_generation_func,
123
+ "regression_func":regression_func }
124
+
125
+ if message_dict[session_hash]['sql'] != None:
126
+ message_dict[session_hash]['sql'].append(ChatMessage.from_user(message))
127
+ else:
128
+ messages = [
129
+ ChatMessage.from_system(
130
+ f"""You are a helpful and knowledgeable agent who has access to an PostgreSQL database which has a series of tables called {db_tables}.
131
+ You also have access to a function, called table_generation_func, that can take a query.csv file generated from our sql query and returns an iframe that we can display in our chat window.
132
+ You also have access to a scatter plot function, called scatter_chart_generation_func, that can take a query.csv file generated from our sql query and uses plotly dictionaries to generate a scatter plot and returns an iframe that we can display in our chat window.
133
+ You also have access to a line chart function, called line_chart_generation_func, that can take a query.csv file generated from our sql query and uses plotly dictionaries to generate a line chart and returns an iframe that we can display in our chat window.
134
+ You also have access to a bar graph function, called line_chart_generation_func, that can take a query.csv file generated from our sql query and uses plotly dictionaries to generate a bar graph and returns an iframe that we can display in our chat window.
135
+ You also have access to a pie chart function, called pie_chart_generation_func, that can take a query.csv file generated from our sql query and uses plotly dictionaries to generate a pie chart and returns an iframe that we can display in our chat window.
136
+ You also have access to a histogram function, called histogram_generation_func, that can take a query.csv file generated from our sql query and uses plotly dictionaries to generate a histogram and returns an iframe that we can display in our chat window.
137
+ You also have access to a linear regression function, called regression_func, that can take a query.csv file generated from our sql query and a list of column names for our independent and dependent variables and return a regression data string and a regression chart which is returned as an iframe.
138
+ Charts, tables, and visualizations are a very important part of your output. If you generate a chart, table, or visualization as part of your answer, please display it always."""
139
  )
140
  ]
141
  messages.append(ChatMessage.from_user(message))
142
+ message_dict[session_hash]['sql'] = messages
143
 
144
+ response = chat_generator.run(messages=message_dict[session_hash]['sql'], generation_kwargs={"tools": tools.sql_tools_call(db_tables)})
145
 
146
  while True:
147
  # if OpenAI response is a tool call
148
  if response and response["replies"][0].meta["finish_reason"] == "tool_calls" or response["replies"][0].tool_calls:
149
  function_calls = response["replies"][0].tool_calls
150
  for function_call in function_calls:
151
+ message_dict[session_hash]['sql'].append(ChatMessage.from_assistant(tool_calls=[function_call]))
152
  ## Parse function calling information
153
  function_name = function_call.tool_name
154
  function_args = function_call.arguments
155
 
156
  ## Find the corresponding function and call it with the given arguments
157
  function_to_call = available_functions[function_name]
158
+ function_response = function_to_call(**function_args, session_hash=session_hash, db_url=db_url,
159
+ db_port=db_port, db_user=db_user, db_pass=db_pass, db_name=db_name, session_folder='sql')
160
  print(function_name)
161
  ## Append function response to the messages list using `ChatMessage.from_tool`
162
+ message_dict[session_hash]['sql'].append(ChatMessage.from_tool(tool_result=function_response['reply'], origin=function_call))
163
+ response = chat_generator.run(messages=message_dict[session_hash]['sql'], generation_kwargs={"tools": tools.sql_tools_call(db_tables)})
164
 
165
  # Regular Conversation
166
  else:
167
+ message_dict[session_hash]['sql'].append(response["replies"][0])
168
  break
169
 
170
  return response["replies"][0].text
functions/{sqlite_functions.py → query_functions.py} RENAMED
@@ -6,6 +6,7 @@ pd.set_option('display.max_columns', None)
6
  pd.set_option('display.width', None)
7
  pd.set_option('display.max_colwidth', None)
8
  import sqlite3
 
9
  from utils import TEMP_DIR
10
 
11
  @component
@@ -16,21 +17,21 @@ class SQLiteQuery:
16
 
17
  @component.output_types(results=List[str], queries=List[str])
18
  def run(self, queries: List[str], session_hash):
19
- print("ATTEMPTING TO RUN QUERY")
20
  dir_path = TEMP_DIR / str(session_hash)
21
  results = []
22
  for query in queries:
23
  result = pd.read_sql(query, self.connection)
24
- result.to_csv(f'{dir_path}/query.csv', index=False)
25
  results.append(f"{result}")
26
  self.connection.close()
27
  return {"results": results, "queries": queries}
28
 
29
 
30
 
31
- def sqlite_query_func(queries: List[str], session_hash):
32
  dir_path = TEMP_DIR / str(session_hash)
33
- sql_query = SQLiteQuery(f'{dir_path}/data_source.db')
34
  try:
35
  result = sql_query.run(queries, session_hash)
36
  if len(result["results"][0]) > 1000:
@@ -45,3 +46,50 @@ def sqlite_query_func(queries: List[str], session_hash):
45
  You should probably try again.
46
  """
47
  return {"reply": reply}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  pd.set_option('display.width', None)
7
  pd.set_option('display.max_colwidth', None)
8
  import sqlite3
9
+ import psycopg2
10
  from utils import TEMP_DIR
11
 
12
  @component
 
17
 
18
  @component.output_types(results=List[str], queries=List[str])
19
  def run(self, queries: List[str], session_hash):
20
+ print("ATTEMPTING TO RUN SQLITE QUERY")
21
  dir_path = TEMP_DIR / str(session_hash)
22
  results = []
23
  for query in queries:
24
  result = pd.read_sql(query, self.connection)
25
+ result.to_csv(f'{dir_path}/file_upload/query.csv', index=False)
26
  results.append(f"{result}")
27
  self.connection.close()
28
  return {"results": results, "queries": queries}
29
 
30
 
31
 
32
+ def sqlite_query_func(queries: List[str], session_hash, **kwargs):
33
  dir_path = TEMP_DIR / str(session_hash)
34
+ sql_query = SQLiteQuery(f'{dir_path}/file_upload/data_source.db')
35
  try:
36
  result = sql_query.run(queries, session_hash)
37
  if len(result["results"][0]) > 1000:
 
46
  You should probably try again.
47
  """
48
  return {"reply": reply}
49
+
50
+ @component
51
+ class PostgreSQLQuery:
52
+
53
+ def __init__(self, url: str, sql_port: int, sql_user: str, sql_pass: str, sql_db_name: str):
54
+ self.connection = psycopg2.connect(
55
+ database=sql_db_name,
56
+ user=sql_user,
57
+ password=sql_pass,
58
+ host=url, # e.g., "localhost" or an IP address
59
+ port=sql_port # default is 5432
60
+ )
61
+
62
+ @component.output_types(results=List[str], queries=List[str])
63
+ def run(self, queries: List[str], session_hash):
64
+ print("ATTEMPTING TO RUN POSTGRESQL QUERY")
65
+ dir_path = TEMP_DIR / str(session_hash)
66
+ results = []
67
+ for query in queries:
68
+ print(query)
69
+ result = pd.read_sql_query(query, self.connection)
70
+ result.to_csv(f'{dir_path}/sql/query.csv', index=False)
71
+ results.append(f"{result}")
72
+ self.connection.close()
73
+ return {"results": results, "queries": queries}
74
+
75
+
76
+
77
+ def sql_query_func(queries: List[str], session_hash, db_url, db_port, db_user, db_pass, db_name, **kwargs):
78
+ sql_query = PostgreSQLQuery(db_url, db_port, db_user, db_pass, db_name)
79
+ try:
80
+ result = sql_query.run(queries, session_hash)
81
+ print("RESULT")
82
+ print(result)
83
+ if len(result["results"][0]) > 1000:
84
+ print("QUERY TOO LARGE")
85
+ return {"reply": "query result too large to be processed by llm, the query results are in our query.csv file. If you need to display the results directly, perhaps use the table_generation_func function."}
86
+ else:
87
+ return {"reply": result["results"][0]}
88
+
89
+ except Exception as e:
90
+ reply = f"""There was an error running the SQL Query = {queries}
91
+ The error is {e},
92
+ You should probably try again.
93
+ """
94
+ print(reply)
95
+ return {"reply": reply}
functions/stat_functions.py CHANGED
@@ -12,12 +12,12 @@ load_dotenv()
12
 
13
  root_url = os.getenv("ROOT_URL")
14
 
15
- def regression_func(independent_variables: List[str], dependent_variable: str, session_hash, category: str=''):
16
  print("LINEAR REGRESSION CALCULATION")
17
  print(independent_variables)
18
  print(dependent_variable)
19
  try:
20
- dir_path = TEMP_DIR / str(session_hash)
21
  chart_path = f'{dir_path}/chart.html'
22
  csv_query_path = f'{dir_path}/query.csv'
23
 
@@ -32,7 +32,7 @@ def regression_func(independent_variables: List[str], dependent_variable: str, s
32
 
33
  pio.write_html(fig, chart_path, full_html=False)
34
 
35
- chart_url = f'{root_url}/gradio_api/file/temp/{session_hash}/chart.html'
36
 
37
  iframe = '<div style=overflow:auto;><iframe\n scrolling="yes"\n width="1000px"\n height="500px"\n src="' + chart_url + '"\n frameborder="0"\n allowfullscreen\n></iframe>\n</div>'
38
 
 
12
 
13
  root_url = os.getenv("ROOT_URL")
14
 
15
+ def regression_func(independent_variables: List[str], dependent_variable: str, session_hash, session_folder, category: str='', **kwargs):
16
  print("LINEAR REGRESSION CALCULATION")
17
  print(independent_variables)
18
  print(dependent_variable)
19
  try:
20
+ dir_path = TEMP_DIR / str(session_hash) / str(session_folder)
21
  chart_path = f'{dir_path}/chart.html'
22
  csv_query_path = f'{dir_path}/query.csv'
23
 
 
32
 
33
  pio.write_html(fig, chart_path, full_html=False)
34
 
35
+ chart_url = f'{root_url}/gradio_api/file/temp/{session_hash}/{session_folder}/chart.html'
36
 
37
  iframe = '<div style=overflow:auto;><iframe\n scrolling="yes"\n width="1000px"\n height="500px"\n src="' + chart_url + '"\n frameborder="0"\n allowfullscreen\n></iframe>\n</div>'
38
 
requirements.txt CHANGED
@@ -6,3 +6,4 @@ plotly
6
  openpyxl
7
  statsmodels
8
  xlrd
 
 
6
  openpyxl
7
  statsmodels
8
  xlrd
9
+ psycopg2-binary
templates/__pycache__/data_file.cpython-312.pyc ADDED
Binary file (8.68 kB). View file
 
templates/__pycache__/sql_db.cpython-312.pyc ADDED
Binary file (6.71 kB). View file
 
data_file.py → templates/data_file.py RENAMED
@@ -68,7 +68,8 @@ with gr.Blocks() as demo:
68
  @gr.render(inputs=file_output)
69
  def data_options(filename, request: gr.Request):
70
  print(filename)
71
- message_dict[request.session_hash] = None
 
72
  if filename:
73
  process_message = process_upload(filename, request.session_hash)
74
  gr.HTML(value=process_message[1], padding=False)
@@ -101,7 +102,9 @@ with gr.Blocks() as demo:
101
  ]
102
  for example in generated_examples:
103
  example_questions.append([example])
104
- except:
 
 
105
  example_questions = [
106
  ["Describe the dataset"],
107
  ["List the columns in the dataset"],
 
68
  @gr.render(inputs=file_output)
69
  def data_options(filename, request: gr.Request):
70
  print(filename)
71
+ message_dict[request.session_hash] = {}
72
+ message_dict[request.session_hash]['file_upload'] = None
73
  if filename:
74
  process_message = process_upload(filename, request.session_hash)
75
  gr.HTML(value=process_message[1], padding=False)
 
102
  ]
103
  for example in generated_examples:
104
  example_questions.append([example])
105
+ except Exception as e:
106
+ print("DATA FILE QUESTION GENERATION ERROR")
107
+ print(e)
108
  example_questions = [
109
  ["Describe the dataset"],
110
  ["List the columns in the dataset"],
templates/sql_db.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import gradio as gr
3
+ from functions import sql_example_question_generator, sql_chatbot_with_fc
4
+ from data_sources import connect_sql_db
5
+ from utils import message_dict
6
+
7
+ def hide_info():
8
+ return gr.update(visible=False)
9
+
10
+ with gr.Blocks() as demo:
11
+ description = gr.HTML("""
12
+ <!-- Header -->
13
+ <div class="max-w-4xl mx-auto mb-12 text-center">
14
+ <div class="bg-blue-50 border border-blue-200 rounded-lg max-w-2xl mx-auto">
15
+ <p>This tool allows users to communicate with and query real time data from a SQL DB (postgres for now, others can be added if requested) using natural
16
+ language and the above features.</p>
17
+ <p style="font-weight:bold;">Notice: the way this system is designed, no login information is retained and credentials are passed as session variables until the user leaves or
18
+ refreshes the page in which they disappear. They are never saved to any files. I also make use of the Pandas read_sql_query function to apply SQL
19
+ queries, which can't delete, drop, or add database lines to avoid unhappy accidents or glitches.
20
+ That being said, it's probably not a good idea to connect a production database to a strange AI tool with an unfamiliar author.
21
+ This should be for demonstration purposes.</p>
22
+ <p>Contact me if this is something you would like built in your organization, on your infrastructure, and with the requisite privacy and control a production
23
+ database analytics tool requires.</p>
24
+ </div>
25
+ </div>
26
+ """, elem_classes="description_component")
27
+ sql_url = gr.Textbox(label="URL", value="virtual-data-analyst-pg.cyetm2yjzppu.us-west-1.rds.amazonaws.com")
28
+ with gr.Row():
29
+ sql_port = gr.Textbox(label="Port", value="5432")
30
+ sql_user = gr.Textbox(label="Username", value="postgres")
31
+ sql_pass = gr.Textbox(label="Password", value="Vda-1988", type="password")
32
+ sql_db_name = gr.Textbox(label="Database Name", value="dvdrental")
33
+
34
+ submit = gr.Button(value="Submit")
35
+ submit.click(fn=hide_info, outputs=description)
36
+
37
+ @gr.render(inputs=[sql_url,sql_port,sql_user,sql_pass,sql_db_name], triggers=[submit.click])
38
+ def sql_chat(request: gr.Request, url=sql_url.value, sql_port=sql_port.value, sql_user=sql_user.value, sql_pass=sql_pass.value, sql_db_name=sql_db_name.value):
39
+ message_dict[request.session_hash]['sql'] = None
40
+ if url:
41
+ print("SQL APP")
42
+ print(request)
43
+ process_message = process_sql_db(url, sql_user, sql_port, sql_pass, sql_db_name, request.session_hash)
44
+ gr.HTML(value=process_message[1], padding=False)
45
+ if process_message[0] == "success":
46
+ if "virtual-data-analyst-pg.cyetm2yjzppu.us-west-1.rds.amazonaws.com" in url:
47
+ example_questions = [
48
+ ["Describe the dataset"],
49
+ ["What is the total revenue generated by each store?"],
50
+ ["Can you generate and display a bar chart of film category to number of films in that category?"],
51
+ ["Can you generate a pie chart showing the top 10 most rented films by revenue vs all other films?"],
52
+ ["Can you generate a line chart of rental revenue over time?"],
53
+ ["What is the relationship between film length and rental frequency?"]
54
+ ]
55
+ else:
56
+ try:
57
+ generated_examples = ast.literal_eval(sql_example_question_generator(request.session_hash, process_message[2], sql_db_name))
58
+ example_questions = [
59
+ ["Describe the dataset"]
60
+ ]
61
+ for example in generated_examples:
62
+ example_questions.append([example])
63
+ except Exception as e:
64
+ print("SQL QUESTION GENERATION ERROR")
65
+ print(e)
66
+ example_questions = [
67
+ ["Describe the dataset"],
68
+ ["List the columns in the dataset"],
69
+ ["What could this data be used for?"],
70
+ ]
71
+ session_hash = gr.Textbox(visible=False, value=request.session_hash)
72
+ db_url = gr.Textbox(visible=False, value=url)
73
+ db_port = gr.Textbox(visible=False, value=sql_port)
74
+ db_user = gr.Textbox(visible=False, value=sql_user)
75
+ db_pass = gr.Textbox(visible=False, value=sql_pass)
76
+ db_name = gr.Textbox(visible=False, value=sql_db_name)
77
+ db_tables = gr.Textbox(value=process_message[2], interactive=False, label="SQL Tables")
78
+ bot = gr.Chatbot(type='messages', label="CSV Chat Window", render_markdown=True, sanitize_html=False, show_label=True, render=False, visible=True, elem_classes="chatbot")
79
+ chat = gr.ChatInterface(
80
+ fn=sql_chatbot_with_fc,
81
+ type='messages',
82
+ chatbot=bot,
83
+ title="Chat with your Database",
84
+ examples=example_questions,
85
+ concurrency_limit=None,
86
+ additional_inputs=[session_hash, db_url, db_port, db_user, db_pass, db_name, db_tables]
87
+ )
88
+
89
+ def process_sql_db(url, sql_user, sql_port, sql_pass, sql_db_name, session_hash):
90
+ if url:
91
+ process_message = connect_sql_db(url, sql_user, sql_port, sql_pass, sql_db_name, session_hash)
92
+ return process_message
93
+
94
+ if __name__ == "__main__":
95
+ demo.launch()
tools/chart_tools.py CHANGED
@@ -3,7 +3,7 @@ chart_tools = [
3
  "type": "function",
4
  "function": {
5
  "name": "scatter_chart_generation_func",
6
- "description": f"""This is a scatter plot generation tool useful to generate scatter plots from queried data from our SQL table called 'data_source'.
7
  The data values will come from the columns of our query.csv (the 'x' and 'y' values of each graph) file but the layout section of the plotly dictionary objects will be generated by you.
8
  Returns an iframe string which will be displayed inline in our chat window. Do not edit the iframe string returned
9
  from the scatter_chart_generation_func function in any way and always display the iframe fully to the user in the chat window. You can add your own text supplementary
@@ -108,7 +108,7 @@ chart_tools = [
108
  "type": "function",
109
  "function": {
110
  "name": "line_chart_generation_func",
111
- "description": f"""This is a line chart generation tool useful to generate line charts from queried data from our SQL table called 'data_source'.
112
  The data values will come from the columns of our query.csv (the 'x' and 'y' values of each graph) file but the layout section of the plotly dictionary objects will be generated by you.
113
  Returns an iframe string which will be displayed inline in our chat window. Do not edit the iframe string returned
114
  from the line_chart_generation_func function in any way and always display the iframe fully to the user in the chat window. You can add your own text supplementary
@@ -164,7 +164,7 @@ chart_tools = [
164
  "type": "function",
165
  "function": {
166
  "name": "bar_chart_generation_func",
167
- "description": f"""This is a bar chart generation tool useful to generate line charts from queried data from our SQL table called 'data_source'.
168
  The data values will come from the columns of our query.csv (the 'x' and 'y' values of each graph) file but the layout section of the plotly dictionary objects will be generated by you.
169
  Returns an iframe string which will be displayed inline in our chat window. Do not edit the iframe string returned
170
  from the bar_chart_generation_func function in any way and always display the iframe fully to the user in the chat window. You can add your own text supplementary
@@ -236,7 +236,7 @@ chart_tools = [
236
  "type": "function",
237
  "function": {
238
  "name": "pie_chart_generation_func",
239
- "description": f"""This is a pie chart generation tool useful to generate pie charts from queried data from our SQL table called 'data_source'.
240
  The data values will come from the columns of our query.csv (the 'values' and 'names' values of each graph) file but the layout section of the plotly dictionary objects will be generated by you.
241
  Returns an iframe string which will be displayed inline in our chat window. Do not edit the iframe string returned
242
  from the pie_chart_generation_func function in any way and always display the iframe fully to the user in the chat window. You can add your own text supplementary
@@ -285,7 +285,7 @@ chart_tools = [
285
  "type": "function",
286
  "function": {
287
  "name": "histogram_generation_func",
288
- "description": f"""This is a histogram generation tool useful to generate histograms from queried data from our SQL table called 'data_source'.
289
  The data values will come from the columns of our query.csv (the 'values' and 'names' values of each graph) file but the layout section of the plotly dictionary objects will be generated by you.
290
  Returns an iframe string which will be displayed inline in our chat window. Do not edit the iframe string returned
291
  from the histogram_generation_func function in any way and always display the iframe fully to the user in the chat window. You can add your own text supplementary
@@ -360,7 +360,7 @@ chart_tools = [
360
  "type": "function",
361
  "function": {
362
  "name": "table_generation_func",
363
- "description": f"""This an table generation tool useful to format data as a table from queried data from our SQL table called 'data_source'.
364
  Takes no parameters as it uses data queried in our query.csv file to build the table.
365
  Call this function after running our SQLite query and generating query.csv.
366
  Returns an iframe string which will be displayed inline in our chat window. Do not edit the iframe string returned
 
3
  "type": "function",
4
  "function": {
5
  "name": "scatter_chart_generation_func",
6
+ "description": f"""This is a scatter plot generation tool useful to generate scatter plots from queried data from our data source that we are querying.
7
  The data values will come from the columns of our query.csv (the 'x' and 'y' values of each graph) file but the layout section of the plotly dictionary objects will be generated by you.
8
  Returns an iframe string which will be displayed inline in our chat window. Do not edit the iframe string returned
9
  from the scatter_chart_generation_func function in any way and always display the iframe fully to the user in the chat window. You can add your own text supplementary
 
108
  "type": "function",
109
  "function": {
110
  "name": "line_chart_generation_func",
111
+ "description": f"""This is a line chart generation tool useful to generate line charts from queried data from our data source that we are querying.
112
  The data values will come from the columns of our query.csv (the 'x' and 'y' values of each graph) file but the layout section of the plotly dictionary objects will be generated by you.
113
  Returns an iframe string which will be displayed inline in our chat window. Do not edit the iframe string returned
114
  from the line_chart_generation_func function in any way and always display the iframe fully to the user in the chat window. You can add your own text supplementary
 
164
  "type": "function",
165
  "function": {
166
  "name": "bar_chart_generation_func",
167
+ "description": f"""This is a bar chart generation tool useful to generate line charts from queried data from our data source that we are querying.
168
  The data values will come from the columns of our query.csv (the 'x' and 'y' values of each graph) file but the layout section of the plotly dictionary objects will be generated by you.
169
  Returns an iframe string which will be displayed inline in our chat window. Do not edit the iframe string returned
170
  from the bar_chart_generation_func function in any way and always display the iframe fully to the user in the chat window. You can add your own text supplementary
 
236
  "type": "function",
237
  "function": {
238
  "name": "pie_chart_generation_func",
239
+ "description": f"""This is a pie chart generation tool useful to generate pie charts from queried data from our data source that we are querying.
240
  The data values will come from the columns of our query.csv (the 'values' and 'names' values of each graph) file but the layout section of the plotly dictionary objects will be generated by you.
241
  Returns an iframe string which will be displayed inline in our chat window. Do not edit the iframe string returned
242
  from the pie_chart_generation_func function in any way and always display the iframe fully to the user in the chat window. You can add your own text supplementary
 
285
  "type": "function",
286
  "function": {
287
  "name": "histogram_generation_func",
288
+ "description": f"""This is a histogram generation tool useful to generate histograms from queried data from our data source that we are querying.
289
  The data values will come from the columns of our query.csv (the 'values' and 'names' values of each graph) file but the layout section of the plotly dictionary objects will be generated by you.
290
  Returns an iframe string which will be displayed inline in our chat window. Do not edit the iframe string returned
291
  from the histogram_generation_func function in any way and always display the iframe fully to the user in the chat window. You can add your own text supplementary
 
360
  "type": "function",
361
  "function": {
362
  "name": "table_generation_func",
363
+ "description": f"""This an table generation tool useful to format data as a table from queried data from our data source that we are querying.
364
  Takes no parameters as it uses data queried in our query.csv file to build the table.
365
  Call this function after running our SQLite query and generating query.csv.
366
  Returns an iframe string which will be displayed inline in our chat window. Do not edit the iframe string returned
tools/stats_tools.py CHANGED
@@ -3,7 +3,7 @@ stats_tools = [
3
  "type": "function",
4
  "function": {
5
  "name": "regression_func",
6
- "description": f"""This a tool to calculate regressions on our SQLite table called 'data_source'.
7
  We can run queries with our 'sql_query_func' function and they will be available to use in this function via the query.csv file that is generated.
8
  Returns a dictionary of values that includes a regression_summary and a regression chart (which is an iframe displaying the
9
  linear regression in chart form and should be shown to the user).""",
 
3
  "type": "function",
4
  "function": {
5
  "name": "regression_func",
6
+ "description": f"""This a tool to calculate regressions on our data source that we are querying.
7
  We can run queries with our 'sql_query_func' function and they will be available to use in this function via the query.csv file that is generated.
8
  Returns a dictionary of values that includes a regression_summary and a regression chart (which is an iframe displaying the
9
  linear regression in chart form and should be shown to the user).""",
tools/tools.py CHANGED
@@ -1,11 +1,12 @@
1
  import sqlite3
 
2
  from .stats_tools import stats_tools
3
  from .chart_tools import chart_tools
4
  from utils import TEMP_DIR
5
 
6
  def data_file_tools_call(session_hash):
7
  dir_path = TEMP_DIR / str(session_hash)
8
- connection = sqlite3.connect(f'{dir_path}/data_source.db')
9
  print("Querying Database in Tools.py");
10
  cur=connection.execute('select * from data_source')
11
  columns = [i[0] for i in cur.description]
@@ -46,22 +47,24 @@ def data_file_tools_call(session_hash):
46
 
47
  return tools_calls
48
 
49
- def graphql_tools_call(sessions_hash):
 
 
50
 
51
  tools_calls = [
52
  {
53
  "type": "function",
54
  "function": {
55
- "name": "graphql_query_func",
56
- "description": f"""This is a tool useful to query a GraphQL endpoint with the following Columns: {column_string}.
57
- There may also be more columns in the table if the number of columns is too large to process.
58
  This function also saves the results of the query to csv file called query.csv.""",
59
  "parameters": {
60
  "type": "object",
61
  "properties": {
62
  "queries": {
63
  "type": "array",
64
- "description": "The graphQL query to use in the search. Infer this from the user's message. It should be a question or a statement",
65
  "items": {
66
  "type": "string",
67
  }
@@ -73,7 +76,7 @@ def graphql_tools_call(sessions_hash):
73
  },
74
  ]
75
 
76
- tools_calls.append(chart_tools)
77
- tools_calls.append(stats_tools)
78
 
79
- return
 
1
  import sqlite3
2
+ import psycopg2
3
  from .stats_tools import stats_tools
4
  from .chart_tools import chart_tools
5
  from utils import TEMP_DIR
6
 
7
  def data_file_tools_call(session_hash):
8
  dir_path = TEMP_DIR / str(session_hash)
9
+ connection = sqlite3.connect(f'{dir_path}/file_upload/data_source.db')
10
  print("Querying Database in Tools.py");
11
  cur=connection.execute('select * from data_source')
12
  columns = [i[0] for i in cur.description]
 
47
 
48
  return tools_calls
49
 
50
+ def sql_tools_call(db_tables):
51
+
52
+ table_string = (db_tables[:625] + '..') if len(db_tables) > 625 else db_tables
53
 
54
  tools_calls = [
55
  {
56
  "type": "function",
57
  "function": {
58
+ "name": "sql_query_func",
59
+ "description": f"""This is a tool useful to query a PostgreSQL database with the following tables, {table_string}.
60
+ There may also be more tables in the database if the number of columns is too large to process.
61
  This function also saves the results of the query to csv file called query.csv.""",
62
  "parameters": {
63
  "type": "object",
64
  "properties": {
65
  "queries": {
66
  "type": "array",
67
+ "description": "The PostgreSQL query to use in the search. Infer this from the user's message. It should be a question or a statement",
68
  "items": {
69
  "type": "string",
70
  }
 
76
  },
77
  ]
78
 
79
+ tools_calls.extend(chart_tools)
80
+ tools_calls.extend(stats_tools)
81
 
82
+ return tools_calls