BaRiDo commited on
Commit
96e47ba
·
verified ·
1 Parent(s): 7239a0e

Update rag.py

Browse files
Files changed (1) hide show
  1. rag.py +190 -0
rag.py CHANGED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import string
4
+ import json
5
+ import gzip
6
+
7
+ import chromadb
8
+ from ibm_watsonx_ai.client import APIClient
9
+ from ibm_watsonx_ai.foundation_models import ModelInference, Rerank
10
+ from ibm_watsonx_ai.foundation_models.embeddings.sentence_transformer_embeddings import SentenceTransformerEmbeddings
11
+
12
+
13
+ def get_credentials():
14
+ """
15
+ Obtain credentials for Watsonx.ai from environment.
16
+ """
17
+ return {
18
+ "url": "https://us-south.ml.cloud.ibm.com",
19
+ "apikey": os.getenv("IBM_API_KEY")
20
+ }
21
+
22
+
23
+ def rerank(client, documents, query, top_n):
24
+ """
25
+ Rerank a list of documents given a query using the Rerank model.
26
+ Returns the documents in a new order (highest relevance first).
27
+ """
28
+ reranker = Rerank(
29
+ model_id="cross-encoder/ms-marco-minilm-l-12-v2",
30
+ api_client=client,
31
+ params={
32
+ "return_options": {
33
+ "top_n": top_n
34
+ },
35
+ "truncate_input_tokens": 512
36
+ }
37
+ )
38
+
39
+ reranked_results = reranker.generate(query=query, inputs=documents)["results"]
40
+
41
+ # Build the new list of documents
42
+ new_documents = []
43
+ for result in reranked_results:
44
+ result_index = result["index"]
45
+ new_documents.append(documents[result_index])
46
+
47
+ return new_documents
48
+
49
+
50
+ def RAGinit():
51
+ """
52
+ Initialize:
53
+ - Watsonx.ai Client
54
+ - Foundation Model
55
+ - Embeddings
56
+ - ChromaDB Collection
57
+ - Vector index properties
58
+ - Top N for query
59
+
60
+ Returns all objects/values needed by RAG_proximity_search.
61
+ """
62
+ # Project/Space from environment
63
+ project_id = os.getenv("IBM_PROJECT_ID")
64
+ space_id = os.getenv("IBM_SPACE_ID")
65
+
66
+ # Watsonx.ai client
67
+ wml_credentials = get_credentials()
68
+ client = APIClient(credentials=wml_credentials, project_id=project_id)
69
+
70
+ # Model Inference
71
+ model_inference_params = {
72
+ "decoding_method": "greedy",
73
+ "max_new_tokens": 900,
74
+ "min_new_tokens": 0,
75
+ "repetition_penalty": 1
76
+ }
77
+ model = ModelInference(
78
+ model_id="ibm/granite-3-8b-instruct",
79
+ params=model_inference_params,
80
+ credentials=get_credentials(),
81
+ project_id=project_id,
82
+ space_id=space_id
83
+ )
84
+
85
+ # Vector index details
86
+ vector_index_id = "14c14504-5f45-4e6c-8f0f-25f2378a1d99"
87
+ vector_index_details = client.data_assets.get_details(vector_index_id)
88
+ vector_index_properties = vector_index_details["entity"]["vector_index"]
89
+
90
+ # Decide how many results to return
91
+ top_n = 20 if vector_index_properties["settings"].get("rerank") \
92
+ else int(vector_index_properties["settings"]["top_k"])
93
+
94
+ # Embedding model
95
+ emb = SentenceTransformerEmbeddings('sentence-transformers/all-MiniLM-L6-v2')
96
+
97
+ # Hydrate ChromaDB with embeddings from the vector index
98
+ chroma_collection = _hydrate_chromadb(client, vector_index_id)
99
+
100
+ return client, model, emb, chroma_collection, vector_index_properties, top_n
101
+
102
+
103
+ def _hydrate_chromadb(client, vector_index_id):
104
+ """
105
+ Helper function to retrieve the stored embedding data from Watsonx.ai,
106
+ then create (or reset) and populate a ChromaDB collection.
107
+ """
108
+ data = client.data_assets.get_content(vector_index_id)
109
+ content = gzip.decompress(data)
110
+ stringified_vectors = content.decode("utf-8")
111
+ vectors = json.loads(stringified_vectors)
112
+
113
+ # Use a Persistent ChromaDB client (on-disk)
114
+ chroma_client = chromadb.PersistentClient(path="./chroma_db")
115
+
116
+ # Create or clear the collection
117
+ collection_name = "my_collection"
118
+ try:
119
+ chroma_client.delete_collection(name=collection_name)
120
+ except:
121
+ print("Collection didn't exist - nothing to do.")
122
+
123
+ collection = chroma_client.create_collection(name=collection_name)
124
+
125
+ # Prepare data for insertion
126
+ vector_embeddings = []
127
+ vector_documents = []
128
+ vector_metadatas = []
129
+ vector_ids = []
130
+
131
+ for vector in vectors:
132
+ embedding = vector["embedding"]
133
+ content = vector["content"]
134
+ metadata = vector["metadata"]
135
+ lines = metadata["loc"]["lines"]
136
+
137
+ vector_embeddings.append(embedding)
138
+ vector_documents.append(content)
139
+
140
+ clean_metadata = {
141
+ "asset_id": metadata["asset_id"],
142
+ "asset_name": metadata["asset_name"],
143
+ "url": metadata["url"],
144
+ "from": lines["from"],
145
+ "to": lines["to"]
146
+ }
147
+ vector_metadatas.append(clean_metadata)
148
+
149
+ # Generate unique ID
150
+ asset_id = metadata["asset_id"]
151
+ random_string = ''.join(random.choices(string.ascii_uppercase + string.digits, k=10))
152
+ doc_id = f"{asset_id}:{lines['from']}-{lines['to']}-{random_string}"
153
+ vector_ids.append(doc_id)
154
+
155
+ # Add all data to the collection
156
+ collection.add(
157
+ embeddings=vector_embeddings,
158
+ documents=vector_documents,
159
+ metadatas=vector_metadatas,
160
+ ids=vector_ids
161
+ )
162
+
163
+ return collection
164
+
165
+
166
+ def RAG_proximity_search(question, client, model, emb, chroma_collection, vector_index_properties, top_n):
167
+ """
168
+ Execute a proximity search in the ChromaDB collection for the given question.
169
+ Optionally rerank results if specified in the vector index properties.
170
+ Returns a concatenated string of best matching documents.
171
+ """
172
+ # Embed query
173
+ query_vectors = emb.embed_query(question)
174
+
175
+ # Query top_n results from ChromaDB
176
+ query_result = chroma_collection.query(
177
+ query_embeddings=query_vectors,
178
+ n_results=top_n,
179
+ include=["documents", "metadatas", "distances"]
180
+ )
181
+
182
+ # Documents come back in ascending distance, so best match is index=0
183
+ documents = query_result["documents"][0]
184
+
185
+ # If rerank is enabled, reorder the documents
186
+ if vector_index_properties["settings"].get("rerank"):
187
+ documents = rerank(client, documents, question, vector_index_properties["settings"]["top_k"])
188
+
189
+ # Return them as a single string
190
+ return "\n".join(documents)