arjunanand13 commited on
Commit
9cea9a4
·
verified ·
1 Parent(s): 71dae17

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -92
app.py CHANGED
@@ -1,24 +1,21 @@
1
  import os
2
- import multiprocessing
3
- import concurrent.futures
4
- from langchain_community.document_loaders import TextLoader, DirectoryLoader
5
- from langchain.text_splitter import RecursiveCharacterTextSplitter
6
- from langchain_community.vectorstores import FAISS
7
- from sentence_transformers import SentenceTransformer
8
- import faiss
9
  import numpy as np
10
- from datetime import datetime
11
- import json
12
  import gradio as gr
13
  import re
14
- from threading import Thread
15
  from openai import OpenAI
 
 
 
 
16
 
17
  class MultiAgentRAG:
18
  def __init__(self, embedding_model_name, openai_model_id, data_folder, api_key=None):
 
19
  self.all_splits = self.load_documents(data_folder)
20
  self.embeddings = SentenceTransformer(embedding_model_name)
21
- self.gpu_index = self.create_faiss_index()
22
  self.openai_client = OpenAI(api_key=api_key or os.environ.get("OPENAI_API_KEY"))
23
  self.openai_model_id = openai_model_id
24
 
@@ -27,10 +24,6 @@ class MultiAgentRAG:
27
  documents = loader.load()
28
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=5000, chunk_overlap=250)
29
  all_splits = text_splitter.split_documents(documents)
30
- print('Length of documents:', len(documents))
31
- print("LEN of all_splits", len(all_splits))
32
- for i in range(min(3, len(all_splits))):
33
- print(all_splits[i].page_content)
34
  return all_splits
35
 
36
  def create_faiss_index(self):
@@ -38,9 +31,12 @@ class MultiAgentRAG:
38
  embeddings = self.embeddings.encode(all_texts, convert_to_tensor=True).cpu().numpy()
39
  index = faiss.IndexFlatL2(embeddings.shape[1])
40
  index.add(embeddings)
41
- gpu_resource = faiss.StandardGpuResources()
42
- gpu_index = faiss.index_cpu_to_gpu(gpu_resource, 0, index)
43
- return gpu_index
 
 
 
44
 
45
  def generate_openai_response(self, messages, max_tokens=1000):
46
  try:
@@ -54,88 +50,49 @@ class MultiAgentRAG:
54
  presence_penalty=0
55
  )
56
  return response.choices[0].message.content
57
- except Exception as e:
58
- print(f"Error in generate_openai_response: {str(e)}")
59
  return "Text generation process encountered an error"
60
 
61
  def retrieval_agent(self, query):
62
  query_embedding = self.embeddings.encode(query, convert_to_tensor=True).cpu().numpy()
63
- distances, indices = self.gpu_index.search(np.array([query_embedding]), k=3)
64
  content = ""
65
- for idx, distance in zip(indices[0], distances[0]):
66
- content += "-" * 50 + "\n"
67
  content += self.all_splits[idx].page_content + "\n"
68
  return content
69
 
70
  def grading_agent(self, query, retrieved_content):
71
  messages = [
72
- {"role": "system", "content": "You are an expert at evaluating the relevance of retrieved content to a query."},
73
- {"role": "user", "content": f"""
74
- Evaluate the relevance of the following retrieved content to the given query:
75
-
76
- Query: {query}
77
-
78
- Retrieved Content:
79
- {retrieved_content}
80
-
81
- Rate the relevance on a scale of 1-10 and explain your rating:
82
- """}
83
  ]
84
-
85
  grading_response = self.generate_openai_response(messages)
86
-
87
- # Extract the numerical rating from the response
88
  match = re.search(r'\b([1-9]|10)\b', grading_response)
89
- rating = int(match.group()) if match else 5 # Default to 5 if no rating found
90
  return rating, grading_response
91
 
92
  def query_rewrite_agent(self, original_query):
93
  messages = [
94
- {"role": "system", "content": "You are an expert at rewriting queries to improve information retrieval results."},
95
- {"role": "user", "content": f"""
96
- The following query did not yield relevant results. Please rewrite it to potentially improve retrieval:
97
-
98
- Original Query: {original_query}
99
-
100
- Rewritten Query:
101
- """}
102
  ]
103
-
104
- rewritten_query = self.generate_openai_response(messages)
105
- return rewritten_query.strip()
106
 
107
  def generation_agent(self, query, retrieved_content):
108
  messages = [
109
- {"role": "system", "content": "You are a knowledgeable assistant with access to a comprehensive database."},
110
- {"role": "user", "content": f"""
111
- I need you to answer my question and provide related information in a specific format.
112
- I have provided five relatable json files {retrieved_content}, choose the most suitable chunks for answering the query.
113
- RETURN ONLY SOLUTION without additional comments, sign-offs, retrived chunks, refrence to any Ticket or extra phrases. Be direct and to the point.
114
- IF THERE IS NO ANSWER RELATABLE IN RETRIEVED CHUNKS, RETURN "NO SOLUTION AVAILABLE".
115
- DO NOT GIVE REFRENCE TO ANY CHUNKS OR TICKETS, BE ON POINT.
116
-
117
- Here's my question:
118
- Query: {query}
119
- Solution==>
120
- """}
121
  ]
122
-
123
  return self.generate_openai_response(messages)
124
 
125
  def run_multi_agent_rag(self, query):
126
- max_iterations = 3
127
- for i in range(max_iterations):
128
  retrieved_content = self.retrieval_agent(query)
129
-
130
  relevance_score, grading_explanation = self.grading_agent(query, retrieved_content)
131
-
132
- if relevance_score >= 7:
133
- answer = self.generation_agent(query, retrieved_content)
134
- return answer, retrieved_content, grading_explanation
135
- else:
136
- query = self.query_rewrite_agent(query)
137
-
138
- return "Unable to find a relevant answer after multiple attempts.", "", "Low relevance across all attempts."
139
 
140
  def qa_infer_gradio(self, query):
141
  answer, retrieved_content, grading_explanation = self.run_multi_agent_rag(query)
@@ -143,25 +100,14 @@ class MultiAgentRAG:
143
 
144
  def launch_interface(doc_retrieval_gen):
145
  css_code = """
146
- .gradio-container {
147
- background-color: #daccdb;
148
- }
149
- button {
150
- background-color: #927fc7;
151
- color: black;
152
- border: 1px solid black;
153
- padding: 10px;
154
- margin-right: 10px;
155
- font-size: 16px;
156
- font-weight: bold;
157
- }
158
  """
159
  EXAMPLES = [
160
  "On which devices can the VIP and CSI2 modules operate simultaneously?",
161
  "I'm using Code Composer Studio 5.4.0.00091 and enabled FPv4SPD16 floating point support for CortexM4 in TDA2. However, after building the project, the .asm file shows --float_support=vfplib instead of FPv4SPD16. Why is this happening?",
162
- "Could you clarify the maximum number of cameras that can be connected simultaneously to the video input ports on the TDA2x SoC, considering it supports up to 10 multiplexed input ports and includes 3 dedicated video input modules?"
163
  ]
164
-
165
  interface = gr.Interface(
166
  fn=doc_retrieval_gen.qa_infer_gradio,
167
  inputs=[gr.Textbox(label="QUERY", placeholder="Enter your query here")],
@@ -172,13 +118,14 @@ def launch_interface(doc_retrieval_gen):
172
  css=css_code,
173
  title="TI E2E FORUM Multi-Agent RAG"
174
  )
175
-
176
  interface.launch(debug=True)
177
 
178
  if __name__ == "__main__":
179
  embedding_model_name = 'flax-sentence-embeddings/all_datasets_v3_MiniLM-L12'
180
- openai_model_id = "gpt-4-turbo"
181
  data_folder = 'sample_embedding_folder2'
182
-
183
- multi_agent_rag = MultiAgentRAG(embedding_model_name, openai_model_id, data_folder)
184
- launch_interface(multi_agent_rag)
 
 
 
1
  import os
2
+ import torch.cuda
 
 
 
 
 
 
3
  import numpy as np
4
+ import faiss
 
5
  import gradio as gr
6
  import re
 
7
  from openai import OpenAI
8
+ from langchain_community.document_loaders import TextLoader, DirectoryLoader
9
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
10
+ from langchain_community.vectorstores import FAISS
11
+ from sentence_transformers import SentenceTransformer
12
 
13
  class MultiAgentRAG:
14
  def __init__(self, embedding_model_name, openai_model_id, data_folder, api_key=None):
15
+ self.use_gpu = torch.cuda.is_available()
16
  self.all_splits = self.load_documents(data_folder)
17
  self.embeddings = SentenceTransformer(embedding_model_name)
18
+ self.faiss_index = self.create_faiss_index()
19
  self.openai_client = OpenAI(api_key=api_key or os.environ.get("OPENAI_API_KEY"))
20
  self.openai_model_id = openai_model_id
21
 
 
24
  documents = loader.load()
25
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=5000, chunk_overlap=250)
26
  all_splits = text_splitter.split_documents(documents)
 
 
 
 
27
  return all_splits
28
 
29
  def create_faiss_index(self):
 
31
  embeddings = self.embeddings.encode(all_texts, convert_to_tensor=True).cpu().numpy()
32
  index = faiss.IndexFlatL2(embeddings.shape[1])
33
  index.add(embeddings)
34
+ try:
35
+ gpu_resource = faiss.StandardGpuResources()
36
+ gpu_index = faiss.index_cpu_to_gpu(gpu_resource, 0, index)
37
+ return gpu_index
38
+ except:
39
+ return index
40
 
41
  def generate_openai_response(self, messages, max_tokens=1000):
42
  try:
 
50
  presence_penalty=0
51
  )
52
  return response.choices[0].message.content
53
+ except:
 
54
  return "Text generation process encountered an error"
55
 
56
  def retrieval_agent(self, query):
57
  query_embedding = self.embeddings.encode(query, convert_to_tensor=True).cpu().numpy()
58
+ distances, indices = self.faiss_index.search(np.array([query_embedding]), k=3)
59
  content = ""
60
+ for idx in indices[0]:
 
61
  content += self.all_splits[idx].page_content + "\n"
62
  return content
63
 
64
  def grading_agent(self, query, retrieved_content):
65
  messages = [
66
+ {"role": "system", "content": "You are an expert at evaluating relevance."},
67
+ {"role": "user", "content": f"Query: {query}\nRetrieved Content:\n{retrieved_content}\nRate the relevance on a scale of 1-10."}
 
 
 
 
 
 
 
 
 
68
  ]
 
69
  grading_response = self.generate_openai_response(messages)
 
 
70
  match = re.search(r'\b([1-9]|10)\b', grading_response)
71
+ rating = int(match.group()) if match else 5
72
  return rating, grading_response
73
 
74
  def query_rewrite_agent(self, original_query):
75
  messages = [
76
+ {"role": "system", "content": "You are an expert at rewriting queries."},
77
+ {"role": "user", "content": f"Original Query: {original_query}\nRewritten Query:"}
 
 
 
 
 
 
78
  ]
79
+ return self.generate_openai_response(messages).strip()
 
 
80
 
81
  def generation_agent(self, query, retrieved_content):
82
  messages = [
83
+ {"role": "system", "content": "You are a knowledgeable assistant."},
84
+ {"role": "user", "content": f"Query: {query}\nSolution==>"}
 
 
 
 
 
 
 
 
 
 
85
  ]
 
86
  return self.generate_openai_response(messages)
87
 
88
  def run_multi_agent_rag(self, query):
89
+ for _ in range(3):
 
90
  retrieved_content = self.retrieval_agent(query)
 
91
  relevance_score, grading_explanation = self.grading_agent(query, retrieved_content)
92
+ if relevance_score >= 7:
93
+ return self.generation_agent(query, retrieved_content), retrieved_content, grading_explanation
94
+ query = self.query_rewrite_agent(query)
95
+ return "Unable to find a relevant answer.", "", "Low relevance across all attempts."
 
 
 
 
96
 
97
  def qa_infer_gradio(self, query):
98
  answer, retrieved_content, grading_explanation = self.run_multi_agent_rag(query)
 
100
 
101
  def launch_interface(doc_retrieval_gen):
102
  css_code = """
103
+ .gradio-container { background-color: #daccdb; }
104
+ button { background-color: #927fc7; color: black; border: 1px solid black; padding: 10px; margin-right: 10px; font-size: 16px; font-weight: bold; }
 
 
 
 
 
 
 
 
 
 
105
  """
106
  EXAMPLES = [
107
  "On which devices can the VIP and CSI2 modules operate simultaneously?",
108
  "I'm using Code Composer Studio 5.4.0.00091 and enabled FPv4SPD16 floating point support for CortexM4 in TDA2. However, after building the project, the .asm file shows --float_support=vfplib instead of FPv4SPD16. Why is this happening?",
109
+ "Could you clarify the maximum number of cameras that can be connected simultaneously to the video input ports on the TDA2x SoC?"
110
  ]
 
111
  interface = gr.Interface(
112
  fn=doc_retrieval_gen.qa_infer_gradio,
113
  inputs=[gr.Textbox(label="QUERY", placeholder="Enter your query here")],
 
118
  css=css_code,
119
  title="TI E2E FORUM Multi-Agent RAG"
120
  )
 
121
  interface.launch(debug=True)
122
 
123
  if __name__ == "__main__":
124
  embedding_model_name = 'flax-sentence-embeddings/all_datasets_v3_MiniLM-L12'
125
+ openai_model_id = "gpt-4-turbo"
126
  data_folder = 'sample_embedding_folder2'
127
+ try:
128
+ multi_agent_rag = MultiAgentRAG(embedding_model_name, openai_model_id, data_folder)
129
+ launch_interface(multi_agent_rag)
130
+ except Exception as e:
131
+ print(f"Error initializing Multi-Agent RAG: {str(e)}")