|
""" |
|
Main RAG chain based on langchain. |
|
""" |
|
|
|
from langchain.chains import LLMChain |
|
|
|
from langchain.prompts import ( |
|
SystemMessagePromptTemplate, |
|
HumanMessagePromptTemplate, |
|
ChatPromptTemplate, |
|
PromptTemplate, |
|
) |
|
|
|
from langchain.chains import ConversationalRetrievalChain |
|
from langchain.chains.conversation.memory import ( |
|
ConversationBufferWindowMemory, |
|
) |
|
from langchain.chains import StuffDocumentsChain |
|
|
|
|
|
def get_cite_combine_docs_chain(llm): |
|
"""Get doc chain which adds metadata to text chunks.""" |
|
|
|
|
|
|
|
def format_document(doc, index, prompt): |
|
"""Format a document into a string based on a prompt template.""" |
|
|
|
|
|
base_info = { |
|
"page_content": doc.page_content, |
|
"index": index, |
|
"source": doc.metadata["source"], |
|
} |
|
|
|
|
|
missing_metadata = set(prompt.input_variables).difference(base_info) |
|
if len(missing_metadata) > 0: |
|
raise ValueError(f"Missing metadata: {list(missing_metadata)}.") |
|
|
|
|
|
document_info = {k: base_info[k] for k in prompt.input_variables} |
|
return prompt.format(**document_info) |
|
|
|
|
|
class StuffDocumentsWithIndexChain(StuffDocumentsChain): |
|
"""Custom chain class to handle document combination with source indices.""" |
|
|
|
def _get_inputs(self, docs, **kwargs): |
|
"""Overwrite _get_inputs to add metadata for text chunks.""" |
|
|
|
|
|
doc_strings = [ |
|
format_document(doc, i, self.document_prompt) |
|
for i, doc in enumerate(docs, 1) |
|
] |
|
|
|
|
|
inputs = { |
|
k: v |
|
for k, v in kwargs.items() |
|
if k in self.llm_chain.prompt.input_variables |
|
} |
|
inputs[self.document_variable_name] = self.document_separator.join( |
|
doc_strings |
|
) |
|
return inputs |
|
|
|
|
|
|
|
|
|
combine_doc_prompt = PromptTemplate( |
|
input_variables=["context", "question"], |
|
template="""You are given a question and passages. Provide a clear and structured Helpful Answer based on the passages provided, |
|
the context and the guidelines. |
|
|
|
Guidelines: |
|
- If the passages have useful facts or numbers, use them in your answer. |
|
- When you use information from a passage, mention where it came from by using format [[i]] at the end of the sentence. i stands for the paper index of the document. |
|
- Do not cite the passage in a style like 'passage i', always use format [[i]] where i stands for the passage index of the document. |
|
- Do not use the sentence such as 'Doc i says ...' or '... in Doc i' or 'Passage i ...' to say where information came from. |
|
- If the same thing is said in more than one document, you can mention all of them like this: [[i]], [[j]], [[k]]. |
|
- Do not just summarize each passage one by one. Group your summaries to highlight the key parts in the explanation. |
|
- If it makes sense, use bullet points and lists to make your answers easier to understand. |
|
- You do not need to use every passage. Only use the ones that help answer the question. |
|
- If the documents do not have the information needed to answer the question, just say you do not have enough information. |
|
- If the passage is the caption of a picture, you can still use it as part of your answer as any other document. |
|
|
|
----------------------- |
|
Passages: |
|
{context} |
|
----------------------- |
|
Question: {question} |
|
|
|
Helpful Answer with format citations:""", |
|
) |
|
|
|
|
|
combine_docs_chain = StuffDocumentsWithIndexChain( |
|
llm_chain=LLMChain( |
|
llm=llm, |
|
prompt=combine_doc_prompt, |
|
), |
|
document_prompt=PromptTemplate( |
|
input_variables=["index", "source", "page_content"], |
|
template="[[{index}]]\nsource: {source}:\n{page_content}", |
|
), |
|
document_variable_name="context", |
|
) |
|
|
|
return combine_docs_chain |
|
|
|
|
|
class RAGChain: |
|
"""Main RAG chain.""" |
|
|
|
def __init__( |
|
self, memory_key="chat_history", output_key="answer", return_messages=True |
|
): |
|
self.memory_key = memory_key |
|
self.output_key = output_key |
|
self.return_messages = return_messages |
|
|
|
def create(self, retriever, llm, add_citation=False): |
|
"""Create a rag chain instance.""" |
|
|
|
|
|
memory = ConversationBufferWindowMemory( |
|
k=2, |
|
memory_key=self.memory_key, |
|
return_messages=self.return_messages, |
|
output_key=self.output_key, |
|
) |
|
|
|
|
|
conversation_chain = ConversationalRetrievalChain.from_llm( |
|
llm=llm, |
|
retriever=retriever, |
|
memory=memory, |
|
return_source_documents=True, |
|
rephrase_question=False, |
|
get_chat_history=lambda x: x, |
|
|
|
|
|
|
|
) |
|
|
|
|
|
if add_citation: |
|
cite_combine_docs_chain = get_cite_combine_docs_chain(llm) |
|
conversation_chain.combine_docs_chain = cite_combine_docs_chain |
|
|
|
return conversation_chain |
|
|