ppsingh commited on
Commit
2204865
·
verified ·
1 Parent(s): 5f23fcf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -3
app.py CHANGED
@@ -115,7 +115,7 @@ class SessionManager:
115
  # Initialize session manager
116
  session_manager = SessionManager()
117
 
118
- async def chat(query,history,sources,reports,subtype, client_ip=None, session_id = None, request:gr.Request = None):
119
  """taking a query and a message history, use a pipeline (reformulation, retriever, answering)
120
  to yield a tuple of:(messages in gradio format/messages in langchain format, source documents)
121
  """
@@ -138,6 +138,51 @@ async def chat(query,history,sources,reports,subtype, client_ip=None, session_id
138
  #print(f"year:{year}")
139
  docs_html = ""
140
  output_query = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
  ##------------------------fetch collection from vectorstore------------------------------
143
  vectorstore = vectorstores["docling"]
@@ -150,6 +195,29 @@ async def chat(query,history,sources,reports,subtype, client_ip=None, session_id
150
  sources=sources,subtype=subtype)
151
  end_time = time.time()
152
  print("Time for retriever:",end_time - start_time)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  context_retrieved_formatted = "||".join(doc.page_content for doc in context_retrieved)
154
  context_retrieved_lst = [doc.page_content for doc in context_retrieved]
155
 
@@ -633,7 +701,7 @@ with gr.Blocks(title="Audit Q&A", css= "style.css", theme=theme,elem_id = "main-
633
  .submit(get_client_ip_handler, [textbox], [client_ip], api_name="get_ip_textbox")
634
  .then(start_chat, [textbox, chatbot], [textbox, tabs, chatbot], queue=False, api_name="start_chat_textbox")
635
  .then(chat,
636
- [textbox, chatbot, dropdown_sources, dropdown_reports, dropdown_category, client_ip, session_id],
637
  [chatbot, sources_textbox, feedback_state, session_id],
638
  queue=True, concurrency_limit=8, api_name="chat_textbox")
639
  .then(show_feedback, [feedback_state], [feedback_row, feedback_thanks, feedback_state], api_name="show_feedback_textbox")
@@ -643,7 +711,7 @@ with gr.Blocks(title="Audit Q&A", css= "style.css", theme=theme,elem_id = "main-
643
  .change(start_chat, [examples_hidden, chatbot], [textbox, tabs, chatbot], queue=False, api_name="start_chat_examples")
644
  .then(get_client_ip_handler, [examples_hidden], [client_ip], api_name="get_ip_examples")
645
  .then(chat,
646
- [examples_hidden, chatbot, dropdown_sources, dropdown_reports, dropdown_category, client_ip, session_id],
647
  [chatbot, sources_textbox, feedback_state, session_id],
648
  concurrency_limit=8, api_name="chat_examples")
649
  .then(show_feedback, [feedback_state], [feedback_row, feedback_thanks, feedback_state], api_name="show_feedback_examples")
 
115
  # Initialize session manager
116
  session_manager = SessionManager()
117
 
118
+ async def chat(query,history, method, sources,reports,subtype, client_ip=None, session_id = None, request:gr.Request = None):
119
  """taking a query and a message history, use a pipeline (reformulation, retriever, answering)
120
  to yield a tuple of:(messages in gradio format/messages in langchain format, source documents)
121
  """
 
138
  #print(f"year:{year}")
139
  docs_html = ""
140
  output_query = ""
141
+ if method == "Search by Report Name":
142
+ if len(reports) == 0:
143
+ warning_message = "⚠️ **No Report Selected"
144
+ history[-1] = (query, warning_message)
145
+ # Update logs with the warning instead of answer
146
+ logs_data = {
147
+ "record_id": str(uuid4()),
148
+ "session_id": session_id,
149
+ "session_duration_seconds": session_duration,
150
+ "client_location": session_data['location_info'],
151
+ "platform": session_data['platform_info'],
152
+ "question": query,
153
+ "retriever": model_config.get('retriever','MODEL'),
154
+ "endpoint_type": model_config.get('reader','TYPE'),
155
+ "reader": model_config.get('reader','NVIDIA_MODEL'),
156
+ "answer": warning_message,
157
+ "no_results": True # Flag to indicate no results were found
158
+ }
159
+ yield [tuple(x) for x in history], "", logs_data, session_id
160
+ # Save log for the warning response
161
+ save_logs(scheduler, JSON_DATASET_PATH, logs_data)
162
+ return
163
+ else:
164
+ if sources is None and subtype is None:
165
+ warning_message = "⚠️ **No Report Selected"
166
+ history[-1] = (query, warning_message)
167
+ # Update logs with the warning instead of answer
168
+ logs_data = {
169
+ "record_id": str(uuid4()),
170
+ "session_id": session_id,
171
+ "session_duration_seconds": session_duration,
172
+ "client_location": session_data['location_info'],
173
+ "platform": session_data['platform_info'],
174
+ "question": query,
175
+ "retriever": model_config.get('retriever','MODEL'),
176
+ "endpoint_type": model_config.get('reader','TYPE'),
177
+ "reader": model_config.get('reader','NVIDIA_MODEL'),
178
+ "answer": warning_message,
179
+ "no_results": True # Flag to indicate no results were found
180
+ }
181
+ yield [tuple(x) for x in history], "", logs_data, session_id
182
+ # Save log for the warning response
183
+ save_logs(scheduler, JSON_DATASET_PATH, logs_data)
184
+ return
185
+
186
 
187
  ##------------------------fetch collection from vectorstore------------------------------
188
  vectorstore = vectorstores["docling"]
 
195
  sources=sources,subtype=subtype)
196
  end_time = time.time()
197
  print("Time for retriever:",end_time - start_time)
198
+
199
+
200
+ if not context_retrieved or len(context_retrieved) == 0:
201
+ warning_message = "⚠️ **No relevant information was found in the audit reports pertaining your query.** Please try rephrasing your question or selecting different report filters."
202
+ history[-1] = (query, warning_message)
203
+ # Update logs with the warning instead of answer
204
+ logs_data = {
205
+ "record_id": str(uuid4()),
206
+ "session_id": session_id,
207
+ "session_duration_seconds": session_duration,
208
+ "client_location": session_data['location_info'],
209
+ "platform": session_data['platform_info'],
210
+ "question": query,
211
+ "retriever": model_config.get('retriever','MODEL'),
212
+ "endpoint_type": model_config.get('reader','TYPE'),
213
+ "reader": model_config.get('reader','NVIDIA_MODEL'),
214
+ "answer": warning_message,
215
+ "no_results": True # Flag to indicate no results were found
216
+ }
217
+ yield [tuple(x) for x in history], "", logs_data, session_id
218
+ # Save log for the warning response
219
+ save_logs(scheduler, JSON_DATASET_PATH, logs_data)
220
+ return
221
  context_retrieved_formatted = "||".join(doc.page_content for doc in context_retrieved)
222
  context_retrieved_lst = [doc.page_content for doc in context_retrieved]
223
 
 
701
  .submit(get_client_ip_handler, [textbox], [client_ip], api_name="get_ip_textbox")
702
  .then(start_chat, [textbox, chatbot], [textbox, tabs, chatbot], queue=False, api_name="start_chat_textbox")
703
  .then(chat,
704
+ [textbox, chatbot, search_method, dropdown_sources, dropdown_reports, dropdown_category, client_ip, session_id],
705
  [chatbot, sources_textbox, feedback_state, session_id],
706
  queue=True, concurrency_limit=8, api_name="chat_textbox")
707
  .then(show_feedback, [feedback_state], [feedback_row, feedback_thanks, feedback_state], api_name="show_feedback_textbox")
 
711
  .change(start_chat, [examples_hidden, chatbot], [textbox, tabs, chatbot], queue=False, api_name="start_chat_examples")
712
  .then(get_client_ip_handler, [examples_hidden], [client_ip], api_name="get_ip_examples")
713
  .then(chat,
714
+ [examples_hidden, chatbot, searh_method, dropdown_sources, dropdown_reports, dropdown_category, client_ip, session_id],
715
  [chatbot, sources_textbox, feedback_state, session_id],
716
  concurrency_limit=8, api_name="chat_examples")
717
  .then(show_feedback, [feedback_state], [feedback_row, feedback_thanks, feedback_state], api_name="show_feedback_examples")