text2sql_rag / app.py
Sharath7693's picture
Create app.py
c2b19f9 verified
import gradio as gr
import pdfplumber
import docx
import json
import re
import sqlalchemy
import requests
from tenacity import retry, stop_after_attempt, wait_exponential
from langchain_postgres.vectorstores import PGVector
from langchain_core.documents import Document
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_postgres import PGVector
# API Keys and Database Connection
GROQ_API_KEY = "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"
NEON_CONNECTION_STRING = "postgresql://neondb_owner:npg_TPtUn1ArS6bo@ep-crimson-king-a12tfmdw-pooler.ap-southeast-1.aws.neon.tech/neondb?sslmode=require"
embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-small-en")
# Extract text from various document types
def extract_text_from_doc(file_path):
if file_path.endswith(".pdf"):
with pdfplumber.open(file_path) as pdf:
return "\n".join([page.extract_text() for page in pdf.pages if page.extract_text()])
elif file_path.endswith(".docx"):
doc = docx.Document(file_path)
return "\n".join([p.text for p in doc.paragraphs])
elif file_path.endswith(".txt"):
with open(file_path, "r", encoding="utf-8") as f:
return f.read()
return ""
# Database Connection
engine = sqlalchemy.create_engine(url=NEON_CONNECTION_STRING, pool_pre_ping=True, pool_recycle=300)
vector_store = PGVector(embeddings=embeddings, connection=engine, use_jsonb=True, collection_name="text-to-sql-context")
# Retry for API Calls
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=2, max=10))
def call_groq_api(prompt):
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {GROQ_API_KEY}",
}
data = {
"model": "llama-3.3-70b-versatile",
"messages": [{"role": "user", "content": prompt}]
}
response = requests.post("https://api.groq.com/openai/v1/chat/completions", headers=headers, json=data)
if response.status_code != 200:
raise Exception(f"Groq API error: {response.text}")
result = response.json()
return result.get("choices", [{}])[0].get("message", {}).get("content", "").strip()
# Remove extra text and extract only JSON
def extract_json(text):
match = re.search(r"\[.*\]", text, re.DOTALL)
if match:
return match.group(0) # Extract only the JSON array
return None # Invalid format
# SQL Query Generation Prompt
generation_prompt = """
Generate 50 SQL queries based on this schema:
<schema>
{SCHEMA}
</schema>
Provide JSON output with 'question' and 'query'.
"""
# Process Schema, Generate Queries, and Answer User's Question
def process_and_query(file, question):
schema_text = extract_text_from_doc(file.name)
# Generate Queries
response = call_groq_api(generation_prompt.format(SCHEMA=schema_text))
# Extract only valid JSON part
json_response = extract_json(response)
if not json_response:
return f"Error: Unexpected response format from Groq API: {response}"
try:
qa_pairs = json.loads(json_response)
except json.JSONDecodeError:
return f"Error: Could not parse JSON: {json_response}"
# Store Schema and Queries in Vector DB
schema_doc = Document(page_content=schema_text, metadata={"id": "schema", "topic": "ddl"})
query_docs = [Document(page_content=json.dumps(pair), metadata={"id": f"query-{i}", "topic": "query"}) for i, pair in enumerate(qa_pairs)]
vector_store.add_documents([schema_doc] + query_docs, ids=[doc.metadata["id"] for doc in [schema_doc] + query_docs])
# Retrieve Relevant Schema and Queries
relevant_ddl = vector_store.similarity_search(query=question, k=5, filter={"topic": {"$eq": "ddl"}})
similar_queries = vector_store.similarity_search(query=question, k=3, filter={"topic": {"$eq": "query"}})
schema = "\n".join([doc.page_content for doc in relevant_ddl])
examples = "\n".join([json.loads(doc.page_content)["question"] + "\nSQL: " + json.loads(doc.page_content)["query"] for doc in similar_queries])
query_prompt = f"""
You are an SQL expert. Generate a valid SQL query based on the schema and example queries.
1. Some DDL statements describing tables, columns and indexes in the database:
<schema>
{schema}
</schema>
2. Some example pairs demonstrating how to convert natural language text into a corresponding SQL query for this schema:
<examples>
{examples}
</examples>
3. The actual natural language question to convert into an SQL query:
<question>
{question}
</question>
Follow the instructions below:
1. Your task is to generate an SQL query that will retrieve the data needed to answer the question, based on the database schema.
2. First, carefully study the provided schema and examples to understand the structure of the database and how the examples map natural language to SQL for this schema.
3. Your answer should have two parts:
- Inside <scratchpad> XML tag, write out step-by-step reasoning to explain how you are generating the query based on the schema, example, and question.
- Then, inside <sql> XML tag, output your generated SQL.
SQL Query:
"""
query_response = call_groq_api(query_prompt)
return query_response
# Gradio UI
with gr.Blocks() as app:
gr.Markdown("# Text-to-SQL Converter")
file_input = gr.File(label="Upload Schema File")
question_input = gr.Textbox(label="Ask a SQL-related Question")
submit_button = gr.Button("Process & Generate SQL")
query_output = gr.Textbox(label="Generated SQL Query")
submit_button.click(process_and_query, inputs=[file_input, question_input], outputs=query_output)
app.launch(share=True)