Spaces:
Running
Running
File size: 5,251 Bytes
43d3036 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import json
import spaces
from typing import List
import gradio as gr
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline, ChatHuggingFace
from langchain_community.vectorstores import FAISS
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.messages import HumanMessage, AIMessage
from langchain.docstore.document import Document
model = HuggingFacePipeline.from_model_id(
model_id="HuggingFaceTB/SmolLM2-360M-Instruct",
task="text-generation",
pipeline_kwargs=dict(
max_new_tokens=512,
do_sample=False,
repetition_penalty=1.03,
return_full_text=False,
),
)
llm = ChatHuggingFace(llm=model)
def create_embeddings_model() -> HuggingFaceEmbeddings:
model_name = "BAAI/bge-m3"
model_kwargs = {
'device': 'cpu',
'trust_remote_code': True,
}
encode_kwargs = {'normalize_embeddings': True}
return HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs,
show_progress=True
)
embeddings = create_embeddings_model()
def load_faiss_retriever(path: str) -> FAISS:
vector_store = FAISS.load_local(path, embeddings, allow_dangerous_deserialization=True)
return vector_store.as_retriever(search_kwargs={"k": 10})
def load_bm25_retriever(load_path: str) -> BM25Retriever:
with open(load_path, "r", encoding="utf-8") as f:
docs_json = json.load(f)
documents = [Document(page_content=doc["page_content"], metadata=doc["metadata"]) for doc in docs_json]
return BM25Retriever.from_documents(documents, language="english")
class EmbeddingBM25RerankerRetriever:
def __init__(self, vector_retriever, bm25_retriever, reranker):
self.vector_retriever = vector_retriever
self.bm25_retriever = bm25_retriever
self.reranker = reranker
def invoke(self, query: str):
vector_docs = self.vector_retriever.invoke(query)
bm25_docs = self.bm25_retriever.invoke(query)
combined_docs = vector_docs + [doc for doc in bm25_docs if doc not in vector_docs]
return self.reranker.compress_documents(combined_docs, query)
faiss_path = "VectorDB/faiss_index"
bm25_path = "VectorDB/bm25_index.json"
faiss_retriever = load_faiss_retriever(faiss_path)
bm25_retriever = load_bm25_retriever(bm25_path)
reranker_model = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-v2-m3")
reranker = CrossEncoderReranker(top_n=4, model=reranker_model)
retriever = EmbeddingBM25RerankerRetriever(faiss_retriever, bm25_retriever, reranker)
qa_prompt = ChatPromptTemplate.from_messages([
("system", """You are an AI research assistant specializing in Autism research, powered by a retrieval system of curated PubMed documents.
Response Guidelines:
- Provide precise, evidence-based answers drawing directly from retrieved medical research
- Synthesize information from multiple documents when possible
- Clearly distinguish between established findings and emerging research
- Maintain scientific rigor and objectivity
Query Handling:
- Prioritize direct, informative responses
- When document evidence is incomplete, explain the current state of research
- Highlight areas where more research is needed
- Never introduce speculation or unsupported claims
Contextual Integrity:
- Ensure all statements are traceable to specific research documents
- Preserve the nuance and complexity of scientific findings
- Communicate with clarity, avoiding unnecessary medical jargon
Knowledge Limitations:
- If no relevant information is found, state: "Current research documents do not provide a comprehensive answer to this specific query."
"""),
MessagesPlaceholder("chat_history"),
("human", "Context:\n{context}\n\nQuestion: {input}")
])
def format_context(docs) -> str:
return "\n\n".join([f"Doc {i+1}: {doc.page_content}" for i, doc in enumerate(docs)])
@spaces.GPU
def chat_with_rag(query: str, history: List[tuple[str, str]]) -> str:
chat_history = []
for human, ai in history:
chat_history.append(HumanMessage(content=human))
chat_history.append(AIMessage(content=ai))
docs = retriever.invoke(query)
context = format_context(docs)
prompt_input = {
"chat_history": chat_history,
"context": context,
"input": query
}
prompt = qa_prompt.format(**prompt_input)
response = llm.invoke(prompt)
return response.content
chat_interface = gr.ChatInterface(
fn=chat_with_rag,
title="Autism RAG Chatbot",
description="Ask questions about Autism.",
examples=["What causes Autism?", "How is Autism treated?", "What is Autism"],
)
if __name__ == "__main__":
chat_interface.launch(share=True)
|