thenativefox commited on
Commit
9751345
·
1 Parent(s): 939262b

removed table name and fixed new table names

Browse files
Files changed (1) hide show
  1. backend/semantic_search.py +17 -3
backend/semantic_search.py CHANGED
@@ -3,18 +3,32 @@ import os
3
  import gradio as gr
4
  from sentence_transformers import SentenceTransformer
5
 
6
-
7
  db = lancedb.connect(".lancedb")
8
 
9
- TABLE = db.open_table(os.getenv("TABLE_NAME"))
 
 
 
10
  VECTOR_COLUMN = os.getenv("VECTOR_COLUMN", "vector")
11
  TEXT_COLUMN = os.getenv("TEXT_COLUMN", "text")
12
  BATCH_SIZE = int(os.getenv("BATCH_SIZE", 32))
13
 
14
  retriever = SentenceTransformer(os.getenv("EMB_MODEL"))
15
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  def retrieve(query, k):
 
 
18
  query_vec = retriever.encode(query)
19
  try:
20
  documents = TABLE.search(query_vec, vector_column_name=VECTOR_COLUMN).limit(k).to_list()
@@ -23,4 +37,4 @@ def retrieve(query, k):
23
  return documents
24
 
25
  except Exception as e:
26
- raise gr.Error(str(e))
 
3
  import gradio as gr
4
  from sentence_transformers import SentenceTransformer
5
 
 
6
  db = lancedb.connect(".lancedb")
7
 
8
+ MODEL1_STRATEGY1 = "model1_fixed.lance"
9
+ MODEL2_STRATEGY1 = "model2_fixed.lance"
10
+ MODEL3_STRATEGY1 = "model3_fixed.lance"
11
+
12
  VECTOR_COLUMN = os.getenv("VECTOR_COLUMN", "vector")
13
  TEXT_COLUMN = os.getenv("TEXT_COLUMN", "text")
14
  BATCH_SIZE = int(os.getenv("BATCH_SIZE", 32))
15
 
16
  retriever = SentenceTransformer(os.getenv("EMB_MODEL"))
17
 
18
+ def get_table_name():
19
+ emb_model = os.getenv("EMB_MODEL")
20
+ if emb_model == "sentence-transformers/all-MiniLM-L6-v2":
21
+ return MODEL1_STRATEGY1
22
+ elif emb_model == "BAAI/bge-large-en-v1.5":
23
+ return MODEL2_STRATEGY1
24
+ elif emb_model == "openai/text-embedding-ada-002":
25
+ return MODEL3_STRATEGY1
26
+ else:
27
+ raise ValueError(f"Unsupported embedding model: {emb_model}")
28
 
29
  def retrieve(query, k):
30
+ table_name = get_table_name()
31
+ TABLE = db.open_table(table_name)
32
  query_vec = retriever.encode(query)
33
  try:
34
  documents = TABLE.search(query_vec, vector_column_name=VECTOR_COLUMN).limit(k).to_list()
 
37
  return documents
38
 
39
  except Exception as e:
40
+ raise gr.Error(str(e))