File size: 6,354 Bytes
1a20a59
 
 
 
646f8c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a20a59
646f8c2
 
 
 
 
1a20a59
646f8c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a20a59
 
646f8c2
1a20a59
 
646f8c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a20a59
646f8c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a20a59
 
646f8c2
 
 
 
 
 
 
 
1a20a59
 
 
 
646f8c2
 
 
 
 
 
1a20a59
646f8c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
"""
Main RAG chain based on langchain.
"""

from langchain.chains import LLMChain

from langchain.prompts import (
    SystemMessagePromptTemplate,
    HumanMessagePromptTemplate,
    ChatPromptTemplate,
    PromptTemplate,
)

from langchain.chains import ConversationalRetrievalChain
from langchain.chains.conversation.memory import (
    ConversationBufferWindowMemory,
)
from langchain.chains import StuffDocumentsChain


def get_cite_combine_docs_chain(llm):
    """Get doc chain which adds metadata to text chunks."""

    # Ref: https://github.com/langchain-ai/langchain/issues/7239
    # Function to format each document with an index, source, and content.
    def format_document(doc, index, prompt):
        """Format a document into a string based on a prompt template."""

        # Create a dictionary with document content and metadata.
        base_info = {
            "page_content": doc.page_content,
            "index": index,
            "source": doc.metadata["source"],
        }

        # Check if any metadata is missing.
        missing_metadata = set(prompt.input_variables).difference(base_info)
        if len(missing_metadata) > 0:
            raise ValueError(f"Missing metadata: {list(missing_metadata)}.")

        # Filter only necessary variables for the prompt.
        document_info = {k: base_info[k] for k in prompt.input_variables}
        return prompt.format(**document_info)

    # Custom chain class to handle document combination with source indices.
    class StuffDocumentsWithIndexChain(StuffDocumentsChain):
        """Custom chain class to handle document combination with source indices."""

        def _get_inputs(self, docs, **kwargs):
            """Overwrite _get_inputs to add metadata for text chunks."""

            # Format each document and combine them.
            doc_strings = [
                format_document(doc, i, self.document_prompt)
                for i, doc in enumerate(docs, 1)
            ]

            # Filter only relevant input variables for the LLM chain prompt.
            inputs = {
                k: v
                for k, v in kwargs.items()
                if k in self.llm_chain.prompt.input_variables
            }
            inputs[self.document_variable_name] = self.document_separator.join(
                doc_strings
            )
            return inputs

    # Main prompt for RAG chain with citation
    # Ref: https://huggingface.co./spaces/Ekimetrics/climate-question-answering/blob/main/climateqa/engine/prompts.py
    # Define a chat prompt with instructions for citing documents.
    combine_doc_prompt = PromptTemplate(
        input_variables=["context", "question"],
        template="""You are given a question and passages. Provide a clear and structured Helpful Answer based on the passages provided, 
        the context and the guidelines.
        
        Guidelines:
        - If the passages have useful facts or numbers, use them in your answer.
        - When you use information from a passage, mention where it came from by using format [[i]] at the end of the sentence. i stands for the paper index of the document.
        - Do not cite the passage in a style like 'passage i', always use format [[i]] where i stands for the passage index of the document.
        - Do not use the sentence such as 'Doc i says ...' or '... in Doc i' or 'Passage i ...'  to say where information came from.
        - If the same thing is said in more than one document, you can mention all of them like this: [[i]], [[j]], [[k]].
        - Do not just summarize each passage one by one. Group your summaries to highlight the key parts in the explanation.
        - If it makes sense, use bullet points and lists to make your answers easier to understand.
        - You do not need to use every passage. Only use the ones that help answer the question.
        - If the documents do not have the information needed to answer the question, just say you do not have enough information.
        - If the passage is the caption of a picture, you can still use it as part of your answer as any other document.

        -----------------------
        Passages:
        {context}
        -----------------------
        Question: {question}

        Helpful Answer with format citations:""",
    )

    # Initialize the custom chain with a specific document format.
    combine_docs_chain = StuffDocumentsWithIndexChain(
        llm_chain=LLMChain(
            llm=llm,
            prompt=combine_doc_prompt,
        ),
        document_prompt=PromptTemplate(
            input_variables=["index", "source", "page_content"],
            template="[[{index}]]\nsource: {source}:\n{page_content}",
        ),
        document_variable_name="context",
    )

    return combine_docs_chain


class RAGChain:
    """Main RAG chain."""

    def __init__(
        self, memory_key="chat_history", output_key="answer", return_messages=True
    ):
        self.memory_key = memory_key
        self.output_key = output_key
        self.return_messages = return_messages

    def create(self, retriever, llm, add_citation=False):
        """Create a rag chain instance."""

        # Memory is kept for later support of conversational chat
        memory = ConversationBufferWindowMemory(  # Or ConversationBufferMemory
            k=2,
            memory_key=self.memory_key,
            return_messages=self.return_messages,
            output_key=self.output_key,
        )

        # Ref: https://github.com/langchain-ai/langchain/issues/4608
        conversation_chain = ConversationalRetrievalChain.from_llm(
            llm=llm,
            retriever=retriever,
            memory=memory,
            return_source_documents=True,
            rephrase_question=False,  # disable rephrase, for test purpose
            get_chat_history=lambda x: x,
            # return_generated_question=True,  # for debug
            # combine_docs_chain_kwargs={"prompt": PROMPT},  # additional prompt control
            # condense_question_prompt=CONDENSE_QUESTION_PROMPT,  # additional prompt control
        )

        # Add citation, ATTENTION: experimental
        if add_citation:
            cite_combine_docs_chain = get_cite_combine_docs_chain(llm)
            conversation_chain.combine_docs_chain = cite_combine_docs_chain

        return conversation_chain