File size: 5,650 Bytes
982bc60
 
 
311764c
4cf11fc
 
0c1268a
982bc60
4876c75
0240527
982bc60
 
45660fc
57c8644
982bc60
311764c
0240527
311764c
 
0240527
982bc60
324eee5
311764c
324eee5
982bc60
 
 
4da099c
982bc60
324eee5
4cf11fc
324eee5
4cf11fc
982bc60
4876c75
982bc60
 
 
 
 
324eee5
 
57c8644
4876c75
 
982bc60
 
 
 
 
311764c
 
4876c75
311764c
 
 
4876c75
 
311764c
 
 
 
4876c75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311764c
 
324eee5
 
4876c75
 
 
 
324eee5
 
4cf11fc
324eee5
4cf11fc
982bc60
324eee5
982bc60
 
 
 
57c8644
982bc60
 
4cf11fc
6b94b38
311764c
4cf11fc
982bc60
 
324eee5
982bc60
 
 
 
 
57c8644
324eee5
 
4cf11fc
324eee5
4cf11fc
982bc60
324eee5
982bc60
 
 
 
57c8644
982bc60
311764c
4cf11fc
 
311764c
4876c75
4cf11fc
982bc60
 
4a4df8f
982bc60
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import os
import shutil
import tempfile
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query
from fastapi.encoders import jsonable_encoder
from bson import ObjectId
from models import InitializeBotResponse, NewChatResponse, QueryRequest, QueryResponse
from trainer_manager import get_trainer
from config import CUSTOM_PROMPT
from prompt_templates import PromptTemplates

router = APIRouter()
trainer = get_trainer()

@router.post("/initialize_bot", response_model=InitializeBotResponse)
def initialize_bot(prompt_type: str = Query(None)):
    """
    Initializes a new bot and returns its bot_id.
    Accepts an optional 'prompt_type' query parameter (provided by the frontend).
    """
    try:
        bot_id = trainer.initialize_bot_id()
        # Optionally, you might want to store the prompt_type with the bot record.
        return InitializeBotResponse(bot_id=bot_id)
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


@router.post("/upload_document")
async def upload_document(bot_id: str = Form(...), file: UploadFile = File(...)):
    """
    Saves the uploaded file temporarily and adds it to the specified bot's knowledge base.
    """
    try:
        # Save the file to a temporary location.
        with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as tmp:
            contents = await file.read()
            tmp.write(contents)
            tmp_path = tmp.name

        # Add the document using the temporary file path to the specified bot.
        trainer.add_document_from_path(tmp_path, bot_id)

        # Remove the temporary file.
        os.remove(tmp_path)
        return {"message": "Document uploaded and added successfully."}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


@router.post("/create_bot/{bot_id}")
def create_bot(bot_id: str, prompt_type: str = Query(None)):
    """
    Finalizes the creation (build) of the bot identified by bot_id.
    Uses the provided (or default) prompt_type to determine the custom prompt template.
    If no prompt_type is provided, it defaults to "quiz_solving".
    """
    try:
        if prompt_type is None:
            prompt_type = "quiz_solving"

        # Determine the appropriate prompt template.
        if prompt_type == "university":
            prompt_template = PromptTemplates.get_university_chatbot_prompt()
        elif prompt_type == "quiz_solving":
            prompt_template = PromptTemplates.get_quiz_solving_prompt()
        elif prompt_type == "assignment_solving":
            prompt_template = PromptTemplates.get_assignment_solving_prompt()
        elif prompt_type == "paper_solving":
            prompt_template = PromptTemplates.get_paper_solving_prompt()
        elif prompt_type == "quiz_creation":
            prompt_template = PromptTemplates.get_quiz_creation_prompt()
        elif prompt_type == "assignment_creation":
            prompt_template = PromptTemplates.get_assignment_creation_prompt()
        elif prompt_type == "paper_creation":
            prompt_template = PromptTemplates.get_paper_creation_prompt()
        elif prompt_type == "check_quiz":
            prompt_template = PromptTemplates.get_check_quiz_prompt()
        elif prompt_type == "check_assignment":
            prompt_template = PromptTemplates.get_check_assignment_prompt()
        elif prompt_type == "check_paper":
            prompt_template = PromptTemplates.get_check_paper_prompt()
        else:
            prompt_template = PromptTemplates.get_quiz_solving_prompt()

        # Create (build) the bot using the specified bot_id and prompt template.
        trainer.create_bot(bot_id, prompt_template)
        return {"message": f"Bot {bot_id} created successfully."}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


@router.post("/new_chat/{bot_id}", response_model=NewChatResponse)
def new_chat(bot_id: str):
    """
    Creates a new chat session for the specified bot.
    """
    try:
        chat_id = trainer.new_chat(bot_id)
        return NewChatResponse(chat_id=chat_id)
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


@router.post("/query", response_model=QueryResponse)
def send_query(query_request: QueryRequest):
    """
    Processes a query and returns the bot's response along with any web sources.
    The request must include bot_id, chat_id, and the query text.
    """
    try:
        response, web_sources = trainer.get_response(
            query_request.query, query_request.bot_id, query_request.chat_id
        )
        return QueryResponse(response=response, web_sources=web_sources)
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


@router.get("/list_chats/{bot_id}")
def list_chats(bot_id: str):
    """
    Returns a list of previous chat sessions for the specified bot.
    """
    try:
        chats = trainer.list_chats(bot_id)
        return chats
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


@router.get("/chat_history/{chat_id}")
def chat_history(chat_id: str, bot_id: str = Query(None)):
    """
    Returns the chat history for a given chat session.
    The bot_id can be provided as a query parameter (if needed).
    ObjectId instances in the history are converted to strings.
    """
    try:
        history = trainer.get_chat_by_id(chat_id=chat_id)
        return jsonable_encoder(history, custom_encoder={ObjectId: str})
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))