File size: 5,251 Bytes
43d3036
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import json
import spaces
from typing import List
import gradio as gr
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline, ChatHuggingFace
from langchain_community.vectorstores import FAISS
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.messages import HumanMessage, AIMessage
from langchain.docstore.document import Document


model = HuggingFacePipeline.from_model_id(
    model_id="HuggingFaceTB/SmolLM2-360M-Instruct",
    task="text-generation",
    pipeline_kwargs=dict(
        max_new_tokens=512,
        do_sample=False,
        repetition_penalty=1.03,
        return_full_text=False,
    ),
)

llm = ChatHuggingFace(llm=model)

def create_embeddings_model() -> HuggingFaceEmbeddings:
    model_name = "BAAI/bge-m3"
    model_kwargs = {
        'device': 'cpu',
        'trust_remote_code': True,
    }
    encode_kwargs = {'normalize_embeddings': True}
    return HuggingFaceEmbeddings(
        model_name=model_name,
        model_kwargs=model_kwargs,
        encode_kwargs=encode_kwargs,
        show_progress=True
    )

embeddings = create_embeddings_model()

def load_faiss_retriever(path: str) -> FAISS:
    vector_store = FAISS.load_local(path, embeddings, allow_dangerous_deserialization=True)
    return vector_store.as_retriever(search_kwargs={"k": 10})

def load_bm25_retriever(load_path: str) -> BM25Retriever:
    with open(load_path, "r", encoding="utf-8") as f:
        docs_json = json.load(f)
    documents = [Document(page_content=doc["page_content"], metadata=doc["metadata"]) for doc in docs_json]
    return BM25Retriever.from_documents(documents, language="english")

class EmbeddingBM25RerankerRetriever:
    def __init__(self, vector_retriever, bm25_retriever, reranker):
        self.vector_retriever = vector_retriever
        self.bm25_retriever = bm25_retriever
        self.reranker = reranker

    def invoke(self, query: str):
        vector_docs = self.vector_retriever.invoke(query)
        bm25_docs = self.bm25_retriever.invoke(query)
        combined_docs = vector_docs + [doc for doc in bm25_docs if doc not in vector_docs]
        return self.reranker.compress_documents(combined_docs, query)

faiss_path = "VectorDB/faiss_index"
bm25_path = "VectorDB/bm25_index.json"

faiss_retriever = load_faiss_retriever(faiss_path)
bm25_retriever = load_bm25_retriever(bm25_path)
reranker_model = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-v2-m3")
reranker = CrossEncoderReranker(top_n=4, model=reranker_model)
retriever = EmbeddingBM25RerankerRetriever(faiss_retriever, bm25_retriever, reranker)

qa_prompt = ChatPromptTemplate.from_messages([
    ("system", """You are an AI research assistant specializing in Autism research, powered by a retrieval system of curated PubMed documents.

                Response Guidelines:
                - Provide precise, evidence-based answers drawing directly from retrieved medical research
                - Synthesize information from multiple documents when possible
                - Clearly distinguish between established findings and emerging research
                - Maintain scientific rigor and objectivity

                Query Handling:
                - Prioritize direct, informative responses
                - When document evidence is incomplete, explain the current state of research
                - Highlight areas where more research is needed
                - Never introduce speculation or unsupported claims

                Contextual Integrity:
                - Ensure all statements are traceable to specific research documents
                - Preserve the nuance and complexity of scientific findings
                - Communicate with clarity, avoiding unnecessary medical jargon

                Knowledge Limitations:
                - If no relevant information is found, state: "Current research documents do not provide a comprehensive answer to this specific query."
                """),
    MessagesPlaceholder("chat_history"),
    ("human", "Context:\n{context}\n\nQuestion: {input}")
])

def format_context(docs) -> str:
    return "\n\n".join([f"Doc {i+1}: {doc.page_content}" for i, doc in enumerate(docs)])

@spaces.GPU
def chat_with_rag(query: str, history: List[tuple[str, str]]) -> str:
    chat_history = []
    for human, ai in history:
        chat_history.append(HumanMessage(content=human))
        chat_history.append(AIMessage(content=ai))

    docs = retriever.invoke(query)
    context = format_context(docs)

    prompt_input = {
        "chat_history": chat_history,
        "context": context,
        "input": query
    }
    prompt = qa_prompt.format(**prompt_input)

    response = llm.invoke(prompt)
    return response.content

chat_interface = gr.ChatInterface(
    fn=chat_with_rag,
    title="Autism RAG Chatbot",
    description="Ask questions about Autism.",
    examples=["What causes Autism?", "How is Autism treated?", "What is Autism"],
)

if __name__ == "__main__":
    chat_interface.launch(share=True)