|
import ast
|
|
import gradio as gr
|
|
from functions import sql_example_question_generator, sql_chatbot_with_fc
|
|
from data_sources import connect_sql_db
|
|
from utils import message_dict
|
|
|
|
def hide_info():
|
|
return gr.update(visible=False)
|
|
|
|
with gr.Blocks() as demo:
|
|
description = gr.HTML("""
|
|
<!-- Header -->
|
|
<div class="max-w-4xl mx-auto mb-12 text-center">
|
|
<div class="bg-blue-50 border border-blue-200 rounded-lg max-w-2xl mx-auto">
|
|
<p>This tool allows users to communicate with and query real time data from a SQL DB (postgres for now, others can be added if requested) using natural
|
|
language and the above features.</p>
|
|
<p style="font-weight:bold;">Notice: the way this system is designed, no login information is retained and credentials are passed as session variables until the user leaves or
|
|
refreshes the page in which they disappear. They are never saved to any files. I also make use of the Pandas read_sql_query function to apply SQL
|
|
queries, which can't delete, drop, or add database lines to avoid unhappy accidents or glitches.
|
|
That being said, it's probably not a good idea to connect a production database to a strange AI tool with an unfamiliar author.
|
|
This should be for demonstration purposes.</p>
|
|
<p>Contact me if this is something you would like built in your organization, on your infrastructure, and with the requisite privacy and control a production
|
|
database analytics tool requires.</p>
|
|
</div>
|
|
</div>
|
|
""", elem_classes="description_component")
|
|
sql_url = gr.Textbox(label="URL", value="virtual-data-analyst-pg.cyetm2yjzppu.us-west-1.rds.amazonaws.com")
|
|
with gr.Row():
|
|
sql_port = gr.Textbox(label="Port", value="5432")
|
|
sql_user = gr.Textbox(label="Username", value="postgres")
|
|
sql_pass = gr.Textbox(label="Password", value="Vda-1988", type="password")
|
|
sql_db_name = gr.Textbox(label="Database Name", value="dvdrental")
|
|
|
|
submit = gr.Button(value="Submit")
|
|
submit.click(fn=hide_info, outputs=description)
|
|
|
|
@gr.render(inputs=[sql_url,sql_port,sql_user,sql_pass,sql_db_name], triggers=[submit.click])
|
|
def sql_chat(request: gr.Request, url=sql_url.value, sql_port=sql_port.value, sql_user=sql_user.value, sql_pass=sql_pass.value, sql_db_name=sql_db_name.value):
|
|
if request.session_hash not in message_dict:
|
|
message_dict[request.session_hash] = {}
|
|
message_dict[request.session_hash]['sql'] = None
|
|
if url:
|
|
print("SQL APP")
|
|
process_message = process_sql_db(url, sql_user, sql_port, sql_pass, sql_db_name, request.session_hash)
|
|
gr.HTML(value=process_message[1], padding=False)
|
|
if process_message[0] == "success":
|
|
if "virtual-data-analyst-pg.cyetm2yjzppu.us-west-1.rds.amazonaws.com" in url:
|
|
example_questions = [
|
|
["Describe the dataset"],
|
|
["What is the total revenue generated by each store?"],
|
|
["Can you generate and display a bar chart of film category to number of films in that category?"],
|
|
["Can you generate a pie chart showing the top 10 most rented films by revenue vs all other films?"],
|
|
["Can you generate a line chart of rental revenue over time?"],
|
|
["What is the relationship between film length and rental frequency?"]
|
|
]
|
|
else:
|
|
try:
|
|
generated_examples = ast.literal_eval(sql_example_question_generator(request.session_hash, process_message[2], sql_db_name))
|
|
example_questions = [
|
|
["Describe the dataset"]
|
|
]
|
|
for example in generated_examples:
|
|
example_questions.append([example])
|
|
except Exception as e:
|
|
print("SQL QUESTION GENERATION ERROR")
|
|
print(e)
|
|
example_questions = [
|
|
["Describe the dataset"],
|
|
["List the columns in the dataset"],
|
|
["What could this data be used for?"],
|
|
]
|
|
session_hash = gr.Textbox(visible=False, value=request.session_hash)
|
|
db_url = gr.Textbox(visible=False, value=url)
|
|
db_port = gr.Textbox(visible=False, value=sql_port)
|
|
db_user = gr.Textbox(visible=False, value=sql_user)
|
|
db_pass = gr.Textbox(visible=False, value=sql_pass)
|
|
db_name = gr.Textbox(visible=False, value=sql_db_name)
|
|
db_tables = gr.Textbox(value=process_message[2], interactive=False, label="SQL Tables")
|
|
bot = gr.Chatbot(type='messages', label="CSV Chat Window", render_markdown=True, sanitize_html=False, show_label=True, render=False, visible=True, elem_classes="chatbot")
|
|
chat = gr.ChatInterface(
|
|
fn=sql_chatbot_with_fc,
|
|
type='messages',
|
|
chatbot=bot,
|
|
title="Chat with your Database",
|
|
examples=example_questions,
|
|
concurrency_limit=None,
|
|
additional_inputs=[session_hash, db_url, db_port, db_user, db_pass, db_name, db_tables]
|
|
)
|
|
|
|
def process_sql_db(url, sql_user, sql_port, sql_pass, sql_db_name, session_hash):
|
|
if url:
|
|
process_message = connect_sql_db(url, sql_user, sql_port, sql_pass, sql_db_name, session_hash)
|
|
return process_message
|
|
|
|
if __name__ == "__main__":
|
|
demo.launch() |