mominah commited on
Commit
b16a6fa
·
verified ·
1 Parent(s): 0f75bc4

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +48 -103
main.py CHANGED
@@ -2,8 +2,10 @@ import os
2
  import tempfile
3
  import zipfile
4
  from typing import List, Optional
 
 
5
 
6
- from fastapi import FastAPI, File, UploadFile, HTTPException, Query
7
  from fastapi.responses import FileResponse, StreamingResponse
8
 
9
  from llm_initialization import get_llm
@@ -27,6 +29,12 @@ MONGO_CLUSTER_URL = os.getenv("CONNECTION_STRING")
27
 
28
  app = FastAPI(title="VectorStore & Document Management API")
29
 
 
 
 
 
 
 
30
  # Global variables (initialized on startup)
31
  llm = None
32
  embeddings = None
@@ -37,13 +45,10 @@ vector_store_manager = None
37
  vector_store = None
38
  k = 3 # Number of documents to retrieve per query
39
 
40
- # Global MongoDB collection to store retrieval chain configuration per chat session.
41
- chat_chains_collection = None
42
-
43
  # ----------------------- Startup Event -----------------------
44
  @app.on_event("startup")
45
  async def startup_event():
46
- global llm, embeddings, chat_manager, document_loader, text_splitter, vector_store_manager, vector_store, chat_chains_collection
47
 
48
  print("Starting up: Initializing components...")
49
 
@@ -66,38 +71,26 @@ async def startup_event():
66
  text_splitter = TextSplitter()
67
  print("Document loader and text splitter initialized.")
68
 
69
- # Initialize vector store manager and ensure vectorstore is set
70
  vector_store_manager = VectorStoreManager(embeddings)
71
- vector_store = vector_store_manager.vectorstore # Now properly initialized
72
  print("Vector store initialized.")
73
 
74
- # Connect to MongoDB and get the collection.
75
- client = MongoClient(MONGO_CLUSTER_URL)
76
- db = client[MONGO_DATABASE_NAME]
77
- chat_chains_collection = db["chat_chains"]
78
- print("Chat chains collection initialized in MongoDB.")
79
-
80
-
81
- # ----------------------- Root Endpoint -----------------------
82
- @app.get("/")
83
- def root():
84
- """
85
- Root endpoint that returns a welcome message.
86
- """
87
- return {"message": "Welcome to the VectorStore & Document Management API!"}
88
-
89
-
90
- # ----------------------- New Chat Endpoint -----------------------
91
  @app.post("/new_chat")
92
- def new_chat():
93
  """
94
- Create a new chat session.
95
  """
96
- new_chat_id = chat_manager.create_new_chat()
 
 
 
 
 
97
  return {"chat_id": new_chat_id}
98
 
99
-
100
- # ----------------------- Create Chain Endpoint -----------------------
101
  @app.post("/create_chain")
102
  def create_chain(
103
  chat_id: str = Query(..., description="Existing chat session ID"),
@@ -105,9 +98,8 @@ def create_chain(
105
  "quiz_solving",
106
  description="Select prompt template. Options: quiz_solving, assignment_solving, paper_solving, quiz_creation, assignment_creation, paper_creation",
107
  ),
 
108
  ):
109
- global chat_chains_collection # Ensure we reference the global variable
110
-
111
  valid_templates = [
112
  "quiz_solving",
113
  "assignment_solving",
@@ -119,42 +111,35 @@ def create_chain(
119
  if template not in valid_templates:
120
  raise HTTPException(status_code=400, detail="Invalid template selection.")
121
 
122
- # Upsert the configuration document for this chat session.
123
- chat_chains_collection.update_one(
124
- {"chat_id": chat_id}, {"$set": {"template": template}}, upsert=True
 
125
  )
126
 
127
  return {"message": "Retrieval chain configuration stored successfully.", "chat_id": chat_id, "template": template}
128
 
129
-
130
  # ----------------------- Chat Endpoint -----------------------
131
  @app.get("/chat")
132
- def chat(query: str, chat_id: str = Query(..., description="Chat session ID created via /new_chat and configured via /create_chain")):
 
 
 
 
133
  """
134
  Process a chat query using the retrieval chain associated with the given chat_id.
135
-
136
- This endpoint uses the following code:
137
-
138
- try:
139
- stream_generator = retrieval_chain.stream_chat_response(
140
- query=query,
141
- chat_id=chat_id,
142
- get_chat_history=chat_manager.get_chat_history,
143
- initialize_chat_history=chat_manager.initialize_chat_history,
144
- )
145
- except Exception as e:
146
- raise HTTPException(status_code=500, detail=f"Error processing chat query: {str(e)}")
147
-
148
- return StreamingResponse(stream_generator, media_type="text/event-stream")
149
-
150
- It first retrieves the configuration from MongoDB, re-creates the chain, and then streams the response.
151
  """
152
- # Retrieve the chat configuration from MongoDB.
153
- config = chat_chains_collection.find_one({"chat_id": chat_id})
154
- if not config:
 
 
 
 
 
155
  raise HTTPException(status_code=400, detail="Chat configuration not found. Please create a chain using /create_chain.")
156
 
157
- template = config.get("template", "quiz_solving")
158
  if template == "quiz_solving":
159
  prompt = PromptTemplates.get_quiz_solving_prompt()
160
  elif template == "assignment_solving":
@@ -170,7 +155,6 @@ def chat(query: str, chat_id: str = Query(..., description="Chat session ID crea
170
  else:
171
  raise HTTPException(status_code=400, detail="Invalid chat configuration.")
172
 
173
- # Re-create the retrieval chain for this chat session.
174
  retrieval_chain = RetrievalChain(
175
  llm,
176
  vector_store.as_retriever(search_kwargs={"k": k}),
@@ -188,36 +172,24 @@ def chat(query: str, chat_id: str = Query(..., description="Chat session ID crea
188
  except Exception as e:
189
  raise HTTPException(status_code=500, detail=f"Error processing chat query: {str(e)}")
190
 
191
- return StreamingResponse(stream_generator, media_type="text/event-stream")
192
-
193
 
194
- # ----------------------- Add Document Endpoint -----------------------
195
- from typing import Any, Optional
196
 
 
 
197
  @app.post("/add_document")
198
  async def add_document(
199
- file: Optional[Any] = File(None),
200
  wiki_query: Optional[str] = Query(None),
201
  wiki_url: Optional[str] = Query(None)
202
  ):
203
- """
204
- Upload a document OR load data from a Wikipedia query or URL.
205
-
206
- - If a file is provided, the document is loaded from the file.
207
- - If 'wiki_query' is provided, the Wikipedia page(s) are loaded using document_loader.wikipedia_query.
208
- - If 'wiki_url' is provided, the URL is loaded using document_loader.load_urls.
209
-
210
- The loaded document(s) are then split into chunks and added to the vector store.
211
- """
212
- # If file is provided but not as an UploadFile (e.g. an empty string), set it to None.
213
  if not isinstance(file, UploadFile):
214
  file = None
215
 
216
- # Ensure at least one input is provided.
217
  if file is None and wiki_query is None and wiki_url is None:
218
  raise HTTPException(status_code=400, detail="No document input provided (file, wiki_query, or wiki_url).")
219
 
220
- # Load document(s) based on input priority: file > wiki_query > wiki_url.
221
  if file is not None:
222
  with tempfile.NamedTemporaryFile(delete=False) as tmp:
223
  contents = await file.read()
@@ -265,13 +237,8 @@ async def add_document(
265
 
266
  return {"message": f"Added {len(chunks)} document chunks.", "ids": ids}
267
 
268
-
269
- # ----------------------- Delete Document Endpoint -----------------------
270
  @app.post("/delete_document")
271
  def delete_document(ids: List[str]):
272
- """
273
- Delete document(s) from the vector store using their IDs.
274
- """
275
  try:
276
  success = vector_store_manager.delete_documents(ids)
277
  except Exception as e:
@@ -280,15 +247,8 @@ def delete_document(ids: List[str]):
280
  raise HTTPException(status_code=400, detail="Failed to delete documents.")
281
  return {"message": f"Deleted documents with IDs: {ids}"}
282
 
283
-
284
- # ----------------------- Save Vectorstore Endpoint -----------------------
285
  @app.get("/save_vectorstore")
286
  def save_vectorstore():
287
- """
288
- Save the current vector store locally.
289
- If it is a directory, it will be zipped.
290
- Returns the file as a downloadable response.
291
- """
292
  try:
293
  save_result = vector_store_manager.save("faiss_index")
294
  except Exception as e:
@@ -299,19 +259,12 @@ def save_vectorstore():
299
  filename=save_result["serve_filename"],
300
  )
301
 
302
-
303
- # ----------------------- Load Vectorstore Endpoint -----------------------
304
  @app.post("/load_vectorstore")
305
  async def load_vectorstore(file: UploadFile = File(...)):
306
- """
307
- Load a vector store from an uploaded file (raw or zipped).
308
- This will replace the current vector store.
309
- """
310
  tmp_filename = None
311
  try:
312
- # Save the uploaded file content to a temporary file.
313
  with tempfile.NamedTemporaryFile(delete=False) as tmp:
314
- file_bytes = await file.read() # await to get bytes
315
  tmp.write(file_bytes)
316
  tmp_filename = tmp.name
317
 
@@ -325,22 +278,15 @@ async def load_vectorstore(file: UploadFile = File(...)):
325
  vector_store_manager = instance
326
  return {"message": message}
327
 
328
-
329
- # ----------------------- Merge Vectorstore Endpoint -----------------------
330
  @app.post("/merge_vectorstore")
331
  async def merge_vectorstore(file: UploadFile = File(...)):
332
- """
333
- Merge an uploaded vector store (raw or zipped) into the current vector store.
334
- """
335
  tmp_filename = None
336
  try:
337
- # Save the uploaded file content to a temporary file.
338
  with tempfile.NamedTemporaryFile(delete=False) as tmp:
339
- file_bytes = await file.read() # Await the file.read() coroutine!
340
  tmp.write(file_bytes)
341
  tmp_filename = tmp.name
342
 
343
- # Pass the filename (a string) to the merge method.
344
  result = vector_store_manager.merge(tmp_filename, embeddings)
345
  except Exception as e:
346
  raise HTTPException(status_code=500, detail=f"Error merging vectorstore: {str(e)}")
@@ -349,7 +295,6 @@ async def merge_vectorstore(file: UploadFile = File(...)):
349
  os.remove(tmp_filename)
350
  return result
351
 
352
-
353
  if __name__ == "__main__":
354
  import uvicorn
355
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
2
  import tempfile
3
  import zipfile
4
  from typing import List, Optional
5
+ import uuid
6
+ from datetime import datetime
7
 
8
+ from fastapi import FastAPI, File, UploadFile, HTTPException, Query, Depends
9
  from fastapi.responses import FileResponse, StreamingResponse
10
 
11
  from llm_initialization import get_llm
 
29
 
30
  app = FastAPI(title="VectorStore & Document Management API")
31
 
32
+ # Import auth router and dependencies
33
+ from auth import router as auth_router, get_current_user, users_collection
34
+
35
+ # Mount auth endpoints under /auth
36
+ app.include_router(auth_router, prefix="/auth")
37
+
38
  # Global variables (initialized on startup)
39
  llm = None
40
  embeddings = None
 
45
  vector_store = None
46
  k = 3 # Number of documents to retrieve per query
47
 
 
 
 
48
  # ----------------------- Startup Event -----------------------
49
  @app.on_event("startup")
50
  async def startup_event():
51
+ global llm, embeddings, chat_manager, document_loader, text_splitter, vector_store_manager, vector_store
52
 
53
  print("Starting up: Initializing components...")
54
 
 
71
  text_splitter = TextSplitter()
72
  print("Document loader and text splitter initialized.")
73
 
74
+ # Initialize vector store manager and set vector store
75
  vector_store_manager = VectorStoreManager(embeddings)
76
+ vector_store = vector_store_manager.vectorstore
77
  print("Vector store initialized.")
78
 
79
+ # ----------------------- New Chat Endpoint (Updated) -----------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  @app.post("/new_chat")
81
+ def new_chat(current_user: dict = Depends(get_current_user)):
82
  """
83
+ Create a new chat session under the current user's document.
84
  """
85
+ new_chat_id = str(uuid.uuid4())
86
+ # Append a new chat session to the user's chat_histories
87
+ users_collection.update_one(
88
+ {"email": current_user["email"]},
89
+ {"$push": {"chat_histories": {"chat_id": new_chat_id, "created_at": datetime.utcnow(), "messages": []}}}
90
+ )
91
  return {"chat_id": new_chat_id}
92
 
93
+ # ----------------------- Create Chain Endpoint (Updated) -----------------------
 
94
  @app.post("/create_chain")
95
  def create_chain(
96
  chat_id: str = Query(..., description="Existing chat session ID"),
 
98
  "quiz_solving",
99
  description="Select prompt template. Options: quiz_solving, assignment_solving, paper_solving, quiz_creation, assignment_creation, paper_creation",
100
  ),
101
+ current_user: dict = Depends(get_current_user)
102
  ):
 
 
103
  valid_templates = [
104
  "quiz_solving",
105
  "assignment_solving",
 
111
  if template not in valid_templates:
112
  raise HTTPException(status_code=400, detail="Invalid template selection.")
113
 
114
+ # Update the specific chat session's configuration in the user's document
115
+ users_collection.update_one(
116
+ {"email": current_user["email"], "chat_histories.chat_id": chat_id},
117
+ {"$set": {"chat_histories.$.template": template}}
118
  )
119
 
120
  return {"message": "Retrieval chain configuration stored successfully.", "chat_id": chat_id, "template": template}
121
 
 
122
  # ----------------------- Chat Endpoint -----------------------
123
  @app.get("/chat")
124
+ def chat(
125
+ query: str,
126
+ chat_id: str = Query(..., description="Chat session ID"),
127
+ current_user: dict = Depends(get_current_user)
128
+ ):
129
  """
130
  Process a chat query using the retrieval chain associated with the given chat_id.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  """
132
+ # Retrieve chat configuration from the user's document
133
+ user = current_user
134
+ chat_config = None
135
+ for chat in user.get("chat_histories", []):
136
+ if chat.get("chat_id") == chat_id:
137
+ chat_config = chat
138
+ break
139
+ if not chat_config:
140
  raise HTTPException(status_code=400, detail="Chat configuration not found. Please create a chain using /create_chain.")
141
 
142
+ template = chat_config.get("template", "quiz_solving")
143
  if template == "quiz_solving":
144
  prompt = PromptTemplates.get_quiz_solving_prompt()
145
  elif template == "assignment_solving":
 
155
  else:
156
  raise HTTPException(status_code=400, detail="Invalid chat configuration.")
157
 
 
158
  retrieval_chain = RetrievalChain(
159
  llm,
160
  vector_store.as_retriever(search_kwargs={"k": k}),
 
172
  except Exception as e:
173
  raise HTTPException(status_code=500, detail=f"Error processing chat query: {str(e)}")
174
 
175
+ # Optionally update the user's chat_histories with the new messages here
 
176
 
177
+ return StreamingResponse(stream_generator, media_type="text/event-stream")
 
178
 
179
+ # ----------------------- Remaining Endpoints -----------------------
180
+ # (The endpoints for adding, deleting, saving, loading, and merging documents remain unchanged.)
181
  @app.post("/add_document")
182
  async def add_document(
183
+ file: Optional[any] = File(None),
184
  wiki_query: Optional[str] = Query(None),
185
  wiki_url: Optional[str] = Query(None)
186
  ):
 
 
 
 
 
 
 
 
 
 
187
  if not isinstance(file, UploadFile):
188
  file = None
189
 
 
190
  if file is None and wiki_query is None and wiki_url is None:
191
  raise HTTPException(status_code=400, detail="No document input provided (file, wiki_query, or wiki_url).")
192
 
 
193
  if file is not None:
194
  with tempfile.NamedTemporaryFile(delete=False) as tmp:
195
  contents = await file.read()
 
237
 
238
  return {"message": f"Added {len(chunks)} document chunks.", "ids": ids}
239
 
 
 
240
  @app.post("/delete_document")
241
  def delete_document(ids: List[str]):
 
 
 
242
  try:
243
  success = vector_store_manager.delete_documents(ids)
244
  except Exception as e:
 
247
  raise HTTPException(status_code=400, detail="Failed to delete documents.")
248
  return {"message": f"Deleted documents with IDs: {ids}"}
249
 
 
 
250
  @app.get("/save_vectorstore")
251
  def save_vectorstore():
 
 
 
 
 
252
  try:
253
  save_result = vector_store_manager.save("faiss_index")
254
  except Exception as e:
 
259
  filename=save_result["serve_filename"],
260
  )
261
 
 
 
262
  @app.post("/load_vectorstore")
263
  async def load_vectorstore(file: UploadFile = File(...)):
 
 
 
 
264
  tmp_filename = None
265
  try:
 
266
  with tempfile.NamedTemporaryFile(delete=False) as tmp:
267
+ file_bytes = await file.read()
268
  tmp.write(file_bytes)
269
  tmp_filename = tmp.name
270
 
 
278
  vector_store_manager = instance
279
  return {"message": message}
280
 
 
 
281
  @app.post("/merge_vectorstore")
282
  async def merge_vectorstore(file: UploadFile = File(...)):
 
 
 
283
  tmp_filename = None
284
  try:
 
285
  with tempfile.NamedTemporaryFile(delete=False) as tmp:
286
+ file_bytes = await file.read()
287
  tmp.write(file_bytes)
288
  tmp_filename = tmp.name
289
 
 
290
  result = vector_store_manager.merge(tmp_filename, embeddings)
291
  except Exception as e:
292
  raise HTTPException(status_code=500, detail=f"Error merging vectorstore: {str(e)}")
 
295
  os.remove(tmp_filename)
296
  return result
297
 
 
298
  if __name__ == "__main__":
299
  import uvicorn
300
  uvicorn.run(app, host="0.0.0.0", port=8000)