BaRiDo commited on
Commit
4be0291
Β·
verified Β·
1 Parent(s): f94d355

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +171 -29
app.py CHANGED
@@ -1,35 +1,177 @@
 
 
 
 
 
1
  import streamlit as st
2
- from rag import RAGinit, RAG_proximity_search
3
 
4
- # Wrap RAGinit() inside a function that shows a spinner
5
- def load_resources():
6
- with st.spinner("Loading resources..."):
7
- client, model, emb, chroma_collection, vector_index_properties, top_n = RAGinit()
8
- return client, model, emb, chroma_collection, vector_index_properties, top_n
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- # Initialize everything once at startup
11
- client, model, emb, chroma_collection, vector_index_properties, top_n = load_resources()
12
 
13
- def main():
14
- st.title("RAG-based QA App")
 
 
 
 
15
 
16
- question = st.text_input("Ask a question:")
 
 
 
 
17
 
18
- if st.button("Search"):
19
- if question.strip():
20
- answer = RAG_proximity_search(
21
- question,
22
- client,
23
- model,
24
- emb,
25
- chroma_collection,
26
- vector_index_properties,
27
- top_n
28
- )
29
- st.markdown("**Answer:**")
30
- st.write(answer)
31
- else:
32
- st.warning("Please enter a question before searching.")
33
-
34
- if __name__ == "__main__":
35
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import getpass
3
+
4
+ import sentence_transformers
5
+
6
  import streamlit as st
 
7
 
8
+ VECTOR_DB ="c8af7dfa-bcad-46e5-b69d-cd85ce9315d1"
9
+
10
+ def get_credentials():
11
+ return {
12
+ "url" : "https://us-south.ml.cloud.ibm.com",
13
+ "apikey" : os.getenv("IBM_API_KEY")
14
+ }
15
+
16
+ model_id = "ibm/granite-3-8b-instruct"
17
+
18
+ parameters = {
19
+ "decoding_method": "greedy",
20
+ "max_new_tokens": 900,
21
+ "min_new_tokens": 0,
22
+ "repetition_penalty": 1
23
+ }
24
+
25
+ project_id = os.getenv("IBM_PROJECT_ID")
26
+ space_id = os.getenv("IBM_SPACE_ID")
27
+
28
+ from ibm_watsonx_ai.foundation_models import ModelInference
29
+
30
+ model = ModelInference(
31
+ model_id = model_id,
32
+ params = parameters,
33
+ credentials = get_credentials(),
34
+ project_id = project_id,
35
+ space_id = space_id
36
+ )
37
+
38
+ from ibm_watsonx_ai.client import APIClient
39
+
40
+ wml_credentials = get_credentials()
41
+ client = APIClient(credentials=wml_credentials, project_id=project_id) #, space_id=space_id)
42
+
43
+ vector_index_id = VECTOR_DB
44
+ vector_index_details = client.data_assets.get_details(vector_index_id)
45
+ vector_index_properties = vector_index_details["entity"]["vector_index"]
46
+
47
+ top_n = 20 if vector_index_properties["settings"].get("rerank") else int(vector_index_properties["settings"]["top_k"])
48
+
49
+ def rerank( client, documents, query, top_n ):
50
+ from ibm_watsonx_ai.foundation_models import Rerank
51
+
52
+ reranker = Rerank(
53
+ model_id="cross-encoder/ms-marco-minilm-l-12-v2",
54
+ api_client=client,
55
+ params={
56
+ "return_options": {
57
+ "top_n": top_n
58
+ },
59
+ "truncate_input_tokens": 512
60
+ }
61
+ )
62
+
63
+ reranked_results = reranker.generate(query=query, inputs=documents)["results"]
64
+
65
+ new_documents = []
66
+
67
+ for result in reranked_results:
68
+ result_index = result["index"]
69
+ new_documents.append(documents[result_index])
70
+
71
+ return new_documents
72
+
73
+ from ibm_watsonx_ai.foundation_models.embeddings.sentence_transformer_embeddings import SentenceTransformerEmbeddings
74
 
75
+ emb = SentenceTransformerEmbeddings('sentence-transformers/all-MiniLM-L6-v2')
 
76
 
77
+ import subprocess
78
+ import gzip
79
+ import json
80
+ import chromadb
81
+ import random
82
+ import string
83
 
84
+ def hydrate_chromadb():
85
+ data = client.data_assets.get_content(vector_index_id)
86
+ content = gzip.decompress(data)
87
+ stringified_vectors = str(content, "utf-8")
88
+ vectors = json.loads(stringified_vectors)
89
 
90
+ #chroma_client = chromadb.Client()
91
+ #chroma_client = chromadb.InMemoryClient()
92
+ chroma_client = chromadb.PersistentClient(path="./chroma_db")
93
+
94
+ # make sure collection is empty if it already existed
95
+ collection_name = "my_collection"
96
+ try:
97
+ collection = chroma_client.delete_collection(name=collection_name)
98
+ except:
99
+ print("Collection didn't exist - nothing to do.")
100
+ collection = chroma_client.create_collection(name=collection_name)
101
+
102
+ vector_embeddings = []
103
+ vector_documents = []
104
+ vector_metadatas = []
105
+ vector_ids = []
106
+
107
+ for vector in vectors:
108
+ vector_embeddings.append(vector["embedding"])
109
+ vector_documents.append(vector["content"])
110
+ metadata = vector["metadata"]
111
+ lines = metadata["loc"]["lines"]
112
+ clean_metadata = {}
113
+ clean_metadata["asset_id"] = metadata["asset_id"]
114
+ clean_metadata["asset_name"] = metadata["asset_name"]
115
+ clean_metadata["url"] = metadata["url"]
116
+ clean_metadata["from"] = lines["from"]
117
+ clean_metadata["to"] = lines["to"]
118
+ vector_metadatas.append(clean_metadata)
119
+ asset_id = vector["metadata"]["asset_id"]
120
+ random_string = ''.join(random.choices(string.ascii_uppercase + string.digits, k=10))
121
+ id = "{}:{}-{}-{}".format(asset_id, lines["from"], lines["to"], random_string)
122
+ vector_ids.append(id)
123
+
124
+ collection.add(
125
+ embeddings=vector_embeddings,
126
+ documents=vector_documents,
127
+ metadatas=vector_metadatas,
128
+ ids=vector_ids
129
+ )
130
+ return collection
131
+
132
+ chroma_collection = hydrate_chromadb()
133
+
134
+ def proximity_search( question ):
135
+ query_vectors = emb.embed_query(question)
136
+ query_result = chroma_collection.query(
137
+ query_embeddings=query_vectors,
138
+ n_results=top_n,
139
+ include=["documents", "metadatas", "distances"]
140
+ )
141
+
142
+ documents = list(reversed(query_result["documents"][0]))
143
+
144
+ if vector_index_properties["settings"].get("rerank"):
145
+ documents = rerank(client, documents, question, vector_index_properties["settings"]["top_k"])
146
+
147
+ return "\n".join(documents)
148
+
149
+ # Streamlit UI
150
+ st.title("πŸ” IBM Watson RAG Chatbot")
151
+
152
+ # User input in Streamlit
153
+ question = st.text_input("Enter your question:")
154
+
155
+ if question:
156
+ # Retrieve relevant grounding context
157
+ grounding = proximity_search(question)
158
+
159
+ # Format the question with retrieved context
160
+ formatted_question = f"""<|start_of_role|>user<|end_of_role|>Use the following pieces of context to answer the question.
161
+ {grounding}
162
+ Question: {question}<|end_of_text|>
163
+ <|start_of_role|>assistant<|end_of_role|>"""
164
+
165
+ # Placeholder for a prompt input (Optional)
166
+ prompt_input = "" # Set this dynamically if needed
167
+ prompt = f"""{prompt_input}{formatted_question}"""
168
+
169
+ # Simulated AI response (Replace with actual model call)
170
+ generated_response = f"AI Response based on: {prompt}"
171
+
172
+ # Display results
173
+ st.subheader("πŸ“Œ Retrieved Context")
174
+ st.write(grounding)
175
+
176
+ st.subheader("πŸ€– AI Response")
177
+ st.write(generated_response)