import lancedb import os import gradio as gr import openai from sentence_transformers import SentenceTransformer from sentence_transformers import CrossEncoder from pathlib import Path from dotenv import load_dotenv import logging # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Load environment variables from the .env file load_dotenv() TABLE_NAME = os.getenv("TABLE_NAME") ce = CrossEncoder('BAAI/bge-reranker-base') # Determine the LanceDB path and log it current_working_dir = Path(os.getcwd()) db_path = current_working_dir / ".lancedb" db = lancedb.connect(db_path) VECTOR_COLUMN = os.getenv("VECTOR_COLUMN", "vector") TEXT_COLUMN = os.getenv("TEXT_COLUMN", "text") BATCH_SIZE = int(os.getenv("BATCH_SIZE", 32)) retriever = SentenceTransformer(os.getenv("EMB_MODEL")) def retrieve(query, k): TABLE = db.open_table(TABLE_NAME) if "3" in TABLE_NAME: client = openai.OpenAI() query_vec = client.embeddings.create(input=[query], model="text-embedding-ada-002").data[0].embedding else: query_vec = retriever.encode(query) try: documents = TABLE.search(query_vec, vector_column_name=VECTOR_COLUMN).limit(k).to_list() documents = [doc[TEXT_COLUMN] for doc in documents] pairs = [(query, doc) for doc in documents] scores = ce.predict(pairs) scored_documents = list(zip(documents, scores)) scored_documents.sort(key=lambda x: x[1], reverse=True) top_documents = [doc[0] for doc in scored_documents[:3]] return top_documents except Exception as e: raise gr.Error(str(e)) if __name__ == "__main__": res = retrieve("What is transformer?", 4) print(res)