Spaces:
Running
Running
File size: 5,788 Bytes
c2b19f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import gradio as gr
import pdfplumber
import docx
import json
import re
import sqlalchemy
import requests
from tenacity import retry, stop_after_attempt, wait_exponential
from langchain_postgres.vectorstores import PGVector
from langchain_core.documents import Document
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_postgres import PGVector
# API Keys and Database Connection
GROQ_API_KEY = "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"
NEON_CONNECTION_STRING = "postgresql://neondb_owner:npg_TPtUn1ArS6bo@ep-crimson-king-a12tfmdw-pooler.ap-southeast-1.aws.neon.tech/neondb?sslmode=require"
embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-small-en")
# Extract text from various document types
def extract_text_from_doc(file_path):
if file_path.endswith(".pdf"):
with pdfplumber.open(file_path) as pdf:
return "\n".join([page.extract_text() for page in pdf.pages if page.extract_text()])
elif file_path.endswith(".docx"):
doc = docx.Document(file_path)
return "\n".join([p.text for p in doc.paragraphs])
elif file_path.endswith(".txt"):
with open(file_path, "r", encoding="utf-8") as f:
return f.read()
return ""
# Database Connection
engine = sqlalchemy.create_engine(url=NEON_CONNECTION_STRING, pool_pre_ping=True, pool_recycle=300)
vector_store = PGVector(embeddings=embeddings, connection=engine, use_jsonb=True, collection_name="text-to-sql-context")
# Retry for API Calls
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=2, max=10))
def call_groq_api(prompt):
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {GROQ_API_KEY}",
}
data = {
"model": "llama-3.3-70b-versatile",
"messages": [{"role": "user", "content": prompt}]
}
response = requests.post("https://api.groq.com/openai/v1/chat/completions", headers=headers, json=data)
if response.status_code != 200:
raise Exception(f"Groq API error: {response.text}")
result = response.json()
return result.get("choices", [{}])[0].get("message", {}).get("content", "").strip()
# Remove extra text and extract only JSON
def extract_json(text):
match = re.search(r"\[.*\]", text, re.DOTALL)
if match:
return match.group(0) # Extract only the JSON array
return None # Invalid format
# SQL Query Generation Prompt
generation_prompt = """
Generate 50 SQL queries based on this schema:
<schema>
{SCHEMA}
</schema>
Provide JSON output with 'question' and 'query'.
"""
# Process Schema, Generate Queries, and Answer User's Question
def process_and_query(file, question):
schema_text = extract_text_from_doc(file.name)
# Generate Queries
response = call_groq_api(generation_prompt.format(SCHEMA=schema_text))
# Extract only valid JSON part
json_response = extract_json(response)
if not json_response:
return f"Error: Unexpected response format from Groq API: {response}"
try:
qa_pairs = json.loads(json_response)
except json.JSONDecodeError:
return f"Error: Could not parse JSON: {json_response}"
# Store Schema and Queries in Vector DB
schema_doc = Document(page_content=schema_text, metadata={"id": "schema", "topic": "ddl"})
query_docs = [Document(page_content=json.dumps(pair), metadata={"id": f"query-{i}", "topic": "query"}) for i, pair in enumerate(qa_pairs)]
vector_store.add_documents([schema_doc] + query_docs, ids=[doc.metadata["id"] for doc in [schema_doc] + query_docs])
# Retrieve Relevant Schema and Queries
relevant_ddl = vector_store.similarity_search(query=question, k=5, filter={"topic": {"$eq": "ddl"}})
similar_queries = vector_store.similarity_search(query=question, k=3, filter={"topic": {"$eq": "query"}})
schema = "\n".join([doc.page_content for doc in relevant_ddl])
examples = "\n".join([json.loads(doc.page_content)["question"] + "\nSQL: " + json.loads(doc.page_content)["query"] for doc in similar_queries])
query_prompt = f"""
You are an SQL expert. Generate a valid SQL query based on the schema and example queries.
1. Some DDL statements describing tables, columns and indexes in the database:
<schema>
{schema}
</schema>
2. Some example pairs demonstrating how to convert natural language text into a corresponding SQL query for this schema:
<examples>
{examples}
</examples>
3. The actual natural language question to convert into an SQL query:
<question>
{question}
</question>
Follow the instructions below:
1. Your task is to generate an SQL query that will retrieve the data needed to answer the question, based on the database schema.
2. First, carefully study the provided schema and examples to understand the structure of the database and how the examples map natural language to SQL for this schema.
3. Your answer should have two parts:
- Inside <scratchpad> XML tag, write out step-by-step reasoning to explain how you are generating the query based on the schema, example, and question.
- Then, inside <sql> XML tag, output your generated SQL.
SQL Query:
"""
query_response = call_groq_api(query_prompt)
return query_response
# Gradio UI
with gr.Blocks() as app:
gr.Markdown("# Text-to-SQL Converter")
file_input = gr.File(label="Upload Schema File")
question_input = gr.Textbox(label="Ask a SQL-related Question")
submit_button = gr.Button("Process & Generate SQL")
query_output = gr.Textbox(label="Generated SQL Query")
submit_button.click(process_and_query, inputs=[file_input, question_input], outputs=query_output)
app.launch(share=True)
|