File size: 7,162 Bytes
32f5b77 fb65c41 c895015 32f5b77 c895015 32f5b77 fb65c41 32f5b77 fb65c41 32f5b77 c895015 32f5b77 fb65c41 c895015 32f5b77 fb65c41 32f5b77 fb65c41 fedee8b 00dae37 fedee8b 32f5b77 c895015 00dae37 73a1633 00dae37 32f5b77 fb65c41 32f5b77 8282cb1 c895015 8282cb1 c895015 8282cb1 fb65c41 c895015 32f5b77 fb65c41 8282cb1 fb65c41 32f5b77 fb65c41 32f5b77 fb65c41 32f5b77 fb65c41 32f5b77 fb65c41 32f5b77 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
from data_sources import process_data_upload
import gradio as gr
from haystack.dataclasses import ChatMessage
from haystack.components.generators.chat import OpenAIChatGenerator
import os
from getpass import getpass
from dotenv import load_dotenv
load_dotenv()
if "OPENAI_API_KEY" not in os.environ:
os.environ["OPENAI_API_KEY"] = getpass("Enter OpenAI API key:")
chat_generator = OpenAIChatGenerator(model="gpt-4o")
response = None
messages = [
ChatMessage.from_system(
"You are a helpful and knowledgeable agent who has access to an SQL database which has a table called 'data_source'"
)
]
def chatbot_with_fc(message, history, session_hash):
from functions import sqlite_query_func, chart_generation_func
from pipelines import rag_pipeline_func
import tools
available_functions = {"sql_query_func": sqlite_query_func, "rag_pipeline_func": rag_pipeline_func, "chart_generation_func": chart_generation_func}
messages.append(ChatMessage.from_user(message))
response = chat_generator.run(messages=messages, generation_kwargs={"tools": tools.tools_call(session_hash)})
while True:
# if OpenAI response is a tool call
if response and response["replies"][0].meta["finish_reason"] == "tool_calls" or response["replies"][0].tool_calls:
function_calls = response["replies"][0].tool_calls
for function_call in function_calls:
messages.append(ChatMessage.from_assistant(tool_calls=[function_call]))
## Parse function calling information
function_name = function_call.tool_name
function_args = function_call.arguments
## Find the corresponding function and call it with the given arguments
function_to_call = available_functions[function_name]
function_response = function_to_call(**function_args, session_hash=session_hash)
print(function_name)
## Append function response to the messages list using `ChatMessage.from_tool`
messages.append(ChatMessage.from_tool(tool_result=function_response['reply'], origin=function_call))
response = chat_generator.run(messages=messages, generation_kwargs={"tools": tools.tools_call(session_hash)})
# Regular Conversation
else:
messages.append(response["replies"][0])
break
return response["replies"][0].text
def delete_db(req: gr.Request):
db_file_path = f'data_source_{req.session_hash}.db'
if os.path.exists(db_file_path):
os.remove(db_file_path)
def run_example(input):
return input
def example_display(input):
if input == None:
display = True
else:
display = False
return [gr.update(visible=display),gr.update(visible=display)]
css= ".file_marker .large{min-height:50px !important;} .example_btn{max-width:300px;}"
with gr.Blocks(css=css) as demo:
title = gr.HTML("<h1 style='text-align:center;'>Virtual Data Analyst</h1>")
description = gr.HTML("""<p style='text-align:center;'>Upload a data file and chat with our virtual data analyst
to get insights on your data set. Currently accepts CSV, TSV, TXT, XLS, XLSX, XML, and JSON files.
Can now generate charts and graphs!
Try a sample file to get started!</p>
<p style='text-align:center;'>This tool is under active development. If you experience bugs with use,
open a discussion in the community tab and I will respond.</p>""")
example_file_1 = gr.File(visible=False, value="samples/bank_marketing_campaign.csv")
example_file_2 = gr.File(visible=False, value="samples/online_retail_data.csv")
with gr.Row():
example_btn_1 = gr.Button(value="Try Me: bank_marketing_campaign.csv", elem_classes="example_btn", size="md", variant="primary")
example_btn_2 = gr.Button(value="Try Me: online_retail_data.csv", elem_classes="example_btn", size="md", variant="primary")
file_output = gr.File(label="Data File (CSV, TSV, TXT, XLS, XLSX, XML, JSON)", show_label=True, elem_classes="file_marker", file_types=['.csv','.xlsx','.txt','.json','.xml','.xls','.tsv'])
example_btn_1.click(fn=run_example, inputs=example_file_1, outputs=file_output)
example_btn_2.click(fn=run_example, inputs=example_file_2, outputs=file_output)
file_output.change(fn=example_display, inputs=file_output, outputs=[example_btn_1, example_btn_2])
@gr.render(inputs=file_output)
def data_options(filename, request: gr.Request):
print(filename)
if filename:
if "bank_marketing_campaign" in filename:
example_questions = [
["Describe the dataset"],
["What levels of education have the highest and lowest average balance?"],
["What job is most and least common for a yes response from the individuals, not counting 'unknown'?"],
["Can you generate a bar chart of education vs. average balance?"]
]
elif "online_retail_data" in filename:
example_questions = [
["Describe the dataset"],
["What month had the highest revenue?"],
["Is revenue higher in the morning or afternoon?"],
["Can you generate a line graph of revenue per month?"]
]
else:
example_questions = [
["Describe the dataset"],
["List the columns in the dataset"],
["What could this data be used for?"],
]
parameters = gr.Textbox(visible=False, value=request.session_hash)
bot = gr.Chatbot(type='messages', label="CSV Chat Window", render_markdown=True, sanitize_html=False, show_label=True, render=False, visible=True, elem_classes="chatbot")
chat = gr.ChatInterface(
fn=chatbot_with_fc,
type='messages',
chatbot=bot,
title="Chat with your data file",
concurrency_limit=None,
examples=example_questions,
additional_inputs=parameters
)
process_upload(filename, request.session_hash)
def process_upload(upload_value, session_hash):
if upload_value:
process_data_upload(upload_value, session_hash)
return [], []
demo.unload(delete_db)
|