Abhaykumar04 commited on
Commit
4657135
·
verified ·
1 Parent(s): 2c8dc40

Update llm_retrieval_conversation_rerank.py

Browse files
Files changed (1) hide show
  1. llm_retrieval_conversation_rerank.py +238 -238
llm_retrieval_conversation_rerank.py CHANGED
@@ -1,239 +1,239 @@
1
- import json
2
- import os
3
- from dotenv import load_dotenv
4
- import yaml
5
- from together import Together
6
- from langchain.llms.together import Together as TogetherLLM
7
- from langchain.prompts import PromptTemplate
8
- from langchain.schema.runnable import RunnablePassthrough
9
- from langchain.schema.output_parser import StrOutputParser
10
- from pinecone import Pinecone
11
- from typing import List, Dict
12
- import cohere
13
- load_dotenv()
14
-
15
-
16
- API_FILE_PATH = r"C:\Users\abhay\Analytics Vidhya\API.yml"
17
- COURSES_FILE_PATH = r"C:\Users\abhay\Analytics Vidhya\courses.json"
18
-
19
- # Global list to store conversation history
20
- conversation_history: List[Dict[str, str]] = []
21
-
22
- def load_api_keys(api_file_path):
23
- """Loads API keys from a YAML file."""
24
- with open(api_file_path, 'r') as f:
25
- api_keys = yaml.safe_load(f)
26
- return api_keys
27
-
28
- def generate_query_embedding(query, together_api_key):
29
- """Generates embedding for the user query."""
30
- client = Together(api_key=together_api_key)
31
- response = client.embeddings.create(
32
- model="WhereIsAI/UAE-Large-V1", input=query
33
- )
34
- return response.data[0].embedding
35
-
36
- def initialize_pinecone(pinecone_api_key):
37
- """Initializes Pinecone with API key."""
38
- return Pinecone(api_key=pinecone_api_key)
39
-
40
- def pinecone_similarity_search(pinecone_instance, index_name, query_embedding, top_k=10):
41
- """Performs a similarity search in Pinecone and increase top k for reranking."""
42
- try:
43
- index = pinecone_instance.Index(index_name)
44
- results = index.query(vector=query_embedding, top_k=top_k, include_metadata=True)
45
- if not results.matches:
46
- return None
47
- return results
48
- except Exception as e:
49
- print(f"Error during similarity search: {e}")
50
- return None
51
-
52
- def create_prompt_template():
53
- """Creates a prompt template for LLM."""
54
- template = """You are a helpful AI assistant that provides information on courses.
55
- Based on the following context, conversation history, and new user query,
56
- suggest relevant courses and explain why they might be useful, or respond accordingly if the user query is unrelated.
57
- If no relevant courses are found, please indicate that.
58
-
59
- Conversation History:
60
- {conversation_history}
61
-
62
- Context: {context}
63
- User Query: {query}
64
-
65
- Response: Let me help you find relevant courses based on your query.
66
- """
67
- return PromptTemplate(template=template, input_variables=["context", "query", "conversation_history"])
68
-
69
- def initialize_llm(together_api_key):
70
- """Initializes Together LLM."""
71
- return TogetherLLM(
72
- model="mistralai/Mixtral-8x7B-Instruct-v0.1",
73
- together_api_key=together_api_key,
74
- temperature=0,
75
- max_tokens=250
76
- )
77
-
78
- def create_chain(llm, prompt):
79
- """Creates a chain using the new RunnableSequence approach."""
80
- chain = (
81
- {"context": RunnablePassthrough(), "query": RunnablePassthrough(), "conversation_history": RunnablePassthrough()}
82
- | prompt
83
- | llm
84
- | StrOutputParser()
85
- )
86
- return chain
87
-
88
-
89
- def initialize_cohere_client(cohere_api_key):
90
- """Initializes the Cohere client."""
91
- return cohere.ClientV2(api_key=cohere_api_key)
92
-
93
-
94
- def rerank_results(cohere_client, query, documents, top_n=3):
95
- """Reranks documents using Cohere."""
96
- try:
97
- results = cohere_client.rerank(
98
- query=query,
99
- documents=documents,
100
- top_n=top_n,
101
- model="rerank-english-v3.0",
102
- )
103
- return results
104
- except Exception as e:
105
- print(f"Error reranking results: {e}")
106
- return None
107
-
108
- def generate_llm_response(chain, query, retrieved_data, history, cohere_client):
109
- """Generates an LLM response based on context and conversation history."""
110
- try:
111
- if not retrieved_data or not retrieved_data.matches:
112
- return "I couldn't find any relevant courses matching your query. Please try a different search term."
113
-
114
- # Prepare documents for reranking
115
- documents = []
116
- for match in retrieved_data.matches:
117
- metadata = match.metadata
118
- if metadata:
119
- documents.append(
120
- { "text" :f"Title: {metadata.get('title', 'No title')}\nDescription: {metadata.get('text', 'No description')}\nLink: {metadata.get('course_link', 'No link')}"
121
- }
122
- )
123
-
124
- if not documents:
125
- return "I found some matches but couldn't extract course information. Please try again."
126
-
127
- # Rerank the documents
128
- reranked_results = rerank_results(cohere_client, query, documents)
129
-
130
- if not reranked_results:
131
- return "I couldn't rerank the results, please try again."
132
-
133
- # Prepare context from reranked results
134
- context_parts = []
135
- for result in reranked_results.results:
136
- context_parts.append(documents[result.index]["text"])
137
-
138
- context = "\n\n".join(context_parts)
139
-
140
- # Format conversation history
141
- formatted_history = "\n".join(f"User: {item['user']}\nAssistant: {item['assistant']}" for item in history) if history else "No previous conversation."
142
-
143
- response = chain.invoke({"context": context, "query": query, "conversation_history":formatted_history})
144
- return response
145
-
146
- except Exception as e:
147
- print(f"Error generating response: {e}")
148
- return "I encountered an error while generating the response. Please try again."
149
-
150
-
151
- def check_context_similarity(query_embedding, previous_query_embedding, threshold=0.7):
152
- """Checks if the new query is related to the previous one."""
153
- if not previous_query_embedding:
154
- return False # First query, no previous embedding to compare
155
-
156
- from numpy import dot
157
- from numpy.linalg import norm
158
-
159
- cos_sim = dot(query_embedding, previous_query_embedding) / (norm(query_embedding) * norm(previous_query_embedding))
160
- return cos_sim > threshold
161
-
162
- def main():
163
- global conversation_history
164
- previous_query_embedding = None
165
-
166
- try:
167
-
168
- api_keys = load_api_keys(API_FILE_PATH)
169
- together_api_key = api_keys["together_ai_api_key"]
170
- pinecone_api_key = api_keys["pinecone_api_key"]
171
- index_name = api_keys["pinecone_index_name"]
172
- cohere_api_key = api_keys["cohere_api_key"]
173
- print("Initializing services...")
174
-
175
- # Initialize Pinecone
176
- pinecone_instance = initialize_pinecone(pinecone_api_key)
177
-
178
- # Initialize Together LLM
179
- llm = initialize_llm(together_api_key)
180
-
181
- # Initialize Cohere client
182
- cohere_client = initialize_cohere_client(cohere_api_key)
183
-
184
-
185
-
186
- prompt = create_prompt_template()
187
-
188
- # Create chain
189
- chain = create_chain(llm, prompt)
190
-
191
- print("Ready to process queries!")
192
-
193
- while True:
194
-
195
- user_query = input("\nEnter your query (or 'quit' to exit): ").strip()
196
-
197
- if user_query.lower() == 'quit':
198
- break
199
-
200
- if not user_query:
201
- print("Please enter a valid query.")
202
- continue
203
-
204
- try:
205
- print("Generating query embedding...")
206
- query_embedding = generate_query_embedding(user_query, together_api_key)
207
-
208
- # Check context similarity
209
- if previous_query_embedding and check_context_similarity(query_embedding, previous_query_embedding):
210
- print("Continuing the previous conversation...")
211
- else:
212
- print("Starting a new conversation...")
213
- conversation_history = [] # Clear history for a new conversation
214
-
215
- print("Searching for relevant courses...")
216
- pinecone_results = pinecone_similarity_search(
217
- pinecone_instance, index_name, query_embedding
218
- )
219
-
220
- print("Generating response...")
221
- llm_response = generate_llm_response(chain, user_query, pinecone_results, conversation_history, cohere_client)
222
-
223
- print("\nResponse:")
224
- print(llm_response)
225
- print("\n" + "="*50)
226
-
227
- # Update conversation history
228
- conversation_history.append({"user": user_query, "assistant": llm_response})
229
- previous_query_embedding = query_embedding # Save for next turn
230
-
231
- except Exception as e:
232
- print(f"Error processing query: {e}")
233
- print("Please try again with a different query.")
234
-
235
- except Exception as e:
236
- print(f"An error occurred during initialization: {str(e)}")
237
-
238
- if __name__ == "__main__":
239
  main()
 
1
+ import json
2
+ import os
3
+ from dotenv import load_dotenv
4
+ import yaml
5
+ from together import Together
6
+ from langchain.llms.together import Together as TogetherLLM
7
+ from langchain.prompts import PromptTemplate
8
+ from langchain.schema.runnable import RunnablePassthrough
9
+ from langchain.schema.output_parser import StrOutputParser
10
+ from pinecone import Pinecone
11
+ from typing import List, Dict
12
+ import cohere
13
+ load_dotenv()
14
+
15
+
16
+ API_FILE_PATH = r".\API.yml"
17
+ COURSES_FILE_PATH = r".\courses.json"
18
+
19
+ # Global list to store conversation history
20
+ conversation_history: List[Dict[str, str]] = []
21
+
22
+ def load_api_keys(api_file_path):
23
+ """Loads API keys from a YAML file."""
24
+ with open(api_file_path, 'r') as f:
25
+ api_keys = yaml.safe_load(f)
26
+ return api_keys
27
+
28
+ def generate_query_embedding(query, together_api_key):
29
+ """Generates embedding for the user query."""
30
+ client = Together(api_key=together_api_key)
31
+ response = client.embeddings.create(
32
+ model="WhereIsAI/UAE-Large-V1", input=query
33
+ )
34
+ return response.data[0].embedding
35
+
36
+ def initialize_pinecone(pinecone_api_key):
37
+ """Initializes Pinecone with API key."""
38
+ return Pinecone(api_key=pinecone_api_key)
39
+
40
+ def pinecone_similarity_search(pinecone_instance, index_name, query_embedding, top_k=10):
41
+ """Performs a similarity search in Pinecone and increase top k for reranking."""
42
+ try:
43
+ index = pinecone_instance.Index(index_name)
44
+ results = index.query(vector=query_embedding, top_k=top_k, include_metadata=True)
45
+ if not results.matches:
46
+ return None
47
+ return results
48
+ except Exception as e:
49
+ print(f"Error during similarity search: {e}")
50
+ return None
51
+
52
+ def create_prompt_template():
53
+ """Creates a prompt template for LLM."""
54
+ template = """You are a helpful AI assistant that provides information on courses.
55
+ Based on the following context, conversation history, and new user query,
56
+ suggest relevant courses and explain why they might be useful, or respond accordingly if the user query is unrelated.
57
+ If no relevant courses are found, please indicate that.
58
+
59
+ Conversation History:
60
+ {conversation_history}
61
+
62
+ Context: {context}
63
+ User Query: {query}
64
+
65
+ Response: Let me help you find relevant courses based on your query.
66
+ """
67
+ return PromptTemplate(template=template, input_variables=["context", "query", "conversation_history"])
68
+
69
+ def initialize_llm(together_api_key):
70
+ """Initializes Together LLM."""
71
+ return TogetherLLM(
72
+ model="mistralai/Mixtral-8x7B-Instruct-v0.1",
73
+ together_api_key=together_api_key,
74
+ temperature=0,
75
+ max_tokens=250
76
+ )
77
+
78
+ def create_chain(llm, prompt):
79
+ """Creates a chain using the new RunnableSequence approach."""
80
+ chain = (
81
+ {"context": RunnablePassthrough(), "query": RunnablePassthrough(), "conversation_history": RunnablePassthrough()}
82
+ | prompt
83
+ | llm
84
+ | StrOutputParser()
85
+ )
86
+ return chain
87
+
88
+
89
+ def initialize_cohere_client(cohere_api_key):
90
+ """Initializes the Cohere client."""
91
+ return cohere.ClientV2(api_key=cohere_api_key)
92
+
93
+
94
+ def rerank_results(cohere_client, query, documents, top_n=3):
95
+ """Reranks documents using Cohere."""
96
+ try:
97
+ results = cohere_client.rerank(
98
+ query=query,
99
+ documents=documents,
100
+ top_n=top_n,
101
+ model="rerank-english-v3.0",
102
+ )
103
+ return results
104
+ except Exception as e:
105
+ print(f"Error reranking results: {e}")
106
+ return None
107
+
108
+ def generate_llm_response(chain, query, retrieved_data, history, cohere_client):
109
+ """Generates an LLM response based on context and conversation history."""
110
+ try:
111
+ if not retrieved_data or not retrieved_data.matches:
112
+ return "I couldn't find any relevant courses matching your query. Please try a different search term."
113
+
114
+ # Prepare documents for reranking
115
+ documents = []
116
+ for match in retrieved_data.matches:
117
+ metadata = match.metadata
118
+ if metadata:
119
+ documents.append(
120
+ { "text" :f"Title: {metadata.get('title', 'No title')}\nDescription: {metadata.get('text', 'No description')}\nLink: {metadata.get('course_link', 'No link')}"
121
+ }
122
+ )
123
+
124
+ if not documents:
125
+ return "I found some matches but couldn't extract course information. Please try again."
126
+
127
+ # Rerank the documents
128
+ reranked_results = rerank_results(cohere_client, query, documents)
129
+
130
+ if not reranked_results:
131
+ return "I couldn't rerank the results, please try again."
132
+
133
+ # Prepare context from reranked results
134
+ context_parts = []
135
+ for result in reranked_results.results:
136
+ context_parts.append(documents[result.index]["text"])
137
+
138
+ context = "\n\n".join(context_parts)
139
+
140
+ # Format conversation history
141
+ formatted_history = "\n".join(f"User: {item['user']}\nAssistant: {item['assistant']}" for item in history) if history else "No previous conversation."
142
+
143
+ response = chain.invoke({"context": context, "query": query, "conversation_history":formatted_history})
144
+ return response
145
+
146
+ except Exception as e:
147
+ print(f"Error generating response: {e}")
148
+ return "I encountered an error while generating the response. Please try again."
149
+
150
+
151
+ def check_context_similarity(query_embedding, previous_query_embedding, threshold=0.7):
152
+ """Checks if the new query is related to the previous one."""
153
+ if not previous_query_embedding:
154
+ return False # First query, no previous embedding to compare
155
+
156
+ from numpy import dot
157
+ from numpy.linalg import norm
158
+
159
+ cos_sim = dot(query_embedding, previous_query_embedding) / (norm(query_embedding) * norm(previous_query_embedding))
160
+ return cos_sim > threshold
161
+
162
+ def main():
163
+ global conversation_history
164
+ previous_query_embedding = None
165
+
166
+ try:
167
+
168
+ api_keys = load_api_keys(API_FILE_PATH)
169
+ together_api_key = api_keys["together_ai_api_key"]
170
+ pinecone_api_key = api_keys["pinecone_api_key"]
171
+ index_name = api_keys["pinecone_index_name"]
172
+ cohere_api_key = api_keys["cohere_api_key"]
173
+ print("Initializing services...")
174
+
175
+ # Initialize Pinecone
176
+ pinecone_instance = initialize_pinecone(pinecone_api_key)
177
+
178
+ # Initialize Together LLM
179
+ llm = initialize_llm(together_api_key)
180
+
181
+ # Initialize Cohere client
182
+ cohere_client = initialize_cohere_client(cohere_api_key)
183
+
184
+
185
+
186
+ prompt = create_prompt_template()
187
+
188
+ # Create chain
189
+ chain = create_chain(llm, prompt)
190
+
191
+ print("Ready to process queries!")
192
+
193
+ while True:
194
+
195
+ user_query = input("\nEnter your query (or 'quit' to exit): ").strip()
196
+
197
+ if user_query.lower() == 'quit':
198
+ break
199
+
200
+ if not user_query:
201
+ print("Please enter a valid query.")
202
+ continue
203
+
204
+ try:
205
+ print("Generating query embedding...")
206
+ query_embedding = generate_query_embedding(user_query, together_api_key)
207
+
208
+ # Check context similarity
209
+ if previous_query_embedding and check_context_similarity(query_embedding, previous_query_embedding):
210
+ print("Continuing the previous conversation...")
211
+ else:
212
+ print("Starting a new conversation...")
213
+ conversation_history = [] # Clear history for a new conversation
214
+
215
+ print("Searching for relevant courses...")
216
+ pinecone_results = pinecone_similarity_search(
217
+ pinecone_instance, index_name, query_embedding
218
+ )
219
+
220
+ print("Generating response...")
221
+ llm_response = generate_llm_response(chain, user_query, pinecone_results, conversation_history, cohere_client)
222
+
223
+ print("\nResponse:")
224
+ print(llm_response)
225
+ print("\n" + "="*50)
226
+
227
+ # Update conversation history
228
+ conversation_history.append({"user": user_query, "assistant": llm_response})
229
+ previous_query_embedding = query_embedding # Save for next turn
230
+
231
+ except Exception as e:
232
+ print(f"Error processing query: {e}")
233
+ print("Please try again with a different query.")
234
+
235
+ except Exception as e:
236
+ print(f"An error occurred during initialization: {str(e)}")
237
+
238
+ if __name__ == "__main__":
239
  main()