nicole-ait commited on
Commit
65a1209
·
1 Parent(s): 0658357

layout w/ tabs

Browse files
Files changed (1) hide show
  1. app.py +75 -32
app.py CHANGED
@@ -5,19 +5,23 @@ from langchain.embeddings import HuggingFaceInstructEmbeddings
5
  from langchain.text_splitter import CharacterTextSplitter
6
  from langchain.vectorstores import Chroma
7
  from langchain.document_loaders import TextLoader
 
 
 
8
 
9
 
10
  def load_embeddings():
11
- print(os.environ)
12
  model_name = os.environ['HUGGINGFACEHUB_EMBEDDINGS_MODEL_NAME']
13
  return HuggingFaceInstructEmbeddings(model_name=model_name)
14
 
15
 
16
- def split_file(file):
17
- print(file.name)
18
  loader = TextLoader(file.name)
19
  documents = loader.load()
20
- text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=20)
 
21
  return text_splitter.split_documents(documents)
22
 
23
 
@@ -25,37 +29,69 @@ def get_persist_directory(file_name):
25
  return os.path.join(os.environ['CHROMADB_PERSIST_DIRECTORY'], file_name)
26
 
27
 
28
- def process_file(file):
 
29
  embeddings = load_embeddings()
30
- print(embeddings)
31
- docs = split_file(file)
32
- print(docs)
33
 
34
  file_name, _ = os.path.splitext(os.path.basename(file.name))
35
  persist_directory = get_persist_directory(file_name)
36
- print(persist_directory)
37
  vectordb = Chroma.from_documents(documents=docs, embedding=embeddings,
38
  collection_name=file_name, persist_directory=persist_directory)
39
  print(vectordb._client.list_collections())
40
  vectordb.persist()
41
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
 
44
  def load_vectordb(file_name):
45
  embeddings = load_embeddings()
46
 
47
  persist_directory = get_persist_directory(file_name)
 
48
  vectordb = Chroma(collection_name=file_name,
49
  embedding_function=embeddings, persist_directory=persist_directory)
 
50
  return vectordb
51
 
52
 
53
- def add_text(bot_history, text):
 
 
 
 
 
 
 
 
 
 
 
 
54
  bot_history = bot_history + [(text, None)]
55
  return bot_history, ""
56
 
57
 
58
- def bot(bot_history):
 
 
 
 
59
  bot_history[-1][1] = 'so cool!'
60
  return bot_history
61
 
@@ -69,26 +105,33 @@ title = "QnA Chatbot"
69
  with gr.Blocks() as demo:
70
  gr.Markdown(f"# {title}")
71
 
72
- with gr.Row():
73
- with gr.Column(scale=0.5):
74
- upload = gr.File(file_types=["text"], label="Upload file")
75
-
76
- process = gr.Button("Process")
77
-
78
- with gr.Column(scale=0.5):
79
- chatbot = gr.Chatbot([], elem_id="chatbot").style(height=750)
80
-
81
- txt = gr.Textbox(
82
- show_label=False,
83
- placeholder="Enter text and press enter",
84
- ).style(container=False)
85
-
86
- clear = gr.Button("Clear")
87
-
88
- process.click(process_file, upload, None)
89
-
90
- txt.submit(add_text, [chatbot, txt], [chatbot, txt]).then(
91
- bot, chatbot, chatbot
 
 
 
 
 
 
 
92
  )
93
  clear.click(clear_bot, None, chatbot)
94
 
 
5
  from langchain.text_splitter import CharacterTextSplitter
6
  from langchain.vectorstores import Chroma
7
  from langchain.document_loaders import TextLoader
8
+ from langchain.memory import ConversationBufferMemory
9
+ from langchain.llms import HuggingFaceHub
10
+ from langchain.chains import ConversationalRetrievalChain
11
 
12
 
13
  def load_embeddings():
14
+ print("Loading embeddings...")
15
  model_name = os.environ['HUGGINGFACEHUB_EMBEDDINGS_MODEL_NAME']
16
  return HuggingFaceInstructEmbeddings(model_name=model_name)
17
 
18
 
19
+ def split_file(file, chunk_size, chunk_overlap):
20
+ print('spliting file', file.name)
21
  loader = TextLoader(file.name)
22
  documents = loader.load()
23
+ text_splitter = CharacterTextSplitter(
24
+ chunk_size=chunk_size, chunk_overlap=chunk_overlap)
25
  return text_splitter.split_documents(documents)
26
 
27
 
 
29
  return os.path.join(os.environ['CHROMADB_PERSIST_DIRECTORY'], file_name)
30
 
31
 
32
+ def process_file(file, chunk_size, chunk_overlap):
33
+ docs = split_file(file, chunk_size, chunk_overlap)
34
  embeddings = load_embeddings()
 
 
 
35
 
36
  file_name, _ = os.path.splitext(os.path.basename(file.name))
37
  persist_directory = get_persist_directory(file_name)
38
+ print("persist directory", persist_directory)
39
  vectordb = Chroma.from_documents(documents=docs, embedding=embeddings,
40
  collection_name=file_name, persist_directory=persist_directory)
41
  print(vectordb._client.list_collections())
42
  vectordb.persist()
43
+ return 'Done!'
44
+
45
+
46
+ def is_dir(root, name):
47
+ path = os.path.join(root, name)
48
+ return os.path.isdir(path)
49
+
50
+
51
+ def get_vector_dbs():
52
+ root = os.environ['CHROMADB_PERSIST_DIRECTORY']
53
+ if not os.path.exists(root):
54
+ return []
55
+
56
+ files = os.listdir(root)
57
+ dirs = filter(lambda x: is_dir(root, x), files)
58
+ print(dirs)
59
+ return dirs
60
 
61
 
62
  def load_vectordb(file_name):
63
  embeddings = load_embeddings()
64
 
65
  persist_directory = get_persist_directory(file_name)
66
+ print(persist_directory)
67
  vectordb = Chroma(collection_name=file_name,
68
  embedding_function=embeddings, persist_directory=persist_directory)
69
+ print(vectordb._client.list_collections())
70
  return vectordb
71
 
72
 
73
+ def create_qa_chain(collection_name, temperature, max_length):
74
+ print('creating qa chain...')
75
+ memory = ConversationBufferMemory(
76
+ memory_key="chat_history", return_messages=True)
77
+ llm = HuggingFaceHub(
78
+ repo_id=os.environ["HUGGINGFACEHUB_LLM_REPO_ID"],
79
+ model_kwargs={"temperature": temperature, "max_length": max_length}
80
+ )
81
+ vectordb = load_vectordb(collection_name)
82
+ return ConversationalRetrievalChain.from_llm(llm=llm, retriever=vectordb.as_retriever(), memory=memory)
83
+
84
+
85
+ def submit_message(bot_history, text):
86
  bot_history = bot_history + [(text, None)]
87
  return bot_history, ""
88
 
89
 
90
+ def bot(bot_history, collection_name, temperature, max_length):
91
+ qa = create_qa_chain(collection_name, temperature, max_length)
92
+ print(qa, bot_history[-1][1])
93
+ qa.run(bot_history[-1][0])
94
+
95
  bot_history[-1][1] = 'so cool!'
96
  return bot_history
97
 
 
105
  with gr.Blocks() as demo:
106
  gr.Markdown(f"# {title}")
107
 
108
+ with gr.Tab("File"):
109
+ upload = gr.File(file_types=["text"], label="Upload File")
110
+ chunk_size = gr.Slider(
111
+ 500, 5000, value=1000, step=100, label="Chunk Size")
112
+ chunk_overlap = gr.Slider(0, 30, value=20, label="Chunk Overlap")
113
+ process = gr.Button("Process")
114
+ result = gr.Label()
115
+
116
+ with gr.Tab("Bot"):
117
+ with gr.Row():
118
+ with gr.Column(scale=0.5):
119
+ collection = gr.Dropdown(
120
+ choices=get_vector_dbs(), label="Document")
121
+ temperature = gr.Slider(
122
+ 0.0, 1.0, value=0.5, step=0.05, label="Temperature")
123
+ max_length = gr.Slider(20, 1000, value=64, label="Max Length")
124
+
125
+ with gr.Column(scale=0.5):
126
+ chatbot = gr.Chatbot([], elem_id="chatbot").style(height=550)
127
+ message = gr.Textbox(
128
+ show_label=False, placeholder="Ask me anything!")
129
+ clear = gr.Button("Clear")
130
+
131
+ process.click(process_file, [upload, chunk_size, chunk_overlap], result)
132
+
133
+ message.submit(submit_message, [chatbot, message], [chatbot, message]).then(
134
+ bot, [chatbot, collection, temperature, max_length], chatbot
135
  )
136
  clear.click(clear_bot, None, chatbot)
137