File size: 1,740 Bytes
b7f4e8c
 
 
6b2076c
b7f4e8c
6b2076c
3f0e240
c44b083
a7432c6
 
6b2076c
a7432c6
 
 
c44b083
 
 
6b2076c
 
 
c44b083
bd075c2
1c860fb
bd075c2
1c860fb
3f0e240
b7f4e8c
 
 
 
 
 
 
6b2076c
 
 
 
 
 
9751345
6b2076c
b7f4e8c
 
 
 
 
6b2076c
 
 
 
 
 
 
 
b7f4e8c
 
c44b083
 
bd075c2
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import lancedb
import os
import gradio as gr
import openai
from sentence_transformers import SentenceTransformer
from sentence_transformers import CrossEncoder
from pathlib import Path
from dotenv import load_dotenv
import logging


# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Load environment variables from the .env file
load_dotenv()
TABLE_NAME = os.getenv("TABLE_NAME")

ce = CrossEncoder('BAAI/bge-reranker-base')

# Determine the LanceDB path and log it
current_working_dir = Path(os.getcwd())
db_path = current_working_dir / ".lancedb"

db = lancedb.connect(db_path)

VECTOR_COLUMN = os.getenv("VECTOR_COLUMN", "vector")
TEXT_COLUMN = os.getenv("TEXT_COLUMN", "text")
BATCH_SIZE = int(os.getenv("BATCH_SIZE", 32))

retriever = SentenceTransformer(os.getenv("EMB_MODEL"))

def retrieve(query, k):
    TABLE = db.open_table(TABLE_NAME)

    if "3" in TABLE_NAME:
        client = openai.OpenAI()
        query_vec = client.embeddings.create(input=[query], model="text-embedding-ada-002").data[0].embedding
    else:
        query_vec = retriever.encode(query)

    try:
        documents = TABLE.search(query_vec, vector_column_name=VECTOR_COLUMN).limit(k).to_list()
        documents = [doc[TEXT_COLUMN] for doc in documents]

        pairs = [(query, doc) for doc in documents]
        scores = ce.predict(pairs)

        scored_documents = list(zip(documents, scores))
        scored_documents.sort(key=lambda x: x[1], reverse=True)
        top_documents = [doc[0] for doc in scored_documents[:3]]

        return top_documents

    except Exception as e:
        raise gr.Error(str(e))
    
if __name__ == "__main__":
    res = retrieve("What is transformer?", 4)
    print(res)