File size: 3,826 Bytes
e1cda2e
 
 
23d9a47
e1cda2e
23d9a47
 
e1cda2e
 
 
23d9a47
 
 
e1cda2e
e19d910
e1cda2e
429af6c
e1cda2e
 
 
429af6c
e1cda2e
 
 
 
429af6c
e1cda2e
 
 
 
429af6c
e1cda2e
 
 
23d9a47
e19d910
23d9a47
e1cda2e
 
 
 
23d9a47
 
 
 
 
 
 
 
 
 
 
 
e1cda2e
 
23d9a47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e19d910
 
 
 
 
eb93a33
e1cda2e
 
 
 
 
 
 
 
eb93a33
e1cda2e
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import (
    ChatPromptTemplate,
    MessagesPlaceholder,
)
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.runnables import RunnablePassthrough
from langchain_core.vectorstores import VectorStoreRetriever
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory
from processing.documents import format_documents
from caching.lfu import LFUCache

def _initialize_llm(model_name) -> ChatGoogleGenerativeAI:
    """
    Initializes the LLM instance.
    """
    llm = ChatGoogleGenerativeAI(model= model_name)
    return llm


class LLMService:
    def __init__(self, logger, system_prompt: str, web_retriever: VectorStoreRetriever,cache_capacity: int = 50, llm_model_name = "gemini-2.0-flash-thinking-exp-01-21"):
        self._conversational_rag_chain = None
        self._logger = logger
        self.system_prompt = system_prompt
        self._web_retriever = web_retriever
        self.llm = _initialize_llm(llm_model_name)

        self._initialize_conversational_rag_chain()

        ### Statefully manage chat history ###
        self.store = LFUCache(capacity=cache_capacity)

    def _initialize_conversational_rag_chain(self):
        """
        Initializes the conversational RAG chain.
        """
        ### Contextualize question ###
        contextualize_q_system_prompt = """Given a chat history and the latest user question \
        which might reference context in the chat history, formulate a standalone question \
        which can be understood without the chat history. Do NOT answer the question, \
        just reformulate it if needed and otherwise return it as is."""

        contextualize_q_prompt = ChatPromptTemplate.from_messages(
            [
                ("system", contextualize_q_system_prompt),
                MessagesPlaceholder("chat_history"),
                ("human", "{input}"),
            ]
        )


        history_aware_retriever = create_history_aware_retriever(
        self.llm, self._web_retriever, contextualize_q_prompt)

        qa_prompt = ChatPromptTemplate.from_messages(
            [
                ("system", self.system_prompt),
                MessagesPlaceholder("chat_history"),
                ("human", "{input}"),
            ]
        )

        question_answer_chain = create_stuff_documents_chain(self.llm, qa_prompt)
        rag_chain  = create_retrieval_chain(history_aware_retriever, question_answer_chain)
        
        self._conversational_rag_chain = RunnableWithMessageHistory(
            rag_chain,
            self._get_session_history,
            input_messages_key="input",
            history_messages_key="chat_history",
            output_messages_key="answer",
        )

    def _get_session_history(self, session_id: str) -> BaseChatMessageHistory:
        history = self.store.get(session_id)
        if history is None:
            history = ChatMessageHistory()
            self.store.put(session_id, history)
        return history

    def conversational_rag_chain(self):
        """
        Returns the initialized conversational RAG chain.

        Returns:
            The conversational RAG chain instance.
        """
        return self._conversational_rag_chain
    
    def get_llm(self) -> ChatGoogleGenerativeAI:
        """
        Returns the LLM instance.
        """

        if self.llm is None:
            raise Exception("llm is not initialized")

        return self.llm