SatyamD31's picture
Update rag.py (#6)
5813bdb verified
import faiss
import numpy as np
import pickle
import threading
import time
import torch
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer
from rank_bm25 import BM25Okapi
class FinancialChatbot:
def __init__(self):
# Load FAISS index
self.faiss_index = faiss.read_index("financial_faiss.index")
with open("index_map.pkl", "rb") as f:
self.index_map = pickle.load(f)
# Extract document texts for BM25 dynamically
self.documents = list(self.index_map.values())
# Build BM25 index dynamically
self.bm25_corpus = [doc.lower().split() for doc in self.documents] # Tokenization
self.bm25 = BM25Okapi(self.bm25_corpus)
# Load SentenceTransformer for embedding-based retrieval
self.sbert_model = SentenceTransformer("all-MiniLM-L6-v2")
# Load Qwen Model
model_name = "Qwen/Qwen2.5-1.5b"
self.qwen_model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype="auto", device_map="auto", trust_remote_code=True
)
self.qwen_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
# Guardrail: Blocked Words
self.BLOCKED_WORDS = [
"hack", "bypass", "illegal", "exploit", "scam", "kill", "laundering",
"murder", "suicide", "self-harm", "assault", "bomb", "terrorism",
"attack", "genocide", "mass shooting", "credit card number"
]
# Relevance threshold
self.min_similarity_threshold = 0.7
def moderate_query(self, query):
"""Check if the query contains inappropriate words."""
query_lower = query.lower()
for word in self.BLOCKED_WORDS :
if word in query_lower:
return False # Block query
return True # Allow query
def query_faiss(self, query, top_k=5):
"""Retrieve relevant documents using FAISS and compute confidence scores."""
query_embedding = self.sbert_model.encode([query], convert_to_numpy=True)
distances, indices = self.faiss_index.search(query_embedding, top_k)
results = []
confidence_scores = []
for idx, dist in zip(indices[0], distances[0]):
if idx in self.index_map:
similarity = 1 / (1 + dist) # Convert L2 distance to similarity
results.append(self.index_map[idx])
confidence_scores.append(similarity)
return results, confidence_scores
def query_bm25(self, query, top_k=5):
"""Retrieve relevant documents using BM25 keyword-based search dynamically."""
tokenized_query = query.lower().split()
scores = self.bm25.get_scores(tokenized_query)
top_indices = np.argsort(scores)[::-1][:top_k]
results = []
confidence_scores = []
for idx in top_indices:
if scores[idx] > 0: # Ignore zero-score matches
results.append(self.documents[idx])
confidence_scores.append(scores[idx])
return results, confidence_scores
def generate_answer(self, context, question):
"""Generate answer using the Qwen model."""
input_text = f"Context: {context}\nQuestion: {question}\nAnswer:"
inputs = self.qwen_tokenizer.encode(input_text, return_tensors="pt")
outputs = self.qwen_model.generate(inputs, max_length=100)
return self.qwen_tokenizer.decode(outputs[0], skip_special_tokens=True)
def get_answer(self, query, timeout=200):
"""Fetch an answer from FAISS and Qwen model while handling timeouts."""
result = ["No relevant information found", 0.0] # Default response
def task():
if query.lower() in ["hi", "hello", "hey"]:
result[:] = ["Hi, how can I help you?", 1.0]
return
if query.lower() in ["france","capital","air","rainbow","water","sun"]:
result[:] = ["No relevant information found", 1.0]
return
if not self.moderate_query(query):
result[:] = ["I'm unable to process your request due to inappropriate language.", 0.0]
return
faiss_results, faiss_conf = self.query_faiss(query)
bm25_results, bm25_conf = self.query_bm25(query)
all_results = faiss_results + bm25_results
all_conf = faiss_conf + bm25_conf
# Check if results are relevant
if not all_results or max(all_conf, default=0) < self.min_similarity_threshold:
result[:] = ["No relevant information found", 0.0]
return
context = " ".join(all_results)
answer = self.generate_answer(context, query)
last_index = answer.rfind("Answer")
extracted_answer = answer[last_index:].strip() if last_index != -1 else ""
# Ensure the answer is grounded in the context
if not extracted_answer or "Answer" not in answer or extracted_answer.isnumeric():
result[:] = ["No relevant information found", 0.0]
else:
result[:] = [extracted_answer, max(all_conf, default=0.9)]
thread = threading.Thread(target=task)
thread.start()
thread.join(timeout)
if thread.is_alive():
return "No relevant information found", 0.0 # Timeout case
return tuple(result)