File size: 6,354 Bytes
1a20a59 646f8c2 1a20a59 646f8c2 1a20a59 646f8c2 1a20a59 646f8c2 1a20a59 646f8c2 1a20a59 646f8c2 1a20a59 646f8c2 1a20a59 646f8c2 1a20a59 646f8c2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
"""
Main RAG chain based on langchain.
"""
from langchain.chains import LLMChain
from langchain.prompts import (
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
ChatPromptTemplate,
PromptTemplate,
)
from langchain.chains import ConversationalRetrievalChain
from langchain.chains.conversation.memory import (
ConversationBufferWindowMemory,
)
from langchain.chains import StuffDocumentsChain
def get_cite_combine_docs_chain(llm):
"""Get doc chain which adds metadata to text chunks."""
# Ref: https://github.com/langchain-ai/langchain/issues/7239
# Function to format each document with an index, source, and content.
def format_document(doc, index, prompt):
"""Format a document into a string based on a prompt template."""
# Create a dictionary with document content and metadata.
base_info = {
"page_content": doc.page_content,
"index": index,
"source": doc.metadata["source"],
}
# Check if any metadata is missing.
missing_metadata = set(prompt.input_variables).difference(base_info)
if len(missing_metadata) > 0:
raise ValueError(f"Missing metadata: {list(missing_metadata)}.")
# Filter only necessary variables for the prompt.
document_info = {k: base_info[k] for k in prompt.input_variables}
return prompt.format(**document_info)
# Custom chain class to handle document combination with source indices.
class StuffDocumentsWithIndexChain(StuffDocumentsChain):
"""Custom chain class to handle document combination with source indices."""
def _get_inputs(self, docs, **kwargs):
"""Overwrite _get_inputs to add metadata for text chunks."""
# Format each document and combine them.
doc_strings = [
format_document(doc, i, self.document_prompt)
for i, doc in enumerate(docs, 1)
]
# Filter only relevant input variables for the LLM chain prompt.
inputs = {
k: v
for k, v in kwargs.items()
if k in self.llm_chain.prompt.input_variables
}
inputs[self.document_variable_name] = self.document_separator.join(
doc_strings
)
return inputs
# Main prompt for RAG chain with citation
# Ref: https://huggingface.co./spaces/Ekimetrics/climate-question-answering/blob/main/climateqa/engine/prompts.py
# Define a chat prompt with instructions for citing documents.
combine_doc_prompt = PromptTemplate(
input_variables=["context", "question"],
template="""You are given a question and passages. Provide a clear and structured Helpful Answer based on the passages provided,
the context and the guidelines.
Guidelines:
- If the passages have useful facts or numbers, use them in your answer.
- When you use information from a passage, mention where it came from by using format [[i]] at the end of the sentence. i stands for the paper index of the document.
- Do not cite the passage in a style like 'passage i', always use format [[i]] where i stands for the passage index of the document.
- Do not use the sentence such as 'Doc i says ...' or '... in Doc i' or 'Passage i ...' to say where information came from.
- If the same thing is said in more than one document, you can mention all of them like this: [[i]], [[j]], [[k]].
- Do not just summarize each passage one by one. Group your summaries to highlight the key parts in the explanation.
- If it makes sense, use bullet points and lists to make your answers easier to understand.
- You do not need to use every passage. Only use the ones that help answer the question.
- If the documents do not have the information needed to answer the question, just say you do not have enough information.
- If the passage is the caption of a picture, you can still use it as part of your answer as any other document.
-----------------------
Passages:
{context}
-----------------------
Question: {question}
Helpful Answer with format citations:""",
)
# Initialize the custom chain with a specific document format.
combine_docs_chain = StuffDocumentsWithIndexChain(
llm_chain=LLMChain(
llm=llm,
prompt=combine_doc_prompt,
),
document_prompt=PromptTemplate(
input_variables=["index", "source", "page_content"],
template="[[{index}]]\nsource: {source}:\n{page_content}",
),
document_variable_name="context",
)
return combine_docs_chain
class RAGChain:
"""Main RAG chain."""
def __init__(
self, memory_key="chat_history", output_key="answer", return_messages=True
):
self.memory_key = memory_key
self.output_key = output_key
self.return_messages = return_messages
def create(self, retriever, llm, add_citation=False):
"""Create a rag chain instance."""
# Memory is kept for later support of conversational chat
memory = ConversationBufferWindowMemory( # Or ConversationBufferMemory
k=2,
memory_key=self.memory_key,
return_messages=self.return_messages,
output_key=self.output_key,
)
# Ref: https://github.com/langchain-ai/langchain/issues/4608
conversation_chain = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=retriever,
memory=memory,
return_source_documents=True,
rephrase_question=False, # disable rephrase, for test purpose
get_chat_history=lambda x: x,
# return_generated_question=True, # for debug
# combine_docs_chain_kwargs={"prompt": PROMPT}, # additional prompt control
# condense_question_prompt=CONDENSE_QUESTION_PROMPT, # additional prompt control
)
# Add citation, ATTENTION: experimental
if add_citation:
cite_combine_docs_chain = get_cite_combine_docs_chain(llm)
conversation_chain.combine_docs_chain = cite_combine_docs_chain
return conversation_chain
|