Refactor functions and improve llm accuracy (#14)
Browse files- refactor functions and improve llm accuracy (3d660e2779859c6a266e023b0f93f48de373bb3a)
- app.py +99 -2
- functions/__init__.py +2 -2
- functions/chart_functions.py +77 -26
- functions/chat_functions.py +3 -111
- pipelines/__init__.py +0 -3
- pipelines/pipelines.py +0 -91
- requirements.txt +1 -1
- tools.py +12 -7
- utils.py +3 -1
app.py
CHANGED
@@ -1,5 +1,9 @@
|
|
1 |
-
from
|
|
|
|
|
|
|
2 |
|
|
|
3 |
import os
|
4 |
from getpass import getpass
|
5 |
from dotenv import load_dotenv
|
@@ -9,5 +13,98 @@ load_dotenv()
|
|
9 |
if "OPENAI_API_KEY" not in os.environ:
|
10 |
os.environ["OPENAI_API_KEY"] = getpass("Enter OpenAI API key:")
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
## Uncomment the line below to launch the chat app with UI
|
13 |
-
demo.launch(debug=True, allowed_paths=["temp/"])
|
|
|
1 |
+
from data_sources import process_data_upload
|
2 |
+
from functions import example_question_generator, chatbot_with_fc
|
3 |
+
from utils import TEMP_DIR, message_dict
|
4 |
+
import gradio as gr
|
5 |
|
6 |
+
import ast
|
7 |
import os
|
8 |
from getpass import getpass
|
9 |
from dotenv import load_dotenv
|
|
|
13 |
if "OPENAI_API_KEY" not in os.environ:
|
14 |
os.environ["OPENAI_API_KEY"] = getpass("Enter OpenAI API key:")
|
15 |
|
16 |
+
def delete_db(req: gr.Request):
|
17 |
+
import shutil
|
18 |
+
dir_path = TEMP_DIR / str(req.session_hash)
|
19 |
+
if os.path.exists(dir_path):
|
20 |
+
shutil.rmtree(dir_path)
|
21 |
+
message_dict[req.session_hash] = None
|
22 |
+
|
23 |
+
def run_example(input):
|
24 |
+
return input
|
25 |
+
|
26 |
+
def example_display(input):
|
27 |
+
if input == None:
|
28 |
+
display = True
|
29 |
+
else:
|
30 |
+
display = False
|
31 |
+
return [gr.update(visible=display),gr.update(visible=display)]
|
32 |
+
|
33 |
+
css= ".file_marker .large{min-height:50px !important;} .example_btn{max-width:300px;}"
|
34 |
+
|
35 |
+
with gr.Blocks(css=css, delete_cache=(3600,3600)) as demo:
|
36 |
+
title = gr.HTML("<h1 style='text-align:center;'>Virtual Data Analyst</h1>")
|
37 |
+
description = gr.HTML("""<p style='text-align:center;'>Upload a data file and chat with our virtual data analyst
|
38 |
+
to get insights on your data set. Currently accepts CSV, TSV, TXT, XLS, XLSX, XML, and JSON files.
|
39 |
+
Can now generate charts and graphs!
|
40 |
+
Try a sample file to get started!</p>
|
41 |
+
<p style='text-align:center;'>This tool is under active development. If you experience bugs with use,
|
42 |
+
open a discussion in the community tab and I will respond.</p>""")
|
43 |
+
example_file_1 = gr.File(visible=False, value="samples/bank_marketing_campaign.csv")
|
44 |
+
example_file_2 = gr.File(visible=False, value="samples/online_retail_data.csv")
|
45 |
+
with gr.Row():
|
46 |
+
example_btn_1 = gr.Button(value="Try Me: bank_marketing_campaign.csv", elem_classes="example_btn", size="md", variant="primary")
|
47 |
+
example_btn_2 = gr.Button(value="Try Me: online_retail_data.csv", elem_classes="example_btn", size="md", variant="primary")
|
48 |
+
|
49 |
+
file_output = gr.File(label="Data File (CSV, TSV, TXT, XLS, XLSX, XML, JSON)", show_label=True, elem_classes="file_marker", file_types=['.csv','.xlsx','.txt','.json','.ndjson','.xml','.xls','.tsv'])
|
50 |
+
example_btn_1.click(fn=run_example, inputs=example_file_1, outputs=file_output)
|
51 |
+
example_btn_2.click(fn=run_example, inputs=example_file_2, outputs=file_output)
|
52 |
+
file_output.change(fn=example_display, inputs=file_output, outputs=[example_btn_1, example_btn_2])
|
53 |
+
|
54 |
+
@gr.render(inputs=file_output)
|
55 |
+
def data_options(filename, request: gr.Request):
|
56 |
+
print(filename)
|
57 |
+
message_dict[request.session_hash] = None
|
58 |
+
if filename:
|
59 |
+
process_upload(filename, request.session_hash)
|
60 |
+
if "bank_marketing_campaign" in filename:
|
61 |
+
example_questions = [
|
62 |
+
["Describe the dataset"],
|
63 |
+
["What levels of education have the highest and lowest average balance?"],
|
64 |
+
["What job is most and least common for a yes response from the individuals, not counting 'unknown'?"],
|
65 |
+
["Can you generate a bar chart of education vs. average balance?"],
|
66 |
+
["Can you generate a table of levels of education versus average balance, percent married, percent with a loan, and percent in default?"]
|
67 |
+
]
|
68 |
+
elif "online_retail_data" in filename:
|
69 |
+
example_questions = [
|
70 |
+
["Describe the dataset"],
|
71 |
+
["What month had the highest revenue?"],
|
72 |
+
["Is revenue higher in the morning or afternoon?"],
|
73 |
+
["Can you generate a line graph of revenue per month?"],
|
74 |
+
["Can you generate a table of revenue per month?"]
|
75 |
+
]
|
76 |
+
else:
|
77 |
+
try:
|
78 |
+
generated_examples = ast.literal_eval(example_question_generator(request.session_hash))
|
79 |
+
example_questions = [
|
80 |
+
["Describe the dataset"]
|
81 |
+
]
|
82 |
+
for example in generated_examples:
|
83 |
+
example_questions.append([example])
|
84 |
+
except:
|
85 |
+
example_questions = [
|
86 |
+
["Describe the dataset"],
|
87 |
+
["List the columns in the dataset"],
|
88 |
+
["What could this data be used for?"],
|
89 |
+
]
|
90 |
+
parameters = gr.Textbox(visible=False, value=request.session_hash)
|
91 |
+
bot = gr.Chatbot(type='messages', label="CSV Chat Window", render_markdown=True, sanitize_html=False, show_label=True, render=False, visible=True, elem_classes="chatbot")
|
92 |
+
chat = gr.ChatInterface(
|
93 |
+
fn=chatbot_with_fc,
|
94 |
+
type='messages',
|
95 |
+
chatbot=bot,
|
96 |
+
title="Chat with your data file",
|
97 |
+
concurrency_limit=None,
|
98 |
+
examples=example_questions,
|
99 |
+
additional_inputs=parameters
|
100 |
+
)
|
101 |
+
|
102 |
+
def process_upload(upload_value, session_hash):
|
103 |
+
if upload_value:
|
104 |
+
process_data_upload(upload_value, session_hash)
|
105 |
+
return [], []
|
106 |
+
|
107 |
+
demo.unload(delete_db)
|
108 |
+
|
109 |
## Uncomment the line below to launch the chat app with UI
|
110 |
+
demo.launch(debug=True, allowed_paths=["temp/"])
|
functions/__init__.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
from .sqlite_functions import SQLiteQuery, sqlite_query_func
|
2 |
from .chart_functions import chart_generation_func, table_generation_func
|
3 |
-
from .chat_functions import
|
4 |
|
5 |
-
__all__ = ["SQLiteQuery","sqlite_query_func","chart_generation_func","table_generation_func","
|
|
|
1 |
from .sqlite_functions import SQLiteQuery, sqlite_query_func
|
2 |
from .chart_functions import chart_generation_func, table_generation_func
|
3 |
+
from .chat_functions import example_question_generator, chatbot_with_fc
|
4 |
|
5 |
+
__all__ = ["SQLiteQuery","sqlite_query_func","chart_generation_func","table_generation_func","example_question_generator","chatbot_with_fc"]
|
functions/chart_functions.py
CHANGED
@@ -1,45 +1,96 @@
|
|
1 |
from typing import List
|
2 |
-
from
|
|
|
3 |
import pandas as pd
|
4 |
from utils import TEMP_DIR
|
5 |
import os
|
|
|
6 |
from dotenv import load_dotenv
|
7 |
|
8 |
load_dotenv()
|
9 |
|
10 |
root_url = os.getenv("ROOT_URL")
|
11 |
|
12 |
-
def chart_generation_func(
|
13 |
print("CHART GENERATION")
|
14 |
-
|
15 |
-
print(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
-
|
18 |
-
|
19 |
-
|
|
|
20 |
|
21 |
-
|
22 |
-
|
|
|
|
|
|
|
23 |
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
interactive_url = url_base + '/chart-maker/view/' + url_id
|
28 |
-
edit_url = url_base + '/chart-maker/edit/' + url_id
|
29 |
|
30 |
-
|
31 |
|
32 |
-
|
33 |
|
34 |
-
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
print("TABLE GENERATION")
|
37 |
print(data)
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from typing import List
|
2 |
+
from typing import Dict
|
3 |
+
import plotly.io as pio
|
4 |
import pandas as pd
|
5 |
from utils import TEMP_DIR
|
6 |
import os
|
7 |
+
import ast
|
8 |
from dotenv import load_dotenv
|
9 |
|
10 |
load_dotenv()
|
11 |
|
12 |
root_url = os.getenv("ROOT_URL")
|
13 |
|
14 |
+
def chart_generation_func(data: List[dict], session_hash: str, layout: Dict[str,str]={}):
|
15 |
print("CHART GENERATION")
|
16 |
+
print(data)
|
17 |
+
print(layout)
|
18 |
+
try:
|
19 |
+
dir_path = TEMP_DIR / str(session_hash)
|
20 |
+
chart_path = f'{dir_path}/chart.html'
|
21 |
+
|
22 |
+
#Processing data to account for variation from LLM
|
23 |
+
data_list = []
|
24 |
+
layout_dict = {}
|
25 |
+
if isinstance(data, list):
|
26 |
+
data_list = data
|
27 |
+
else:
|
28 |
+
data_list.append(data)
|
29 |
+
|
30 |
+
if isinstance(data[0], str):
|
31 |
+
data_list[0] = ast.literal_eval(data_list[0])
|
32 |
|
33 |
+
if isinstance(layout, list):
|
34 |
+
layout_obj = layout[0]
|
35 |
+
else:
|
36 |
+
layout_obj = layout
|
37 |
|
38 |
+
if isinstance(layout_obj, str):
|
39 |
+
layout_dict = ast.literal_eval(layout_obj)
|
40 |
+
else:
|
41 |
+
layout_dict = layout_obj
|
42 |
+
|
43 |
|
44 |
+
fig = dict({"data": data_list,
|
45 |
+
"layout": layout_dict})
|
46 |
+
pio.write_html(fig, chart_path, full_html=False)
|
|
|
|
|
47 |
|
48 |
+
chart_url = f'{root_url}/gradio_api/file/temp/{session_hash}/chart.html'
|
49 |
|
50 |
+
iframe = '<div style=overflow:auto;><iframe\n scrolling="yes"\n width="1000px"\n height="500px"\n src="' + chart_url + '"\n frameborder="0"\n allowfullscreen\n></iframe>\n</div>'
|
51 |
|
52 |
+
return {"reply": iframe}
|
53 |
+
|
54 |
+
except Exception as e:
|
55 |
+
print("CHART ERROR")
|
56 |
+
reply = f"""There was an error generating the Plotly Chart from {data} and {layout}
|
57 |
+
The error is {e},
|
58 |
+
You should probably try again.
|
59 |
+
"""
|
60 |
+
return {"reply": reply}
|
61 |
+
|
62 |
+
def table_generation_func(data: List[dict], session_hash):
|
63 |
print("TABLE GENERATION")
|
64 |
print(data)
|
65 |
+
try:
|
66 |
+
dir_path = TEMP_DIR / str(session_hash)
|
67 |
+
csv_path = f'{dir_path}/data.csv'
|
68 |
+
|
69 |
+
#Processing data to account for variation from LLM
|
70 |
+
if isinstance(data, list):
|
71 |
+
data_obj = data[0]
|
72 |
+
else:
|
73 |
+
data_obj = data
|
74 |
+
|
75 |
+
if isinstance(data_obj, str):
|
76 |
+
data_dict = ast.literal_eval(data_obj)
|
77 |
+
else:
|
78 |
+
data_dict = data_obj
|
79 |
+
|
80 |
+
df = pd.DataFrame.from_dict(data_dict)
|
81 |
+
print(df)
|
82 |
+
df.to_csv(csv_path)
|
83 |
+
|
84 |
+
download_path = f'{root_url}/gradio_api/file/temp/{session_hash}/data.csv'
|
85 |
+
html_table = df.to_html() + f'<p>Download as a <a href="{download_path}">CSV file</a></p>'
|
86 |
+
print(html_table)
|
87 |
+
|
88 |
+
return {"reply": html_table}
|
89 |
+
|
90 |
+
except Exception as e:
|
91 |
+
print("TABLE ERROR")
|
92 |
+
reply = f"""There was an error generating the Pandas DataFrame table from {data}
|
93 |
+
The error is {e},
|
94 |
+
You should probably try again.
|
95 |
+
"""
|
96 |
+
return {"reply": reply}
|
functions/chat_functions.py
CHANGED
@@ -1,24 +1,10 @@
|
|
1 |
-
from
|
2 |
-
from utils import TEMP_DIR
|
3 |
-
|
4 |
-
import gradio as gr
|
5 |
|
6 |
from haystack.dataclasses import ChatMessage
|
7 |
from haystack.components.generators.chat import OpenAIChatGenerator
|
8 |
|
9 |
-
import os
|
10 |
-
import ast
|
11 |
-
from getpass import getpass
|
12 |
-
from dotenv import load_dotenv
|
13 |
-
|
14 |
-
load_dotenv()
|
15 |
-
|
16 |
-
if "OPENAI_API_KEY" not in os.environ:
|
17 |
-
os.environ["OPENAI_API_KEY"] = getpass("Enter OpenAI API key:")
|
18 |
-
|
19 |
chat_generator = OpenAIChatGenerator(model="gpt-4o")
|
20 |
response = None
|
21 |
-
message_dict = {}
|
22 |
|
23 |
def example_question_generator(session_hash):
|
24 |
import sqlite3
|
@@ -51,10 +37,9 @@ def example_question_generator(session_hash):
|
|
51 |
|
52 |
def chatbot_with_fc(message, history, session_hash):
|
53 |
from functions import sqlite_query_func, chart_generation_func, table_generation_func
|
54 |
-
from pipelines import rag_pipeline_func
|
55 |
import tools
|
56 |
|
57 |
-
available_functions = {"sql_query_func": sqlite_query_func, "
|
58 |
|
59 |
if message_dict[session_hash] != None:
|
60 |
message_dict[session_hash].append(ChatMessage.from_user(message))
|
@@ -62,7 +47,7 @@ def chatbot_with_fc(message, history, session_hash):
|
|
62 |
messages = [
|
63 |
ChatMessage.from_system(
|
64 |
"""You are a helpful and knowledgeable agent who has access to an SQLite database which has a table called 'data_source'.
|
65 |
-
You also have access to a chart
|
66 |
You also have access to a function, called table_generation_func, that builds table formatted html and generates a link to download as CSV."""
|
67 |
)
|
68 |
]
|
@@ -95,97 +80,4 @@ def chatbot_with_fc(message, history, session_hash):
|
|
95 |
break
|
96 |
return response["replies"][0].text
|
97 |
|
98 |
-
def delete_db(req: gr.Request):
|
99 |
-
import shutil
|
100 |
-
dir_path = TEMP_DIR / str(req.session_hash)
|
101 |
-
if os.path.exists(dir_path):
|
102 |
-
shutil.rmtree(dir_path)
|
103 |
-
message_dict[req.session_hash] = None
|
104 |
-
|
105 |
-
def run_example(input):
|
106 |
-
return input
|
107 |
-
|
108 |
-
def example_display(input):
|
109 |
-
if input == None:
|
110 |
-
display = True
|
111 |
-
else:
|
112 |
-
display = False
|
113 |
-
return [gr.update(visible=display),gr.update(visible=display)]
|
114 |
-
|
115 |
-
css= ".file_marker .large{min-height:50px !important;} .example_btn{max-width:300px;}"
|
116 |
-
|
117 |
-
with gr.Blocks(css=css, delete_cache=(3600,3600)) as demo:
|
118 |
-
title = gr.HTML("<h1 style='text-align:center;'>Virtual Data Analyst</h1>")
|
119 |
-
description = gr.HTML("""<p style='text-align:center;'>Upload a data file and chat with our virtual data analyst
|
120 |
-
to get insights on your data set. Currently accepts CSV, TSV, TXT, XLS, XLSX, XML, and JSON files.
|
121 |
-
Can now generate charts and graphs!
|
122 |
-
Try a sample file to get started!</p>
|
123 |
-
<p style='text-align:center;'>This tool is under active development. If you experience bugs with use,
|
124 |
-
open a discussion in the community tab and I will respond.</p>""")
|
125 |
-
example_file_1 = gr.File(visible=False, value="samples/bank_marketing_campaign.csv")
|
126 |
-
example_file_2 = gr.File(visible=False, value="samples/online_retail_data.csv")
|
127 |
-
with gr.Row():
|
128 |
-
example_btn_1 = gr.Button(value="Try Me: bank_marketing_campaign.csv", elem_classes="example_btn", size="md", variant="primary")
|
129 |
-
example_btn_2 = gr.Button(value="Try Me: online_retail_data.csv", elem_classes="example_btn", size="md", variant="primary")
|
130 |
-
|
131 |
-
file_output = gr.File(label="Data File (CSV, TSV, TXT, XLS, XLSX, XML, JSON)", show_label=True, elem_classes="file_marker", file_types=['.csv','.xlsx','.txt','.json','.ndjson','.xml','.xls','.tsv'])
|
132 |
-
example_btn_1.click(fn=run_example, inputs=example_file_1, outputs=file_output)
|
133 |
-
example_btn_2.click(fn=run_example, inputs=example_file_2, outputs=file_output)
|
134 |
-
file_output.change(fn=example_display, inputs=file_output, outputs=[example_btn_1, example_btn_2])
|
135 |
-
|
136 |
-
@gr.render(inputs=file_output)
|
137 |
-
def data_options(filename, request: gr.Request):
|
138 |
-
print(filename)
|
139 |
-
message_dict[request.session_hash] = None
|
140 |
-
if filename:
|
141 |
-
process_upload(filename, request.session_hash)
|
142 |
-
if "bank_marketing_campaign" in filename:
|
143 |
-
example_questions = [
|
144 |
-
["Describe the dataset"],
|
145 |
-
["What levels of education have the highest and lowest average balance?"],
|
146 |
-
["What job is most and least common for a yes response from the individuals, not counting 'unknown'?"],
|
147 |
-
["Can you generate a bar chart of education vs. average balance?"],
|
148 |
-
["Can you generate a table of levels of education versus average balance, percent married, percent with a loan, and percent in default?"]
|
149 |
-
]
|
150 |
-
elif "online_retail_data" in filename:
|
151 |
-
example_questions = [
|
152 |
-
["Describe the dataset"],
|
153 |
-
["What month had the highest revenue?"],
|
154 |
-
["Is revenue higher in the morning or afternoon?"],
|
155 |
-
["Can you generate a line graph of revenue per month?"],
|
156 |
-
["Can you generate a table of revenue per month?"]
|
157 |
-
]
|
158 |
-
else:
|
159 |
-
try:
|
160 |
-
generated_examples = ast.literal_eval(example_question_generator(request.session_hash))
|
161 |
-
example_questions = [
|
162 |
-
["Describe the dataset"]
|
163 |
-
]
|
164 |
-
for example in generated_examples:
|
165 |
-
example_questions.append([example])
|
166 |
-
except:
|
167 |
-
example_questions = [
|
168 |
-
["Describe the dataset"],
|
169 |
-
["List the columns in the dataset"],
|
170 |
-
["What could this data be used for?"],
|
171 |
-
]
|
172 |
-
parameters = gr.Textbox(visible=False, value=request.session_hash)
|
173 |
-
bot = gr.Chatbot(type='messages', label="CSV Chat Window", render_markdown=True, sanitize_html=False, show_label=True, render=False, visible=True, elem_classes="chatbot")
|
174 |
-
chat = gr.ChatInterface(
|
175 |
-
fn=chatbot_with_fc,
|
176 |
-
type='messages',
|
177 |
-
chatbot=bot,
|
178 |
-
title="Chat with your data file",
|
179 |
-
concurrency_limit=None,
|
180 |
-
examples=example_questions,
|
181 |
-
additional_inputs=parameters
|
182 |
-
)
|
183 |
-
|
184 |
-
def process_upload(upload_value, session_hash):
|
185 |
-
if upload_value:
|
186 |
-
process_data_upload(upload_value, session_hash)
|
187 |
-
return [], []
|
188 |
-
|
189 |
-
demo.unload(delete_db)
|
190 |
-
|
191 |
|
|
|
1 |
+
from utils import TEMP_DIR, message_dict
|
|
|
|
|
|
|
2 |
|
3 |
from haystack.dataclasses import ChatMessage
|
4 |
from haystack.components.generators.chat import OpenAIChatGenerator
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
chat_generator = OpenAIChatGenerator(model="gpt-4o")
|
7 |
response = None
|
|
|
8 |
|
9 |
def example_question_generator(session_hash):
|
10 |
import sqlite3
|
|
|
37 |
|
38 |
def chatbot_with_fc(message, history, session_hash):
|
39 |
from functions import sqlite_query_func, chart_generation_func, table_generation_func
|
|
|
40 |
import tools
|
41 |
|
42 |
+
available_functions = {"sql_query_func": sqlite_query_func, "chart_generation_func": chart_generation_func, "table_generation_func":table_generation_func }
|
43 |
|
44 |
if message_dict[session_hash] != None:
|
45 |
message_dict[session_hash].append(ChatMessage.from_user(message))
|
|
|
47 |
messages = [
|
48 |
ChatMessage.from_system(
|
49 |
"""You are a helpful and knowledgeable agent who has access to an SQLite database which has a table called 'data_source'.
|
50 |
+
You also have access to a chart function that uses plotly dictionaries to generate charts and graphs.
|
51 |
You also have access to a function, called table_generation_func, that builds table formatted html and generates a link to download as CSV."""
|
52 |
)
|
53 |
]
|
|
|
80 |
break
|
81 |
return response["replies"][0].text
|
82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
|
pipelines/__init__.py
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
from .pipelines import rag_pipeline_func
|
2 |
-
|
3 |
-
__all__ = ["rag_pipeline_func"]
|
|
|
|
|
|
|
|
pipelines/pipelines.py
DELETED
@@ -1,91 +0,0 @@
|
|
1 |
-
from haystack import Pipeline
|
2 |
-
from haystack.components.builders import PromptBuilder
|
3 |
-
from haystack.components.generators.openai import OpenAIGenerator
|
4 |
-
from haystack.components.routers import ConditionalRouter
|
5 |
-
|
6 |
-
from functions import SQLiteQuery
|
7 |
-
|
8 |
-
from typing import List
|
9 |
-
import sqlite3
|
10 |
-
|
11 |
-
import os
|
12 |
-
from getpass import getpass
|
13 |
-
from dotenv import load_dotenv
|
14 |
-
|
15 |
-
load_dotenv()
|
16 |
-
|
17 |
-
if "OPENAI_API_KEY" not in os.environ:
|
18 |
-
os.environ["OPENAI_API_KEY"] = getpass("Enter OpenAI API key:")
|
19 |
-
|
20 |
-
from haystack.components.builders import PromptBuilder
|
21 |
-
from haystack.components.generators import OpenAIGenerator
|
22 |
-
|
23 |
-
llm = OpenAIGenerator(model="gpt-4o")
|
24 |
-
def rag_pipeline_func(queries: str, session_hash):
|
25 |
-
sql_query = SQLiteQuery(f'data_source_{session_hash}.db')
|
26 |
-
|
27 |
-
connection = sqlite3.connect(f'data_source_{session_hash}.db')
|
28 |
-
cur=connection.execute('select * from data_source')
|
29 |
-
columns = [i[0] for i in cur.description]
|
30 |
-
cur.close()
|
31 |
-
|
32 |
-
#Rag Pipeline
|
33 |
-
prompt = PromptBuilder(template="""Please generate an SQL query. The query should answer the following Question: {{question}};
|
34 |
-
If the question cannot be answered given the provided table and columns, return 'no_answer'
|
35 |
-
The query is to be answered for the table is called 'data_source' with the following
|
36 |
-
Columns: {{columns}};
|
37 |
-
Answer:""")
|
38 |
-
|
39 |
-
routes = [
|
40 |
-
{
|
41 |
-
"condition": "{{'no_answer' not in replies[0]}}",
|
42 |
-
"output": "{{replies}}",
|
43 |
-
"output_name": "sql",
|
44 |
-
"output_type": List[str],
|
45 |
-
},
|
46 |
-
{
|
47 |
-
"condition": "{{'no_answer' in replies[0]}}",
|
48 |
-
"output": "{{question}}",
|
49 |
-
"output_name": "go_to_fallback",
|
50 |
-
"output_type": str,
|
51 |
-
},
|
52 |
-
]
|
53 |
-
|
54 |
-
router = ConditionalRouter(routes)
|
55 |
-
|
56 |
-
fallback_prompt = PromptBuilder(template="""User entered a query that cannot be answered with the given table.
|
57 |
-
The query was: {{question}} and the table had columns: {{columns}}.
|
58 |
-
Let the user know why the question cannot be answered""")
|
59 |
-
fallback_llm = OpenAIGenerator(model="gpt-4")
|
60 |
-
|
61 |
-
conditional_sql_pipeline = Pipeline()
|
62 |
-
conditional_sql_pipeline.add_component("prompt", prompt)
|
63 |
-
conditional_sql_pipeline.add_component("llm", llm)
|
64 |
-
conditional_sql_pipeline.add_component("router", router)
|
65 |
-
conditional_sql_pipeline.add_component("fallback_prompt", fallback_prompt)
|
66 |
-
conditional_sql_pipeline.add_component("fallback_llm", fallback_llm)
|
67 |
-
conditional_sql_pipeline.add_component("sql_querier", sql_query)
|
68 |
-
|
69 |
-
conditional_sql_pipeline.connect("prompt", "llm")
|
70 |
-
conditional_sql_pipeline.connect("llm.replies", "router.replies")
|
71 |
-
conditional_sql_pipeline.connect("router.sql", "sql_querier.queries")
|
72 |
-
conditional_sql_pipeline.connect("router.go_to_fallback", "fallback_prompt.question")
|
73 |
-
conditional_sql_pipeline.connect("fallback_prompt", "fallback_llm")
|
74 |
-
|
75 |
-
print("RAG PIPELINE FUNCTION")
|
76 |
-
result = conditional_sql_pipeline.run({"prompt": {"question": queries,
|
77 |
-
"columns": columns},
|
78 |
-
"router": {"question": queries},
|
79 |
-
"fallback_prompt": {"columns": columns}})
|
80 |
-
|
81 |
-
if 'sql_querier' in result:
|
82 |
-
reply = result['sql_querier']['results'][0]
|
83 |
-
elif 'fallback_llm' in result:
|
84 |
-
reply = result['fallback_llm']['replies'][0]
|
85 |
-
else:
|
86 |
-
reply = result["llm"]["replies"][0]
|
87 |
-
|
88 |
-
print("reply content")
|
89 |
-
print(reply.content)
|
90 |
-
|
91 |
-
return {"reply": reply.content}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
@@ -2,5 +2,5 @@ haystack-ai
|
|
2 |
python-dotenv
|
3 |
gradio
|
4 |
pandas
|
5 |
-
|
6 |
openpyxl
|
|
|
2 |
python-dotenv
|
3 |
gradio
|
4 |
pandas
|
5 |
+
plotly
|
6 |
openpyxl
|
tools.py
CHANGED
@@ -29,7 +29,7 @@ def tools_call(session_hash):
|
|
29 |
}
|
30 |
}
|
31 |
},
|
32 |
-
"required": ["
|
33 |
},
|
34 |
},
|
35 |
},
|
@@ -44,17 +44,22 @@ def tools_call(session_hash):
|
|
44 |
"parameters": {
|
45 |
"type": "object",
|
46 |
"properties": {
|
47 |
-
"
|
48 |
"type": "array",
|
49 |
-
"description": """The
|
50 |
-
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
52 |
"items": {
|
53 |
"type": "string",
|
54 |
}
|
55 |
}
|
56 |
},
|
57 |
-
"required": ["
|
58 |
},
|
59 |
},
|
60 |
},
|
@@ -82,7 +87,7 @@ def tools_call(session_hash):
|
|
82 |
}
|
83 |
}
|
84 |
},
|
85 |
-
"required": ["
|
86 |
},
|
87 |
},
|
88 |
}
|
|
|
29 |
}
|
30 |
}
|
31 |
},
|
32 |
+
"required": ["queries"],
|
33 |
},
|
34 |
},
|
35 |
},
|
|
|
44 |
"parameters": {
|
45 |
"type": "object",
|
46 |
"properties": {
|
47 |
+
"data": {
|
48 |
"type": "array",
|
49 |
+
"description": """The list containing a dictionary that contains the 'data' portion of the plotly chart generation. Infer this from the user's message.""",
|
50 |
+
"items": {
|
51 |
+
"type": "string",
|
52 |
+
}
|
53 |
+
},
|
54 |
+
"layout": {
|
55 |
+
"type": "array",
|
56 |
+
"description": """The dictionary that contains the 'layout' portion of the plotly chart generation""",
|
57 |
"items": {
|
58 |
"type": "string",
|
59 |
}
|
60 |
}
|
61 |
},
|
62 |
+
"required": ["data"],
|
63 |
},
|
64 |
},
|
65 |
},
|
|
|
87 |
}
|
88 |
}
|
89 |
},
|
90 |
+
"required": ["data"],
|
91 |
},
|
92 |
},
|
93 |
}
|
utils.py
CHANGED
@@ -2,4 +2,6 @@ from pathlib import Path
|
|
2 |
|
3 |
current_dir = Path(__file__).parent
|
4 |
|
5 |
-
TEMP_DIR = current_dir / 'temp'
|
|
|
|
|
|
2 |
|
3 |
current_dir = Path(__file__).parent
|
4 |
|
5 |
+
TEMP_DIR = current_dir / 'temp'
|
6 |
+
|
7 |
+
message_dict = {}
|