LISA-demo / ragchain.py
Kadi-IAM's picture
Clean code and add readme
1a20a59
"""
Main RAG chain based on langchain.
"""
from langchain.chains import LLMChain
from langchain.prompts import (
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
ChatPromptTemplate,
PromptTemplate,
)
from langchain.chains import ConversationalRetrievalChain
from langchain.chains.conversation.memory import (
ConversationBufferWindowMemory,
)
from langchain.chains import StuffDocumentsChain
def get_cite_combine_docs_chain(llm):
"""Get doc chain which adds metadata to text chunks."""
# Ref: https://github.com/langchain-ai/langchain/issues/7239
# Function to format each document with an index, source, and content.
def format_document(doc, index, prompt):
"""Format a document into a string based on a prompt template."""
# Create a dictionary with document content and metadata.
base_info = {
"page_content": doc.page_content,
"index": index,
"source": doc.metadata["source"],
}
# Check if any metadata is missing.
missing_metadata = set(prompt.input_variables).difference(base_info)
if len(missing_metadata) > 0:
raise ValueError(f"Missing metadata: {list(missing_metadata)}.")
# Filter only necessary variables for the prompt.
document_info = {k: base_info[k] for k in prompt.input_variables}
return prompt.format(**document_info)
# Custom chain class to handle document combination with source indices.
class StuffDocumentsWithIndexChain(StuffDocumentsChain):
"""Custom chain class to handle document combination with source indices."""
def _get_inputs(self, docs, **kwargs):
"""Overwrite _get_inputs to add metadata for text chunks."""
# Format each document and combine them.
doc_strings = [
format_document(doc, i, self.document_prompt)
for i, doc in enumerate(docs, 1)
]
# Filter only relevant input variables for the LLM chain prompt.
inputs = {
k: v
for k, v in kwargs.items()
if k in self.llm_chain.prompt.input_variables
}
inputs[self.document_variable_name] = self.document_separator.join(
doc_strings
)
return inputs
# Main prompt for RAG chain with citation
# Ref: https://huggingface.co./spaces/Ekimetrics/climate-question-answering/blob/main/climateqa/engine/prompts.py
# Define a chat prompt with instructions for citing documents.
combine_doc_prompt = PromptTemplate(
input_variables=["context", "question"],
template="""You are given a question and passages. Provide a clear and structured Helpful Answer based on the passages provided,
the context and the guidelines.
Guidelines:
- If the passages have useful facts or numbers, use them in your answer.
- When you use information from a passage, mention where it came from by using format [[i]] at the end of the sentence. i stands for the paper index of the document.
- Do not cite the passage in a style like 'passage i', always use format [[i]] where i stands for the passage index of the document.
- Do not use the sentence such as 'Doc i says ...' or '... in Doc i' or 'Passage i ...' to say where information came from.
- If the same thing is said in more than one document, you can mention all of them like this: [[i]], [[j]], [[k]].
- Do not just summarize each passage one by one. Group your summaries to highlight the key parts in the explanation.
- If it makes sense, use bullet points and lists to make your answers easier to understand.
- You do not need to use every passage. Only use the ones that help answer the question.
- If the documents do not have the information needed to answer the question, just say you do not have enough information.
- If the passage is the caption of a picture, you can still use it as part of your answer as any other document.
-----------------------
Passages:
{context}
-----------------------
Question: {question}
Helpful Answer with format citations:""",
)
# Initialize the custom chain with a specific document format.
combine_docs_chain = StuffDocumentsWithIndexChain(
llm_chain=LLMChain(
llm=llm,
prompt=combine_doc_prompt,
),
document_prompt=PromptTemplate(
input_variables=["index", "source", "page_content"],
template="[[{index}]]\nsource: {source}:\n{page_content}",
),
document_variable_name="context",
)
return combine_docs_chain
class RAGChain:
"""Main RAG chain."""
def __init__(
self, memory_key="chat_history", output_key="answer", return_messages=True
):
self.memory_key = memory_key
self.output_key = output_key
self.return_messages = return_messages
def create(self, retriever, llm, add_citation=False):
"""Create a rag chain instance."""
# Memory is kept for later support of conversational chat
memory = ConversationBufferWindowMemory( # Or ConversationBufferMemory
k=2,
memory_key=self.memory_key,
return_messages=self.return_messages,
output_key=self.output_key,
)
# Ref: https://github.com/langchain-ai/langchain/issues/4608
conversation_chain = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=retriever,
memory=memory,
return_source_documents=True,
rephrase_question=False, # disable rephrase, for test purpose
get_chat_history=lambda x: x,
# return_generated_question=True, # for debug
# combine_docs_chain_kwargs={"prompt": PROMPT}, # additional prompt control
# condense_question_prompt=CONDENSE_QUESTION_PROMPT, # additional prompt control
)
# Add citation, ATTENTION: experimental
if add_citation:
cite_combine_docs_chain = get_cite_combine_docs_chain(llm)
conversation_chain.combine_docs_chain = cite_combine_docs_chain
return conversation_chain