Spaces:
Configuration error
Configuration error
from langchain_community.vectorstores import Chroma,FAISS | |
from langchain_community.llms import Ollama | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_community.embeddings import OllamaEmbeddings | |
from langchain_core.runnables import RunnablePassthrough | |
from operator import itemgetter | |
from langchain.prompts import ChatPromptTemplate | |
from rerank_code import rerank_topn | |
from Config.config import VECTOR_DB,DB_directory | |
from langchain_elasticsearch.vectorstores import ElasticsearchStore | |
class RAG_class: | |
def __init__(self, model="qwen2:7b", embed="milkey/dmeta-embedding-zh:f16", c_name="sss1", | |
persist_directory="E:/pycode/jupyter_code/langGraph/sss2/chroma.sqlite3/",es_url="http://localhost:9200"): | |
template = """ | |
根据上下文回答以下问题,不要自己发挥,要根据以下参考内容总结答案,如果以下内容无法得到答案,就返回无法根据参考内容获取答案, | |
参考内容为:{context} | |
问题: {question} | |
""" | |
self.prompts = ChatPromptTemplate.from_template(template) | |
# 使用 问题扩展+结果递归方式得到最终答案 | |
template1 = """你是一个乐于助人的助手,可以生成与输入问题相关的多个子问题。 | |
目标是将输入分解为一组可以单独回答的子问题/子问题。 | |
生成多个与以下内容相关的搜索查询:{question} | |
输出4个相关问题,以换行符隔开:""" | |
self.prompt_questions = ChatPromptTemplate.from_template(template1) | |
# 构建 问答对 | |
template2 = """ | |
以下是您需要回答的问题: | |
\n--\n {question} \n---\n | |
以下是任何可用的背景问答对: | |
\n--\n {q_a_pairs} \n---\n | |
以下是与该问题相关的其他上下文: | |
\n--\n {context} \n---\n | |
使用以上上下文和背景问答对来回答问题,问题是:{question} ,答案是: | |
""" | |
self.decomposition_prompt = ChatPromptTemplate.from_template(template2) | |
self.llm = Ollama(model=model) | |
self.embeding = OllamaEmbeddings(model=embed) | |
if VECTOR_DB==1: | |
self.vectstore = Chroma(embedding_function=self.embeding, collection_name=c_name, | |
persist_directory=persist_directory) | |
elif VECTOR_DB ==2: | |
self.vectstore = FAISS.load_local(folder_path=persist_directory + c_name, embeddings=self.embeding, | |
allow_dangerous_deserialization=True) | |
elif VECTOR_DB ==3: | |
self.vectstore = ElasticsearchStore( | |
es_url=es_url, | |
index_name=c_name, | |
embedding=self.embeding | |
) | |
self.retriever = self.vectstore.as_retriever() | |
try: | |
if VECTOR_DB==1: | |
self.vectstore = Chroma(embedding_function=self.embeding, collection_name=c_name, | |
persist_directory=persist_directory) | |
elif VECTOR_DB ==2: | |
self.vectstore = FAISS.load_local(folder_path=persist_directory + c_name, embeddings=self.embeding, | |
allow_dangerous_deserialization=True) | |
elif VECTOR_DB ==3: | |
self.vectstore = ElasticsearchStore( | |
es_url=es_url, | |
index_name=c_name, | |
embedding=self.embeding | |
) | |
self.retriever = self.vectstore.as_retriever() | |
except Exception as e: | |
print("仅模型时无需加载数据库",e) | |
# | |
# Post-processing | |
def format_docs(self,docs): | |
return "\n\n".join(doc.page_content for doc in docs) | |
# 传统方式召回,单问题召回,然后llm总结答案回答 | |
def simple_chain(self,question): | |
_chain = ( | |
{"context": self.retriever|self.format_docs,"question":RunnablePassthrough()} | |
|self.prompts | |
|self.llm | |
|StrOutputParser() | |
) | |
answer = _chain.invoke({"question":question}) | |
return answer | |
def rerank_chain(self,question): | |
retriever = self.vectstore.as_retriever(search_kwargs={"k": 10}) | |
docs = retriever.invoke(question) | |
docs = rerank_topn(question,docs,N=5) | |
_chain = ( | |
self.prompts | |
| self.llm | |
| StrOutputParser() | |
) | |
answer = _chain.invoke({"context":self.format_docs(docs),"question": question}) | |
return answer | |
def format_qa_pairs(self, question, answer): | |
formatted_string = "" | |
formatted_string += f"Question: {question}\nAnswer:{answer}\n\n" | |
return formatted_string | |
# 获取问题的 扩展问题 | |
def decomposition_chain(self, question): | |
_chain = ( | |
{"question": RunnablePassthrough()} | |
| self.prompt_questions | |
| self.llm | |
| StrOutputParser() | |
| (lambda x: x.split("\n")) | |
) | |
questions = _chain.invoke({"question": question}) + [question] | |
return questions | |
# 多问题递归召回,每次召回后,问题和答案同时作为下一次召回的参考,再次用新问题召回 | |
def rag_chain(self, questions): | |
q_a_pairs = "" | |
for q in questions: | |
_chain = ( | |
{"context": itemgetter("question") | self.retriever, | |
"question": itemgetter("question"), | |
"q_a_pairs": itemgetter("q_a_paris") | |
} | |
| self.decomposition_prompt | |
| self.llm | |
| StrOutputParser() | |
) | |
answer = _chain.invoke({"question": q, "q_a_paris": q_a_pairs}) | |
q_a_pairs = self.format_qa_pairs(q, answer) | |
q_a_pairs = q_a_pairs + "\n----\n" + q_a_pairs | |
return answer | |
# 将聊天历史格式化为一个字符串 | |
def format_chat_history(self,history): | |
formatted_history = "" | |
for role,content in history: | |
formatted_history += f"{role}: {content}\n" | |
return formatted_history | |
# 基于ollama大模型的大模型 多轮对话,不使用知识库的 | |
def mult_chat(self,chat_history): | |
# 格式化聊天历史 | |
formatted_history = self.format_chat_history(chat_history) | |
# 调用模型生成回复 | |
response = self.llm.invoke(formatted_history) | |
return response | |
# if __name__ == "__main__": | |
# rag = RAG_class(model="deepseek-r1:14b") | |
# question = "人卫社官网网址是?" | |
# questions = rag.decomposition_chain(question) | |
# print(questions) | |
# answer = rag.rag_chain(questions) | |
# print(answer) |