Spaces:
Running
Running
thenativefox
commited on
Commit
·
9751345
1
Parent(s):
939262b
removed table name and fixed new table names
Browse files- backend/semantic_search.py +17 -3
backend/semantic_search.py
CHANGED
@@ -3,18 +3,32 @@ import os
|
|
3 |
import gradio as gr
|
4 |
from sentence_transformers import SentenceTransformer
|
5 |
|
6 |
-
|
7 |
db = lancedb.connect(".lancedb")
|
8 |
|
9 |
-
|
|
|
|
|
|
|
10 |
VECTOR_COLUMN = os.getenv("VECTOR_COLUMN", "vector")
|
11 |
TEXT_COLUMN = os.getenv("TEXT_COLUMN", "text")
|
12 |
BATCH_SIZE = int(os.getenv("BATCH_SIZE", 32))
|
13 |
|
14 |
retriever = SentenceTransformer(os.getenv("EMB_MODEL"))
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
def retrieve(query, k):
|
|
|
|
|
18 |
query_vec = retriever.encode(query)
|
19 |
try:
|
20 |
documents = TABLE.search(query_vec, vector_column_name=VECTOR_COLUMN).limit(k).to_list()
|
@@ -23,4 +37,4 @@ def retrieve(query, k):
|
|
23 |
return documents
|
24 |
|
25 |
except Exception as e:
|
26 |
-
raise gr.Error(str(e))
|
|
|
3 |
import gradio as gr
|
4 |
from sentence_transformers import SentenceTransformer
|
5 |
|
|
|
6 |
db = lancedb.connect(".lancedb")
|
7 |
|
8 |
+
MODEL1_STRATEGY1 = "model1_fixed.lance"
|
9 |
+
MODEL2_STRATEGY1 = "model2_fixed.lance"
|
10 |
+
MODEL3_STRATEGY1 = "model3_fixed.lance"
|
11 |
+
|
12 |
VECTOR_COLUMN = os.getenv("VECTOR_COLUMN", "vector")
|
13 |
TEXT_COLUMN = os.getenv("TEXT_COLUMN", "text")
|
14 |
BATCH_SIZE = int(os.getenv("BATCH_SIZE", 32))
|
15 |
|
16 |
retriever = SentenceTransformer(os.getenv("EMB_MODEL"))
|
17 |
|
18 |
+
def get_table_name():
|
19 |
+
emb_model = os.getenv("EMB_MODEL")
|
20 |
+
if emb_model == "sentence-transformers/all-MiniLM-L6-v2":
|
21 |
+
return MODEL1_STRATEGY1
|
22 |
+
elif emb_model == "BAAI/bge-large-en-v1.5":
|
23 |
+
return MODEL2_STRATEGY1
|
24 |
+
elif emb_model == "openai/text-embedding-ada-002":
|
25 |
+
return MODEL3_STRATEGY1
|
26 |
+
else:
|
27 |
+
raise ValueError(f"Unsupported embedding model: {emb_model}")
|
28 |
|
29 |
def retrieve(query, k):
|
30 |
+
table_name = get_table_name()
|
31 |
+
TABLE = db.open_table(table_name)
|
32 |
query_vec = retriever.encode(query)
|
33 |
try:
|
34 |
documents = TABLE.search(query_vec, vector_column_name=VECTOR_COLUMN).limit(k).to_list()
|
|
|
37 |
return documents
|
38 |
|
39 |
except Exception as e:
|
40 |
+
raise gr.Error(str(e))
|