nolanzandi commited on
Commit
24371db
·
verified ·
1 Parent(s): 85079bb

Upload 11 files

Browse files

initial demo files

__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .main import data_url
2
+
3
+ __all__ = ["data_url"]
data_sources/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .upload_file import process_data_upload
2
+
3
+ __all__ = ["process_data_upload"]
data_sources/upload_file.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import sqlite3
3
+
4
+ def process_data_upload(data_file):
5
+
6
+ df = pd.read_csv(data_file, sep=";")
7
+
8
+ # Read each sheet and store data in a DataFrame
9
+ #data = df.parse(sheet_name)
10
+ # Process the data as needed
11
+ # ...
12
+ df.columns = df.columns.str.replace(' ', '_')
13
+ df.columns = df.columns.str.replace('/', '_')
14
+
15
+ connection = sqlite3.connect('data_source.db')
16
+ print("Opened database successfully");
17
+ print(df.columns)
18
+
19
+ df.to_sql('data_source', connection, if_exists='replace', index = False)
20
+
21
+ connection.commit()
22
+ connection.close()
functions/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .sqlite_functions import SQLiteQuery, sqlite_query_func
2
+ from .chat_functions import demo
3
+
4
+ __all__ = ["SQLiteQuery","sqlite_query_func","demo"]
functions/chat_functions.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from data_sources import process_data_upload
2
+
3
+ import gradio as gr
4
+ import json
5
+
6
+ from haystack.dataclasses import ChatMessage
7
+ from haystack.components.generators.chat import OpenAIChatGenerator
8
+
9
+ import os
10
+ from getpass import getpass
11
+ from dotenv import load_dotenv
12
+
13
+ load_dotenv()
14
+
15
+ if "OPENAI_API_KEY" not in os.environ:
16
+ os.environ["OPENAI_API_KEY"] = getpass("Enter OpenAI API key:")
17
+
18
+ chat_generator = OpenAIChatGenerator(model="gpt-4o")
19
+ response = None
20
+ messages = [
21
+ ChatMessage.from_system(
22
+ "You are a helpful and knowledgeable agent who has access to an SQL database which has a table called 'data_source'"
23
+ )
24
+ ]
25
+
26
+ def chatbot_with_fc(message, history):
27
+ print("CHATBOT FUNCTIONS")
28
+ from functions import sqlite_query_func
29
+ from pipelines import rag_pipeline_func
30
+ import tools
31
+ import importlib
32
+ importlib.reload(tools)
33
+
34
+ available_functions = {"sql_query_func": sqlite_query_func, "rag_pipeline_func": rag_pipeline_func}
35
+ messages.append(ChatMessage.from_user(message))
36
+ response = chat_generator.run(messages=messages, generation_kwargs={"tools": tools.tools})
37
+
38
+ while True:
39
+ # if OpenAI response is a tool call
40
+ if response and response["replies"][0].meta["finish_reason"] == "tool_calls":
41
+ function_calls = json.loads(response["replies"][0].content)
42
+ for function_call in function_calls:
43
+ ## Parse function calling information
44
+ function_name = function_call["function"]["name"]
45
+ function_args = json.loads(function_call["function"]["arguments"])
46
+
47
+ ## Find the correspoding function and call it with the given arguments
48
+ function_to_call = available_functions[function_name]
49
+ function_response = function_to_call(**function_args)
50
+ ## Append function response to the messages list using `ChatMessage.from_function`
51
+ messages.append(ChatMessage.from_function(content=function_response['reply'], name=function_name))
52
+ response = chat_generator.run(messages=messages, generation_kwargs={"tools": tools.tools})
53
+
54
+ # Regular Conversation
55
+ else:
56
+ messages.append(response["replies"][0])
57
+ break
58
+ return response["replies"][0].content
59
+
60
+ css= ".file_marker .large{min-height:50px !important;}"
61
+
62
+ with gr.Blocks(css=css) as demo:
63
+ title = gr.HTML("<h1 style='text-align:center;'>Virtual Data Analyst</h1>")
64
+ description = gr.HTML("<p style='text-align:center;'>Upload a CSV file and chat with our virtual data analyst to get insights on your data set</p>")
65
+ file_output = gr.File(label="CSV File", show_label=True, elem_classes="file_marker", file_types=['.csv'])
66
+
67
+ @gr.render(inputs=file_output)
68
+ def data_options(filename):
69
+ print(filename)
70
+ if filename:
71
+ bot = gr.Chatbot(type='messages', label="CSV Chat Window", show_label=True, render=False, visible=True, elem_classes="chatbot")
72
+ chat = gr.ChatInterface(
73
+ fn=chatbot_with_fc,
74
+ type='messages',
75
+ chatbot=bot,
76
+ title="Chat with your data file",
77
+ examples=[
78
+ ["Describe the dataset"],
79
+ ["List the columns in the dataset"],
80
+ ["What could this data be used for?"],
81
+ ],
82
+ )
83
+
84
+ process_upload(filename)
85
+
86
+ def process_upload(upload_value):
87
+ if upload_value:
88
+ print("UPLOAD VALUE")
89
+ print(upload_value)
90
+ process_data_upload(upload_value)
91
+ return [], []
92
+
93
+
functions/sqlite_functions.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from haystack import component
3
+ import pandas as pd
4
+ import sqlite3
5
+
6
+ @component
7
+ class SQLiteQuery:
8
+
9
+ def __init__(self, sql_database: str):
10
+ self.connection = sqlite3.connect(sql_database, check_same_thread=False)
11
+
12
+ @component.output_types(results=List[str], queries=List[str])
13
+ def run(self, queries: List[str]):
14
+ results = []
15
+ for query in queries:
16
+ result = pd.read_sql(query, self.connection)
17
+ results.append(f"{result}")
18
+ self.connection.close()
19
+ return {"results": results, "queries": queries}
20
+
21
+
22
+ sql_query = SQLiteQuery('data_source.db')
23
+
24
+ def sqlite_query_func(queries: List[str]):
25
+ try:
26
+ result = sql_query.run(queries)
27
+ return {"reply": result["results"][0]}
28
+
29
+ except Exception as e:
30
+ reply = f"""There was an error running the SQL Query = {queries}
31
+ The error is {e},
32
+ You should probably try again.
33
+ """
34
+ return {"reply": reply}
main.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functions import demo
2
+
3
+ import os
4
+ from getpass import getpass
5
+ from dotenv import load_dotenv
6
+
7
+ load_dotenv()
8
+
9
+ if "OPENAI_API_KEY" not in os.environ:
10
+ os.environ["OPENAI_API_KEY"] = getpass("Enter OpenAI API key:")
11
+
12
+ ## Uncomment the line below to launch the chat app with UI
13
+ demo.launch(debug=True, share=True)
pipelines/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .pipelines import conditional_sql_pipeline, rag_pipeline_func
2
+
3
+ __all__ = ["conditional_sql_pipeline", "rag_pipeline_func"]
pipelines/pipelines.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from haystack import Pipeline
2
+ from haystack.components.builders import PromptBuilder
3
+ from haystack.components.generators.openai import OpenAIGenerator
4
+ from haystack.components.routers import ConditionalRouter
5
+
6
+ from functions import SQLiteQuery
7
+
8
+ from typing import List
9
+ import sqlite3
10
+
11
+ import os
12
+ from getpass import getpass
13
+ from dotenv import load_dotenv
14
+
15
+ load_dotenv()
16
+
17
+ if "OPENAI_API_KEY" not in os.environ:
18
+ os.environ["OPENAI_API_KEY"] = getpass("Enter OpenAI API key:")
19
+ '''
20
+ prompt = PromptBuilder(template="""Please generate an SQL query. The query should answer the following Question: {{question}};
21
+ The query is to be answered for the table is called 'data_source' with the following
22
+ Columns: {{columns}};
23
+ Answer:""")
24
+ sql_query = SQLQuery('data_source.db')
25
+ llm = OpenAIGenerator(model="gpt-4")
26
+
27
+ sql_pipeline = Pipeline()
28
+ sql_pipeline.add_component("prompt", prompt)
29
+ sql_pipeline.add_component("llm", llm)
30
+ sql_pipeline.add_component("sql_querier", sql_query)
31
+
32
+ sql_pipeline.connect("prompt", "llm")
33
+ sql_pipeline.connect("llm.replies", "sql_querier.queries")
34
+
35
+ # If you want to draw the pipeline, uncomment below 👇
36
+ sql_pipeline.show()
37
+ print("PIPELINE RUNNING")
38
+ result = sql_pipeline.run({"prompt": {"question": "On which days of the week are average sales highest?",
39
+ "columns": columns}})
40
+
41
+ print(result["sql_querier"]["results"][0])
42
+ '''
43
+ from haystack.components.builders import PromptBuilder
44
+ from haystack.components.generators import OpenAIGenerator
45
+
46
+ llm = OpenAIGenerator(model="gpt-4o")
47
+ sql_query = SQLiteQuery('data_source.db')
48
+
49
+ connection = sqlite3.connect('data_source.db')
50
+ cur=connection.execute('select * from data_source')
51
+ columns = [i[0] for i in cur.description]
52
+ print("COLUMNS 2")
53
+ print(columns)
54
+ cur.close()
55
+
56
+ #Rag Pipeline
57
+ prompt = PromptBuilder(template="""Please generate an SQL query. The query should answer the following Question: {{question}};
58
+ If the question cannot be answered given the provided table and columns, return 'no_answer'
59
+ The query is to be answered for the table is called 'data_source' with the following
60
+ Columns: {{columns}};
61
+ Answer:""")
62
+
63
+ routes = [
64
+ {
65
+ "condition": "{{'no_answer' not in replies[0]}}",
66
+ "output": "{{replies}}",
67
+ "output_name": "sql",
68
+ "output_type": List[str],
69
+ },
70
+ {
71
+ "condition": "{{'no_answer' in replies[0]}}",
72
+ "output": "{{question}}",
73
+ "output_name": "go_to_fallback",
74
+ "output_type": str,
75
+ },
76
+ ]
77
+
78
+ router = ConditionalRouter(routes)
79
+
80
+ fallback_prompt = PromptBuilder(template="""User entered a query that cannot be answered with the given table.
81
+ The query was: {{question}} and the table had columns: {{columns}}.
82
+ Let the user know why the question cannot be answered""")
83
+ fallback_llm = OpenAIGenerator(model="gpt-4")
84
+
85
+ conditional_sql_pipeline = Pipeline()
86
+ conditional_sql_pipeline.add_component("prompt", prompt)
87
+ conditional_sql_pipeline.add_component("llm", llm)
88
+ conditional_sql_pipeline.add_component("router", router)
89
+ conditional_sql_pipeline.add_component("fallback_prompt", fallback_prompt)
90
+ conditional_sql_pipeline.add_component("fallback_llm", fallback_llm)
91
+ conditional_sql_pipeline.add_component("sql_querier", sql_query)
92
+
93
+ conditional_sql_pipeline.connect("prompt", "llm")
94
+ conditional_sql_pipeline.connect("llm.replies", "router.replies")
95
+ conditional_sql_pipeline.connect("router.sql", "sql_querier.queries")
96
+ conditional_sql_pipeline.connect("router.go_to_fallback", "fallback_prompt.question")
97
+ conditional_sql_pipeline.connect("fallback_prompt", "fallback_llm")
98
+
99
+ question = "When is my birthday?"
100
+ result = conditional_sql_pipeline.run({"prompt": {"question": question,
101
+ "columns": columns},
102
+ "router": {"question": question},
103
+ "fallback_prompt": {"columns": columns}})
104
+
105
+
106
+ def rag_pipeline_func(question: str, columns: str):
107
+ result = conditional_sql_pipeline.run({"prompt": {"question": question,
108
+ "columns": columns},
109
+ "router": {"question": question},
110
+ "fallback_prompt": {"columns": columns}})
111
+
112
+ if 'sql_querier' in result:
113
+ reply = result['sql_querier']['results'][0]
114
+ elif 'fallback_llm' in result:
115
+ reply = result['fallback_llm']['replies'][0]
116
+ else:
117
+ reply = result["llm"]["replies"][0]
118
+
119
+ print("reply content")
120
+ print(reply.content)
121
+
122
+ return {"reply": reply.content}
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ haystack-ai
2
+ hayhooks
3
+ sentence-transformers>=3.0.0
4
+ python-dotenv
5
+ gradio
6
+ pandas
7
+ openpyxl
8
+ snowflake-haystack
9
+ psutil
tools.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+
3
+ connection = sqlite3.connect('data_source.db')
4
+ print("Querying Database in Tools.py");
5
+ cur=connection.execute('select * from data_source')
6
+ columns = [i[0] for i in cur.description]
7
+ print("COLUMNS 2")
8
+ print(columns)
9
+ cur.close()
10
+
11
+ tools = [
12
+ {
13
+ "type": "function",
14
+ "function": {
15
+ "name": "sql_query_func",
16
+ "description": f"This a tool useful to query a SQL table called 'data_source' with the following Columns: {columns}",
17
+ "parameters": {
18
+ "type": "object",
19
+ "properties": {
20
+ "queries": {
21
+ "type": "array",
22
+ "description": "The query to use in the search. Infer this from the user's message. It should be a question or a statement",
23
+ "items": {
24
+ "type": "string",
25
+ }
26
+ }
27
+ },
28
+ "required": ["question"],
29
+ },
30
+ },
31
+ },
32
+ {
33
+ "type": "function",
34
+ "function": {
35
+ "name": "rag_pipeline_func",
36
+ "description": f"This a tool useful to query a SQL table called 'data_source' with the following Columns: {columns}",
37
+ "parameters": {
38
+ "type": "object",
39
+ "properties": {
40
+ "query": {
41
+ "type": "array",
42
+ "description": "The query to use in the search. Infer this from the user's message. It should be a question or a statement",
43
+ "items": {
44
+ "type": "string",
45
+ }
46
+ }
47
+ },
48
+ "required": ["query"],
49
+ },
50
+ },
51
+ }
52
+ ]