nicole-ait commited on
Commit
402f092
·
1 Parent(s): 4c4129f

global embeddings & qa chain

Browse files
Files changed (1) hide show
  1. app.py +37 -14
app.py CHANGED
@@ -10,10 +10,19 @@ from langchain.llms import HuggingFaceHub
10
  from langchain.chains import ConversationalRetrievalChain
11
 
12
 
 
 
 
 
13
  def load_embeddings():
14
- print("loading embeddings...")
15
- model_name = os.environ['HUGGINGFACEHUB_EMBEDDINGS_MODEL_NAME']
16
- return HuggingFaceInstructEmbeddings(model_name=model_name)
 
 
 
 
 
17
 
18
 
19
  def split_file(file, chunk_size, chunk_overlap):
@@ -73,6 +82,10 @@ def load_vectordb(file_name):
73
 
74
  def create_qa_chain(collection_name, temperature, max_length):
75
  print('creating qa chain...', collection_name, temperature, max_length)
 
 
 
 
76
  memory = ConversationBufferMemory(
77
  memory_key="chat_history", return_messages=True)
78
  llm = HuggingFaceHub(
@@ -80,7 +93,8 @@ def create_qa_chain(collection_name, temperature, max_length):
80
  model_kwargs={"temperature": temperature, "max_length": max_length}
81
  )
82
  vectordb = load_vectordb(collection_name)
83
- return ConversationalRetrievalChain.from_llm(llm=llm, retriever=vectordb.as_retriever(), memory=memory)
 
84
 
85
 
86
  def submit_message(bot_history, text):
@@ -88,12 +102,12 @@ def submit_message(bot_history, text):
88
  return bot_history, ""
89
 
90
 
91
- def bot(bot_history, collection_name, temperature, max_length):
92
- qa = create_qa_chain(collection_name, temperature, max_length)
93
- print(qa, bot_history[-1][1])
94
- qa.run(bot_history[-1][0])
95
-
96
- bot_history[-1][1] = 'so cool!'
97
  return bot_history
98
 
99
 
@@ -122,7 +136,8 @@ with gr.Blocks() as demo:
122
  choices, value=choices[0] if choices else None, label="Document")
123
  temperature = gr.Slider(
124
  0.0, 1.0, value=0.5, step=0.05, label="Temperature")
125
- max_length = gr.Slider(20, 1000, value=64, label="Max Length")
 
126
 
127
  with gr.Column(scale=0.5):
128
  chatbot = gr.Chatbot([], elem_id="chatbot").style(height=550)
@@ -130,11 +145,19 @@ with gr.Blocks() as demo:
130
  show_label=False, placeholder="Ask me anything!")
131
  clear = gr.Button("Clear")
132
 
133
- process.click(process_file, [upload, chunk_size,
134
- chunk_overlap], [result, collection])
 
 
 
 
 
 
 
 
135
 
136
  message.submit(submit_message, [chatbot, message], [chatbot, message]).then(
137
- bot, [chatbot, collection, temperature, max_length], chatbot
138
  )
139
  clear.click(clear_bot, None, chatbot)
140
 
 
10
  from langchain.chains import ConversationalRetrievalChain
11
 
12
 
13
+ embeddings = None
14
+ qa_chain = None
15
+
16
+
17
  def load_embeddings():
18
+ global embeddings
19
+
20
+ if not embeddings:
21
+ print("loading embeddings...")
22
+ model_name = os.environ['HUGGINGFACEHUB_EMBEDDINGS_MODEL_NAME']
23
+ embeddings = HuggingFaceInstructEmbeddings(model_name=model_name)
24
+
25
+ return embeddings
26
 
27
 
28
  def split_file(file, chunk_size, chunk_overlap):
 
82
 
83
  def create_qa_chain(collection_name, temperature, max_length):
84
  print('creating qa chain...', collection_name, temperature, max_length)
85
+ if not collection_name:
86
+ return
87
+
88
+ global qa_chain
89
  memory = ConversationBufferMemory(
90
  memory_key="chat_history", return_messages=True)
91
  llm = HuggingFaceHub(
 
93
  model_kwargs={"temperature": temperature, "max_length": max_length}
94
  )
95
  vectordb = load_vectordb(collection_name)
96
+ qa_chain = ConversationalRetrievalChain.from_llm(
97
+ llm=llm, retriever=vectordb.as_retriever(), memory=memory)
98
 
99
 
100
  def submit_message(bot_history, text):
 
102
  return bot_history, ""
103
 
104
 
105
+ def bot(bot_history):
106
+ global qa_chain
107
+ print(qa_chain, bot_history[-1][1])
108
+ result = qa_chain.run(bot_history[-1][0])
109
+ print(result)
110
+ bot_history[-1][1] = result
111
  return bot_history
112
 
113
 
 
136
  choices, value=choices[0] if choices else None, label="Document")
137
  temperature = gr.Slider(
138
  0.0, 1.0, value=0.5, step=0.05, label="Temperature")
139
+ max_length = gr.Slider(
140
+ 20, 1000, value=100, step=10, label="Max Length")
141
 
142
  with gr.Column(scale=0.5):
143
  chatbot = gr.Chatbot([], elem_id="chatbot").style(height=550)
 
145
  show_label=False, placeholder="Ask me anything!")
146
  clear = gr.Button("Clear")
147
 
148
+ process.click(
149
+ process_file,
150
+ [upload, chunk_size, chunk_overlap],
151
+ [result, collection]
152
+ )
153
+
154
+ create_qa_chain(collection.value, temperature.value, max_length.value)
155
+ collection.change(create_qa_chain, [collection, temperature, max_length])
156
+ temperature.change(create_qa_chain, [collection, temperature, max_length])
157
+ max_length.change(create_qa_chain, [collection, temperature, max_length])
158
 
159
  message.submit(submit_message, [chatbot, message], [chatbot, message]).then(
160
+ bot, chatbot, chatbot
161
  )
162
  clear.click(clear_bot, None, chatbot)
163