nicole-ait commited on
Commit
4c4129f
·
1 Parent(s): 65a1209

update collection selector

Browse files
Files changed (1) hide show
  1. app.py +13 -10
app.py CHANGED
@@ -11,13 +11,13 @@ from langchain.chains import ConversationalRetrievalChain
11
 
12
 
13
  def load_embeddings():
14
- print("Loading embeddings...")
15
  model_name = os.environ['HUGGINGFACEHUB_EMBEDDINGS_MODEL_NAME']
16
  return HuggingFaceInstructEmbeddings(model_name=model_name)
17
 
18
 
19
  def split_file(file, chunk_size, chunk_overlap):
20
- print('spliting file', file.name)
21
  loader = TextLoader(file.name)
22
  documents = loader.load()
23
  text_splitter = CharacterTextSplitter(
@@ -30,17 +30,17 @@ def get_persist_directory(file_name):
30
 
31
 
32
  def process_file(file, chunk_size, chunk_overlap):
33
- docs = split_file(file, chunk_size, chunk_overlap)
34
- embeddings = load_embeddings()
35
-
36
  file_name, _ = os.path.splitext(os.path.basename(file.name))
37
  persist_directory = get_persist_directory(file_name)
38
  print("persist directory", persist_directory)
 
 
 
39
  vectordb = Chroma.from_documents(documents=docs, embedding=embeddings,
40
  collection_name=file_name, persist_directory=persist_directory)
41
  print(vectordb._client.list_collections())
42
  vectordb.persist()
43
- return 'Done!'
44
 
45
 
46
  def is_dir(root, name):
@@ -53,8 +53,9 @@ def get_vector_dbs():
53
  if not os.path.exists(root):
54
  return []
55
 
 
56
  files = os.listdir(root)
57
- dirs = filter(lambda x: is_dir(root, x), files)
58
  print(dirs)
59
  return dirs
60
 
@@ -71,7 +72,7 @@ def load_vectordb(file_name):
71
 
72
 
73
  def create_qa_chain(collection_name, temperature, max_length):
74
- print('creating qa chain...')
75
  memory = ConversationBufferMemory(
76
  memory_key="chat_history", return_messages=True)
77
  llm = HuggingFaceHub(
@@ -116,8 +117,9 @@ with gr.Blocks() as demo:
116
  with gr.Tab("Bot"):
117
  with gr.Row():
118
  with gr.Column(scale=0.5):
 
119
  collection = gr.Dropdown(
120
- choices=get_vector_dbs(), label="Document")
121
  temperature = gr.Slider(
122
  0.0, 1.0, value=0.5, step=0.05, label="Temperature")
123
  max_length = gr.Slider(20, 1000, value=64, label="Max Length")
@@ -128,7 +130,8 @@ with gr.Blocks() as demo:
128
  show_label=False, placeholder="Ask me anything!")
129
  clear = gr.Button("Clear")
130
 
131
- process.click(process_file, [upload, chunk_size, chunk_overlap], result)
 
132
 
133
  message.submit(submit_message, [chatbot, message], [chatbot, message]).then(
134
  bot, [chatbot, collection, temperature, max_length], chatbot
 
11
 
12
 
13
  def load_embeddings():
14
+ print("loading embeddings...")
15
  model_name = os.environ['HUGGINGFACEHUB_EMBEDDINGS_MODEL_NAME']
16
  return HuggingFaceInstructEmbeddings(model_name=model_name)
17
 
18
 
19
  def split_file(file, chunk_size, chunk_overlap):
20
+ print('spliting file...', file.name, chunk_size, chunk_overlap)
21
  loader = TextLoader(file.name)
22
  documents = loader.load()
23
  text_splitter = CharacterTextSplitter(
 
30
 
31
 
32
  def process_file(file, chunk_size, chunk_overlap):
 
 
 
33
  file_name, _ = os.path.splitext(os.path.basename(file.name))
34
  persist_directory = get_persist_directory(file_name)
35
  print("persist directory", persist_directory)
36
+
37
+ docs = split_file(file, chunk_size, chunk_overlap)
38
+ embeddings = load_embeddings()
39
  vectordb = Chroma.from_documents(documents=docs, embedding=embeddings,
40
  collection_name=file_name, persist_directory=persist_directory)
41
  print(vectordb._client.list_collections())
42
  vectordb.persist()
43
+ return 'Done!', gr.Dropdown.update(choices=get_vector_dbs(), value=file_name)
44
 
45
 
46
  def is_dir(root, name):
 
53
  if not os.path.exists(root):
54
  return []
55
 
56
+ print('get vector dbs...', root)
57
  files = os.listdir(root)
58
+ dirs = list(filter(lambda x: is_dir(root, x), files))
59
  print(dirs)
60
  return dirs
61
 
 
72
 
73
 
74
  def create_qa_chain(collection_name, temperature, max_length):
75
+ print('creating qa chain...', collection_name, temperature, max_length)
76
  memory = ConversationBufferMemory(
77
  memory_key="chat_history", return_messages=True)
78
  llm = HuggingFaceHub(
 
117
  with gr.Tab("Bot"):
118
  with gr.Row():
119
  with gr.Column(scale=0.5):
120
+ choices = get_vector_dbs()
121
  collection = gr.Dropdown(
122
+ choices, value=choices[0] if choices else None, label="Document")
123
  temperature = gr.Slider(
124
  0.0, 1.0, value=0.5, step=0.05, label="Temperature")
125
  max_length = gr.Slider(20, 1000, value=64, label="Max Length")
 
130
  show_label=False, placeholder="Ask me anything!")
131
  clear = gr.Button("Clear")
132
 
133
+ process.click(process_file, [upload, chunk_size,
134
+ chunk_overlap], [result, collection])
135
 
136
  message.submit(submit_message, [chatbot, message], [chatbot, message]).then(
137
  bot, [chatbot, collection, temperature, max_length], chatbot