Spaces:
Running
Running
File size: 1,740 Bytes
b7f4e8c 6b2076c b7f4e8c 6b2076c 3f0e240 c44b083 a7432c6 6b2076c a7432c6 c44b083 6b2076c c44b083 bd075c2 1c860fb bd075c2 1c860fb 3f0e240 b7f4e8c 6b2076c 9751345 6b2076c b7f4e8c 6b2076c b7f4e8c c44b083 bd075c2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import lancedb
import os
import gradio as gr
import openai
from sentence_transformers import SentenceTransformer
from sentence_transformers import CrossEncoder
from pathlib import Path
from dotenv import load_dotenv
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Load environment variables from the .env file
load_dotenv()
TABLE_NAME = os.getenv("TABLE_NAME")
ce = CrossEncoder('BAAI/bge-reranker-base')
# Determine the LanceDB path and log it
current_working_dir = Path(os.getcwd())
db_path = current_working_dir / ".lancedb"
db = lancedb.connect(db_path)
VECTOR_COLUMN = os.getenv("VECTOR_COLUMN", "vector")
TEXT_COLUMN = os.getenv("TEXT_COLUMN", "text")
BATCH_SIZE = int(os.getenv("BATCH_SIZE", 32))
retriever = SentenceTransformer(os.getenv("EMB_MODEL"))
def retrieve(query, k):
TABLE = db.open_table(TABLE_NAME)
if "3" in TABLE_NAME:
client = openai.OpenAI()
query_vec = client.embeddings.create(input=[query], model="text-embedding-ada-002").data[0].embedding
else:
query_vec = retriever.encode(query)
try:
documents = TABLE.search(query_vec, vector_column_name=VECTOR_COLUMN).limit(k).to_list()
documents = [doc[TEXT_COLUMN] for doc in documents]
pairs = [(query, doc) for doc in documents]
scores = ce.predict(pairs)
scored_documents = list(zip(documents, scores))
scored_documents.sort(key=lambda x: x[1], reverse=True)
top_documents = [doc[0] for doc in scored_documents[:3]]
return top_documents
except Exception as e:
raise gr.Error(str(e))
if __name__ == "__main__":
res = retrieve("What is transformer?", 4)
print(res) |