Spaces:
Runtime error
Runtime error
nicole-ait
commited on
Commit
·
4c4129f
1
Parent(s):
65a1209
update collection selector
Browse files
app.py
CHANGED
@@ -11,13 +11,13 @@ from langchain.chains import ConversationalRetrievalChain
|
|
11 |
|
12 |
|
13 |
def load_embeddings():
|
14 |
-
print("
|
15 |
model_name = os.environ['HUGGINGFACEHUB_EMBEDDINGS_MODEL_NAME']
|
16 |
return HuggingFaceInstructEmbeddings(model_name=model_name)
|
17 |
|
18 |
|
19 |
def split_file(file, chunk_size, chunk_overlap):
|
20 |
-
print('spliting file', file.name)
|
21 |
loader = TextLoader(file.name)
|
22 |
documents = loader.load()
|
23 |
text_splitter = CharacterTextSplitter(
|
@@ -30,17 +30,17 @@ def get_persist_directory(file_name):
|
|
30 |
|
31 |
|
32 |
def process_file(file, chunk_size, chunk_overlap):
|
33 |
-
docs = split_file(file, chunk_size, chunk_overlap)
|
34 |
-
embeddings = load_embeddings()
|
35 |
-
|
36 |
file_name, _ = os.path.splitext(os.path.basename(file.name))
|
37 |
persist_directory = get_persist_directory(file_name)
|
38 |
print("persist directory", persist_directory)
|
|
|
|
|
|
|
39 |
vectordb = Chroma.from_documents(documents=docs, embedding=embeddings,
|
40 |
collection_name=file_name, persist_directory=persist_directory)
|
41 |
print(vectordb._client.list_collections())
|
42 |
vectordb.persist()
|
43 |
-
return 'Done!'
|
44 |
|
45 |
|
46 |
def is_dir(root, name):
|
@@ -53,8 +53,9 @@ def get_vector_dbs():
|
|
53 |
if not os.path.exists(root):
|
54 |
return []
|
55 |
|
|
|
56 |
files = os.listdir(root)
|
57 |
-
dirs = filter(lambda x: is_dir(root, x), files)
|
58 |
print(dirs)
|
59 |
return dirs
|
60 |
|
@@ -71,7 +72,7 @@ def load_vectordb(file_name):
|
|
71 |
|
72 |
|
73 |
def create_qa_chain(collection_name, temperature, max_length):
|
74 |
-
print('creating qa chain...')
|
75 |
memory = ConversationBufferMemory(
|
76 |
memory_key="chat_history", return_messages=True)
|
77 |
llm = HuggingFaceHub(
|
@@ -116,8 +117,9 @@ with gr.Blocks() as demo:
|
|
116 |
with gr.Tab("Bot"):
|
117 |
with gr.Row():
|
118 |
with gr.Column(scale=0.5):
|
|
|
119 |
collection = gr.Dropdown(
|
120 |
-
choices=
|
121 |
temperature = gr.Slider(
|
122 |
0.0, 1.0, value=0.5, step=0.05, label="Temperature")
|
123 |
max_length = gr.Slider(20, 1000, value=64, label="Max Length")
|
@@ -128,7 +130,8 @@ with gr.Blocks() as demo:
|
|
128 |
show_label=False, placeholder="Ask me anything!")
|
129 |
clear = gr.Button("Clear")
|
130 |
|
131 |
-
process.click(process_file, [upload, chunk_size,
|
|
|
132 |
|
133 |
message.submit(submit_message, [chatbot, message], [chatbot, message]).then(
|
134 |
bot, [chatbot, collection, temperature, max_length], chatbot
|
|
|
11 |
|
12 |
|
13 |
def load_embeddings():
|
14 |
+
print("loading embeddings...")
|
15 |
model_name = os.environ['HUGGINGFACEHUB_EMBEDDINGS_MODEL_NAME']
|
16 |
return HuggingFaceInstructEmbeddings(model_name=model_name)
|
17 |
|
18 |
|
19 |
def split_file(file, chunk_size, chunk_overlap):
|
20 |
+
print('spliting file...', file.name, chunk_size, chunk_overlap)
|
21 |
loader = TextLoader(file.name)
|
22 |
documents = loader.load()
|
23 |
text_splitter = CharacterTextSplitter(
|
|
|
30 |
|
31 |
|
32 |
def process_file(file, chunk_size, chunk_overlap):
|
|
|
|
|
|
|
33 |
file_name, _ = os.path.splitext(os.path.basename(file.name))
|
34 |
persist_directory = get_persist_directory(file_name)
|
35 |
print("persist directory", persist_directory)
|
36 |
+
|
37 |
+
docs = split_file(file, chunk_size, chunk_overlap)
|
38 |
+
embeddings = load_embeddings()
|
39 |
vectordb = Chroma.from_documents(documents=docs, embedding=embeddings,
|
40 |
collection_name=file_name, persist_directory=persist_directory)
|
41 |
print(vectordb._client.list_collections())
|
42 |
vectordb.persist()
|
43 |
+
return 'Done!', gr.Dropdown.update(choices=get_vector_dbs(), value=file_name)
|
44 |
|
45 |
|
46 |
def is_dir(root, name):
|
|
|
53 |
if not os.path.exists(root):
|
54 |
return []
|
55 |
|
56 |
+
print('get vector dbs...', root)
|
57 |
files = os.listdir(root)
|
58 |
+
dirs = list(filter(lambda x: is_dir(root, x), files))
|
59 |
print(dirs)
|
60 |
return dirs
|
61 |
|
|
|
72 |
|
73 |
|
74 |
def create_qa_chain(collection_name, temperature, max_length):
|
75 |
+
print('creating qa chain...', collection_name, temperature, max_length)
|
76 |
memory = ConversationBufferMemory(
|
77 |
memory_key="chat_history", return_messages=True)
|
78 |
llm = HuggingFaceHub(
|
|
|
117 |
with gr.Tab("Bot"):
|
118 |
with gr.Row():
|
119 |
with gr.Column(scale=0.5):
|
120 |
+
choices = get_vector_dbs()
|
121 |
collection = gr.Dropdown(
|
122 |
+
choices, value=choices[0] if choices else None, label="Document")
|
123 |
temperature = gr.Slider(
|
124 |
0.0, 1.0, value=0.5, step=0.05, label="Temperature")
|
125 |
max_length = gr.Slider(20, 1000, value=64, label="Max Length")
|
|
|
130 |
show_label=False, placeholder="Ask me anything!")
|
131 |
clear = gr.Button("Clear")
|
132 |
|
133 |
+
process.click(process_file, [upload, chunk_size,
|
134 |
+
chunk_overlap], [result, collection])
|
135 |
|
136 |
message.submit(submit_message, [chatbot, message], [chatbot, message]).then(
|
137 |
bot, [chatbot, collection, temperature, max_length], chatbot
|