nolanzandi commited on
Commit
95c52e2
·
verified ·
1 Parent(s): 63c3a67

Create example questions when file is uploaded (#11)

Browse files

- Create example questions when file is uploaded (cfbedd80b99e0cbb428c01c8b685d2e7158e40c6)

Files changed (1) hide show
  1. functions/chat_functions.py +58 -13
functions/chat_functions.py CHANGED
@@ -6,6 +6,7 @@ from haystack.dataclasses import ChatMessage
6
  from haystack.components.generators.chat import OpenAIChatGenerator
7
 
8
  import os
 
9
  from getpass import getpass
10
  from dotenv import load_dotenv
11
 
@@ -16,11 +17,35 @@ if "OPENAI_API_KEY" not in os.environ:
16
 
17
  chat_generator = OpenAIChatGenerator(model="gpt-4o")
18
  response = None
19
- messages = [
20
- ChatMessage.from_system(
21
- "You are a helpful and knowledgeable agent who has access to an SQLite database which has a table called 'data_source'. You also have access to a chart API that uses chart.js dictionaries formatted as a string to generate charts and graphs."
22
- )
23
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  def chatbot_with_fc(message, history, session_hash):
26
  from functions import sqlite_query_func, chart_generation_func
@@ -29,15 +54,25 @@ def chatbot_with_fc(message, history, session_hash):
29
 
30
  available_functions = {"sql_query_func": sqlite_query_func, "rag_pipeline_func": rag_pipeline_func, "chart_generation_func": chart_generation_func}
31
 
32
- messages.append(ChatMessage.from_user(message))
33
- response = chat_generator.run(messages=messages, generation_kwargs={"tools": tools.tools_call(session_hash)})
 
 
 
 
 
 
 
 
 
 
34
 
35
  while True:
36
  # if OpenAI response is a tool call
37
  if response and response["replies"][0].meta["finish_reason"] == "tool_calls" or response["replies"][0].tool_calls:
38
  function_calls = response["replies"][0].tool_calls
39
  for function_call in function_calls:
40
- messages.append(ChatMessage.from_assistant(tool_calls=[function_call]))
41
  ## Parse function calling information
42
  function_name = function_call.tool_name
43
  function_args = function_call.arguments
@@ -47,12 +82,12 @@ def chatbot_with_fc(message, history, session_hash):
47
  function_response = function_to_call(**function_args, session_hash=session_hash)
48
  print(function_name)
49
  ## Append function response to the messages list using `ChatMessage.from_tool`
50
- messages.append(ChatMessage.from_tool(tool_result=function_response['reply'], origin=function_call))
51
- response = chat_generator.run(messages=messages, generation_kwargs={"tools": tools.tools_call(session_hash)})
52
 
53
  # Regular Conversation
54
  else:
55
- messages.append(response["replies"][0])
56
  break
57
  return response["replies"][0].text
58
 
@@ -60,6 +95,7 @@ def delete_db(req: gr.Request):
60
  db_file_path = f'data_source_{req.session_hash}.db'
61
  if os.path.exists(db_file_path):
62
  os.remove(db_file_path)
 
63
 
64
  def run_example(input):
65
  return input
@@ -95,7 +131,9 @@ with gr.Blocks(css=css, delete_cache=(3600,3600)) as demo:
95
  @gr.render(inputs=file_output)
96
  def data_options(filename, request: gr.Request):
97
  print(filename)
 
98
  if filename:
 
99
  if "bank_marketing_campaign" in filename:
100
  example_questions = [
101
  ["Describe the dataset"],
@@ -111,7 +149,15 @@ with gr.Blocks(css=css, delete_cache=(3600,3600)) as demo:
111
  ["Can you generate a line graph of revenue per month?"]
112
  ]
113
  else:
114
- example_questions = [
 
 
 
 
 
 
 
 
115
  ["Describe the dataset"],
116
  ["List the columns in the dataset"],
117
  ["What could this data be used for?"],
@@ -127,7 +173,6 @@ with gr.Blocks(css=css, delete_cache=(3600,3600)) as demo:
127
  examples=example_questions,
128
  additional_inputs=parameters
129
  )
130
- process_upload(filename, request.session_hash)
131
 
132
  def process_upload(upload_value, session_hash):
133
  if upload_value:
 
6
  from haystack.components.generators.chat import OpenAIChatGenerator
7
 
8
  import os
9
+ import ast
10
  from getpass import getpass
11
  from dotenv import load_dotenv
12
 
 
17
 
18
  chat_generator = OpenAIChatGenerator(model="gpt-4o")
19
  response = None
20
+ message_dict = {}
21
+
22
+ def example_question_generator(session_hash):
23
+ import sqlite3
24
+ example_response = None
25
+ example_messages = [
26
+ ChatMessage.from_system(
27
+ "You are a helpful and knowledgeable agent who has access to an SQLite database which has a table called 'data_source'."
28
+ )
29
+ ]
30
+ connection = sqlite3.connect(f'data_source_{session_hash}.db')
31
+ print("Querying questions");
32
+ cur=connection.execute('select * from data_source')
33
+ columns = [i[0] for i in cur.description]
34
+ print("QUESTION COLUMNS")
35
+ print(columns)
36
+ cur.close()
37
+ connection.close()
38
+
39
+ example_messages.append(ChatMessage.from_user(text=f"""We have a SQLite database with the following {columns}.
40
+ We also have an AI agent with access to the same database that will be performing data analysis.
41
+ Please return an array of seven strings, each one being a question for our data analysis agent
42
+ that we can suggest that you believe will be insightful or helpful to a data analysis looking for
43
+ data insights. Return nothing more than the array of questions because I need that specific data structure
44
+ to process your response. No other response type or data structure will work."""))
45
+
46
+ example_response = chat_generator.run(messages=example_messages)
47
+
48
+ return example_response["replies"][0].text
49
 
50
  def chatbot_with_fc(message, history, session_hash):
51
  from functions import sqlite_query_func, chart_generation_func
 
54
 
55
  available_functions = {"sql_query_func": sqlite_query_func, "rag_pipeline_func": rag_pipeline_func, "chart_generation_func": chart_generation_func}
56
 
57
+ if message_dict[session_hash] != None:
58
+ message_dict[session_hash].append(ChatMessage.from_user(message))
59
+ else:
60
+ messages = [
61
+ ChatMessage.from_system(
62
+ "You are a helpful and knowledgeable agent who has access to an SQLite database which has a table called 'data_source'. You also have access to a chart API that uses chart.js dictionaries formatted as a string to generate charts and graphs."
63
+ )
64
+ ]
65
+ messages.append(ChatMessage.from_user(message))
66
+ message_dict[session_hash] = messages
67
+
68
+ response = chat_generator.run(messages=message_dict[session_hash], generation_kwargs={"tools": tools.tools_call(session_hash)})
69
 
70
  while True:
71
  # if OpenAI response is a tool call
72
  if response and response["replies"][0].meta["finish_reason"] == "tool_calls" or response["replies"][0].tool_calls:
73
  function_calls = response["replies"][0].tool_calls
74
  for function_call in function_calls:
75
+ message_dict[session_hash].append(ChatMessage.from_assistant(tool_calls=[function_call]))
76
  ## Parse function calling information
77
  function_name = function_call.tool_name
78
  function_args = function_call.arguments
 
82
  function_response = function_to_call(**function_args, session_hash=session_hash)
83
  print(function_name)
84
  ## Append function response to the messages list using `ChatMessage.from_tool`
85
+ message_dict[session_hash].append(ChatMessage.from_tool(tool_result=function_response['reply'], origin=function_call))
86
+ response = chat_generator.run(messages=message_dict[session_hash], generation_kwargs={"tools": tools.tools_call(session_hash)})
87
 
88
  # Regular Conversation
89
  else:
90
+ message_dict[session_hash].append(response["replies"][0])
91
  break
92
  return response["replies"][0].text
93
 
 
95
  db_file_path = f'data_source_{req.session_hash}.db'
96
  if os.path.exists(db_file_path):
97
  os.remove(db_file_path)
98
+ message_dict[req.session_hash] = None
99
 
100
  def run_example(input):
101
  return input
 
131
  @gr.render(inputs=file_output)
132
  def data_options(filename, request: gr.Request):
133
  print(filename)
134
+ message_dict[request.session_hash] = None
135
  if filename:
136
+ process_upload(filename, request.session_hash)
137
  if "bank_marketing_campaign" in filename:
138
  example_questions = [
139
  ["Describe the dataset"],
 
149
  ["Can you generate a line graph of revenue per month?"]
150
  ]
151
  else:
152
+ try:
153
+ generated_examples = ast.literal_eval(example_question_generator(request.session_hash))
154
+ example_questions = [
155
+ ["Describe the dataset"]
156
+ ]
157
+ for example in generated_examples:
158
+ example_questions.append([example])
159
+ except:
160
+ example_questions = [
161
  ["Describe the dataset"],
162
  ["List the columns in the dataset"],
163
  ["What could this data be used for?"],
 
173
  examples=example_questions,
174
  additional_inputs=parameters
175
  )
 
176
 
177
  def process_upload(upload_value, session_hash):
178
  if upload_value: