Table generation and download (#12)
Browse files- Table generation and download (e7b4bfb12b525c5770dd25408a1a433d15cc5200)
- app.py +1 -1
- data_sources/upload_file.py +6 -2
- functions/__init__.py +2 -2
- functions/chart_functions.py +15 -0
- functions/chat_functions.py +13 -8
- functions/sqlite_functions.py +3 -1
- tools.py +34 -3
- utils.py +5 -0
app.py
CHANGED
@@ -10,4 +10,4 @@ if "OPENAI_API_KEY" not in os.environ:
|
|
10 |
os.environ["OPENAI_API_KEY"] = getpass("Enter OpenAI API key:")
|
11 |
|
12 |
## Uncomment the line below to launch the chat app with UI
|
13 |
-
demo.launch(debug=True)
|
|
|
10 |
os.environ["OPENAI_API_KEY"] = getpass("Enter OpenAI API key:")
|
11 |
|
12 |
## Uncomment the line below to launch the chat app with UI
|
13 |
+
demo.launch(debug=True, allowed_paths=["temp/"])
|
data_sources/upload_file.py
CHANGED
@@ -3,6 +3,8 @@ import sqlite3
|
|
3 |
import csv
|
4 |
import json
|
5 |
import time
|
|
|
|
|
6 |
|
7 |
def is_file_done_saving(file_path):
|
8 |
try:
|
@@ -65,8 +67,10 @@ def process_data_upload(data_file, session_hash):
|
|
65 |
if df[column].dtype == 'object' and isinstance(df[column].iloc[0], list):
|
66 |
df[column] = df[column].explode()
|
67 |
|
68 |
-
|
69 |
-
|
|
|
|
|
70 |
print("Opened database successfully");
|
71 |
print(df.columns)
|
72 |
|
|
|
3 |
import csv
|
4 |
import json
|
5 |
import time
|
6 |
+
import os
|
7 |
+
from utils import TEMP_DIR
|
8 |
|
9 |
def is_file_done_saving(file_path):
|
10 |
try:
|
|
|
67 |
if df[column].dtype == 'object' and isinstance(df[column].iloc[0], list):
|
68 |
df[column] = df[column].explode()
|
69 |
|
70 |
+
dir_path = TEMP_DIR / str(session_hash)
|
71 |
+
os.makedirs(dir_path, exist_ok=True)
|
72 |
+
|
73 |
+
connection = sqlite3.connect(f'{dir_path}/data_source.db')
|
74 |
print("Opened database successfully");
|
75 |
print(df.columns)
|
76 |
|
functions/__init__.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
from .sqlite_functions import SQLiteQuery, sqlite_query_func
|
2 |
-
from .chart_functions import chart_generation_func
|
3 |
from .chat_functions import demo
|
4 |
|
5 |
-
__all__ = ["SQLiteQuery","sqlite_query_func","chart_generation_func","demo"]
|
|
|
1 |
from .sqlite_functions import SQLiteQuery, sqlite_query_func
|
2 |
+
from .chart_functions import chart_generation_func, table_generation_func
|
3 |
from .chat_functions import demo
|
4 |
|
5 |
+
__all__ = ["SQLiteQuery","sqlite_query_func","chart_generation_func","table_generation_func","demo"]
|
functions/chart_functions.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1 |
from typing import List
|
2 |
from quickchart import QuickChart
|
|
|
|
|
3 |
|
4 |
def chart_generation_func(queries: List[str], session_hash):
|
5 |
print("CHART GENERATION")
|
@@ -22,3 +24,16 @@ def chart_generation_func(queries: List[str], session_hash):
|
|
22 |
iframe = '<div style=overflow:auto;><iframe\n scrolling="yes"\n width="1000px"\n height="500px"\n src="' + interactive_url + '"\n frameborder="0"\n allowfullscreen\n></iframe>\n <p>Edit, share, and download this graph <a target="_blank" href="' + edit_url + '">here</a></p></div>'
|
23 |
|
24 |
return {"reply": iframe}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from typing import List
|
2 |
from quickchart import QuickChart
|
3 |
+
import pandas as pd
|
4 |
+
from utils import TEMP_DIR
|
5 |
|
6 |
def chart_generation_func(queries: List[str], session_hash):
|
7 |
print("CHART GENERATION")
|
|
|
24 |
iframe = '<div style=overflow:auto;><iframe\n scrolling="yes"\n width="1000px"\n height="500px"\n src="' + interactive_url + '"\n frameborder="0"\n allowfullscreen\n></iframe>\n <p>Edit, share, and download this graph <a target="_blank" href="' + edit_url + '">here</a></p></div>'
|
25 |
|
26 |
return {"reply": iframe}
|
27 |
+
|
28 |
+
def table_generation_func(data: List[str], session_hash):
|
29 |
+
dir_path = TEMP_DIR / str(session_hash)
|
30 |
+
print("TABLE GENERATION")
|
31 |
+
print(data)
|
32 |
+
df = pd.DataFrame(data)
|
33 |
+
csv_path = f'{dir_path}/data.csv'
|
34 |
+
df.to_csv(csv_path)
|
35 |
+
download_path = f'gradio_api/file/temp/{session_hash}/data.csv'
|
36 |
+
html_table = df.to_html() + f'<p>Download as a <a href="{download_path}">CSV</a></p>'
|
37 |
+
print(html_table)
|
38 |
+
|
39 |
+
return {"reply": html_table}
|
functions/chat_functions.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
from data_sources import process_data_upload
|
|
|
2 |
|
3 |
import gradio as gr
|
4 |
|
@@ -27,7 +28,8 @@ def example_question_generator(session_hash):
|
|
27 |
"You are a helpful and knowledgeable agent who has access to an SQLite database which has a table called 'data_source'."
|
28 |
)
|
29 |
]
|
30 |
-
|
|
|
31 |
print("Querying questions");
|
32 |
cur=connection.execute('select * from data_source')
|
33 |
columns = [i[0] for i in cur.description]
|
@@ -48,11 +50,11 @@ def example_question_generator(session_hash):
|
|
48 |
return example_response["replies"][0].text
|
49 |
|
50 |
def chatbot_with_fc(message, history, session_hash):
|
51 |
-
from functions import sqlite_query_func, chart_generation_func
|
52 |
from pipelines import rag_pipeline_func
|
53 |
import tools
|
54 |
|
55 |
-
available_functions = {"sql_query_func": sqlite_query_func, "rag_pipeline_func": rag_pipeline_func, "chart_generation_func": chart_generation_func}
|
56 |
|
57 |
if message_dict[session_hash] != None:
|
58 |
message_dict[session_hash].append(ChatMessage.from_user(message))
|
@@ -92,9 +94,10 @@ def chatbot_with_fc(message, history, session_hash):
|
|
92 |
return response["replies"][0].text
|
93 |
|
94 |
def delete_db(req: gr.Request):
|
95 |
-
|
96 |
-
|
97 |
-
|
|
|
98 |
message_dict[req.session_hash] = None
|
99 |
|
100 |
def run_example(input):
|
@@ -139,14 +142,16 @@ with gr.Blocks(css=css, delete_cache=(3600,3600)) as demo:
|
|
139 |
["Describe the dataset"],
|
140 |
["What levels of education have the highest and lowest average balance?"],
|
141 |
["What job is most and least common for a yes response from the individuals, not counting 'unknown'?"],
|
142 |
-
["Can you generate a bar chart of education vs. average balance?"]
|
|
|
143 |
]
|
144 |
elif "online_retail_data" in filename:
|
145 |
example_questions = [
|
146 |
["Describe the dataset"],
|
147 |
["What month had the highest revenue?"],
|
148 |
["Is revenue higher in the morning or afternoon?"],
|
149 |
-
["Can you generate a line graph of revenue per month?"]
|
|
|
150 |
]
|
151 |
else:
|
152 |
try:
|
|
|
1 |
from data_sources import process_data_upload
|
2 |
+
from utils import TEMP_DIR
|
3 |
|
4 |
import gradio as gr
|
5 |
|
|
|
28 |
"You are a helpful and knowledgeable agent who has access to an SQLite database which has a table called 'data_source'."
|
29 |
)
|
30 |
]
|
31 |
+
dir_path = TEMP_DIR / str(session_hash)
|
32 |
+
connection = sqlite3.connect(f'{dir_path}/data_source.db')
|
33 |
print("Querying questions");
|
34 |
cur=connection.execute('select * from data_source')
|
35 |
columns = [i[0] for i in cur.description]
|
|
|
50 |
return example_response["replies"][0].text
|
51 |
|
52 |
def chatbot_with_fc(message, history, session_hash):
|
53 |
+
from functions import sqlite_query_func, chart_generation_func, table_generation_func
|
54 |
from pipelines import rag_pipeline_func
|
55 |
import tools
|
56 |
|
57 |
+
available_functions = {"sql_query_func": sqlite_query_func, "rag_pipeline_func": rag_pipeline_func, "chart_generation_func": chart_generation_func, "table_generation_func":table_generation_func }
|
58 |
|
59 |
if message_dict[session_hash] != None:
|
60 |
message_dict[session_hash].append(ChatMessage.from_user(message))
|
|
|
94 |
return response["replies"][0].text
|
95 |
|
96 |
def delete_db(req: gr.Request):
|
97 |
+
import shutil
|
98 |
+
dir_path = TEMP_DIR / str(req.session_hash)
|
99 |
+
if os.path.exists(dir_path):
|
100 |
+
shutil.rmtree(dir_path)
|
101 |
message_dict[req.session_hash] = None
|
102 |
|
103 |
def run_example(input):
|
|
|
142 |
["Describe the dataset"],
|
143 |
["What levels of education have the highest and lowest average balance?"],
|
144 |
["What job is most and least common for a yes response from the individuals, not counting 'unknown'?"],
|
145 |
+
["Can you generate a bar chart of education vs. average balance?"],
|
146 |
+
["Can you generate a table of levels of education versus average balance, percent married, percent with a loan, and percent in default?"]
|
147 |
]
|
148 |
elif "online_retail_data" in filename:
|
149 |
example_questions = [
|
150 |
["Describe the dataset"],
|
151 |
["What month had the highest revenue?"],
|
152 |
["Is revenue higher in the morning or afternoon?"],
|
153 |
+
["Can you generate a line graph of revenue per month?"],
|
154 |
+
["Can you generate a table of revenue per month?"]
|
155 |
]
|
156 |
else:
|
157 |
try:
|
functions/sqlite_functions.py
CHANGED
@@ -6,6 +6,7 @@ pd.set_option('display.max_columns', None)
|
|
6 |
pd.set_option('display.width', None)
|
7 |
pd.set_option('display.max_colwidth', None)
|
8 |
import sqlite3
|
|
|
9 |
|
10 |
@component
|
11 |
class SQLiteQuery:
|
@@ -26,7 +27,8 @@ class SQLiteQuery:
|
|
26 |
|
27 |
|
28 |
def sqlite_query_func(queries: List[str], session_hash):
|
29 |
-
|
|
|
30 |
try:
|
31 |
result = sql_query.run(queries)
|
32 |
return {"reply": result["results"][0]}
|
|
|
6 |
pd.set_option('display.width', None)
|
7 |
pd.set_option('display.max_colwidth', None)
|
8 |
import sqlite3
|
9 |
+
from utils import TEMP_DIR
|
10 |
|
11 |
@component
|
12 |
class SQLiteQuery:
|
|
|
27 |
|
28 |
|
29 |
def sqlite_query_func(queries: List[str], session_hash):
|
30 |
+
dir_path = TEMP_DIR / str(session_hash)
|
31 |
+
sql_query = SQLiteQuery(f'{dir_path}/data_source.db')
|
32 |
try:
|
33 |
result = sql_query.run(queries)
|
34 |
return {"reply": result["results"][0]}
|
tools.py
CHANGED
@@ -1,7 +1,9 @@
|
|
1 |
import sqlite3
|
|
|
2 |
|
3 |
def tools_call(session_hash):
|
4 |
-
|
|
|
5 |
print("Querying Database in Tools.py");
|
6 |
cur=connection.execute('select * from data_source')
|
7 |
columns = [i[0] for i in cur.description]
|
@@ -35,7 +37,7 @@ def tools_call(session_hash):
|
|
35 |
"type": "function",
|
36 |
"function": {
|
37 |
"name": "chart_generation_func",
|
38 |
-
"description": f"This
|
39 |
"parameters": {
|
40 |
"type": "object",
|
41 |
"properties": {
|
@@ -43,7 +45,36 @@ def tools_call(session_hash):
|
|
43 |
"type": "array",
|
44 |
"description": """The data points to use in the chart generation. Infer this from the user's message.
|
45 |
Send a chart.js dictionary with options that correspond to the users request. But also format this dictionary as a string as this will allow javascript to be interpreted by the API we are using.
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
"items": {
|
48 |
"type": "string",
|
49 |
}
|
|
|
1 |
import sqlite3
|
2 |
+
from utils import TEMP_DIR
|
3 |
|
4 |
def tools_call(session_hash):
|
5 |
+
dir_path = TEMP_DIR / str(session_hash)
|
6 |
+
connection = sqlite3.connect(f'{dir_path}/data_source.db')
|
7 |
print("Querying Database in Tools.py");
|
8 |
cur=connection.execute('select * from data_source')
|
9 |
columns = [i[0] for i in cur.description]
|
|
|
37 |
"type": "function",
|
38 |
"function": {
|
39 |
"name": "chart_generation_func",
|
40 |
+
"description": f"This a chart generation tool useful to generate charts and graphs from queried data from our SQL table called 'data_source with the following Columns: {columns}. Returns an iframe string which will be displayed inline in our chat window. Do not edit the string returned from the chart_generation_func function in any way and display it fully to the user in the chat window. You can add your own text supplementary to it for context if desired.",
|
41 |
"parameters": {
|
42 |
"type": "object",
|
43 |
"properties": {
|
|
|
45 |
"type": "array",
|
46 |
"description": """The data points to use in the chart generation. Infer this from the user's message.
|
47 |
Send a chart.js dictionary with options that correspond to the users request. But also format this dictionary as a string as this will allow javascript to be interpreted by the API we are using.
|
48 |
+
Send nothing else.""",
|
49 |
+
"items": {
|
50 |
+
"type": "string",
|
51 |
+
}
|
52 |
+
}
|
53 |
+
},
|
54 |
+
"required": ["question"],
|
55 |
+
},
|
56 |
+
},
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"type": "function",
|
60 |
+
"function": {
|
61 |
+
"name": "table_generation_func",
|
62 |
+
"description": f"""This an table generation tool useful to format data as a table from queried data from our SQL table called
|
63 |
+
'data_source with the following Columns: {columns}. Returns an html string generated from the pandas library and pandas.to_html()
|
64 |
+
function which will be displayed inline in our chat window. There will also be a link to download the CSV included in the HTML string.
|
65 |
+
Do not change or edit this link in any way. Do not edit the HTML returned by the function in any way.
|
66 |
+
Do not edit the string returned from the table_generation_func function in any way and display it fully to the user in the chat window.
|
67 |
+
You can add your own text next to the returned string for context if desired.""",
|
68 |
+
"parameters": {
|
69 |
+
"type": "object",
|
70 |
+
"properties": {
|
71 |
+
"data": {
|
72 |
+
"type": "array",
|
73 |
+
"description": """The data points to use in the table generation. Infer this from the user's message.
|
74 |
+
Send a python dictionary object with query data that correspond to data that will be converted into a pandas DataFrame and that correspond to the users request.
|
75 |
+
The keys of this python dictionary object will be the names of the columns and values will be a list of values for each object.
|
76 |
+
Make sure this is a dictionary object and not a string or an array.
|
77 |
+
Send nothing else.""",
|
78 |
"items": {
|
79 |
"type": "string",
|
80 |
}
|
utils.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
current_dir = Path(__file__).parent
|
4 |
+
|
5 |
+
TEMP_DIR = current_dir / 'temp'
|