ragdoing / rag /rag_class.py
chengyingmo's picture
Upload 38 files
c604980 verified
from langchain_community.vectorstores import Chroma,FAISS
from langchain_community.llms import Ollama
from langchain_core.output_parsers import StrOutputParser
from langchain_community.embeddings import OllamaEmbeddings
from langchain_core.runnables import RunnablePassthrough
from operator import itemgetter
from langchain.prompts import ChatPromptTemplate
from rerank_code import rerank_topn
from Config.config import VECTOR_DB,DB_directory
from langchain_elasticsearch.vectorstores import ElasticsearchStore
class RAG_class:
def __init__(self, model="qwen2:7b", embed="milkey/dmeta-embedding-zh:f16", c_name="sss1",
persist_directory="E:/pycode/jupyter_code/langGraph/sss2/chroma.sqlite3/",es_url="http://localhost:9200"):
template = """
根据上下文回答以下问题,不要自己发挥,要根据以下参考内容总结答案,如果以下内容无法得到答案,就返回无法根据参考内容获取答案,
参考内容为:{context}
问题: {question}
"""
self.prompts = ChatPromptTemplate.from_template(template)
# 使用 问题扩展+结果递归方式得到最终答案
template1 = """你是一个乐于助人的助手,可以生成与输入问题相关的多个子问题。
目标是将输入分解为一组可以单独回答的子问题/子问题。
生成多个与以下内容相关的搜索查询:{question}
输出4个相关问题,以换行符隔开:"""
self.prompt_questions = ChatPromptTemplate.from_template(template1)
# 构建 问答对
template2 = """
以下是您需要回答的问题:
\n--\n {question} \n---\n
以下是任何可用的背景问答对:
\n--\n {q_a_pairs} \n---\n
以下是与该问题相关的其他上下文:
\n--\n {context} \n---\n
使用以上上下文和背景问答对来回答问题,问题是:{question} ,答案是:
"""
self.decomposition_prompt = ChatPromptTemplate.from_template(template2)
self.llm = Ollama(model=model)
self.embeding = OllamaEmbeddings(model=embed)
if VECTOR_DB==1:
self.vectstore = Chroma(embedding_function=self.embeding, collection_name=c_name,
persist_directory=persist_directory)
elif VECTOR_DB ==2:
self.vectstore = FAISS.load_local(folder_path=persist_directory + c_name, embeddings=self.embeding,
allow_dangerous_deserialization=True)
elif VECTOR_DB ==3:
self.vectstore = ElasticsearchStore(
es_url=es_url,
index_name=c_name,
embedding=self.embeding
)
self.retriever = self.vectstore.as_retriever()
try:
if VECTOR_DB==1:
self.vectstore = Chroma(embedding_function=self.embeding, collection_name=c_name,
persist_directory=persist_directory)
elif VECTOR_DB ==2:
self.vectstore = FAISS.load_local(folder_path=persist_directory + c_name, embeddings=self.embeding,
allow_dangerous_deserialization=True)
elif VECTOR_DB ==3:
self.vectstore = ElasticsearchStore(
es_url=es_url,
index_name=c_name,
embedding=self.embeding
)
self.retriever = self.vectstore.as_retriever()
except Exception as e:
print("仅模型时无需加载数据库",e)
#
# Post-processing
def format_docs(self,docs):
return "\n\n".join(doc.page_content for doc in docs)
# 传统方式召回,单问题召回,然后llm总结答案回答
def simple_chain(self,question):
_chain = (
{"context": self.retriever|self.format_docs,"question":RunnablePassthrough()}
|self.prompts
|self.llm
|StrOutputParser()
)
answer = _chain.invoke({"question":question})
return answer
def rerank_chain(self,question):
retriever = self.vectstore.as_retriever(search_kwargs={"k": 10})
docs = retriever.invoke(question)
docs = rerank_topn(question,docs,N=5)
_chain = (
self.prompts
| self.llm
| StrOutputParser()
)
answer = _chain.invoke({"context":self.format_docs(docs),"question": question})
return answer
def format_qa_pairs(self, question, answer):
formatted_string = ""
formatted_string += f"Question: {question}\nAnswer:{answer}\n\n"
return formatted_string
# 获取问题的 扩展问题
def decomposition_chain(self, question):
_chain = (
{"question": RunnablePassthrough()}
| self.prompt_questions
| self.llm
| StrOutputParser()
| (lambda x: x.split("\n"))
)
questions = _chain.invoke({"question": question}) + [question]
return questions
# 多问题递归召回,每次召回后,问题和答案同时作为下一次召回的参考,再次用新问题召回
def rag_chain(self, questions):
q_a_pairs = ""
for q in questions:
_chain = (
{"context": itemgetter("question") | self.retriever,
"question": itemgetter("question"),
"q_a_pairs": itemgetter("q_a_paris")
}
| self.decomposition_prompt
| self.llm
| StrOutputParser()
)
answer = _chain.invoke({"question": q, "q_a_paris": q_a_pairs})
q_a_pairs = self.format_qa_pairs(q, answer)
q_a_pairs = q_a_pairs + "\n----\n" + q_a_pairs
return answer
# 将聊天历史格式化为一个字符串
def format_chat_history(self,history):
formatted_history = ""
for role,content in history:
formatted_history += f"{role}: {content}\n"
return formatted_history
# 基于ollama大模型的大模型 多轮对话,不使用知识库的
def mult_chat(self,chat_history):
# 格式化聊天历史
formatted_history = self.format_chat_history(chat_history)
# 调用模型生成回复
response = self.llm.invoke(formatted_history)
return response
# if __name__ == "__main__":
# rag = RAG_class(model="deepseek-r1:14b")
# question = "人卫社官网网址是?"
# questions = rag.decomposition_chain(question)
# print(questions)
# answer = rag.rag_chain(questions)
# print(answer)