virtual-data-analyst / functions /chat_functions.py
nolanzandi's picture
Table generation and download
e7b4bfb verified
raw
history blame
10.2 kB
from data_sources import process_data_upload
from utils import TEMP_DIR
import gradio as gr
from haystack.dataclasses import ChatMessage
from haystack.components.generators.chat import OpenAIChatGenerator
import os
import ast
from getpass import getpass
from dotenv import load_dotenv
load_dotenv()
if "OPENAI_API_KEY" not in os.environ:
os.environ["OPENAI_API_KEY"] = getpass("Enter OpenAI API key:")
chat_generator = OpenAIChatGenerator(model="gpt-4o")
response = None
message_dict = {}
def example_question_generator(session_hash):
import sqlite3
example_response = None
example_messages = [
ChatMessage.from_system(
"You are a helpful and knowledgeable agent who has access to an SQLite database which has a table called 'data_source'."
)
]
dir_path = TEMP_DIR / str(session_hash)
connection = sqlite3.connect(f'{dir_path}/data_source.db')
print("Querying questions");
cur=connection.execute('select * from data_source')
columns = [i[0] for i in cur.description]
print("QUESTION COLUMNS")
print(columns)
cur.close()
connection.close()
example_messages.append(ChatMessage.from_user(text=f"""We have a SQLite database with the following {columns}.
We also have an AI agent with access to the same database that will be performing data analysis.
Please return an array of seven strings, each one being a question for our data analysis agent
that we can suggest that you believe will be insightful or helpful to a data analysis looking for
data insights. Return nothing more than the array of questions because I need that specific data structure
to process your response. No other response type or data structure will work."""))
example_response = chat_generator.run(messages=example_messages)
return example_response["replies"][0].text
def chatbot_with_fc(message, history, session_hash):
from functions import sqlite_query_func, chart_generation_func, table_generation_func
from pipelines import rag_pipeline_func
import tools
available_functions = {"sql_query_func": sqlite_query_func, "rag_pipeline_func": rag_pipeline_func, "chart_generation_func": chart_generation_func, "table_generation_func":table_generation_func }
if message_dict[session_hash] != None:
message_dict[session_hash].append(ChatMessage.from_user(message))
else:
messages = [
ChatMessage.from_system(
"You are a helpful and knowledgeable agent who has access to an SQLite database which has a table called 'data_source'. You also have access to a chart API that uses chart.js dictionaries formatted as a string to generate charts and graphs."
)
]
messages.append(ChatMessage.from_user(message))
message_dict[session_hash] = messages
response = chat_generator.run(messages=message_dict[session_hash], generation_kwargs={"tools": tools.tools_call(session_hash)})
while True:
# if OpenAI response is a tool call
if response and response["replies"][0].meta["finish_reason"] == "tool_calls" or response["replies"][0].tool_calls:
function_calls = response["replies"][0].tool_calls
for function_call in function_calls:
message_dict[session_hash].append(ChatMessage.from_assistant(tool_calls=[function_call]))
## Parse function calling information
function_name = function_call.tool_name
function_args = function_call.arguments
## Find the corresponding function and call it with the given arguments
function_to_call = available_functions[function_name]
function_response = function_to_call(**function_args, session_hash=session_hash)
print(function_name)
## Append function response to the messages list using `ChatMessage.from_tool`
message_dict[session_hash].append(ChatMessage.from_tool(tool_result=function_response['reply'], origin=function_call))
response = chat_generator.run(messages=message_dict[session_hash], generation_kwargs={"tools": tools.tools_call(session_hash)})
# Regular Conversation
else:
message_dict[session_hash].append(response["replies"][0])
break
return response["replies"][0].text
def delete_db(req: gr.Request):
import shutil
dir_path = TEMP_DIR / str(req.session_hash)
if os.path.exists(dir_path):
shutil.rmtree(dir_path)
message_dict[req.session_hash] = None
def run_example(input):
return input
def example_display(input):
if input == None:
display = True
else:
display = False
return [gr.update(visible=display),gr.update(visible=display)]
css= ".file_marker .large{min-height:50px !important;} .example_btn{max-width:300px;}"
with gr.Blocks(css=css, delete_cache=(3600,3600)) as demo:
title = gr.HTML("<h1 style='text-align:center;'>Virtual Data Analyst</h1>")
description = gr.HTML("""<p style='text-align:center;'>Upload a data file and chat with our virtual data analyst
to get insights on your data set. Currently accepts CSV, TSV, TXT, XLS, XLSX, XML, and JSON files.
Can now generate charts and graphs!
Try a sample file to get started!</p>
<p style='text-align:center;'>This tool is under active development. If you experience bugs with use,
open a discussion in the community tab and I will respond.</p>""")
example_file_1 = gr.File(visible=False, value="samples/bank_marketing_campaign.csv")
example_file_2 = gr.File(visible=False, value="samples/online_retail_data.csv")
with gr.Row():
example_btn_1 = gr.Button(value="Try Me: bank_marketing_campaign.csv", elem_classes="example_btn", size="md", variant="primary")
example_btn_2 = gr.Button(value="Try Me: online_retail_data.csv", elem_classes="example_btn", size="md", variant="primary")
file_output = gr.File(label="Data File (CSV, TSV, TXT, XLS, XLSX, XML, JSON)", show_label=True, elem_classes="file_marker", file_types=['.csv','.xlsx','.txt','.json','.ndjson','.xml','.xls','.tsv'])
example_btn_1.click(fn=run_example, inputs=example_file_1, outputs=file_output)
example_btn_2.click(fn=run_example, inputs=example_file_2, outputs=file_output)
file_output.change(fn=example_display, inputs=file_output, outputs=[example_btn_1, example_btn_2])
@gr.render(inputs=file_output)
def data_options(filename, request: gr.Request):
print(filename)
message_dict[request.session_hash] = None
if filename:
process_upload(filename, request.session_hash)
if "bank_marketing_campaign" in filename:
example_questions = [
["Describe the dataset"],
["What levels of education have the highest and lowest average balance?"],
["What job is most and least common for a yes response from the individuals, not counting 'unknown'?"],
["Can you generate a bar chart of education vs. average balance?"],
["Can you generate a table of levels of education versus average balance, percent married, percent with a loan, and percent in default?"]
]
elif "online_retail_data" in filename:
example_questions = [
["Describe the dataset"],
["What month had the highest revenue?"],
["Is revenue higher in the morning or afternoon?"],
["Can you generate a line graph of revenue per month?"],
["Can you generate a table of revenue per month?"]
]
else:
try:
generated_examples = ast.literal_eval(example_question_generator(request.session_hash))
example_questions = [
["Describe the dataset"]
]
for example in generated_examples:
example_questions.append([example])
except:
example_questions = [
["Describe the dataset"],
["List the columns in the dataset"],
["What could this data be used for?"],
]
parameters = gr.Textbox(visible=False, value=request.session_hash)
bot = gr.Chatbot(type='messages', label="CSV Chat Window", render_markdown=True, sanitize_html=False, show_label=True, render=False, visible=True, elem_classes="chatbot")
chat = gr.ChatInterface(
fn=chatbot_with_fc,
type='messages',
chatbot=bot,
title="Chat with your data file",
concurrency_limit=None,
examples=example_questions,
additional_inputs=parameters
)
def process_upload(upload_value, session_hash):
if upload_value:
process_data_upload(upload_value, session_hash)
return [], []
demo.unload(delete_db)