awacke1 commited on
Commit
427f085
ยท
verified ยท
1 Parent(s): cb6c7be

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -195
app.py CHANGED
@@ -6,82 +6,75 @@ from huggingface_hub import InferenceClient
6
  import re
7
  from datetime import datetime
8
  import json
 
9
  import arxiv
10
  from utils import get_md_text_abstract, search_cleaner, get_arxiv_live_search
11
- import os
12
- import glob
13
 
14
- # ๐ŸŽ›๏ธ App configuration - tweak these knobs for maximum brain power! ๐Ÿง ๐Ÿ’ช
15
- retrieve_results = 20
16
- show_examples = True
17
- llm_models_to_choose = ['mistralai/Mixtral-8x7B-Instruct-v0.1', 'mistralai/Mistral-7B-Instruct-v0.2', 'google/gemma-7b-it', 'None']
 
18
 
19
- # ๐ŸŽญ LLM acting instructions - "To be, or not to be... verbose" ๐Ÿค”
20
  generate_kwargs = dict(
21
  temperature = None,
22
  max_new_tokens = 512,
23
  top_p = None,
24
  do_sample = False,
25
- )
26
 
27
- # ๐Ÿง™โ€โ™‚๏ธ Summoning the RAG model - "Accio knowledge!" ๐Ÿ“šโœจ
28
  RAG = RAGPretrainedModel.from_index("colbert/indexes/arxiv_colbert")
29
 
30
  try:
31
- gr.Info("๐Ÿ—๏ธ Setting up the knowledge retriever, please wait... ๐Ÿ•ฐ๏ธ")
32
- rag_initial_output = RAG.search("What is Generative AI in Healthcare?", k = 1)
33
- gr.Info("๐ŸŽ‰ Retriever is up and running! Time to flex those brain muscles! ๐Ÿ’ช๐Ÿง ")
 
34
  except:
35
- gr.Warning("๐Ÿ˜ฑ Oh no! The retriever took a coffee break. Try again later! โ˜•")
36
 
37
- # ๐Ÿ“œ The grand introduction - roll out the red carpet! ๐ŸŽญ
38
- mark_text = '# ๐Ÿฉบ๐Ÿ” Search Results\n'
39
- header_text = "## ๐Ÿ“šArxiv๐Ÿ“–Paper๐Ÿ”Search - ๐Ÿ•ต๏ธโ€โ™€๏ธ Uncover, ๐Ÿ“ Summarize, and ๐Ÿงฉ Solve ๐Ÿ”ฌ Research ๐Ÿค”โ“ Puzzles โœ๏ธ with ๐Ÿ“š Papers and ๐Ÿค– RAG AI ๐Ÿง \n"
40
 
41
- # ๐Ÿ•ฐ๏ธ Time travel to find when our knowledge was last updated ๐Ÿš€
42
  try:
43
- with open("README.md", "r") as f:
44
- mdfile = f.read()
45
- date_pattern = r'Index Last Updated : \d{4}-\d{2}-\d{2}'
46
- match = re.search(date_pattern, mdfile)
47
- date = match.group().split(': ')[1]
48
- formatted_date = datetime.strptime(date, '%Y-%m-%d').strftime('%d %b %Y')
49
- header_text += f'Index Last Updated: {formatted_date}\n'
50
- index_info = f"Semantic Search - up to {formatted_date}"
51
  except:
52
- index_info = "Semantic Search"
53
 
54
- database_choices = [index_info, 'Arxiv Search - Latest - (EXPERIMENTAL)']
55
 
56
- # ๐Ÿฆ‰ Arxiv API - the wise old owl of academic knowledge ๐Ÿ“œ
57
  arx_client = arxiv.Client()
58
  is_arxiv_available = True
59
- check_arxiv_result = get_arxiv_live_search("What is Self Rewarding AI and how can it be used in Multi-Agent Systems?", arx_client, retrieve_results)
60
  if len(check_arxiv_result) == 0:
61
- is_arxiv_available = False
62
- print("๐Ÿ˜ด Arxiv search is taking a nap, switching to default search ...")
63
- database_choices = [index_info]
 
 
64
 
65
- # ๐ŸŽญ Show examples - a teaser trailer for your brain! ๐Ÿฟ๐Ÿง 
66
- sample_outputs = {
67
- 'output_placeholder': 'The LLM will provide an answer to your question here...',
68
- 'search_placeholder': '''
69
- 1. What is MoE?
70
- 2. What are Multi Agent Systems?
71
- 3. What is Self Rewarding AI?
72
- 4. What is Semantic and Episodic memory?
73
- 5. What is AutoGen?
74
- 6. What is ChatDev?
75
- 7. What is Omniverse?
76
- 8. What is Lumiere?
77
- 9. What is SORA?
78
- '''
79
- }
80
 
81
- output_placeholder = sample_outputs['output_placeholder']
82
- md_text_initial = sample_outputs['search_placeholder']
83
 
84
- # ๐Ÿงน Clean up the RAG output - nobody likes a messy mind! ๐Ÿงผ๐Ÿง 
85
  def rag_cleaner(inp):
86
  rank = inp['rank']
87
  title = inp['document_metadata']['title']
@@ -89,20 +82,19 @@ def rag_cleaner(inp):
89
  date = inp['document_metadata']['_time']
90
  return f"{rank}. <b> {title} </b> \n Date : {date} \n Abstract: {content}"
91
 
92
- # ๐ŸŽญ Craft the perfect prompt - it's showtime for the LLM! ๐ŸŽฌ
93
  def get_prompt_text(question, context, formatted = True, llm_model_picked = 'mistralai/Mistral-7B-Instruct-v0.2'):
94
  if formatted:
95
- sys_instruction = f"Context:\n {context} \n Given the following scientific paper abstracts, take a deep breath and let's think step by step to answer the question. Cite the titles of your sources when answering, do not cite links or dates."
96
- message = f"Question: {question}"
97
 
98
- if 'mistralai' in llm_model_picked:
99
- return f"<s>" + f"[INST] {sys_instruction}" + f" {message}[/INST]"
100
- elif 'gemma' in llm_model_picked:
101
- return f"<bos><start_of_turn>user\n{sys_instruction}" + f" {message}<end_of_turn>\n"
 
102
 
103
- return f"Context:\n {context} \n Given the following info, take a deep breath and let's think step by step to answer the question: {question}. Cite the titles of your sources when answering.\n\n"
104
 
105
- # ๐Ÿ•ต๏ธโ€โ™€๏ธ Get those juicy references - time to go treasure hunting! ๐Ÿ’Ž๐Ÿ“š
106
  def get_references(question, retriever, k = retrieve_results):
107
  rag_out = retriever.search(query=question, k=k)
108
  return rag_out
@@ -110,169 +102,85 @@ def get_references(question, retriever, k = retrieve_results):
110
  def get_rag(message):
111
  return get_references(message, RAG)
112
 
113
- # ๐ŸŽค Save the response and read it aloud - it's karaoke time for your brain! ๐Ÿง ๐ŸŽถ
114
- def SaveResponseAndRead(result):
115
- documentHTML5='''
116
- <!DOCTYPE html>
117
- <html>
118
- <head>
119
- <title>Read It Aloud</title>
120
- <script type="text/javascript">
121
- function readAloud() {
122
- const text = document.getElementById("textArea").value;
123
- const speech = new SpeechSynthesisUtterance(text);
124
- window.speechSynthesis.speak(speech);
125
- }
126
- </script>
127
- </head>
128
- <body>
129
- <h1>๐Ÿ”Š Read It Aloud</h1>
130
- <textarea id="textArea" rows="10" cols="80">
131
- '''
132
- documentHTML5 = documentHTML5 + result
133
- documentHTML5 = documentHTML5 + '''
134
- </textarea>
135
- <br>
136
- <button onclick="readAloud()">๐Ÿ”Š Read Aloud</button>
137
- </body>
138
- </html>
139
- '''
140
- gr.HTML(documentHTML5)
141
-
142
- # ๐Ÿ“ File management functions - because even AI needs a filing system! ๐Ÿ—„๏ธ๐Ÿค–
143
-
144
- def save_response_as_markdown(question, response):
145
- timestamp = datetime.now().strftime("%Y%m%d%H%M")
146
- filename = f"{timestamp}_{question[:50]}.md" # Truncate question to 50 chars for filename
147
- with open(filename, "w", encoding="utf-8") as f:
148
- f.write(response)
149
- return filename
150
-
151
- def list_markdown_files():
152
- files = glob.glob("*.md")
153
- files.sort(key=os.path.getmtime, reverse=True)
154
- return [f for f in files if f != "README.md"]
155
-
156
- def delete_file(filename):
157
- if filename != "README.md":
158
- os.remove(filename)
159
- return f"Deleted {filename}"
160
- return "Cannot delete README.md"
161
-
162
- def display_markdown_contents():
163
- files = list_markdown_files()
164
- output = ""
165
- for file in files:
166
- with open(file, "r", encoding="utf-8") as f:
167
- content = f.read()
168
- output += f"## {file}\n\n```markdown\n{content}\n```\n\n"
169
- return output
170
-
171
- # ๐ŸŽจ Building the UI - it's like LEGO, but for brains! ๐Ÿง ๐Ÿ—๏ธ
172
  with gr.Blocks(theme = gr.themes.Soft()) as demo:
173
  header = gr.Markdown(header_text)
174
 
175
  with gr.Group():
176
- msg = gr.Textbox(label = 'Search', placeholder = 'What is Generative AI in Healthcare?')
177
 
178
- with gr.Accordion("Advanced Settings", open=False):
179
- with gr.Row(equal_height = True):
180
- llm_model = gr.Dropdown(choices = llm_models_to_choose, value = 'mistralai/Mistral-7B-Instruct-v0.2', label = 'LLM Model')
181
- llm_results = gr.Slider(minimum=4, maximum=10, value=5, step=1, interactive=True, label="Top n results as context")
182
- database_src = gr.Dropdown(choices = database_choices, value = index_info, label = 'Search Source')
183
- stream_results = gr.Checkbox(value = True, label = "Stream output", visible = False)
184
 
185
  output_text = gr.Textbox(show_label = True, container = True, label = 'LLM Answer', visible = True, placeholder = output_placeholder)
186
  input = gr.Textbox(show_label = False, visible = False)
187
  gr_md = gr.Markdown(mark_text + md_text_initial)
188
 
189
- with gr.Tab("Saved Responses"):
190
- refresh_button = gr.Button("๐Ÿ”„ Refresh File List")
191
- file_list = gr.Dropdown(choices=list_markdown_files(), label="Saved Responses")
192
- delete_button = gr.Button("๐Ÿ—‘๏ธ Delete Selected File")
193
- markdown_display = gr.Markdown()
194
-
195
- # ๐Ÿ”„ Update the file list - keeping things fresh! ๐ŸŒฟ
196
- def update_file_list():
197
- return gr.Dropdown(choices=list_markdown_files())
198
-
199
- refresh_button.click(update_file_list, outputs=[file_list])
200
- delete_button.click(delete_file, inputs=[file_list], outputs=[markdown_display]).then(update_file_list, outputs=[file_list])
201
- file_list.change(lambda x: open(x, "r", encoding="utf-8").read() if x else "", inputs=[file_list], outputs=[markdown_display])
202
-
203
- # ๐ŸŽญ The grand finale - where the magic happens! ๐ŸŽฉโœจ
204
  def update_with_rag_md(message, llm_results_use = 5, database_choice = index_info, llm_model_picked = 'mistralai/Mistral-7B-Instruct-v0.2'):
205
  prompt_text_from_data = ""
206
  database_to_use = database_choice
207
  if database_choice == index_info:
208
- rag_out = get_rag(message)
209
  else:
210
- arxiv_search_success = True
211
- try:
212
- rag_out = get_arxiv_live_search(message, arx_client, retrieve_results)
213
- if len(rag_out) == 0:
214
- arxiv_search_success = False
215
- except:
216
- arxiv_search_success = False
217
-
218
- if not arxiv_search_success:
219
- gr.Warning("๐Ÿ˜ด Arxiv Search is taking a siesta, switching to semantic search ...")
220
- rag_out = get_rag(message)
221
- database_to_use = index_info
 
222
 
223
  md_text_updated = mark_text
224
  for i in range(retrieve_results):
225
- rag_answer = rag_out[i]
226
- if i < llm_results_use:
227
- md_text_paper, prompt_text = get_md_text_abstract(rag_answer, source = database_to_use, return_prompt_formatting = True)
228
- prompt_text_from_data += f"{i+1}. {prompt_text}"
229
- else:
230
- md_text_paper = get_md_text_abstract(rag_answer, source = database_to_use)
231
- md_text_updated += md_text_paper
232
  prompt = get_prompt_text(message, prompt_text_from_data, llm_model_picked = llm_model_picked)
233
  return md_text_updated, prompt
234
 
235
- # ๐Ÿง  Asking the LLM - it's like a really smart magic 8-ball! ๐ŸŽฑโœจ
236
  def ask_llm(prompt, llm_model_picked = 'mistralai/Mistral-7B-Instruct-v0.2', stream_outputs = False):
237
- model_disabled_text = "LLM Model is taking a vacation. Try again later! ๐Ÿ–๏ธ"
238
- output = ""
239
 
240
- if llm_model_picked == 'None':
241
- if stream_outputs:
242
- for out in model_disabled_text:
243
- output += out
244
- yield output
245
- return output
246
- else:
247
- return model_disabled_text
248
 
249
- client = InferenceClient(llm_model_picked)
250
- try:
251
- stream = client.text_generation(prompt, **generate_kwargs, stream=stream_outputs, details=False, return_full_text=False)
252
 
253
- except:
254
- gr.Warning("๐Ÿšฆ LLM Inference hit a traffic jam! Take a breather and try again later.")
255
- return ""
256
 
257
- if stream_outputs:
258
- for response in stream:
259
- output += response
260
- SaveResponseAndRead(response)
261
- yield output
262
- return output
263
- else:
264
- return stream
265
 
266
- # ๐ŸŽฌ Action! Process the query and save the response
267
- def process_and_save(message, llm_results_use, database_choice, llm_model_picked):
268
- md_text_updated, prompt = update_with_rag_md(message, llm_results_use, database_choice, llm_model_picked)
269
- llm_response = ask_llm(prompt, llm_model_picked, stream_outputs=False)
270
- full_response = f"Question: {message}\n\nResponse:\n{llm_response}\n\nReferences:\n{md_text_updated}"
271
- filename = save_response_as_markdown(message, full_response)
272
- return md_text_updated, prompt, llm_response, filename
273
 
274
- # ๐ŸŽฌ Lights, camera, action! Let's get this show on the road! ๐Ÿš€
275
- msg.submit(process_and_save, [msg, llm_results, database_src, llm_model], [gr_md, input, output_text, file_list]).then(update_file_list, outputs=[file_list])
276
 
277
- # ๐ŸŽ‰ Launch the app - let the knowledge party begin! ๐ŸŽŠ๐Ÿง 
278
  demo.queue().launch()
 
6
  import re
7
  from datetime import datetime
8
  import json
9
+ import os
10
  import arxiv
11
  from utils import get_md_text_abstract, search_cleaner, get_arxiv_live_search
 
 
12
 
13
+ retrieve_results = 10
14
+ show_examples = False
15
+ llm_models_to_choose = ['mistralai/Mixtral-8x7B-Instruct-v0.1','mistralai/Mistral-7B-Instruct-v0.2', 'google/gemma-2-2b-it', 'None']
16
+
17
+ token = os.getenv("HF_TOKEN")
18
 
 
19
  generate_kwargs = dict(
20
  temperature = None,
21
  max_new_tokens = 512,
22
  top_p = None,
23
  do_sample = False,
24
+ )
25
 
26
+ ## RAG Model
27
  RAG = RAGPretrainedModel.from_index("colbert/indexes/arxiv_colbert")
28
 
29
  try:
30
+ gr.Info("Setting up retriever, please wait...")
31
+ rag_initial_output = RAG.search("what is Mistral?", k = 1)
32
+ gr.Info("Retriever working successfully!")
33
+
34
  except:
35
+ gr.Warning("Retriever not working!")
36
 
37
+ ## Header
38
+ mark_text = '# ๐Ÿ” Search Results\n'
39
+ header_text = "# ArXiv CS RAG \n"
40
 
 
41
  try:
42
+ with open("README.md", "r") as f:
43
+ mdfile = f.read()
44
+ date_pattern = r'Index Last Updated : \d{4}-\d{2}-\d{2}'
45
+ match = re.search(date_pattern, mdfile)
46
+ date = match.group().split(': ')[1]
47
+ formatted_date = datetime.strptime(date, '%Y-%m-%d').strftime('%d %b %Y')
48
+ header_text += f'Index Last Updated: {formatted_date}\n'
49
+ index_info = f"Semantic Search - up to {formatted_date}"
50
  except:
51
+ index_info = "Semantic Search"
52
 
53
+ database_choices = [index_info,'Arxiv Search - Latest - (EXPERIMENTAL)']
54
 
55
+ ## Arxiv API
56
  arx_client = arxiv.Client()
57
  is_arxiv_available = True
58
+ check_arxiv_result = get_arxiv_live_search("What is Mistral?", arx_client, retrieve_results)
59
  if len(check_arxiv_result) == 0:
60
+ is_arxiv_available = False
61
+ print("Arxiv search not working, switching to default search ...")
62
+ database_choices = [index_info]
63
+
64
+
65
 
66
+ ## Show examples (disabled)
67
+ if show_examples:
68
+ with open("sample_outputs.json", "r") as f:
69
+ sample_outputs = json.load(f)
70
+ output_placeholder = sample_outputs['output_placeholder']
71
+ md_text_initial = sample_outputs['search_placeholder']
72
+
73
+ else:
74
+ output_placeholder = None
75
+ md_text_initial = ''
 
 
 
 
 
76
 
 
 
77
 
 
78
  def rag_cleaner(inp):
79
  rank = inp['rank']
80
  title = inp['document_metadata']['title']
 
82
  date = inp['document_metadata']['_time']
83
  return f"{rank}. <b> {title} </b> \n Date : {date} \n Abstract: {content}"
84
 
 
85
  def get_prompt_text(question, context, formatted = True, llm_model_picked = 'mistralai/Mistral-7B-Instruct-v0.2'):
86
  if formatted:
87
+ sys_instruction = f"Context:\n {context} \n Given the following scientific paper abstracts, take a deep breath and lets think step by step to answer the question. Cite the titles of your sources when answering, do not cite links or dates."
88
+ message = f"Question: {question}"
89
 
90
+ if 'mistralai' in llm_model_picked:
91
+ return f"<s>" + f"[INST] {sys_instruction}" + f" {message}[/INST]"
92
+
93
+ elif 'gemma' in llm_model_picked:
94
+ return f"<bos><start_of_turn>user\n{sys_instruction}" + f" {message}<end_of_turn>\n"
95
 
96
+ return f"Context:\n {context} \n Given the following info, take a deep breath and lets think step by step to answer the question: {question}. Cite the titles of your sources when answering.\n\n"
97
 
 
98
  def get_references(question, retriever, k = retrieve_results):
99
  rag_out = retriever.search(query=question, k=k)
100
  return rag_out
 
102
  def get_rag(message):
103
  return get_references(message, RAG)
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  with gr.Blocks(theme = gr.themes.Soft()) as demo:
106
  header = gr.Markdown(header_text)
107
 
108
  with gr.Group():
109
+ msg = gr.Textbox(label = 'Search', placeholder = 'What is Mistral?')
110
 
111
+ with gr.Accordion("Advanced Settings", open=False):
112
+ with gr.Row(equal_height = True):
113
+ llm_model = gr.Dropdown(choices = llm_models_to_choose, value = 'mistralai/Mistral-7B-Instruct-v0.2', label = 'LLM Model')
114
+ llm_results = gr.Slider(minimum=4, maximum=10, value=5, step=1, interactive=True, label="Top n results as context")
115
+ database_src = gr.Dropdown(choices = database_choices, value = index_info, label = 'Search Source')
116
+ stream_results = gr.Checkbox(value = True, label = "Stream output", visible = False)
117
 
118
  output_text = gr.Textbox(show_label = True, container = True, label = 'LLM Answer', visible = True, placeholder = output_placeholder)
119
  input = gr.Textbox(show_label = False, visible = False)
120
  gr_md = gr.Markdown(mark_text + md_text_initial)
121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  def update_with_rag_md(message, llm_results_use = 5, database_choice = index_info, llm_model_picked = 'mistralai/Mistral-7B-Instruct-v0.2'):
123
  prompt_text_from_data = ""
124
  database_to_use = database_choice
125
  if database_choice == index_info:
126
+ rag_out = get_rag(message)
127
  else:
128
+ arxiv_search_success = True
129
+ try:
130
+ rag_out = get_arxiv_live_search(message, arx_client, retrieve_results)
131
+ if len(rag_out) == 0:
132
+ arxiv_search_success = False
133
+ except:
134
+ arxiv_search_success = False
135
+
136
+
137
+ if not arxiv_search_success:
138
+ gr.Warning("Arxiv Search not working, switching to semantic search ...")
139
+ rag_out = get_rag(message)
140
+ database_to_use = index_info
141
 
142
  md_text_updated = mark_text
143
  for i in range(retrieve_results):
144
+ rag_answer = rag_out[i]
145
+ if i < llm_results_use:
146
+ md_text_paper, prompt_text = get_md_text_abstract(rag_answer, source = database_to_use, return_prompt_formatting = True)
147
+ prompt_text_from_data += f"{i+1}. {prompt_text}"
148
+ else:
149
+ md_text_paper = get_md_text_abstract(rag_answer, source = database_to_use)
150
+ md_text_updated += md_text_paper
151
  prompt = get_prompt_text(message, prompt_text_from_data, llm_model_picked = llm_model_picked)
152
  return md_text_updated, prompt
153
 
 
154
  def ask_llm(prompt, llm_model_picked = 'mistralai/Mistral-7B-Instruct-v0.2', stream_outputs = False):
155
+ model_disabled_text = "LLM Model is disabled"
156
+ output = ""
157
 
158
+ if llm_model_picked == 'None':
159
+ if stream_outputs:
160
+ for out in model_disabled_text:
161
+ output += out
162
+ yield output
163
+ return output
164
+ else:
165
+ return model_disabled_text
166
 
167
+ client = InferenceClient(llm_model_picked, token = token)
168
+ try:
169
+ stream = client.text_generation(prompt, **generate_kwargs, stream=stream_outputs, details=False, return_full_text=False)
170
 
171
+ except:
172
+ gr.Warning("LLM Inference rate limit reached, try again later!")
173
+ return ""
174
 
175
+ if stream_outputs:
176
+ for response in stream:
177
+ output += response
178
+ yield output
179
+ return output
180
+ else:
181
+ return stream
 
182
 
 
 
 
 
 
 
 
183
 
184
+ msg.submit(update_with_rag_md, [msg, llm_results, database_src, llm_model], [gr_md, input]).success(ask_llm, [input, llm_model, stream_results], output_text)
 
185
 
 
186
  demo.queue().launch()