RAG / backend /semantic_search.py
thenativefox
Added recursive tables and reranking
6b2076c
import lancedb
import os
import gradio as gr
import openai
from sentence_transformers import SentenceTransformer
from sentence_transformers import CrossEncoder
from pathlib import Path
from dotenv import load_dotenv
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Load environment variables from the .env file
load_dotenv()
TABLE_NAME = os.getenv("TABLE_NAME")
ce = CrossEncoder('BAAI/bge-reranker-base')
# Determine the LanceDB path and log it
current_working_dir = Path(os.getcwd())
db_path = current_working_dir / ".lancedb"
db = lancedb.connect(db_path)
VECTOR_COLUMN = os.getenv("VECTOR_COLUMN", "vector")
TEXT_COLUMN = os.getenv("TEXT_COLUMN", "text")
BATCH_SIZE = int(os.getenv("BATCH_SIZE", 32))
retriever = SentenceTransformer(os.getenv("EMB_MODEL"))
def retrieve(query, k):
TABLE = db.open_table(TABLE_NAME)
if "3" in TABLE_NAME:
client = openai.OpenAI()
query_vec = client.embeddings.create(input=[query], model="text-embedding-ada-002").data[0].embedding
else:
query_vec = retriever.encode(query)
try:
documents = TABLE.search(query_vec, vector_column_name=VECTOR_COLUMN).limit(k).to_list()
documents = [doc[TEXT_COLUMN] for doc in documents]
pairs = [(query, doc) for doc in documents]
scores = ce.predict(pairs)
scored_documents = list(zip(documents, scores))
scored_documents.sort(key=lambda x: x[1], reverse=True)
top_documents = [doc[0] for doc in scored_documents[:3]]
return top_documents
except Exception as e:
raise gr.Error(str(e))
if __name__ == "__main__":
res = retrieve("What is transformer?", 4)
print(res)