Spaces:
Sleeping
Sleeping
Update main.py
Browse files
main.py
CHANGED
@@ -2,8 +2,10 @@ import os
|
|
2 |
import tempfile
|
3 |
import zipfile
|
4 |
from typing import List, Optional
|
|
|
|
|
5 |
|
6 |
-
from fastapi import FastAPI, File, UploadFile, HTTPException, Query
|
7 |
from fastapi.responses import FileResponse, StreamingResponse
|
8 |
|
9 |
from llm_initialization import get_llm
|
@@ -27,6 +29,12 @@ MONGO_CLUSTER_URL = os.getenv("CONNECTION_STRING")
|
|
27 |
|
28 |
app = FastAPI(title="VectorStore & Document Management API")
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
# Global variables (initialized on startup)
|
31 |
llm = None
|
32 |
embeddings = None
|
@@ -37,13 +45,10 @@ vector_store_manager = None
|
|
37 |
vector_store = None
|
38 |
k = 3 # Number of documents to retrieve per query
|
39 |
|
40 |
-
# Global MongoDB collection to store retrieval chain configuration per chat session.
|
41 |
-
chat_chains_collection = None
|
42 |
-
|
43 |
# ----------------------- Startup Event -----------------------
|
44 |
@app.on_event("startup")
|
45 |
async def startup_event():
|
46 |
-
global llm, embeddings, chat_manager, document_loader, text_splitter, vector_store_manager, vector_store
|
47 |
|
48 |
print("Starting up: Initializing components...")
|
49 |
|
@@ -66,38 +71,26 @@ async def startup_event():
|
|
66 |
text_splitter = TextSplitter()
|
67 |
print("Document loader and text splitter initialized.")
|
68 |
|
69 |
-
# Initialize vector store manager and
|
70 |
vector_store_manager = VectorStoreManager(embeddings)
|
71 |
-
vector_store = vector_store_manager.vectorstore
|
72 |
print("Vector store initialized.")
|
73 |
|
74 |
-
|
75 |
-
client = MongoClient(MONGO_CLUSTER_URL)
|
76 |
-
db = client[MONGO_DATABASE_NAME]
|
77 |
-
chat_chains_collection = db["chat_chains"]
|
78 |
-
print("Chat chains collection initialized in MongoDB.")
|
79 |
-
|
80 |
-
|
81 |
-
# ----------------------- Root Endpoint -----------------------
|
82 |
-
@app.get("/")
|
83 |
-
def root():
|
84 |
-
"""
|
85 |
-
Root endpoint that returns a welcome message.
|
86 |
-
"""
|
87 |
-
return {"message": "Welcome to the VectorStore & Document Management API!"}
|
88 |
-
|
89 |
-
|
90 |
-
# ----------------------- New Chat Endpoint -----------------------
|
91 |
@app.post("/new_chat")
|
92 |
-
def new_chat():
|
93 |
"""
|
94 |
-
Create a new chat session.
|
95 |
"""
|
96 |
-
new_chat_id =
|
|
|
|
|
|
|
|
|
|
|
97 |
return {"chat_id": new_chat_id}
|
98 |
|
99 |
-
|
100 |
-
# ----------------------- Create Chain Endpoint -----------------------
|
101 |
@app.post("/create_chain")
|
102 |
def create_chain(
|
103 |
chat_id: str = Query(..., description="Existing chat session ID"),
|
@@ -105,9 +98,8 @@ def create_chain(
|
|
105 |
"quiz_solving",
|
106 |
description="Select prompt template. Options: quiz_solving, assignment_solving, paper_solving, quiz_creation, assignment_creation, paper_creation",
|
107 |
),
|
|
|
108 |
):
|
109 |
-
global chat_chains_collection # Ensure we reference the global variable
|
110 |
-
|
111 |
valid_templates = [
|
112 |
"quiz_solving",
|
113 |
"assignment_solving",
|
@@ -119,42 +111,35 @@ def create_chain(
|
|
119 |
if template not in valid_templates:
|
120 |
raise HTTPException(status_code=400, detail="Invalid template selection.")
|
121 |
|
122 |
-
#
|
123 |
-
|
124 |
-
{"
|
|
|
125 |
)
|
126 |
|
127 |
return {"message": "Retrieval chain configuration stored successfully.", "chat_id": chat_id, "template": template}
|
128 |
|
129 |
-
|
130 |
# ----------------------- Chat Endpoint -----------------------
|
131 |
@app.get("/chat")
|
132 |
-
def chat(
|
|
|
|
|
|
|
|
|
133 |
"""
|
134 |
Process a chat query using the retrieval chain associated with the given chat_id.
|
135 |
-
|
136 |
-
This endpoint uses the following code:
|
137 |
-
|
138 |
-
try:
|
139 |
-
stream_generator = retrieval_chain.stream_chat_response(
|
140 |
-
query=query,
|
141 |
-
chat_id=chat_id,
|
142 |
-
get_chat_history=chat_manager.get_chat_history,
|
143 |
-
initialize_chat_history=chat_manager.initialize_chat_history,
|
144 |
-
)
|
145 |
-
except Exception as e:
|
146 |
-
raise HTTPException(status_code=500, detail=f"Error processing chat query: {str(e)}")
|
147 |
-
|
148 |
-
return StreamingResponse(stream_generator, media_type="text/event-stream")
|
149 |
-
|
150 |
-
It first retrieves the configuration from MongoDB, re-creates the chain, and then streams the response.
|
151 |
"""
|
152 |
-
# Retrieve
|
153 |
-
|
154 |
-
|
|
|
|
|
|
|
|
|
|
|
155 |
raise HTTPException(status_code=400, detail="Chat configuration not found. Please create a chain using /create_chain.")
|
156 |
|
157 |
-
template =
|
158 |
if template == "quiz_solving":
|
159 |
prompt = PromptTemplates.get_quiz_solving_prompt()
|
160 |
elif template == "assignment_solving":
|
@@ -170,7 +155,6 @@ def chat(query: str, chat_id: str = Query(..., description="Chat session ID crea
|
|
170 |
else:
|
171 |
raise HTTPException(status_code=400, detail="Invalid chat configuration.")
|
172 |
|
173 |
-
# Re-create the retrieval chain for this chat session.
|
174 |
retrieval_chain = RetrievalChain(
|
175 |
llm,
|
176 |
vector_store.as_retriever(search_kwargs={"k": k}),
|
@@ -188,36 +172,24 @@ def chat(query: str, chat_id: str = Query(..., description="Chat session ID crea
|
|
188 |
except Exception as e:
|
189 |
raise HTTPException(status_code=500, detail=f"Error processing chat query: {str(e)}")
|
190 |
|
191 |
-
|
192 |
-
|
193 |
|
194 |
-
|
195 |
-
from typing import Any, Optional
|
196 |
|
|
|
|
|
197 |
@app.post("/add_document")
|
198 |
async def add_document(
|
199 |
-
file: Optional[
|
200 |
wiki_query: Optional[str] = Query(None),
|
201 |
wiki_url: Optional[str] = Query(None)
|
202 |
):
|
203 |
-
"""
|
204 |
-
Upload a document OR load data from a Wikipedia query or URL.
|
205 |
-
|
206 |
-
- If a file is provided, the document is loaded from the file.
|
207 |
-
- If 'wiki_query' is provided, the Wikipedia page(s) are loaded using document_loader.wikipedia_query.
|
208 |
-
- If 'wiki_url' is provided, the URL is loaded using document_loader.load_urls.
|
209 |
-
|
210 |
-
The loaded document(s) are then split into chunks and added to the vector store.
|
211 |
-
"""
|
212 |
-
# If file is provided but not as an UploadFile (e.g. an empty string), set it to None.
|
213 |
if not isinstance(file, UploadFile):
|
214 |
file = None
|
215 |
|
216 |
-
# Ensure at least one input is provided.
|
217 |
if file is None and wiki_query is None and wiki_url is None:
|
218 |
raise HTTPException(status_code=400, detail="No document input provided (file, wiki_query, or wiki_url).")
|
219 |
|
220 |
-
# Load document(s) based on input priority: file > wiki_query > wiki_url.
|
221 |
if file is not None:
|
222 |
with tempfile.NamedTemporaryFile(delete=False) as tmp:
|
223 |
contents = await file.read()
|
@@ -265,13 +237,8 @@ async def add_document(
|
|
265 |
|
266 |
return {"message": f"Added {len(chunks)} document chunks.", "ids": ids}
|
267 |
|
268 |
-
|
269 |
-
# ----------------------- Delete Document Endpoint -----------------------
|
270 |
@app.post("/delete_document")
|
271 |
def delete_document(ids: List[str]):
|
272 |
-
"""
|
273 |
-
Delete document(s) from the vector store using their IDs.
|
274 |
-
"""
|
275 |
try:
|
276 |
success = vector_store_manager.delete_documents(ids)
|
277 |
except Exception as e:
|
@@ -280,15 +247,8 @@ def delete_document(ids: List[str]):
|
|
280 |
raise HTTPException(status_code=400, detail="Failed to delete documents.")
|
281 |
return {"message": f"Deleted documents with IDs: {ids}"}
|
282 |
|
283 |
-
|
284 |
-
# ----------------------- Save Vectorstore Endpoint -----------------------
|
285 |
@app.get("/save_vectorstore")
|
286 |
def save_vectorstore():
|
287 |
-
"""
|
288 |
-
Save the current vector store locally.
|
289 |
-
If it is a directory, it will be zipped.
|
290 |
-
Returns the file as a downloadable response.
|
291 |
-
"""
|
292 |
try:
|
293 |
save_result = vector_store_manager.save("faiss_index")
|
294 |
except Exception as e:
|
@@ -299,19 +259,12 @@ def save_vectorstore():
|
|
299 |
filename=save_result["serve_filename"],
|
300 |
)
|
301 |
|
302 |
-
|
303 |
-
# ----------------------- Load Vectorstore Endpoint -----------------------
|
304 |
@app.post("/load_vectorstore")
|
305 |
async def load_vectorstore(file: UploadFile = File(...)):
|
306 |
-
"""
|
307 |
-
Load a vector store from an uploaded file (raw or zipped).
|
308 |
-
This will replace the current vector store.
|
309 |
-
"""
|
310 |
tmp_filename = None
|
311 |
try:
|
312 |
-
# Save the uploaded file content to a temporary file.
|
313 |
with tempfile.NamedTemporaryFile(delete=False) as tmp:
|
314 |
-
file_bytes = await file.read()
|
315 |
tmp.write(file_bytes)
|
316 |
tmp_filename = tmp.name
|
317 |
|
@@ -325,22 +278,15 @@ async def load_vectorstore(file: UploadFile = File(...)):
|
|
325 |
vector_store_manager = instance
|
326 |
return {"message": message}
|
327 |
|
328 |
-
|
329 |
-
# ----------------------- Merge Vectorstore Endpoint -----------------------
|
330 |
@app.post("/merge_vectorstore")
|
331 |
async def merge_vectorstore(file: UploadFile = File(...)):
|
332 |
-
"""
|
333 |
-
Merge an uploaded vector store (raw or zipped) into the current vector store.
|
334 |
-
"""
|
335 |
tmp_filename = None
|
336 |
try:
|
337 |
-
# Save the uploaded file content to a temporary file.
|
338 |
with tempfile.NamedTemporaryFile(delete=False) as tmp:
|
339 |
-
file_bytes = await file.read()
|
340 |
tmp.write(file_bytes)
|
341 |
tmp_filename = tmp.name
|
342 |
|
343 |
-
# Pass the filename (a string) to the merge method.
|
344 |
result = vector_store_manager.merge(tmp_filename, embeddings)
|
345 |
except Exception as e:
|
346 |
raise HTTPException(status_code=500, detail=f"Error merging vectorstore: {str(e)}")
|
@@ -349,7 +295,6 @@ async def merge_vectorstore(file: UploadFile = File(...)):
|
|
349 |
os.remove(tmp_filename)
|
350 |
return result
|
351 |
|
352 |
-
|
353 |
if __name__ == "__main__":
|
354 |
import uvicorn
|
355 |
uvicorn.run(app, host="0.0.0.0", port=8000)
|
|
|
2 |
import tempfile
|
3 |
import zipfile
|
4 |
from typing import List, Optional
|
5 |
+
import uuid
|
6 |
+
from datetime import datetime
|
7 |
|
8 |
+
from fastapi import FastAPI, File, UploadFile, HTTPException, Query, Depends
|
9 |
from fastapi.responses import FileResponse, StreamingResponse
|
10 |
|
11 |
from llm_initialization import get_llm
|
|
|
29 |
|
30 |
app = FastAPI(title="VectorStore & Document Management API")
|
31 |
|
32 |
+
# Import auth router and dependencies
|
33 |
+
from auth import router as auth_router, get_current_user, users_collection
|
34 |
+
|
35 |
+
# Mount auth endpoints under /auth
|
36 |
+
app.include_router(auth_router, prefix="/auth")
|
37 |
+
|
38 |
# Global variables (initialized on startup)
|
39 |
llm = None
|
40 |
embeddings = None
|
|
|
45 |
vector_store = None
|
46 |
k = 3 # Number of documents to retrieve per query
|
47 |
|
|
|
|
|
|
|
48 |
# ----------------------- Startup Event -----------------------
|
49 |
@app.on_event("startup")
|
50 |
async def startup_event():
|
51 |
+
global llm, embeddings, chat_manager, document_loader, text_splitter, vector_store_manager, vector_store
|
52 |
|
53 |
print("Starting up: Initializing components...")
|
54 |
|
|
|
71 |
text_splitter = TextSplitter()
|
72 |
print("Document loader and text splitter initialized.")
|
73 |
|
74 |
+
# Initialize vector store manager and set vector store
|
75 |
vector_store_manager = VectorStoreManager(embeddings)
|
76 |
+
vector_store = vector_store_manager.vectorstore
|
77 |
print("Vector store initialized.")
|
78 |
|
79 |
+
# ----------------------- New Chat Endpoint (Updated) -----------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
@app.post("/new_chat")
|
81 |
+
def new_chat(current_user: dict = Depends(get_current_user)):
|
82 |
"""
|
83 |
+
Create a new chat session under the current user's document.
|
84 |
"""
|
85 |
+
new_chat_id = str(uuid.uuid4())
|
86 |
+
# Append a new chat session to the user's chat_histories
|
87 |
+
users_collection.update_one(
|
88 |
+
{"email": current_user["email"]},
|
89 |
+
{"$push": {"chat_histories": {"chat_id": new_chat_id, "created_at": datetime.utcnow(), "messages": []}}}
|
90 |
+
)
|
91 |
return {"chat_id": new_chat_id}
|
92 |
|
93 |
+
# ----------------------- Create Chain Endpoint (Updated) -----------------------
|
|
|
94 |
@app.post("/create_chain")
|
95 |
def create_chain(
|
96 |
chat_id: str = Query(..., description="Existing chat session ID"),
|
|
|
98 |
"quiz_solving",
|
99 |
description="Select prompt template. Options: quiz_solving, assignment_solving, paper_solving, quiz_creation, assignment_creation, paper_creation",
|
100 |
),
|
101 |
+
current_user: dict = Depends(get_current_user)
|
102 |
):
|
|
|
|
|
103 |
valid_templates = [
|
104 |
"quiz_solving",
|
105 |
"assignment_solving",
|
|
|
111 |
if template not in valid_templates:
|
112 |
raise HTTPException(status_code=400, detail="Invalid template selection.")
|
113 |
|
114 |
+
# Update the specific chat session's configuration in the user's document
|
115 |
+
users_collection.update_one(
|
116 |
+
{"email": current_user["email"], "chat_histories.chat_id": chat_id},
|
117 |
+
{"$set": {"chat_histories.$.template": template}}
|
118 |
)
|
119 |
|
120 |
return {"message": "Retrieval chain configuration stored successfully.", "chat_id": chat_id, "template": template}
|
121 |
|
|
|
122 |
# ----------------------- Chat Endpoint -----------------------
|
123 |
@app.get("/chat")
|
124 |
+
def chat(
|
125 |
+
query: str,
|
126 |
+
chat_id: str = Query(..., description="Chat session ID"),
|
127 |
+
current_user: dict = Depends(get_current_user)
|
128 |
+
):
|
129 |
"""
|
130 |
Process a chat query using the retrieval chain associated with the given chat_id.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
"""
|
132 |
+
# Retrieve chat configuration from the user's document
|
133 |
+
user = current_user
|
134 |
+
chat_config = None
|
135 |
+
for chat in user.get("chat_histories", []):
|
136 |
+
if chat.get("chat_id") == chat_id:
|
137 |
+
chat_config = chat
|
138 |
+
break
|
139 |
+
if not chat_config:
|
140 |
raise HTTPException(status_code=400, detail="Chat configuration not found. Please create a chain using /create_chain.")
|
141 |
|
142 |
+
template = chat_config.get("template", "quiz_solving")
|
143 |
if template == "quiz_solving":
|
144 |
prompt = PromptTemplates.get_quiz_solving_prompt()
|
145 |
elif template == "assignment_solving":
|
|
|
155 |
else:
|
156 |
raise HTTPException(status_code=400, detail="Invalid chat configuration.")
|
157 |
|
|
|
158 |
retrieval_chain = RetrievalChain(
|
159 |
llm,
|
160 |
vector_store.as_retriever(search_kwargs={"k": k}),
|
|
|
172 |
except Exception as e:
|
173 |
raise HTTPException(status_code=500, detail=f"Error processing chat query: {str(e)}")
|
174 |
|
175 |
+
# Optionally update the user's chat_histories with the new messages here
|
|
|
176 |
|
177 |
+
return StreamingResponse(stream_generator, media_type="text/event-stream")
|
|
|
178 |
|
179 |
+
# ----------------------- Remaining Endpoints -----------------------
|
180 |
+
# (The endpoints for adding, deleting, saving, loading, and merging documents remain unchanged.)
|
181 |
@app.post("/add_document")
|
182 |
async def add_document(
|
183 |
+
file: Optional[any] = File(None),
|
184 |
wiki_query: Optional[str] = Query(None),
|
185 |
wiki_url: Optional[str] = Query(None)
|
186 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
if not isinstance(file, UploadFile):
|
188 |
file = None
|
189 |
|
|
|
190 |
if file is None and wiki_query is None and wiki_url is None:
|
191 |
raise HTTPException(status_code=400, detail="No document input provided (file, wiki_query, or wiki_url).")
|
192 |
|
|
|
193 |
if file is not None:
|
194 |
with tempfile.NamedTemporaryFile(delete=False) as tmp:
|
195 |
contents = await file.read()
|
|
|
237 |
|
238 |
return {"message": f"Added {len(chunks)} document chunks.", "ids": ids}
|
239 |
|
|
|
|
|
240 |
@app.post("/delete_document")
|
241 |
def delete_document(ids: List[str]):
|
|
|
|
|
|
|
242 |
try:
|
243 |
success = vector_store_manager.delete_documents(ids)
|
244 |
except Exception as e:
|
|
|
247 |
raise HTTPException(status_code=400, detail="Failed to delete documents.")
|
248 |
return {"message": f"Deleted documents with IDs: {ids}"}
|
249 |
|
|
|
|
|
250 |
@app.get("/save_vectorstore")
|
251 |
def save_vectorstore():
|
|
|
|
|
|
|
|
|
|
|
252 |
try:
|
253 |
save_result = vector_store_manager.save("faiss_index")
|
254 |
except Exception as e:
|
|
|
259 |
filename=save_result["serve_filename"],
|
260 |
)
|
261 |
|
|
|
|
|
262 |
@app.post("/load_vectorstore")
|
263 |
async def load_vectorstore(file: UploadFile = File(...)):
|
|
|
|
|
|
|
|
|
264 |
tmp_filename = None
|
265 |
try:
|
|
|
266 |
with tempfile.NamedTemporaryFile(delete=False) as tmp:
|
267 |
+
file_bytes = await file.read()
|
268 |
tmp.write(file_bytes)
|
269 |
tmp_filename = tmp.name
|
270 |
|
|
|
278 |
vector_store_manager = instance
|
279 |
return {"message": message}
|
280 |
|
|
|
|
|
281 |
@app.post("/merge_vectorstore")
|
282 |
async def merge_vectorstore(file: UploadFile = File(...)):
|
|
|
|
|
|
|
283 |
tmp_filename = None
|
284 |
try:
|
|
|
285 |
with tempfile.NamedTemporaryFile(delete=False) as tmp:
|
286 |
+
file_bytes = await file.read()
|
287 |
tmp.write(file_bytes)
|
288 |
tmp_filename = tmp.name
|
289 |
|
|
|
290 |
result = vector_store_manager.merge(tmp_filename, embeddings)
|
291 |
except Exception as e:
|
292 |
raise HTTPException(status_code=500, detail=f"Error merging vectorstore: {str(e)}")
|
|
|
295 |
os.remove(tmp_filename)
|
296 |
return result
|
297 |
|
|
|
298 |
if __name__ == "__main__":
|
299 |
import uvicorn
|
300 |
uvicorn.run(app, host="0.0.0.0", port=8000)
|