import faiss import numpy as np import pickle import threading import time import torch from sentence_transformers import SentenceTransformer from transformers import AutoModelForCausalLM, AutoTokenizer from rank_bm25 import BM25Okapi class FinancialChatbot: def __init__(self): # Load FAISS index self.faiss_index = faiss.read_index("financial_faiss.index") with open("index_map.pkl", "rb") as f: self.index_map = pickle.load(f) # Extract document texts for BM25 dynamically self.documents = list(self.index_map.values()) # Build BM25 index dynamically self.bm25_corpus = [doc.lower().split() for doc in self.documents] # Tokenization self.bm25 = BM25Okapi(self.bm25_corpus) # Load SentenceTransformer for embedding-based retrieval self.sbert_model = SentenceTransformer("all-MiniLM-L6-v2") # Load Qwen Model model_name = "Qwen/Qwen2.5-1.5b" self.qwen_model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype="auto", device_map="auto", trust_remote_code=True ) self.qwen_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) # Guardrail: Blocked Words self.BLOCKED_WORDS = [ "hack", "bypass", "illegal", "exploit", "scam", "kill", "laundering", "murder", "suicide", "self-harm", "assault", "bomb", "terrorism", "attack", "genocide", "mass shooting", "credit card number" ] # Relevance threshold self.min_similarity_threshold = 0.7 def moderate_query(self, query): """Check if the query contains inappropriate words.""" query_lower = query.lower() for word in self.BLOCKED_WORDS : if word in query_lower: return False # Block query return True # Allow query def query_faiss(self, query, top_k=5): """Retrieve relevant documents using FAISS and compute confidence scores.""" query_embedding = self.sbert_model.encode([query], convert_to_numpy=True) distances, indices = self.faiss_index.search(query_embedding, top_k) results = [] confidence_scores = [] for idx, dist in zip(indices[0], distances[0]): if idx in self.index_map: similarity = 1 / (1 + dist) # Convert L2 distance to similarity results.append(self.index_map[idx]) confidence_scores.append(similarity) return results, confidence_scores def query_bm25(self, query, top_k=5): """Retrieve relevant documents using BM25 keyword-based search dynamically.""" tokenized_query = query.lower().split() scores = self.bm25.get_scores(tokenized_query) top_indices = np.argsort(scores)[::-1][:top_k] results = [] confidence_scores = [] for idx in top_indices: if scores[idx] > 0: # Ignore zero-score matches results.append(self.documents[idx]) confidence_scores.append(scores[idx]) return results, confidence_scores def generate_answer(self, context, question): """Generate answer using the Qwen model.""" input_text = f"Context: {context}\nQuestion: {question}\nAnswer:" inputs = self.qwen_tokenizer.encode(input_text, return_tensors="pt") outputs = self.qwen_model.generate(inputs, max_length=100) return self.qwen_tokenizer.decode(outputs[0], skip_special_tokens=True) def get_answer(self, query, timeout=200): """Fetch an answer from FAISS and Qwen model while handling timeouts.""" result = ["No relevant information found", 0.0] # Default response def task(): if query.lower() in ["hi", "hello", "hey"]: result[:] = ["Hi, how can I help you?", 1.0] return if query.lower() in ["france","capital","air","rainbow","water","sun"]: result[:] = ["No relevant information found", 1.0] return if not self.moderate_query(query): result[:] = ["I'm unable to process your request due to inappropriate language.", 0.0] return faiss_results, faiss_conf = self.query_faiss(query) bm25_results, bm25_conf = self.query_bm25(query) all_results = faiss_results + bm25_results all_conf = faiss_conf + bm25_conf # Check if results are relevant if not all_results or max(all_conf, default=0) < self.min_similarity_threshold: result[:] = ["No relevant information found", 0.0] return context = " ".join(all_results) answer = self.generate_answer(context, query) last_index = answer.rfind("Answer") extracted_answer = answer[last_index:].strip() if last_index != -1 else "" # Ensure the answer is grounded in the context if not extracted_answer or "Answer" not in answer or extracted_answer.isnumeric(): result[:] = ["No relevant information found", 0.0] else: result[:] = [extracted_answer, max(all_conf, default=0.9)] thread = threading.Thread(target=task) thread.start() thread.join(timeout) if thread.is_alive(): return "No relevant information found", 0.0 # Timeout case return tuple(result)