import ast
import gradio as gr
from functions import sql_example_question_generator, sql_chatbot_with_fc
from data_sources import connect_sql_db
from utils import message_dict
def hide_info():
return gr.update(visible=False)
with gr.Blocks() as demo:
description = gr.HTML("""
This tool allows users to communicate with and query real time data from a SQL DB (postgres for now, others can be added if requested) using natural
language and the above features.
Notice: the way this system is designed, no login information is retained and credentials are passed as session variables until the user leaves or
refreshes the page in which they disappear. They are never saved to any files. I also make use of the Pandas read_sql_query function to apply SQL
queries, which can't delete, drop, or add database lines to avoid unhappy accidents or glitches.
That being said, it's probably not a good idea to connect a production database to a strange AI tool with an unfamiliar author.
This should be for demonstration purposes.
Contact me if this is something you would like built in your organization, on your infrastructure, and with the requisite privacy and control a production
database analytics tool requires.
""", elem_classes="description_component")
sql_url = gr.Textbox(label="URL", value="virtual-data-analyst-pg.cyetm2yjzppu.us-west-1.rds.amazonaws.com")
with gr.Row():
sql_port = gr.Textbox(label="Port", value="5432")
sql_user = gr.Textbox(label="Username", value="postgres")
sql_pass = gr.Textbox(label="Password", value="Vda-1988", type="password")
sql_db_name = gr.Textbox(label="Database Name", value="dvdrental")
submit = gr.Button(value="Submit")
submit.click(fn=hide_info, outputs=description)
@gr.render(inputs=[sql_url,sql_port,sql_user,sql_pass,sql_db_name], triggers=[submit.click])
def sql_chat(request: gr.Request, url=sql_url.value, sql_port=sql_port.value, sql_user=sql_user.value, sql_pass=sql_pass.value, sql_db_name=sql_db_name.value):
if request.session_hash not in message_dict:
message_dict[request.session_hash] = {}
message_dict[request.session_hash]['sql'] = None
if url:
print("SQL APP")
process_message = process_sql_db(url, sql_user, sql_port, sql_pass, sql_db_name, request.session_hash)
gr.HTML(value=process_message[1], padding=False)
if process_message[0] == "success":
if "virtual-data-analyst-pg.cyetm2yjzppu.us-west-1.rds.amazonaws.com" in url:
example_questions = [
["Describe the dataset"],
["What is the total revenue generated by each store?"],
["Can you generate and display a bar chart of film category to number of films in that category?"],
["Can you generate a pie chart showing the top 10 most rented films by revenue vs all other films?"],
["Can you generate a line chart of rental revenue over time?"],
["What is the relationship between film length and rental frequency?"]
]
else:
try:
generated_examples = ast.literal_eval(sql_example_question_generator(request.session_hash, process_message[2], sql_db_name))
example_questions = [
["Describe the dataset"]
]
for example in generated_examples:
example_questions.append([example])
except Exception as e:
print("SQL QUESTION GENERATION ERROR")
print(e)
example_questions = [
["Describe the dataset"],
["List the columns in the dataset"],
["What could this data be used for?"],
]
session_hash = gr.Textbox(visible=False, value=request.session_hash)
db_url = gr.Textbox(visible=False, value=url)
db_port = gr.Textbox(visible=False, value=sql_port)
db_user = gr.Textbox(visible=False, value=sql_user)
db_pass = gr.Textbox(visible=False, value=sql_pass)
db_name = gr.Textbox(visible=False, value=sql_db_name)
db_tables = gr.Textbox(value=process_message[2], interactive=False, label="SQL Tables")
bot = gr.Chatbot(type='messages', label="CSV Chat Window", render_markdown=True, sanitize_html=False, show_label=True, render=False, visible=True, elem_classes="chatbot")
chat = gr.ChatInterface(
fn=sql_chatbot_with_fc,
type='messages',
chatbot=bot,
title="Chat with your Database",
examples=example_questions,
concurrency_limit=None,
additional_inputs=[session_hash, db_url, db_port, db_user, db_pass, db_name, db_tables]
)
def process_sql_db(url, sql_user, sql_port, sql_pass, sql_db_name, session_hash):
if url:
process_message = connect_sql_db(url, sql_user, sql_port, sql_pass, sql_db_name, session_hash)
return process_message
if __name__ == "__main__":
demo.launch()