nolanzandi commited on
Commit
dd82d0a
·
verified ·
1 Parent(s): 95c52e2

Table generation and download (#12)

Browse files

- Table generation and download (e7b4bfb12b525c5770dd25408a1a433d15cc5200)

app.py CHANGED
@@ -10,4 +10,4 @@ if "OPENAI_API_KEY" not in os.environ:
10
  os.environ["OPENAI_API_KEY"] = getpass("Enter OpenAI API key:")
11
 
12
  ## Uncomment the line below to launch the chat app with UI
13
- demo.launch(debug=True)
 
10
  os.environ["OPENAI_API_KEY"] = getpass("Enter OpenAI API key:")
11
 
12
  ## Uncomment the line below to launch the chat app with UI
13
+ demo.launch(debug=True, allowed_paths=["temp/"])
data_sources/upload_file.py CHANGED
@@ -3,6 +3,8 @@ import sqlite3
3
  import csv
4
  import json
5
  import time
 
 
6
 
7
  def is_file_done_saving(file_path):
8
  try:
@@ -65,8 +67,10 @@ def process_data_upload(data_file, session_hash):
65
  if df[column].dtype == 'object' and isinstance(df[column].iloc[0], list):
66
  df[column] = df[column].explode()
67
 
68
-
69
- connection = sqlite3.connect(f'data_source_{session_hash}.db')
 
 
70
  print("Opened database successfully");
71
  print(df.columns)
72
 
 
3
  import csv
4
  import json
5
  import time
6
+ import os
7
+ from utils import TEMP_DIR
8
 
9
  def is_file_done_saving(file_path):
10
  try:
 
67
  if df[column].dtype == 'object' and isinstance(df[column].iloc[0], list):
68
  df[column] = df[column].explode()
69
 
70
+ dir_path = TEMP_DIR / str(session_hash)
71
+ os.makedirs(dir_path, exist_ok=True)
72
+
73
+ connection = sqlite3.connect(f'{dir_path}/data_source.db')
74
  print("Opened database successfully");
75
  print(df.columns)
76
 
functions/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
  from .sqlite_functions import SQLiteQuery, sqlite_query_func
2
- from .chart_functions import chart_generation_func
3
  from .chat_functions import demo
4
 
5
- __all__ = ["SQLiteQuery","sqlite_query_func","chart_generation_func","demo"]
 
1
  from .sqlite_functions import SQLiteQuery, sqlite_query_func
2
+ from .chart_functions import chart_generation_func, table_generation_func
3
  from .chat_functions import demo
4
 
5
+ __all__ = ["SQLiteQuery","sqlite_query_func","chart_generation_func","table_generation_func","demo"]
functions/chart_functions.py CHANGED
@@ -1,5 +1,7 @@
1
  from typing import List
2
  from quickchart import QuickChart
 
 
3
 
4
  def chart_generation_func(queries: List[str], session_hash):
5
  print("CHART GENERATION")
@@ -22,3 +24,16 @@ def chart_generation_func(queries: List[str], session_hash):
22
  iframe = '<div style=overflow:auto;><iframe\n scrolling="yes"\n width="1000px"\n height="500px"\n src="' + interactive_url + '"\n frameborder="0"\n allowfullscreen\n></iframe>\n <p>Edit, share, and download this graph <a target="_blank" href="' + edit_url + '">here</a></p></div>'
23
 
24
  return {"reply": iframe}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from typing import List
2
  from quickchart import QuickChart
3
+ import pandas as pd
4
+ from utils import TEMP_DIR
5
 
6
  def chart_generation_func(queries: List[str], session_hash):
7
  print("CHART GENERATION")
 
24
  iframe = '<div style=overflow:auto;><iframe\n scrolling="yes"\n width="1000px"\n height="500px"\n src="' + interactive_url + '"\n frameborder="0"\n allowfullscreen\n></iframe>\n <p>Edit, share, and download this graph <a target="_blank" href="' + edit_url + '">here</a></p></div>'
25
 
26
  return {"reply": iframe}
27
+
28
+ def table_generation_func(data: List[str], session_hash):
29
+ dir_path = TEMP_DIR / str(session_hash)
30
+ print("TABLE GENERATION")
31
+ print(data)
32
+ df = pd.DataFrame(data)
33
+ csv_path = f'{dir_path}/data.csv'
34
+ df.to_csv(csv_path)
35
+ download_path = f'gradio_api/file/temp/{session_hash}/data.csv'
36
+ html_table = df.to_html() + f'<p>Download as a <a href="{download_path}">CSV</a></p>'
37
+ print(html_table)
38
+
39
+ return {"reply": html_table}
functions/chat_functions.py CHANGED
@@ -1,4 +1,5 @@
1
  from data_sources import process_data_upload
 
2
 
3
  import gradio as gr
4
 
@@ -27,7 +28,8 @@ def example_question_generator(session_hash):
27
  "You are a helpful and knowledgeable agent who has access to an SQLite database which has a table called 'data_source'."
28
  )
29
  ]
30
- connection = sqlite3.connect(f'data_source_{session_hash}.db')
 
31
  print("Querying questions");
32
  cur=connection.execute('select * from data_source')
33
  columns = [i[0] for i in cur.description]
@@ -48,11 +50,11 @@ def example_question_generator(session_hash):
48
  return example_response["replies"][0].text
49
 
50
  def chatbot_with_fc(message, history, session_hash):
51
- from functions import sqlite_query_func, chart_generation_func
52
  from pipelines import rag_pipeline_func
53
  import tools
54
 
55
- available_functions = {"sql_query_func": sqlite_query_func, "rag_pipeline_func": rag_pipeline_func, "chart_generation_func": chart_generation_func}
56
 
57
  if message_dict[session_hash] != None:
58
  message_dict[session_hash].append(ChatMessage.from_user(message))
@@ -92,9 +94,10 @@ def chatbot_with_fc(message, history, session_hash):
92
  return response["replies"][0].text
93
 
94
  def delete_db(req: gr.Request):
95
- db_file_path = f'data_source_{req.session_hash}.db'
96
- if os.path.exists(db_file_path):
97
- os.remove(db_file_path)
 
98
  message_dict[req.session_hash] = None
99
 
100
  def run_example(input):
@@ -139,14 +142,16 @@ with gr.Blocks(css=css, delete_cache=(3600,3600)) as demo:
139
  ["Describe the dataset"],
140
  ["What levels of education have the highest and lowest average balance?"],
141
  ["What job is most and least common for a yes response from the individuals, not counting 'unknown'?"],
142
- ["Can you generate a bar chart of education vs. average balance?"]
 
143
  ]
144
  elif "online_retail_data" in filename:
145
  example_questions = [
146
  ["Describe the dataset"],
147
  ["What month had the highest revenue?"],
148
  ["Is revenue higher in the morning or afternoon?"],
149
- ["Can you generate a line graph of revenue per month?"]
 
150
  ]
151
  else:
152
  try:
 
1
  from data_sources import process_data_upload
2
+ from utils import TEMP_DIR
3
 
4
  import gradio as gr
5
 
 
28
  "You are a helpful and knowledgeable agent who has access to an SQLite database which has a table called 'data_source'."
29
  )
30
  ]
31
+ dir_path = TEMP_DIR / str(session_hash)
32
+ connection = sqlite3.connect(f'{dir_path}/data_source.db')
33
  print("Querying questions");
34
  cur=connection.execute('select * from data_source')
35
  columns = [i[0] for i in cur.description]
 
50
  return example_response["replies"][0].text
51
 
52
  def chatbot_with_fc(message, history, session_hash):
53
+ from functions import sqlite_query_func, chart_generation_func, table_generation_func
54
  from pipelines import rag_pipeline_func
55
  import tools
56
 
57
+ available_functions = {"sql_query_func": sqlite_query_func, "rag_pipeline_func": rag_pipeline_func, "chart_generation_func": chart_generation_func, "table_generation_func":table_generation_func }
58
 
59
  if message_dict[session_hash] != None:
60
  message_dict[session_hash].append(ChatMessage.from_user(message))
 
94
  return response["replies"][0].text
95
 
96
  def delete_db(req: gr.Request):
97
+ import shutil
98
+ dir_path = TEMP_DIR / str(req.session_hash)
99
+ if os.path.exists(dir_path):
100
+ shutil.rmtree(dir_path)
101
  message_dict[req.session_hash] = None
102
 
103
  def run_example(input):
 
142
  ["Describe the dataset"],
143
  ["What levels of education have the highest and lowest average balance?"],
144
  ["What job is most and least common for a yes response from the individuals, not counting 'unknown'?"],
145
+ ["Can you generate a bar chart of education vs. average balance?"],
146
+ ["Can you generate a table of levels of education versus average balance, percent married, percent with a loan, and percent in default?"]
147
  ]
148
  elif "online_retail_data" in filename:
149
  example_questions = [
150
  ["Describe the dataset"],
151
  ["What month had the highest revenue?"],
152
  ["Is revenue higher in the morning or afternoon?"],
153
+ ["Can you generate a line graph of revenue per month?"],
154
+ ["Can you generate a table of revenue per month?"]
155
  ]
156
  else:
157
  try:
functions/sqlite_functions.py CHANGED
@@ -6,6 +6,7 @@ pd.set_option('display.max_columns', None)
6
  pd.set_option('display.width', None)
7
  pd.set_option('display.max_colwidth', None)
8
  import sqlite3
 
9
 
10
  @component
11
  class SQLiteQuery:
@@ -26,7 +27,8 @@ class SQLiteQuery:
26
 
27
 
28
  def sqlite_query_func(queries: List[str], session_hash):
29
- sql_query = SQLiteQuery(f'data_source_{session_hash}.db')
 
30
  try:
31
  result = sql_query.run(queries)
32
  return {"reply": result["results"][0]}
 
6
  pd.set_option('display.width', None)
7
  pd.set_option('display.max_colwidth', None)
8
  import sqlite3
9
+ from utils import TEMP_DIR
10
 
11
  @component
12
  class SQLiteQuery:
 
27
 
28
 
29
  def sqlite_query_func(queries: List[str], session_hash):
30
+ dir_path = TEMP_DIR / str(session_hash)
31
+ sql_query = SQLiteQuery(f'{dir_path}/data_source.db')
32
  try:
33
  result = sql_query.run(queries)
34
  return {"reply": result["results"][0]}
tools.py CHANGED
@@ -1,7 +1,9 @@
1
  import sqlite3
 
2
 
3
  def tools_call(session_hash):
4
- connection = sqlite3.connect(f'data_source_{session_hash}.db')
 
5
  print("Querying Database in Tools.py");
6
  cur=connection.execute('select * from data_source')
7
  columns = [i[0] for i in cur.description]
@@ -35,7 +37,7 @@ def tools_call(session_hash):
35
  "type": "function",
36
  "function": {
37
  "name": "chart_generation_func",
38
- "description": f"This an chart generation tool useful to generate charts and graphs from queried data from our SQL table called 'data_source with the following Columns: {columns}. Returns an iframe string which will be displayed inline in our chat window. Do not edit the string returned from the chart_generation_func function in any way and display it fully to the user. You can add your own text supplementary to it for context if desired.",
39
  "parameters": {
40
  "type": "object",
41
  "properties": {
@@ -43,7 +45,36 @@ def tools_call(session_hash):
43
  "type": "array",
44
  "description": """The data points to use in the chart generation. Infer this from the user's message.
45
  Send a chart.js dictionary with options that correspond to the users request. But also format this dictionary as a string as this will allow javascript to be interpreted by the API we are using.
46
- Return nothing else.""",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  "items": {
48
  "type": "string",
49
  }
 
1
  import sqlite3
2
+ from utils import TEMP_DIR
3
 
4
  def tools_call(session_hash):
5
+ dir_path = TEMP_DIR / str(session_hash)
6
+ connection = sqlite3.connect(f'{dir_path}/data_source.db')
7
  print("Querying Database in Tools.py");
8
  cur=connection.execute('select * from data_source')
9
  columns = [i[0] for i in cur.description]
 
37
  "type": "function",
38
  "function": {
39
  "name": "chart_generation_func",
40
+ "description": f"This a chart generation tool useful to generate charts and graphs from queried data from our SQL table called 'data_source with the following Columns: {columns}. Returns an iframe string which will be displayed inline in our chat window. Do not edit the string returned from the chart_generation_func function in any way and display it fully to the user in the chat window. You can add your own text supplementary to it for context if desired.",
41
  "parameters": {
42
  "type": "object",
43
  "properties": {
 
45
  "type": "array",
46
  "description": """The data points to use in the chart generation. Infer this from the user's message.
47
  Send a chart.js dictionary with options that correspond to the users request. But also format this dictionary as a string as this will allow javascript to be interpreted by the API we are using.
48
+ Send nothing else.""",
49
+ "items": {
50
+ "type": "string",
51
+ }
52
+ }
53
+ },
54
+ "required": ["question"],
55
+ },
56
+ },
57
+ },
58
+ {
59
+ "type": "function",
60
+ "function": {
61
+ "name": "table_generation_func",
62
+ "description": f"""This an table generation tool useful to format data as a table from queried data from our SQL table called
63
+ 'data_source with the following Columns: {columns}. Returns an html string generated from the pandas library and pandas.to_html()
64
+ function which will be displayed inline in our chat window. There will also be a link to download the CSV included in the HTML string.
65
+ Do not change or edit this link in any way. Do not edit the HTML returned by the function in any way.
66
+ Do not edit the string returned from the table_generation_func function in any way and display it fully to the user in the chat window.
67
+ You can add your own text next to the returned string for context if desired.""",
68
+ "parameters": {
69
+ "type": "object",
70
+ "properties": {
71
+ "data": {
72
+ "type": "array",
73
+ "description": """The data points to use in the table generation. Infer this from the user's message.
74
+ Send a python dictionary object with query data that correspond to data that will be converted into a pandas DataFrame and that correspond to the users request.
75
+ The keys of this python dictionary object will be the names of the columns and values will be a list of values for each object.
76
+ Make sure this is a dictionary object and not a string or an array.
77
+ Send nothing else.""",
78
  "items": {
79
  "type": "string",
80
  }
utils.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ current_dir = Path(__file__).parent
4
+
5
+ TEMP_DIR = current_dir / 'temp'