Autism-RAG / app.py
Shriharshan's picture
Update app.py
5843ebf verified
import json
import spaces
from typing import List
import gradio as gr
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline, ChatHuggingFace
from langchain_community.vectorstores import FAISS
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.messages import HumanMessage, AIMessage
from langchain.docstore.document import Document
model = HuggingFacePipeline.from_model_id(
model_id="HuggingFaceTB/SmolLM2-360M-Instruct",
task="text-generation",
pipeline_kwargs=dict(
max_new_tokens=512,
do_sample=False,
repetition_penalty=1.03,
return_full_text=False,
),
)
llm = ChatHuggingFace(llm=model)
def create_embeddings_model() -> HuggingFaceEmbeddings:
model_name = "BAAI/bge-m3"
model_kwargs = {
'device': 'cpu',
'trust_remote_code': True,
}
encode_kwargs = {'normalize_embeddings': True}
return HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs,
show_progress=True
)
embeddings = create_embeddings_model()
def load_faiss_retriever(path: str) -> FAISS:
vector_store = FAISS.load_local(path, embeddings, allow_dangerous_deserialization=True)
return vector_store.as_retriever(search_kwargs={"k": 10})
def load_bm25_retriever(load_path: str) -> BM25Retriever:
with open(load_path, "r", encoding="utf-8") as f:
docs_json = json.load(f)
documents = [Document(page_content=doc["page_content"], metadata=doc["metadata"]) for doc in docs_json]
return BM25Retriever.from_documents(documents, language="english")
class EmbeddingBM25RerankerRetriever:
def __init__(self, vector_retriever, bm25_retriever, reranker):
self.vector_retriever = vector_retriever
self.bm25_retriever = bm25_retriever
self.reranker = reranker
def invoke(self, query: str):
vector_docs = self.vector_retriever.invoke(query)
bm25_docs = self.bm25_retriever.invoke(query)
combined_docs = vector_docs + [doc for doc in bm25_docs if doc not in vector_docs]
return self.reranker.compress_documents(combined_docs, query)
faiss_path = "VectorDB/faiss_index"
bm25_path = "VectorDB/bm25_index.json"
faiss_retriever = load_faiss_retriever(faiss_path)
bm25_retriever = load_bm25_retriever(bm25_path)
reranker_model = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-v2-m3")
reranker = CrossEncoderReranker(top_n=4, model=reranker_model)
retriever = EmbeddingBM25RerankerRetriever(faiss_retriever, bm25_retriever, reranker)
qa_prompt = ChatPromptTemplate.from_messages([
("system", """You are an AI research assistant specializing in Autism research, powered by a retrieval system of curated PubMed documents.
Response Guidelines:
- Provide precise, evidence-based answers drawing directly from retrieved medical research
- Synthesize information from multiple documents when possible
- Clearly distinguish between established findings and emerging research
- Maintain scientific rigor and objectivity
Query Handling:
- Prioritize direct, informative responses
- When document evidence is incomplete, explain the current state of research
- Highlight areas where more research is needed
- Never introduce speculation or unsupported claims
Contextual Integrity:
- Ensure all statements are traceable to specific research documents
- Preserve the nuance and complexity of scientific findings
- Communicate with clarity, avoiding unnecessary medical jargon
Knowledge Limitations:
- If no relevant information is found, state: "Current research documents do not provide a comprehensive answer to this specific query."
"""),
MessagesPlaceholder("chat_history"),
("human", "Context:\n{context}\n\nQuestion: {input}")
])
def format_context(docs) -> str:
return "\n\n".join([f"Doc {i+1}: {doc.page_content}" for i, doc in enumerate(docs)])
@spaces.GPU
def chat_with_rag(query: str, history: List[tuple[str, str]]) -> str:
chat_history = []
for human, ai in history:
chat_history.append(HumanMessage(content=human))
chat_history.append(AIMessage(content=ai))
docs = retriever.invoke(query)
context = format_context(docs)
prompt_input = {
"chat_history": chat_history,
"context": context,
"input": query
}
prompt = qa_prompt.format(**prompt_input)
response = llm.invoke(prompt)
return response.content
chat_interface = gr.ChatInterface(
fn=chat_with_rag,
title="Autism RAG Chatbot",
description="Ask questions about Autism.",
examples=["What causes Autism?", "How is Autism treated?", "What is Autism"],
)
if __name__ == "__main__":
chat_interface.launch(share=True)