Spaces:
Running
Running
File size: 5,650 Bytes
982bc60 311764c 4cf11fc 0c1268a 982bc60 4876c75 0240527 982bc60 45660fc 57c8644 982bc60 311764c 0240527 311764c 0240527 982bc60 324eee5 311764c 324eee5 982bc60 4da099c 982bc60 324eee5 4cf11fc 324eee5 4cf11fc 982bc60 4876c75 982bc60 324eee5 57c8644 4876c75 982bc60 311764c 4876c75 311764c 4876c75 311764c 4876c75 311764c 324eee5 4876c75 324eee5 4cf11fc 324eee5 4cf11fc 982bc60 324eee5 982bc60 57c8644 982bc60 4cf11fc 6b94b38 311764c 4cf11fc 982bc60 324eee5 982bc60 57c8644 324eee5 4cf11fc 324eee5 4cf11fc 982bc60 324eee5 982bc60 57c8644 982bc60 311764c 4cf11fc 311764c 4876c75 4cf11fc 982bc60 4a4df8f 982bc60 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import os
import shutil
import tempfile
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query
from fastapi.encoders import jsonable_encoder
from bson import ObjectId
from models import InitializeBotResponse, NewChatResponse, QueryRequest, QueryResponse
from trainer_manager import get_trainer
from config import CUSTOM_PROMPT
from prompt_templates import PromptTemplates
router = APIRouter()
trainer = get_trainer()
@router.post("/initialize_bot", response_model=InitializeBotResponse)
def initialize_bot(prompt_type: str = Query(None)):
"""
Initializes a new bot and returns its bot_id.
Accepts an optional 'prompt_type' query parameter (provided by the frontend).
"""
try:
bot_id = trainer.initialize_bot_id()
# Optionally, you might want to store the prompt_type with the bot record.
return InitializeBotResponse(bot_id=bot_id)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/upload_document")
async def upload_document(bot_id: str = Form(...), file: UploadFile = File(...)):
"""
Saves the uploaded file temporarily and adds it to the specified bot's knowledge base.
"""
try:
# Save the file to a temporary location.
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as tmp:
contents = await file.read()
tmp.write(contents)
tmp_path = tmp.name
# Add the document using the temporary file path to the specified bot.
trainer.add_document_from_path(tmp_path, bot_id)
# Remove the temporary file.
os.remove(tmp_path)
return {"message": "Document uploaded and added successfully."}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/create_bot/{bot_id}")
def create_bot(bot_id: str, prompt_type: str = Query(None)):
"""
Finalizes the creation (build) of the bot identified by bot_id.
Uses the provided (or default) prompt_type to determine the custom prompt template.
If no prompt_type is provided, it defaults to "quiz_solving".
"""
try:
if prompt_type is None:
prompt_type = "quiz_solving"
# Determine the appropriate prompt template.
if prompt_type == "university":
prompt_template = PromptTemplates.get_university_chatbot_prompt()
elif prompt_type == "quiz_solving":
prompt_template = PromptTemplates.get_quiz_solving_prompt()
elif prompt_type == "assignment_solving":
prompt_template = PromptTemplates.get_assignment_solving_prompt()
elif prompt_type == "paper_solving":
prompt_template = PromptTemplates.get_paper_solving_prompt()
elif prompt_type == "quiz_creation":
prompt_template = PromptTemplates.get_quiz_creation_prompt()
elif prompt_type == "assignment_creation":
prompt_template = PromptTemplates.get_assignment_creation_prompt()
elif prompt_type == "paper_creation":
prompt_template = PromptTemplates.get_paper_creation_prompt()
elif prompt_type == "check_quiz":
prompt_template = PromptTemplates.get_check_quiz_prompt()
elif prompt_type == "check_assignment":
prompt_template = PromptTemplates.get_check_assignment_prompt()
elif prompt_type == "check_paper":
prompt_template = PromptTemplates.get_check_paper_prompt()
else:
prompt_template = PromptTemplates.get_quiz_solving_prompt()
# Create (build) the bot using the specified bot_id and prompt template.
trainer.create_bot(bot_id, prompt_template)
return {"message": f"Bot {bot_id} created successfully."}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/new_chat/{bot_id}", response_model=NewChatResponse)
def new_chat(bot_id: str):
"""
Creates a new chat session for the specified bot.
"""
try:
chat_id = trainer.new_chat(bot_id)
return NewChatResponse(chat_id=chat_id)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/query", response_model=QueryResponse)
def send_query(query_request: QueryRequest):
"""
Processes a query and returns the bot's response along with any web sources.
The request must include bot_id, chat_id, and the query text.
"""
try:
response, web_sources = trainer.get_response(
query_request.query, query_request.bot_id, query_request.chat_id
)
return QueryResponse(response=response, web_sources=web_sources)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/list_chats/{bot_id}")
def list_chats(bot_id: str):
"""
Returns a list of previous chat sessions for the specified bot.
"""
try:
chats = trainer.list_chats(bot_id)
return chats
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/chat_history/{chat_id}")
def chat_history(chat_id: str, bot_id: str = Query(None)):
"""
Returns the chat history for a given chat session.
The bot_id can be provided as a query parameter (if needed).
ObjectId instances in the history are converted to strings.
"""
try:
history = trainer.get_chat_by_id(chat_id=chat_id)
return jsonable_encoder(history, custom_encoder={ObjectId: str})
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
|