Spaces:
Running
Running
import gradio as gr | |
import pdfplumber | |
import docx | |
import json | |
import re | |
import sqlalchemy | |
import requests | |
from tenacity import retry, stop_after_attempt, wait_exponential | |
from langchain_postgres.vectorstores import PGVector | |
from langchain_core.documents import Document | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain_postgres import PGVector | |
# API Keys and Database Connection | |
GROQ_API_KEY = "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" | |
NEON_CONNECTION_STRING = "postgresql://neondb_owner:npg_TPtUn1ArS6bo@ep-crimson-king-a12tfmdw-pooler.ap-southeast-1.aws.neon.tech/neondb?sslmode=require" | |
embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-small-en") | |
# Extract text from various document types | |
def extract_text_from_doc(file_path): | |
if file_path.endswith(".pdf"): | |
with pdfplumber.open(file_path) as pdf: | |
return "\n".join([page.extract_text() for page in pdf.pages if page.extract_text()]) | |
elif file_path.endswith(".docx"): | |
doc = docx.Document(file_path) | |
return "\n".join([p.text for p in doc.paragraphs]) | |
elif file_path.endswith(".txt"): | |
with open(file_path, "r", encoding="utf-8") as f: | |
return f.read() | |
return "" | |
# Database Connection | |
engine = sqlalchemy.create_engine(url=NEON_CONNECTION_STRING, pool_pre_ping=True, pool_recycle=300) | |
vector_store = PGVector(embeddings=embeddings, connection=engine, use_jsonb=True, collection_name="text-to-sql-context") | |
# Retry for API Calls | |
def call_groq_api(prompt): | |
headers = { | |
"Content-Type": "application/json", | |
"Authorization": f"Bearer {GROQ_API_KEY}", | |
} | |
data = { | |
"model": "llama-3.3-70b-versatile", | |
"messages": [{"role": "user", "content": prompt}] | |
} | |
response = requests.post("https://api.groq.com/openai/v1/chat/completions", headers=headers, json=data) | |
if response.status_code != 200: | |
raise Exception(f"Groq API error: {response.text}") | |
result = response.json() | |
return result.get("choices", [{}])[0].get("message", {}).get("content", "").strip() | |
# Remove extra text and extract only JSON | |
def extract_json(text): | |
match = re.search(r"\[.*\]", text, re.DOTALL) | |
if match: | |
return match.group(0) # Extract only the JSON array | |
return None # Invalid format | |
# SQL Query Generation Prompt | |
generation_prompt = """ | |
Generate 50 SQL queries based on this schema: | |
<schema> | |
{SCHEMA} | |
</schema> | |
Provide JSON output with 'question' and 'query'. | |
""" | |
# Process Schema, Generate Queries, and Answer User's Question | |
def process_and_query(file, question): | |
schema_text = extract_text_from_doc(file.name) | |
# Generate Queries | |
response = call_groq_api(generation_prompt.format(SCHEMA=schema_text)) | |
# Extract only valid JSON part | |
json_response = extract_json(response) | |
if not json_response: | |
return f"Error: Unexpected response format from Groq API: {response}" | |
try: | |
qa_pairs = json.loads(json_response) | |
except json.JSONDecodeError: | |
return f"Error: Could not parse JSON: {json_response}" | |
# Store Schema and Queries in Vector DB | |
schema_doc = Document(page_content=schema_text, metadata={"id": "schema", "topic": "ddl"}) | |
query_docs = [Document(page_content=json.dumps(pair), metadata={"id": f"query-{i}", "topic": "query"}) for i, pair in enumerate(qa_pairs)] | |
vector_store.add_documents([schema_doc] + query_docs, ids=[doc.metadata["id"] for doc in [schema_doc] + query_docs]) | |
# Retrieve Relevant Schema and Queries | |
relevant_ddl = vector_store.similarity_search(query=question, k=5, filter={"topic": {"$eq": "ddl"}}) | |
similar_queries = vector_store.similarity_search(query=question, k=3, filter={"topic": {"$eq": "query"}}) | |
schema = "\n".join([doc.page_content for doc in relevant_ddl]) | |
examples = "\n".join([json.loads(doc.page_content)["question"] + "\nSQL: " + json.loads(doc.page_content)["query"] for doc in similar_queries]) | |
query_prompt = f""" | |
You are an SQL expert. Generate a valid SQL query based on the schema and example queries. | |
1. Some DDL statements describing tables, columns and indexes in the database: | |
<schema> | |
{schema} | |
</schema> | |
2. Some example pairs demonstrating how to convert natural language text into a corresponding SQL query for this schema: | |
<examples> | |
{examples} | |
</examples> | |
3. The actual natural language question to convert into an SQL query: | |
<question> | |
{question} | |
</question> | |
Follow the instructions below: | |
1. Your task is to generate an SQL query that will retrieve the data needed to answer the question, based on the database schema. | |
2. First, carefully study the provided schema and examples to understand the structure of the database and how the examples map natural language to SQL for this schema. | |
3. Your answer should have two parts: | |
- Inside <scratchpad> XML tag, write out step-by-step reasoning to explain how you are generating the query based on the schema, example, and question. | |
- Then, inside <sql> XML tag, output your generated SQL. | |
SQL Query: | |
""" | |
query_response = call_groq_api(query_prompt) | |
return query_response | |
# Gradio UI | |
with gr.Blocks() as app: | |
gr.Markdown("# Text-to-SQL Converter") | |
file_input = gr.File(label="Upload Schema File") | |
question_input = gr.Textbox(label="Ask a SQL-related Question") | |
submit_button = gr.Button("Process & Generate SQL") | |
query_output = gr.Textbox(label="Generated SQL Query") | |
submit_button.click(process_and_query, inputs=[file_input, question_input], outputs=query_output) | |
app.launch(share=True) | |