feat/postgresql-integration (#25)
Browse files- sql integration (5acf39cc95da88ec7b5ae96fd388b6b40f5a1606)
- app.py +4 -5
- data_sources/__init__.py +2 -1
- data_sources/connect_sql_db.py +42 -0
- data_sources/upload_file.py +3 -1
- functions/__init__.py +4 -4
- functions/chart_functions.py +22 -24
- functions/chat_functions.py +88 -10
- functions/{sqlite_functions.py → query_functions.py} +52 -4
- functions/stat_functions.py +3 -3
- requirements.txt +1 -0
- templates/__pycache__/data_file.cpython-312.pyc +0 -0
- templates/__pycache__/sql_db.cpython-312.pyc +0 -0
- data_file.py → templates/data_file.py +5 -2
- templates/sql_db.py +95 -0
- tools/chart_tools.py +6 -6
- tools/stats_tools.py +1 -1
- tools/tools.py +12 -9
app.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
from utils import TEMP_DIR, message_dict
|
2 |
import gradio as gr
|
3 |
-
import data_file
|
4 |
|
5 |
import os
|
6 |
from getpass import getpass
|
@@ -18,7 +18,7 @@ def delete_db(req: gr.Request):
|
|
18 |
if "OPENAI_API_KEY" not in os.environ:
|
19 |
os.environ["OPENAI_API_KEY"] = getpass("Enter OpenAI API key:")
|
20 |
|
21 |
-
css= ".file_marker .large{min-height:50px !important;} .
|
22 |
head = """<meta charset="UTF-8">
|
23 |
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
24 |
<title>Virtual Data Analyst</title>
|
@@ -72,9 +72,8 @@ with gr.Blocks(theme=theme, css=css, head=head, delete_cache=(3600,3600)) as dem
|
|
72 |
</main>""")
|
73 |
with gr.Tab("Data File"):
|
74 |
data_file.demo.render()
|
75 |
-
with gr.Tab("SQL Database
|
76 |
-
|
77 |
-
# sql_db.demo.render()
|
78 |
|
79 |
footer = gr.HTML("""<!-- Footer -->
|
80 |
<footer class="max-w-4xl mx-auto mt-12 text-center text-gray-500 text-sm">
|
|
|
1 |
from utils import TEMP_DIR, message_dict
|
2 |
import gradio as gr
|
3 |
+
import templates.data_file as data_file, templates.sql_db as sql_db
|
4 |
|
5 |
import os
|
6 |
from getpass import getpass
|
|
|
18 |
if "OPENAI_API_KEY" not in os.environ:
|
19 |
os.environ["OPENAI_API_KEY"] = getpass("Enter OpenAI API key:")
|
20 |
|
21 |
+
css= ".file_marker .large{min-height:50px !important;} .padding{padding:0;} .description_component{overflow:visible !important;}"
|
22 |
head = """<meta charset="UTF-8">
|
23 |
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
24 |
<title>Virtual Data Analyst</title>
|
|
|
72 |
</main>""")
|
73 |
with gr.Tab("Data File"):
|
74 |
data_file.demo.render()
|
75 |
+
with gr.Tab("SQL Database"):
|
76 |
+
sql_db.demo.render()
|
|
|
77 |
|
78 |
footer = gr.HTML("""<!-- Footer -->
|
79 |
<footer class="max-w-4xl mx-auto mt-12 text-center text-gray-500 text-sm">
|
data_sources/__init__.py
CHANGED
@@ -1,3 +1,4 @@
|
|
1 |
from .upload_file import process_data_upload
|
|
|
2 |
|
3 |
-
__all__ = ["process_data_upload"]
|
|
|
1 |
from .upload_file import process_data_upload
|
2 |
+
from .connect_sql_db import connect_sql_db
|
3 |
|
4 |
+
__all__ = ["process_data_upload","connect_sql_db"]
|
data_sources/connect_sql_db.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import psycopg2
|
2 |
+
import os
|
3 |
+
from utils import TEMP_DIR
|
4 |
+
|
5 |
+
def connect_sql_db(url, sql_user, sql_port, sql_pass, sql_db_name, session_hash):
|
6 |
+
try:
|
7 |
+
conn = psycopg2.connect(
|
8 |
+
database=sql_db_name,
|
9 |
+
user=sql_user,
|
10 |
+
password=sql_pass,
|
11 |
+
host=url, # e.g., "localhost" or an IP address
|
12 |
+
port=sql_port # default is 5432
|
13 |
+
)
|
14 |
+
print("Connected to PostgreSQL")
|
15 |
+
|
16 |
+
# Create a cursor object to execute SQL queries
|
17 |
+
cur = conn.cursor()
|
18 |
+
# Example: Execute a query
|
19 |
+
cur.execute("""SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'""")
|
20 |
+
table_tuples = cur.fetchall()
|
21 |
+
table_names = []
|
22 |
+
for table in table_tuples:
|
23 |
+
table_names.append(table[0])
|
24 |
+
|
25 |
+
print(table_names)
|
26 |
+
|
27 |
+
# Close the cursor and connection
|
28 |
+
cur.close()
|
29 |
+
conn.close()
|
30 |
+
print("Connection closed.")
|
31 |
+
|
32 |
+
session_path = 'sql'
|
33 |
+
|
34 |
+
dir_path = TEMP_DIR / str(session_hash) / str(session_path)
|
35 |
+
os.makedirs(dir_path, exist_ok=True)
|
36 |
+
|
37 |
+
return ["success","<p style='color:green;text-align:center;font-size:18px;'>SQL database connected successful</p>", table_names]
|
38 |
+
except Exception as e:
|
39 |
+
print("UPLOAD ERROR")
|
40 |
+
print(e)
|
41 |
+
return ["error",f"<p style='color:red;text-align:center;font-size:18px;font-weight:bold;'>ERROR: {e}</p>"]
|
42 |
+
|
data_sources/upload_file.py
CHANGED
@@ -74,7 +74,9 @@ def process_data_upload(data_file, session_hash):
|
|
74 |
if df[column].dtype == 'object' and isinstance(df[column].iloc[0], list):
|
75 |
df[column] = df[column].explode()
|
76 |
|
77 |
-
|
|
|
|
|
78 |
os.makedirs(dir_path, exist_ok=True)
|
79 |
|
80 |
connection = sqlite3.connect(f'{dir_path}/data_source.db')
|
|
|
74 |
if df[column].dtype == 'object' and isinstance(df[column].iloc[0], list):
|
75 |
df[column] = df[column].explode()
|
76 |
|
77 |
+
session_path = 'file_upload'
|
78 |
+
|
79 |
+
dir_path = TEMP_DIR / str(session_hash) / str(session_path)
|
80 |
os.makedirs(dir_path, exist_ok=True)
|
81 |
|
82 |
connection = sqlite3.connect(f'{dir_path}/data_source.db')
|
functions/__init__.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
-
from .
|
2 |
from .chart_functions import table_generation_func, scatter_chart_generation_func, \
|
3 |
line_chart_generation_func, bar_chart_generation_func, pie_chart_generation_func, histogram_generation_func, scatter_chart_fig
|
4 |
-
from .chat_functions import example_question_generator, chatbot_with_fc
|
5 |
from .stat_functions import regression_func
|
6 |
|
7 |
-
__all__ = ["SQLiteQuery","sqlite_query_func","table_generation_func","scatter_chart_generation_func",
|
8 |
"line_chart_generation_func","bar_chart_generation_func","regression_func", "pie_chart_generation_func", "histogram_generation_func",
|
9 |
-
"scatter_chart_fig","example_question_generator","chatbot_with_fc"]
|
|
|
1 |
+
from .query_functions import SQLiteQuery, sqlite_query_func, PostgreSQLQuery, sql_query_func
|
2 |
from .chart_functions import table_generation_func, scatter_chart_generation_func, \
|
3 |
line_chart_generation_func, bar_chart_generation_func, pie_chart_generation_func, histogram_generation_func, scatter_chart_fig
|
4 |
+
from .chat_functions import sql_example_question_generator, example_question_generator, chatbot_with_fc, sql_chatbot_with_fc
|
5 |
from .stat_functions import regression_func
|
6 |
|
7 |
+
__all__ = ["SQLiteQuery","sqlite_query_func","sql_query_func","table_generation_func","scatter_chart_generation_func",
|
8 |
"line_chart_generation_func","bar_chart_generation_func","regression_func", "pie_chart_generation_func", "histogram_generation_func",
|
9 |
+
"scatter_chart_fig","sql_example_question_generator","example_question_generator","chatbot_with_fc","sql_chatbot_with_fc"]
|
functions/chart_functions.py
CHANGED
@@ -92,11 +92,11 @@ def scatter_chart_fig(df, x_column: List[str], y_column: str, category: str="",
|
|
92 |
|
93 |
return fig
|
94 |
|
95 |
-
def scatter_chart_generation_func(x_column: List[str], y_column: str, session_hash, data: List[dict]=[{}], layout: List[dict]=[{}],
|
96 |
category: str="", trendline: str="", trendline_options: List[dict]=[{}], marginal_x: str="", marginal_y: str="",
|
97 |
-
size: str=""):
|
98 |
try:
|
99 |
-
dir_path = TEMP_DIR / str(session_hash)
|
100 |
chart_path = f'{dir_path}/chart.html'
|
101 |
csv_query_path = f'{dir_path}/query.csv'
|
102 |
|
@@ -129,7 +129,7 @@ def scatter_chart_generation_func(x_column: List[str], y_column: str, session_ha
|
|
129 |
|
130 |
pio.write_html(fig, chart_path, full_html=False)
|
131 |
|
132 |
-
chart_url = f'{root_url}/gradio_api/file/temp/{session_hash}/chart.html'
|
133 |
|
134 |
iframe = '<div style=overflow:auto;><iframe\n scrolling="yes"\n width="1000px"\n height="500px"\n src="' + chart_url + '"\n frameborder="0"\n allowfullscreen\n></iframe>\n</div>'
|
135 |
|
@@ -144,10 +144,10 @@ def scatter_chart_generation_func(x_column: List[str], y_column: str, session_ha
|
|
144 |
"""
|
145 |
return {"reply": reply}
|
146 |
|
147 |
-
def line_chart_generation_func(x_column: str, y_column: str, session_hash, data: List[dict]=[{}], layout: List[dict]=[{}],
|
148 |
-
category: str=""):
|
149 |
try:
|
150 |
-
dir_path = TEMP_DIR / str(session_hash)
|
151 |
chart_path = f'{dir_path}/chart.html'
|
152 |
csv_query_path = f'{dir_path}/query.csv'
|
153 |
|
@@ -180,7 +180,7 @@ def line_chart_generation_func(x_column: str, y_column: str, session_hash, data:
|
|
180 |
|
181 |
pio.write_html(fig, chart_path, full_html=False)
|
182 |
|
183 |
-
chart_url = f'{root_url}/gradio_api/file/temp/{session_hash}/chart.html'
|
184 |
|
185 |
iframe = '<div style=overflow:auto;><iframe\n scrolling="yes"\n width="1000px"\n height="500px"\n src="' + chart_url + '"\n frameborder="0"\n allowfullscreen\n></iframe>\n</div>'
|
186 |
|
@@ -195,10 +195,10 @@ def line_chart_generation_func(x_column: str, y_column: str, session_hash, data:
|
|
195 |
"""
|
196 |
return {"reply": reply}
|
197 |
|
198 |
-
def bar_chart_generation_func(x_column: str, y_column: str, session_hash, data: List[dict]=[{}], layout: List[dict]=[{}],
|
199 |
-
category: str="", facet_row: str="", facet_col: str=""):
|
200 |
try:
|
201 |
-
dir_path = TEMP_DIR / str(session_hash)
|
202 |
chart_path = f'{dir_path}/chart.html'
|
203 |
csv_query_path = f'{dir_path}/query.csv'
|
204 |
|
@@ -235,7 +235,7 @@ def bar_chart_generation_func(x_column: str, y_column: str, session_hash, data:
|
|
235 |
|
236 |
pio.write_html(fig, chart_path, full_html=False)
|
237 |
|
238 |
-
chart_url = f'{root_url}/gradio_api/file/temp/{session_hash}/chart.html'
|
239 |
|
240 |
iframe = '<div style=overflow:auto;><iframe\n scrolling="yes"\n width="1000px"\n height="500px"\n src="' + chart_url + '"\n frameborder="0"\n allowfullscreen\n></iframe>\n</div>'
|
241 |
|
@@ -250,9 +250,9 @@ def bar_chart_generation_func(x_column: str, y_column: str, session_hash, data:
|
|
250 |
"""
|
251 |
return {"reply": reply}
|
252 |
|
253 |
-
def pie_chart_generation_func(values: str, names: str, session_hash, data: List[dict]=[{}], layout: List[dict]=[{}]):
|
254 |
try:
|
255 |
-
dir_path = TEMP_DIR / str(session_hash)
|
256 |
chart_path = f'{dir_path}/chart.html'
|
257 |
csv_query_path = f'{dir_path}/query.csv'
|
258 |
|
@@ -282,7 +282,7 @@ def pie_chart_generation_func(values: str, names: str, session_hash, data: List[
|
|
282 |
|
283 |
pio.write_html(fig, chart_path, full_html=False)
|
284 |
|
285 |
-
chart_url = f'{root_url}/gradio_api/file/temp/{session_hash}/chart.html'
|
286 |
|
287 |
iframe = '<div style=overflow:auto;><iframe\n scrolling="yes"\n width="1000px"\n height="500px"\n src="' + chart_url + '"\n frameborder="0"\n allowfullscreen\n></iframe>\n</div>'
|
288 |
|
@@ -297,16 +297,15 @@ def pie_chart_generation_func(values: str, names: str, session_hash, data: List[
|
|
297 |
"""
|
298 |
return {"reply": reply}
|
299 |
|
300 |
-
def histogram_generation_func(x_column: str, session_hash, y_column: str="", data: List[dict]=[{}], layout: List[dict]=[{}], histnorm: str="", category: str="",
|
301 |
-
histfunc: str=""):
|
302 |
try:
|
303 |
-
dir_path = TEMP_DIR / str(session_hash)
|
304 |
chart_path = f'{dir_path}/chart.html'
|
305 |
csv_query_path = f'{dir_path}/query.csv'
|
306 |
|
307 |
df = pd.read_csv(csv_query_path)
|
308 |
|
309 |
-
print(df)
|
310 |
print(x_column)
|
311 |
|
312 |
function_args = {"data_frame":df, "x":x_column}
|
@@ -342,7 +341,7 @@ def histogram_generation_func(x_column: str, session_hash, y_column: str="", dat
|
|
342 |
|
343 |
pio.write_html(fig, chart_path, full_html=False)
|
344 |
|
345 |
-
chart_url = f'{root_url}/gradio_api/file/temp/{session_hash}/chart.html'
|
346 |
|
347 |
iframe = '<div style=overflow:auto;><iframe\n scrolling="yes"\n width="1000px"\n height="500px"\n src="' + chart_url + '"\n frameborder="0"\n allowfullscreen\n></iframe>\n</div>'
|
348 |
|
@@ -357,15 +356,14 @@ def histogram_generation_func(x_column: str, session_hash, y_column: str="", dat
|
|
357 |
"""
|
358 |
return {"reply": reply}
|
359 |
|
360 |
-
def table_generation_func(session_hash):
|
361 |
print("TABLE GENERATION")
|
362 |
try:
|
363 |
-
dir_path = TEMP_DIR / str(session_hash)
|
364 |
csv_query_path = f'{dir_path}/query.csv'
|
365 |
table_path = f'{dir_path}/table.html'
|
366 |
|
367 |
df = pd.read_csv(csv_query_path)
|
368 |
-
print(df)
|
369 |
|
370 |
html_table = df.to_html()
|
371 |
print(html_table)
|
@@ -373,7 +371,7 @@ def table_generation_func(session_hash):
|
|
373 |
with open(table_path, "w") as file:
|
374 |
file.write(html_table)
|
375 |
|
376 |
-
table_url = f'{root_url}/gradio_api/file/temp/{session_hash}/table.html'
|
377 |
|
378 |
iframe = '<div style=overflow:auto;><iframe\n scrolling="yes"\n width="1000px"\n height="500px"\n src="' + table_url + '"\n frameborder="0"\n allowfullscreen\n></iframe>\n</div>'
|
379 |
print(iframe)
|
|
|
92 |
|
93 |
return fig
|
94 |
|
95 |
+
def scatter_chart_generation_func(x_column: List[str], y_column: str, session_hash, session_folder, data: List[dict]=[{}], layout: List[dict]=[{}],
|
96 |
category: str="", trendline: str="", trendline_options: List[dict]=[{}], marginal_x: str="", marginal_y: str="",
|
97 |
+
size: str="", **kwargs):
|
98 |
try:
|
99 |
+
dir_path = TEMP_DIR / str(session_hash) / str(session_folder)
|
100 |
chart_path = f'{dir_path}/chart.html'
|
101 |
csv_query_path = f'{dir_path}/query.csv'
|
102 |
|
|
|
129 |
|
130 |
pio.write_html(fig, chart_path, full_html=False)
|
131 |
|
132 |
+
chart_url = f'{root_url}/gradio_api/file/temp/{session_hash}/{session_folder}/chart.html'
|
133 |
|
134 |
iframe = '<div style=overflow:auto;><iframe\n scrolling="yes"\n width="1000px"\n height="500px"\n src="' + chart_url + '"\n frameborder="0"\n allowfullscreen\n></iframe>\n</div>'
|
135 |
|
|
|
144 |
"""
|
145 |
return {"reply": reply}
|
146 |
|
147 |
+
def line_chart_generation_func(x_column: str, y_column: str, session_hash, session_folder, data: List[dict]=[{}], layout: List[dict]=[{}],
|
148 |
+
category: str="", **kwargs):
|
149 |
try:
|
150 |
+
dir_path = TEMP_DIR / str(session_hash) / str(session_folder)
|
151 |
chart_path = f'{dir_path}/chart.html'
|
152 |
csv_query_path = f'{dir_path}/query.csv'
|
153 |
|
|
|
180 |
|
181 |
pio.write_html(fig, chart_path, full_html=False)
|
182 |
|
183 |
+
chart_url = f'{root_url}/gradio_api/file/temp/{session_hash}/{session_folder}/chart.html'
|
184 |
|
185 |
iframe = '<div style=overflow:auto;><iframe\n scrolling="yes"\n width="1000px"\n height="500px"\n src="' + chart_url + '"\n frameborder="0"\n allowfullscreen\n></iframe>\n</div>'
|
186 |
|
|
|
195 |
"""
|
196 |
return {"reply": reply}
|
197 |
|
198 |
+
def bar_chart_generation_func(x_column: str, y_column: str, session_hash, session_folder, data: List[dict]=[{}], layout: List[dict]=[{}],
|
199 |
+
category: str="", facet_row: str="", facet_col: str="", **kwargs):
|
200 |
try:
|
201 |
+
dir_path = TEMP_DIR / str(session_hash) / str(session_folder)
|
202 |
chart_path = f'{dir_path}/chart.html'
|
203 |
csv_query_path = f'{dir_path}/query.csv'
|
204 |
|
|
|
235 |
|
236 |
pio.write_html(fig, chart_path, full_html=False)
|
237 |
|
238 |
+
chart_url = f'{root_url}/gradio_api/file/temp/{session_hash}/{session_folder}/chart.html'
|
239 |
|
240 |
iframe = '<div style=overflow:auto;><iframe\n scrolling="yes"\n width="1000px"\n height="500px"\n src="' + chart_url + '"\n frameborder="0"\n allowfullscreen\n></iframe>\n</div>'
|
241 |
|
|
|
250 |
"""
|
251 |
return {"reply": reply}
|
252 |
|
253 |
+
def pie_chart_generation_func(values: str, names: str, session_hash, session_folder, data: List[dict]=[{}], layout: List[dict]=[{}], **kwargs):
|
254 |
try:
|
255 |
+
dir_path = TEMP_DIR / str(session_hash) / str(session_folder)
|
256 |
chart_path = f'{dir_path}/chart.html'
|
257 |
csv_query_path = f'{dir_path}/query.csv'
|
258 |
|
|
|
282 |
|
283 |
pio.write_html(fig, chart_path, full_html=False)
|
284 |
|
285 |
+
chart_url = f'{root_url}/gradio_api/file/temp/{session_hash}/{session_folder}/chart.html'
|
286 |
|
287 |
iframe = '<div style=overflow:auto;><iframe\n scrolling="yes"\n width="1000px"\n height="500px"\n src="' + chart_url + '"\n frameborder="0"\n allowfullscreen\n></iframe>\n</div>'
|
288 |
|
|
|
297 |
"""
|
298 |
return {"reply": reply}
|
299 |
|
300 |
+
def histogram_generation_func(x_column: str, session_hash, session_folder, y_column: str="", data: List[dict]=[{}], layout: List[dict]=[{}], histnorm: str="", category: str="",
|
301 |
+
histfunc: str="", **kwargs):
|
302 |
try:
|
303 |
+
dir_path = TEMP_DIR / str(session_hash) / str(session_folder)
|
304 |
chart_path = f'{dir_path}/chart.html'
|
305 |
csv_query_path = f'{dir_path}/query.csv'
|
306 |
|
307 |
df = pd.read_csv(csv_query_path)
|
308 |
|
|
|
309 |
print(x_column)
|
310 |
|
311 |
function_args = {"data_frame":df, "x":x_column}
|
|
|
341 |
|
342 |
pio.write_html(fig, chart_path, full_html=False)
|
343 |
|
344 |
+
chart_url = f'{root_url}/gradio_api/file/temp/{session_hash}/{session_folder}/chart.html'
|
345 |
|
346 |
iframe = '<div style=overflow:auto;><iframe\n scrolling="yes"\n width="1000px"\n height="500px"\n src="' + chart_url + '"\n frameborder="0"\n allowfullscreen\n></iframe>\n</div>'
|
347 |
|
|
|
356 |
"""
|
357 |
return {"reply": reply}
|
358 |
|
359 |
+
def table_generation_func(session_hash, session_folder, **kwargs):
|
360 |
print("TABLE GENERATION")
|
361 |
try:
|
362 |
+
dir_path = TEMP_DIR / str(session_hash) / str(session_folder)
|
363 |
csv_query_path = f'{dir_path}/query.csv'
|
364 |
table_path = f'{dir_path}/table.html'
|
365 |
|
366 |
df = pd.read_csv(csv_query_path)
|
|
|
367 |
|
368 |
html_table = df.to_html()
|
369 |
print(html_table)
|
|
|
371 |
with open(table_path, "w") as file:
|
372 |
file.write(html_table)
|
373 |
|
374 |
+
table_url = f'{root_url}/gradio_api/file/temp/{session_hash}/{session_folder}/table.html'
|
375 |
|
376 |
iframe = '<div style=overflow:auto;><iframe\n scrolling="yes"\n width="1000px"\n height="500px"\n src="' + table_url + '"\n frameborder="0"\n allowfullscreen\n></iframe>\n</div>'
|
377 |
print(iframe)
|
functions/chat_functions.py
CHANGED
@@ -35,6 +35,25 @@ def example_question_generator(session_hash):
|
|
35 |
|
36 |
return example_response["replies"][0].text
|
37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
def chatbot_with_fc(message, history, session_hash):
|
39 |
from functions import sqlite_query_func, table_generation_func, regression_func, scatter_chart_generation_func, \
|
40 |
line_chart_generation_func,bar_chart_generation_func,pie_chart_generation_func,histogram_generation_func
|
@@ -46,8 +65,8 @@ def chatbot_with_fc(message, history, session_hash):
|
|
46 |
"histogram_generation_func":histogram_generation_func,
|
47 |
"regression_func":regression_func }
|
48 |
|
49 |
-
if message_dict[session_hash] != None:
|
50 |
-
message_dict[session_hash].append(ChatMessage.from_user(message))
|
51 |
else:
|
52 |
messages = [
|
53 |
ChatMessage.from_system(
|
@@ -58,35 +77,94 @@ def chatbot_with_fc(message, history, session_hash):
|
|
58 |
You also have access to a bar graph function, called line_chart_generation_func, that can take a query.csv file generated from our sql query and uses plotly dictionaries to generate a bar graph and returns an iframe that we can display in our chat window.
|
59 |
You also have access to a pie chart function, called pie_chart_generation_func, that can take a query.csv file generated from our sql query and uses plotly dictionaries to generate a pie chart and returns an iframe that we can display in our chat window.
|
60 |
You also have access to a histogram function, called histogram_generation_func, that can take a query.csv file generated from our sql query and uses plotly dictionaries to generate a histogram and returns an iframe that we can display in our chat window.
|
61 |
-
You also have access to a linear regression function, called regression_func, that can take a query.csv file generated from our sql query and a list of column names for our independent and dependent variables and return a regression data string and a regression chart which is returned as an iframe.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
)
|
63 |
]
|
64 |
messages.append(ChatMessage.from_user(message))
|
65 |
-
message_dict[session_hash] = messages
|
66 |
|
67 |
-
response = chat_generator.run(messages=message_dict[session_hash], generation_kwargs={"tools": tools.
|
68 |
|
69 |
while True:
|
70 |
# if OpenAI response is a tool call
|
71 |
if response and response["replies"][0].meta["finish_reason"] == "tool_calls" or response["replies"][0].tool_calls:
|
72 |
function_calls = response["replies"][0].tool_calls
|
73 |
for function_call in function_calls:
|
74 |
-
message_dict[session_hash].append(ChatMessage.from_assistant(tool_calls=[function_call]))
|
75 |
## Parse function calling information
|
76 |
function_name = function_call.tool_name
|
77 |
function_args = function_call.arguments
|
78 |
|
79 |
## Find the corresponding function and call it with the given arguments
|
80 |
function_to_call = available_functions[function_name]
|
81 |
-
function_response = function_to_call(**function_args, session_hash=session_hash
|
|
|
82 |
print(function_name)
|
83 |
## Append function response to the messages list using `ChatMessage.from_tool`
|
84 |
-
message_dict[session_hash].append(ChatMessage.from_tool(tool_result=function_response['reply'], origin=function_call))
|
85 |
-
response = chat_generator.run(messages=message_dict[session_hash], generation_kwargs={"tools": tools.
|
86 |
|
87 |
# Regular Conversation
|
88 |
else:
|
89 |
-
message_dict[session_hash].append(response["replies"][0])
|
90 |
break
|
91 |
|
92 |
return response["replies"][0].text
|
|
|
35 |
|
36 |
return example_response["replies"][0].text
|
37 |
|
38 |
+
def sql_example_question_generator(session_hash, db_tables, db_name):
|
39 |
+
example_response = None
|
40 |
+
example_messages = [
|
41 |
+
ChatMessage.from_system(
|
42 |
+
f"You are a helpful and knowledgeable agent who has access to an PostgreSQL database called {db_name}."
|
43 |
+
)
|
44 |
+
]
|
45 |
+
|
46 |
+
example_messages.append(ChatMessage.from_user(text=f"""We have a PostgreSQL database with the following tables: {db_tables}.
|
47 |
+
We also have an AI agent with access to the same database that will be performing data analysis.
|
48 |
+
Please return an array of seven strings, each one being a question for our data analysis agent
|
49 |
+
that we can suggest that you believe will be insightful or helpful to a data analysis looking for
|
50 |
+
data insights. Return nothing more than the array of questions because I need that specific data structure
|
51 |
+
to process your response. No other response type or data structure will work."""))
|
52 |
+
|
53 |
+
example_response = chat_generator.run(messages=example_messages)
|
54 |
+
|
55 |
+
return example_response["replies"][0].text
|
56 |
+
|
57 |
def chatbot_with_fc(message, history, session_hash):
|
58 |
from functions import sqlite_query_func, table_generation_func, regression_func, scatter_chart_generation_func, \
|
59 |
line_chart_generation_func,bar_chart_generation_func,pie_chart_generation_func,histogram_generation_func
|
|
|
65 |
"histogram_generation_func":histogram_generation_func,
|
66 |
"regression_func":regression_func }
|
67 |
|
68 |
+
if message_dict[session_hash]['file_upload'] != None:
|
69 |
+
message_dict[session_hash]['file_upload'].append(ChatMessage.from_user(message))
|
70 |
else:
|
71 |
messages = [
|
72 |
ChatMessage.from_system(
|
|
|
77 |
You also have access to a bar graph function, called line_chart_generation_func, that can take a query.csv file generated from our sql query and uses plotly dictionaries to generate a bar graph and returns an iframe that we can display in our chat window.
|
78 |
You also have access to a pie chart function, called pie_chart_generation_func, that can take a query.csv file generated from our sql query and uses plotly dictionaries to generate a pie chart and returns an iframe that we can display in our chat window.
|
79 |
You also have access to a histogram function, called histogram_generation_func, that can take a query.csv file generated from our sql query and uses plotly dictionaries to generate a histogram and returns an iframe that we can display in our chat window.
|
80 |
+
You also have access to a linear regression function, called regression_func, that can take a query.csv file generated from our sql query and a list of column names for our independent and dependent variables and return a regression data string and a regression chart which is returned as an iframe.
|
81 |
+
Charts, tables, and visualizations are a very important part of your output. If you generate a chart, table, or visualization as part of your answer, please display it always."""
|
82 |
+
)
|
83 |
+
]
|
84 |
+
messages.append(ChatMessage.from_user(message))
|
85 |
+
message_dict[session_hash]['file_upload'] = messages
|
86 |
+
|
87 |
+
response = chat_generator.run(messages=message_dict[session_hash]['file_upload'], generation_kwargs={"tools": tools.data_file_tools_call(session_hash)})
|
88 |
+
|
89 |
+
while True:
|
90 |
+
# if OpenAI response is a tool call
|
91 |
+
if response and response["replies"][0].meta["finish_reason"] == "tool_calls" or response["replies"][0].tool_calls:
|
92 |
+
function_calls = response["replies"][0].tool_calls
|
93 |
+
for function_call in function_calls:
|
94 |
+
message_dict[session_hash]['file_upload'].append(ChatMessage.from_assistant(tool_calls=[function_call]))
|
95 |
+
## Parse function calling information
|
96 |
+
function_name = function_call.tool_name
|
97 |
+
function_args = function_call.arguments
|
98 |
+
|
99 |
+
## Find the corresponding function and call it with the given arguments
|
100 |
+
function_to_call = available_functions[function_name]
|
101 |
+
function_response = function_to_call(**function_args, session_hash=session_hash, session_folder='file_upload')
|
102 |
+
print(function_name)
|
103 |
+
## Append function response to the messages list using `ChatMessage.from_tool`
|
104 |
+
message_dict[session_hash]['file_upload'].append(ChatMessage.from_tool(tool_result=function_response['reply'], origin=function_call))
|
105 |
+
response = chat_generator.run(messages=message_dict[session_hash]['file_upload'], generation_kwargs={"tools": tools.data_file_tools_call(session_hash)})
|
106 |
+
|
107 |
+
# Regular Conversation
|
108 |
+
else:
|
109 |
+
message_dict[session_hash]['file_upload'].append(response["replies"][0])
|
110 |
+
break
|
111 |
+
|
112 |
+
return response["replies"][0].text
|
113 |
+
|
114 |
+
def sql_chatbot_with_fc(message, history, session_hash, db_url, db_port, db_user, db_pass, db_name, db_tables):
|
115 |
+
from functions import sql_query_func, table_generation_func, regression_func, scatter_chart_generation_func, \
|
116 |
+
line_chart_generation_func,bar_chart_generation_func,pie_chart_generation_func,histogram_generation_func
|
117 |
+
import tools.tools as tools
|
118 |
+
|
119 |
+
available_functions = {"sql_query_func": sql_query_func,"table_generation_func":table_generation_func,
|
120 |
+
"line_chart_generation_func":line_chart_generation_func,"bar_chart_generation_func":bar_chart_generation_func,
|
121 |
+
"scatter_chart_generation_func":scatter_chart_generation_func, "pie_chart_generation_func":pie_chart_generation_func,
|
122 |
+
"histogram_generation_func":histogram_generation_func,
|
123 |
+
"regression_func":regression_func }
|
124 |
+
|
125 |
+
if message_dict[session_hash]['sql'] != None:
|
126 |
+
message_dict[session_hash]['sql'].append(ChatMessage.from_user(message))
|
127 |
+
else:
|
128 |
+
messages = [
|
129 |
+
ChatMessage.from_system(
|
130 |
+
f"""You are a helpful and knowledgeable agent who has access to an PostgreSQL database which has a series of tables called {db_tables}.
|
131 |
+
You also have access to a function, called table_generation_func, that can take a query.csv file generated from our sql query and returns an iframe that we can display in our chat window.
|
132 |
+
You also have access to a scatter plot function, called scatter_chart_generation_func, that can take a query.csv file generated from our sql query and uses plotly dictionaries to generate a scatter plot and returns an iframe that we can display in our chat window.
|
133 |
+
You also have access to a line chart function, called line_chart_generation_func, that can take a query.csv file generated from our sql query and uses plotly dictionaries to generate a line chart and returns an iframe that we can display in our chat window.
|
134 |
+
You also have access to a bar graph function, called line_chart_generation_func, that can take a query.csv file generated from our sql query and uses plotly dictionaries to generate a bar graph and returns an iframe that we can display in our chat window.
|
135 |
+
You also have access to a pie chart function, called pie_chart_generation_func, that can take a query.csv file generated from our sql query and uses plotly dictionaries to generate a pie chart and returns an iframe that we can display in our chat window.
|
136 |
+
You also have access to a histogram function, called histogram_generation_func, that can take a query.csv file generated from our sql query and uses plotly dictionaries to generate a histogram and returns an iframe that we can display in our chat window.
|
137 |
+
You also have access to a linear regression function, called regression_func, that can take a query.csv file generated from our sql query and a list of column names for our independent and dependent variables and return a regression data string and a regression chart which is returned as an iframe.
|
138 |
+
Charts, tables, and visualizations are a very important part of your output. If you generate a chart, table, or visualization as part of your answer, please display it always."""
|
139 |
)
|
140 |
]
|
141 |
messages.append(ChatMessage.from_user(message))
|
142 |
+
message_dict[session_hash]['sql'] = messages
|
143 |
|
144 |
+
response = chat_generator.run(messages=message_dict[session_hash]['sql'], generation_kwargs={"tools": tools.sql_tools_call(db_tables)})
|
145 |
|
146 |
while True:
|
147 |
# if OpenAI response is a tool call
|
148 |
if response and response["replies"][0].meta["finish_reason"] == "tool_calls" or response["replies"][0].tool_calls:
|
149 |
function_calls = response["replies"][0].tool_calls
|
150 |
for function_call in function_calls:
|
151 |
+
message_dict[session_hash]['sql'].append(ChatMessage.from_assistant(tool_calls=[function_call]))
|
152 |
## Parse function calling information
|
153 |
function_name = function_call.tool_name
|
154 |
function_args = function_call.arguments
|
155 |
|
156 |
## Find the corresponding function and call it with the given arguments
|
157 |
function_to_call = available_functions[function_name]
|
158 |
+
function_response = function_to_call(**function_args, session_hash=session_hash, db_url=db_url,
|
159 |
+
db_port=db_port, db_user=db_user, db_pass=db_pass, db_name=db_name, session_folder='sql')
|
160 |
print(function_name)
|
161 |
## Append function response to the messages list using `ChatMessage.from_tool`
|
162 |
+
message_dict[session_hash]['sql'].append(ChatMessage.from_tool(tool_result=function_response['reply'], origin=function_call))
|
163 |
+
response = chat_generator.run(messages=message_dict[session_hash]['sql'], generation_kwargs={"tools": tools.sql_tools_call(db_tables)})
|
164 |
|
165 |
# Regular Conversation
|
166 |
else:
|
167 |
+
message_dict[session_hash]['sql'].append(response["replies"][0])
|
168 |
break
|
169 |
|
170 |
return response["replies"][0].text
|
functions/{sqlite_functions.py → query_functions.py}
RENAMED
@@ -6,6 +6,7 @@ pd.set_option('display.max_columns', None)
|
|
6 |
pd.set_option('display.width', None)
|
7 |
pd.set_option('display.max_colwidth', None)
|
8 |
import sqlite3
|
|
|
9 |
from utils import TEMP_DIR
|
10 |
|
11 |
@component
|
@@ -16,21 +17,21 @@ class SQLiteQuery:
|
|
16 |
|
17 |
@component.output_types(results=List[str], queries=List[str])
|
18 |
def run(self, queries: List[str], session_hash):
|
19 |
-
print("ATTEMPTING TO RUN QUERY")
|
20 |
dir_path = TEMP_DIR / str(session_hash)
|
21 |
results = []
|
22 |
for query in queries:
|
23 |
result = pd.read_sql(query, self.connection)
|
24 |
-
result.to_csv(f'{dir_path}/query.csv', index=False)
|
25 |
results.append(f"{result}")
|
26 |
self.connection.close()
|
27 |
return {"results": results, "queries": queries}
|
28 |
|
29 |
|
30 |
|
31 |
-
def sqlite_query_func(queries: List[str], session_hash):
|
32 |
dir_path = TEMP_DIR / str(session_hash)
|
33 |
-
sql_query = SQLiteQuery(f'{dir_path}/data_source.db')
|
34 |
try:
|
35 |
result = sql_query.run(queries, session_hash)
|
36 |
if len(result["results"][0]) > 1000:
|
@@ -45,3 +46,50 @@ def sqlite_query_func(queries: List[str], session_hash):
|
|
45 |
You should probably try again.
|
46 |
"""
|
47 |
return {"reply": reply}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
pd.set_option('display.width', None)
|
7 |
pd.set_option('display.max_colwidth', None)
|
8 |
import sqlite3
|
9 |
+
import psycopg2
|
10 |
from utils import TEMP_DIR
|
11 |
|
12 |
@component
|
|
|
17 |
|
18 |
@component.output_types(results=List[str], queries=List[str])
|
19 |
def run(self, queries: List[str], session_hash):
|
20 |
+
print("ATTEMPTING TO RUN SQLITE QUERY")
|
21 |
dir_path = TEMP_DIR / str(session_hash)
|
22 |
results = []
|
23 |
for query in queries:
|
24 |
result = pd.read_sql(query, self.connection)
|
25 |
+
result.to_csv(f'{dir_path}/file_upload/query.csv', index=False)
|
26 |
results.append(f"{result}")
|
27 |
self.connection.close()
|
28 |
return {"results": results, "queries": queries}
|
29 |
|
30 |
|
31 |
|
32 |
+
def sqlite_query_func(queries: List[str], session_hash, **kwargs):
|
33 |
dir_path = TEMP_DIR / str(session_hash)
|
34 |
+
sql_query = SQLiteQuery(f'{dir_path}/file_upload/data_source.db')
|
35 |
try:
|
36 |
result = sql_query.run(queries, session_hash)
|
37 |
if len(result["results"][0]) > 1000:
|
|
|
46 |
You should probably try again.
|
47 |
"""
|
48 |
return {"reply": reply}
|
49 |
+
|
50 |
+
@component
|
51 |
+
class PostgreSQLQuery:
|
52 |
+
|
53 |
+
def __init__(self, url: str, sql_port: int, sql_user: str, sql_pass: str, sql_db_name: str):
|
54 |
+
self.connection = psycopg2.connect(
|
55 |
+
database=sql_db_name,
|
56 |
+
user=sql_user,
|
57 |
+
password=sql_pass,
|
58 |
+
host=url, # e.g., "localhost" or an IP address
|
59 |
+
port=sql_port # default is 5432
|
60 |
+
)
|
61 |
+
|
62 |
+
@component.output_types(results=List[str], queries=List[str])
|
63 |
+
def run(self, queries: List[str], session_hash):
|
64 |
+
print("ATTEMPTING TO RUN POSTGRESQL QUERY")
|
65 |
+
dir_path = TEMP_DIR / str(session_hash)
|
66 |
+
results = []
|
67 |
+
for query in queries:
|
68 |
+
print(query)
|
69 |
+
result = pd.read_sql_query(query, self.connection)
|
70 |
+
result.to_csv(f'{dir_path}/sql/query.csv', index=False)
|
71 |
+
results.append(f"{result}")
|
72 |
+
self.connection.close()
|
73 |
+
return {"results": results, "queries": queries}
|
74 |
+
|
75 |
+
|
76 |
+
|
77 |
+
def sql_query_func(queries: List[str], session_hash, db_url, db_port, db_user, db_pass, db_name, **kwargs):
|
78 |
+
sql_query = PostgreSQLQuery(db_url, db_port, db_user, db_pass, db_name)
|
79 |
+
try:
|
80 |
+
result = sql_query.run(queries, session_hash)
|
81 |
+
print("RESULT")
|
82 |
+
print(result)
|
83 |
+
if len(result["results"][0]) > 1000:
|
84 |
+
print("QUERY TOO LARGE")
|
85 |
+
return {"reply": "query result too large to be processed by llm, the query results are in our query.csv file. If you need to display the results directly, perhaps use the table_generation_func function."}
|
86 |
+
else:
|
87 |
+
return {"reply": result["results"][0]}
|
88 |
+
|
89 |
+
except Exception as e:
|
90 |
+
reply = f"""There was an error running the SQL Query = {queries}
|
91 |
+
The error is {e},
|
92 |
+
You should probably try again.
|
93 |
+
"""
|
94 |
+
print(reply)
|
95 |
+
return {"reply": reply}
|
functions/stat_functions.py
CHANGED
@@ -12,12 +12,12 @@ load_dotenv()
|
|
12 |
|
13 |
root_url = os.getenv("ROOT_URL")
|
14 |
|
15 |
-
def regression_func(independent_variables: List[str], dependent_variable: str, session_hash, category: str=''):
|
16 |
print("LINEAR REGRESSION CALCULATION")
|
17 |
print(independent_variables)
|
18 |
print(dependent_variable)
|
19 |
try:
|
20 |
-
dir_path = TEMP_DIR / str(session_hash)
|
21 |
chart_path = f'{dir_path}/chart.html'
|
22 |
csv_query_path = f'{dir_path}/query.csv'
|
23 |
|
@@ -32,7 +32,7 @@ def regression_func(independent_variables: List[str], dependent_variable: str, s
|
|
32 |
|
33 |
pio.write_html(fig, chart_path, full_html=False)
|
34 |
|
35 |
-
chart_url = f'{root_url}/gradio_api/file/temp/{session_hash}/chart.html'
|
36 |
|
37 |
iframe = '<div style=overflow:auto;><iframe\n scrolling="yes"\n width="1000px"\n height="500px"\n src="' + chart_url + '"\n frameborder="0"\n allowfullscreen\n></iframe>\n</div>'
|
38 |
|
|
|
12 |
|
13 |
root_url = os.getenv("ROOT_URL")
|
14 |
|
15 |
+
def regression_func(independent_variables: List[str], dependent_variable: str, session_hash, session_folder, category: str='', **kwargs):
|
16 |
print("LINEAR REGRESSION CALCULATION")
|
17 |
print(independent_variables)
|
18 |
print(dependent_variable)
|
19 |
try:
|
20 |
+
dir_path = TEMP_DIR / str(session_hash) / str(session_folder)
|
21 |
chart_path = f'{dir_path}/chart.html'
|
22 |
csv_query_path = f'{dir_path}/query.csv'
|
23 |
|
|
|
32 |
|
33 |
pio.write_html(fig, chart_path, full_html=False)
|
34 |
|
35 |
+
chart_url = f'{root_url}/gradio_api/file/temp/{session_hash}/{session_folder}/chart.html'
|
36 |
|
37 |
iframe = '<div style=overflow:auto;><iframe\n scrolling="yes"\n width="1000px"\n height="500px"\n src="' + chart_url + '"\n frameborder="0"\n allowfullscreen\n></iframe>\n</div>'
|
38 |
|
requirements.txt
CHANGED
@@ -6,3 +6,4 @@ plotly
|
|
6 |
openpyxl
|
7 |
statsmodels
|
8 |
xlrd
|
|
|
|
6 |
openpyxl
|
7 |
statsmodels
|
8 |
xlrd
|
9 |
+
psycopg2-binary
|
templates/__pycache__/data_file.cpython-312.pyc
ADDED
Binary file (8.68 kB). View file
|
|
templates/__pycache__/sql_db.cpython-312.pyc
ADDED
Binary file (6.71 kB). View file
|
|
data_file.py → templates/data_file.py
RENAMED
@@ -68,7 +68,8 @@ with gr.Blocks() as demo:
|
|
68 |
@gr.render(inputs=file_output)
|
69 |
def data_options(filename, request: gr.Request):
|
70 |
print(filename)
|
71 |
-
message_dict[request.session_hash] =
|
|
|
72 |
if filename:
|
73 |
process_message = process_upload(filename, request.session_hash)
|
74 |
gr.HTML(value=process_message[1], padding=False)
|
@@ -101,7 +102,9 @@ with gr.Blocks() as demo:
|
|
101 |
]
|
102 |
for example in generated_examples:
|
103 |
example_questions.append([example])
|
104 |
-
except:
|
|
|
|
|
105 |
example_questions = [
|
106 |
["Describe the dataset"],
|
107 |
["List the columns in the dataset"],
|
|
|
68 |
@gr.render(inputs=file_output)
|
69 |
def data_options(filename, request: gr.Request):
|
70 |
print(filename)
|
71 |
+
message_dict[request.session_hash] = {}
|
72 |
+
message_dict[request.session_hash]['file_upload'] = None
|
73 |
if filename:
|
74 |
process_message = process_upload(filename, request.session_hash)
|
75 |
gr.HTML(value=process_message[1], padding=False)
|
|
|
102 |
]
|
103 |
for example in generated_examples:
|
104 |
example_questions.append([example])
|
105 |
+
except Exception as e:
|
106 |
+
print("DATA FILE QUESTION GENERATION ERROR")
|
107 |
+
print(e)
|
108 |
example_questions = [
|
109 |
["Describe the dataset"],
|
110 |
["List the columns in the dataset"],
|
templates/sql_db.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ast
|
2 |
+
import gradio as gr
|
3 |
+
from functions import sql_example_question_generator, sql_chatbot_with_fc
|
4 |
+
from data_sources import connect_sql_db
|
5 |
+
from utils import message_dict
|
6 |
+
|
7 |
+
def hide_info():
|
8 |
+
return gr.update(visible=False)
|
9 |
+
|
10 |
+
with gr.Blocks() as demo:
|
11 |
+
description = gr.HTML("""
|
12 |
+
<!-- Header -->
|
13 |
+
<div class="max-w-4xl mx-auto mb-12 text-center">
|
14 |
+
<div class="bg-blue-50 border border-blue-200 rounded-lg max-w-2xl mx-auto">
|
15 |
+
<p>This tool allows users to communicate with and query real time data from a SQL DB (postgres for now, others can be added if requested) using natural
|
16 |
+
language and the above features.</p>
|
17 |
+
<p style="font-weight:bold;">Notice: the way this system is designed, no login information is retained and credentials are passed as session variables until the user leaves or
|
18 |
+
refreshes the page in which they disappear. They are never saved to any files. I also make use of the Pandas read_sql_query function to apply SQL
|
19 |
+
queries, which can't delete, drop, or add database lines to avoid unhappy accidents or glitches.
|
20 |
+
That being said, it's probably not a good idea to connect a production database to a strange AI tool with an unfamiliar author.
|
21 |
+
This should be for demonstration purposes.</p>
|
22 |
+
<p>Contact me if this is something you would like built in your organization, on your infrastructure, and with the requisite privacy and control a production
|
23 |
+
database analytics tool requires.</p>
|
24 |
+
</div>
|
25 |
+
</div>
|
26 |
+
""", elem_classes="description_component")
|
27 |
+
sql_url = gr.Textbox(label="URL", value="virtual-data-analyst-pg.cyetm2yjzppu.us-west-1.rds.amazonaws.com")
|
28 |
+
with gr.Row():
|
29 |
+
sql_port = gr.Textbox(label="Port", value="5432")
|
30 |
+
sql_user = gr.Textbox(label="Username", value="postgres")
|
31 |
+
sql_pass = gr.Textbox(label="Password", value="Vda-1988", type="password")
|
32 |
+
sql_db_name = gr.Textbox(label="Database Name", value="dvdrental")
|
33 |
+
|
34 |
+
submit = gr.Button(value="Submit")
|
35 |
+
submit.click(fn=hide_info, outputs=description)
|
36 |
+
|
37 |
+
@gr.render(inputs=[sql_url,sql_port,sql_user,sql_pass,sql_db_name], triggers=[submit.click])
|
38 |
+
def sql_chat(request: gr.Request, url=sql_url.value, sql_port=sql_port.value, sql_user=sql_user.value, sql_pass=sql_pass.value, sql_db_name=sql_db_name.value):
|
39 |
+
message_dict[request.session_hash]['sql'] = None
|
40 |
+
if url:
|
41 |
+
print("SQL APP")
|
42 |
+
print(request)
|
43 |
+
process_message = process_sql_db(url, sql_user, sql_port, sql_pass, sql_db_name, request.session_hash)
|
44 |
+
gr.HTML(value=process_message[1], padding=False)
|
45 |
+
if process_message[0] == "success":
|
46 |
+
if "virtual-data-analyst-pg.cyetm2yjzppu.us-west-1.rds.amazonaws.com" in url:
|
47 |
+
example_questions = [
|
48 |
+
["Describe the dataset"],
|
49 |
+
["What is the total revenue generated by each store?"],
|
50 |
+
["Can you generate and display a bar chart of film category to number of films in that category?"],
|
51 |
+
["Can you generate a pie chart showing the top 10 most rented films by revenue vs all other films?"],
|
52 |
+
["Can you generate a line chart of rental revenue over time?"],
|
53 |
+
["What is the relationship between film length and rental frequency?"]
|
54 |
+
]
|
55 |
+
else:
|
56 |
+
try:
|
57 |
+
generated_examples = ast.literal_eval(sql_example_question_generator(request.session_hash, process_message[2], sql_db_name))
|
58 |
+
example_questions = [
|
59 |
+
["Describe the dataset"]
|
60 |
+
]
|
61 |
+
for example in generated_examples:
|
62 |
+
example_questions.append([example])
|
63 |
+
except Exception as e:
|
64 |
+
print("SQL QUESTION GENERATION ERROR")
|
65 |
+
print(e)
|
66 |
+
example_questions = [
|
67 |
+
["Describe the dataset"],
|
68 |
+
["List the columns in the dataset"],
|
69 |
+
["What could this data be used for?"],
|
70 |
+
]
|
71 |
+
session_hash = gr.Textbox(visible=False, value=request.session_hash)
|
72 |
+
db_url = gr.Textbox(visible=False, value=url)
|
73 |
+
db_port = gr.Textbox(visible=False, value=sql_port)
|
74 |
+
db_user = gr.Textbox(visible=False, value=sql_user)
|
75 |
+
db_pass = gr.Textbox(visible=False, value=sql_pass)
|
76 |
+
db_name = gr.Textbox(visible=False, value=sql_db_name)
|
77 |
+
db_tables = gr.Textbox(value=process_message[2], interactive=False, label="SQL Tables")
|
78 |
+
bot = gr.Chatbot(type='messages', label="CSV Chat Window", render_markdown=True, sanitize_html=False, show_label=True, render=False, visible=True, elem_classes="chatbot")
|
79 |
+
chat = gr.ChatInterface(
|
80 |
+
fn=sql_chatbot_with_fc,
|
81 |
+
type='messages',
|
82 |
+
chatbot=bot,
|
83 |
+
title="Chat with your Database",
|
84 |
+
examples=example_questions,
|
85 |
+
concurrency_limit=None,
|
86 |
+
additional_inputs=[session_hash, db_url, db_port, db_user, db_pass, db_name, db_tables]
|
87 |
+
)
|
88 |
+
|
89 |
+
def process_sql_db(url, sql_user, sql_port, sql_pass, sql_db_name, session_hash):
|
90 |
+
if url:
|
91 |
+
process_message = connect_sql_db(url, sql_user, sql_port, sql_pass, sql_db_name, session_hash)
|
92 |
+
return process_message
|
93 |
+
|
94 |
+
if __name__ == "__main__":
|
95 |
+
demo.launch()
|
tools/chart_tools.py
CHANGED
@@ -3,7 +3,7 @@ chart_tools = [
|
|
3 |
"type": "function",
|
4 |
"function": {
|
5 |
"name": "scatter_chart_generation_func",
|
6 |
-
"description": f"""This is a scatter plot generation tool useful to generate scatter plots from queried data from our
|
7 |
The data values will come from the columns of our query.csv (the 'x' and 'y' values of each graph) file but the layout section of the plotly dictionary objects will be generated by you.
|
8 |
Returns an iframe string which will be displayed inline in our chat window. Do not edit the iframe string returned
|
9 |
from the scatter_chart_generation_func function in any way and always display the iframe fully to the user in the chat window. You can add your own text supplementary
|
@@ -108,7 +108,7 @@ chart_tools = [
|
|
108 |
"type": "function",
|
109 |
"function": {
|
110 |
"name": "line_chart_generation_func",
|
111 |
-
"description": f"""This is a line chart generation tool useful to generate line charts from queried data from our
|
112 |
The data values will come from the columns of our query.csv (the 'x' and 'y' values of each graph) file but the layout section of the plotly dictionary objects will be generated by you.
|
113 |
Returns an iframe string which will be displayed inline in our chat window. Do not edit the iframe string returned
|
114 |
from the line_chart_generation_func function in any way and always display the iframe fully to the user in the chat window. You can add your own text supplementary
|
@@ -164,7 +164,7 @@ chart_tools = [
|
|
164 |
"type": "function",
|
165 |
"function": {
|
166 |
"name": "bar_chart_generation_func",
|
167 |
-
"description": f"""This is a bar chart generation tool useful to generate line charts from queried data from our
|
168 |
The data values will come from the columns of our query.csv (the 'x' and 'y' values of each graph) file but the layout section of the plotly dictionary objects will be generated by you.
|
169 |
Returns an iframe string which will be displayed inline in our chat window. Do not edit the iframe string returned
|
170 |
from the bar_chart_generation_func function in any way and always display the iframe fully to the user in the chat window. You can add your own text supplementary
|
@@ -236,7 +236,7 @@ chart_tools = [
|
|
236 |
"type": "function",
|
237 |
"function": {
|
238 |
"name": "pie_chart_generation_func",
|
239 |
-
"description": f"""This is a pie chart generation tool useful to generate pie charts from queried data from our
|
240 |
The data values will come from the columns of our query.csv (the 'values' and 'names' values of each graph) file but the layout section of the plotly dictionary objects will be generated by you.
|
241 |
Returns an iframe string which will be displayed inline in our chat window. Do not edit the iframe string returned
|
242 |
from the pie_chart_generation_func function in any way and always display the iframe fully to the user in the chat window. You can add your own text supplementary
|
@@ -285,7 +285,7 @@ chart_tools = [
|
|
285 |
"type": "function",
|
286 |
"function": {
|
287 |
"name": "histogram_generation_func",
|
288 |
-
"description": f"""This is a histogram generation tool useful to generate histograms from queried data from our
|
289 |
The data values will come from the columns of our query.csv (the 'values' and 'names' values of each graph) file but the layout section of the plotly dictionary objects will be generated by you.
|
290 |
Returns an iframe string which will be displayed inline in our chat window. Do not edit the iframe string returned
|
291 |
from the histogram_generation_func function in any way and always display the iframe fully to the user in the chat window. You can add your own text supplementary
|
@@ -360,7 +360,7 @@ chart_tools = [
|
|
360 |
"type": "function",
|
361 |
"function": {
|
362 |
"name": "table_generation_func",
|
363 |
-
"description": f"""This an table generation tool useful to format data as a table from queried data from our
|
364 |
Takes no parameters as it uses data queried in our query.csv file to build the table.
|
365 |
Call this function after running our SQLite query and generating query.csv.
|
366 |
Returns an iframe string which will be displayed inline in our chat window. Do not edit the iframe string returned
|
|
|
3 |
"type": "function",
|
4 |
"function": {
|
5 |
"name": "scatter_chart_generation_func",
|
6 |
+
"description": f"""This is a scatter plot generation tool useful to generate scatter plots from queried data from our data source that we are querying.
|
7 |
The data values will come from the columns of our query.csv (the 'x' and 'y' values of each graph) file but the layout section of the plotly dictionary objects will be generated by you.
|
8 |
Returns an iframe string which will be displayed inline in our chat window. Do not edit the iframe string returned
|
9 |
from the scatter_chart_generation_func function in any way and always display the iframe fully to the user in the chat window. You can add your own text supplementary
|
|
|
108 |
"type": "function",
|
109 |
"function": {
|
110 |
"name": "line_chart_generation_func",
|
111 |
+
"description": f"""This is a line chart generation tool useful to generate line charts from queried data from our data source that we are querying.
|
112 |
The data values will come from the columns of our query.csv (the 'x' and 'y' values of each graph) file but the layout section of the plotly dictionary objects will be generated by you.
|
113 |
Returns an iframe string which will be displayed inline in our chat window. Do not edit the iframe string returned
|
114 |
from the line_chart_generation_func function in any way and always display the iframe fully to the user in the chat window. You can add your own text supplementary
|
|
|
164 |
"type": "function",
|
165 |
"function": {
|
166 |
"name": "bar_chart_generation_func",
|
167 |
+
"description": f"""This is a bar chart generation tool useful to generate line charts from queried data from our data source that we are querying.
|
168 |
The data values will come from the columns of our query.csv (the 'x' and 'y' values of each graph) file but the layout section of the plotly dictionary objects will be generated by you.
|
169 |
Returns an iframe string which will be displayed inline in our chat window. Do not edit the iframe string returned
|
170 |
from the bar_chart_generation_func function in any way and always display the iframe fully to the user in the chat window. You can add your own text supplementary
|
|
|
236 |
"type": "function",
|
237 |
"function": {
|
238 |
"name": "pie_chart_generation_func",
|
239 |
+
"description": f"""This is a pie chart generation tool useful to generate pie charts from queried data from our data source that we are querying.
|
240 |
The data values will come from the columns of our query.csv (the 'values' and 'names' values of each graph) file but the layout section of the plotly dictionary objects will be generated by you.
|
241 |
Returns an iframe string which will be displayed inline in our chat window. Do not edit the iframe string returned
|
242 |
from the pie_chart_generation_func function in any way and always display the iframe fully to the user in the chat window. You can add your own text supplementary
|
|
|
285 |
"type": "function",
|
286 |
"function": {
|
287 |
"name": "histogram_generation_func",
|
288 |
+
"description": f"""This is a histogram generation tool useful to generate histograms from queried data from our data source that we are querying.
|
289 |
The data values will come from the columns of our query.csv (the 'values' and 'names' values of each graph) file but the layout section of the plotly dictionary objects will be generated by you.
|
290 |
Returns an iframe string which will be displayed inline in our chat window. Do not edit the iframe string returned
|
291 |
from the histogram_generation_func function in any way and always display the iframe fully to the user in the chat window. You can add your own text supplementary
|
|
|
360 |
"type": "function",
|
361 |
"function": {
|
362 |
"name": "table_generation_func",
|
363 |
+
"description": f"""This an table generation tool useful to format data as a table from queried data from our data source that we are querying.
|
364 |
Takes no parameters as it uses data queried in our query.csv file to build the table.
|
365 |
Call this function after running our SQLite query and generating query.csv.
|
366 |
Returns an iframe string which will be displayed inline in our chat window. Do not edit the iframe string returned
|
tools/stats_tools.py
CHANGED
@@ -3,7 +3,7 @@ stats_tools = [
|
|
3 |
"type": "function",
|
4 |
"function": {
|
5 |
"name": "regression_func",
|
6 |
-
"description": f"""This a tool to calculate regressions on our
|
7 |
We can run queries with our 'sql_query_func' function and they will be available to use in this function via the query.csv file that is generated.
|
8 |
Returns a dictionary of values that includes a regression_summary and a regression chart (which is an iframe displaying the
|
9 |
linear regression in chart form and should be shown to the user).""",
|
|
|
3 |
"type": "function",
|
4 |
"function": {
|
5 |
"name": "regression_func",
|
6 |
+
"description": f"""This a tool to calculate regressions on our data source that we are querying.
|
7 |
We can run queries with our 'sql_query_func' function and they will be available to use in this function via the query.csv file that is generated.
|
8 |
Returns a dictionary of values that includes a regression_summary and a regression chart (which is an iframe displaying the
|
9 |
linear regression in chart form and should be shown to the user).""",
|
tools/tools.py
CHANGED
@@ -1,11 +1,12 @@
|
|
1 |
import sqlite3
|
|
|
2 |
from .stats_tools import stats_tools
|
3 |
from .chart_tools import chart_tools
|
4 |
from utils import TEMP_DIR
|
5 |
|
6 |
def data_file_tools_call(session_hash):
|
7 |
dir_path = TEMP_DIR / str(session_hash)
|
8 |
-
connection = sqlite3.connect(f'{dir_path}/data_source.db')
|
9 |
print("Querying Database in Tools.py");
|
10 |
cur=connection.execute('select * from data_source')
|
11 |
columns = [i[0] for i in cur.description]
|
@@ -46,22 +47,24 @@ def data_file_tools_call(session_hash):
|
|
46 |
|
47 |
return tools_calls
|
48 |
|
49 |
-
def
|
|
|
|
|
50 |
|
51 |
tools_calls = [
|
52 |
{
|
53 |
"type": "function",
|
54 |
"function": {
|
55 |
-
"name": "
|
56 |
-
"description": f"""This is a tool useful to query a
|
57 |
-
There may also be more
|
58 |
This function also saves the results of the query to csv file called query.csv.""",
|
59 |
"parameters": {
|
60 |
"type": "object",
|
61 |
"properties": {
|
62 |
"queries": {
|
63 |
"type": "array",
|
64 |
-
"description": "The
|
65 |
"items": {
|
66 |
"type": "string",
|
67 |
}
|
@@ -73,7 +76,7 @@ def graphql_tools_call(sessions_hash):
|
|
73 |
},
|
74 |
]
|
75 |
|
76 |
-
tools_calls.
|
77 |
-
tools_calls.
|
78 |
|
79 |
-
return
|
|
|
1 |
import sqlite3
|
2 |
+
import psycopg2
|
3 |
from .stats_tools import stats_tools
|
4 |
from .chart_tools import chart_tools
|
5 |
from utils import TEMP_DIR
|
6 |
|
7 |
def data_file_tools_call(session_hash):
|
8 |
dir_path = TEMP_DIR / str(session_hash)
|
9 |
+
connection = sqlite3.connect(f'{dir_path}/file_upload/data_source.db')
|
10 |
print("Querying Database in Tools.py");
|
11 |
cur=connection.execute('select * from data_source')
|
12 |
columns = [i[0] for i in cur.description]
|
|
|
47 |
|
48 |
return tools_calls
|
49 |
|
50 |
+
def sql_tools_call(db_tables):
|
51 |
+
|
52 |
+
table_string = (db_tables[:625] + '..') if len(db_tables) > 625 else db_tables
|
53 |
|
54 |
tools_calls = [
|
55 |
{
|
56 |
"type": "function",
|
57 |
"function": {
|
58 |
+
"name": "sql_query_func",
|
59 |
+
"description": f"""This is a tool useful to query a PostgreSQL database with the following tables, {table_string}.
|
60 |
+
There may also be more tables in the database if the number of columns is too large to process.
|
61 |
This function also saves the results of the query to csv file called query.csv.""",
|
62 |
"parameters": {
|
63 |
"type": "object",
|
64 |
"properties": {
|
65 |
"queries": {
|
66 |
"type": "array",
|
67 |
+
"description": "The PostgreSQL query to use in the search. Infer this from the user's message. It should be a question or a statement",
|
68 |
"items": {
|
69 |
"type": "string",
|
70 |
}
|
|
|
76 |
},
|
77 |
]
|
78 |
|
79 |
+
tools_calls.extend(chart_tools)
|
80 |
+
tools_calls.extend(stats_tools)
|
81 |
|
82 |
+
return tools_calls
|