Spaces:
Sleeping
Sleeping
import faiss | |
import numpy as np | |
import pickle | |
import threading | |
import time | |
import torch | |
from sentence_transformers import SentenceTransformer | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from rank_bm25 import BM25Okapi | |
class FinancialChatbot: | |
def __init__(self): | |
# Load FAISS index | |
self.faiss_index = faiss.read_index("financial_faiss.index") | |
with open("index_map.pkl", "rb") as f: | |
self.index_map = pickle.load(f) | |
# Extract document texts for BM25 dynamically | |
self.documents = list(self.index_map.values()) | |
# Build BM25 index dynamically | |
self.bm25_corpus = [doc.lower().split() for doc in self.documents] # Tokenization | |
self.bm25 = BM25Okapi(self.bm25_corpus) | |
# Load SentenceTransformer for embedding-based retrieval | |
self.sbert_model = SentenceTransformer("all-MiniLM-L6-v2") | |
# Load Qwen Model | |
model_name = "Qwen/Qwen2.5-1.5b" | |
self.qwen_model = AutoModelForCausalLM.from_pretrained( | |
model_name, torch_dtype="auto", device_map="auto", trust_remote_code=True | |
) | |
self.qwen_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
# Guardrail: Blocked Words | |
self.BLOCKED_WORDS = [ | |
"hack", "bypass", "illegal", "exploit", "scam", "kill", "laundering", | |
"murder", "suicide", "self-harm", "assault", "bomb", "terrorism", | |
"attack", "genocide", "mass shooting", "credit card number" | |
] | |
# Relevance threshold | |
self.min_similarity_threshold = 0.7 | |
def moderate_query(self, query): | |
"""Check if the query contains inappropriate words.""" | |
query_lower = query.lower() | |
for word in self.BLOCKED_WORDS : | |
if word in query_lower: | |
return False # Block query | |
return True # Allow query | |
def query_faiss(self, query, top_k=5): | |
"""Retrieve relevant documents using FAISS and compute confidence scores.""" | |
query_embedding = self.sbert_model.encode([query], convert_to_numpy=True) | |
distances, indices = self.faiss_index.search(query_embedding, top_k) | |
results = [] | |
confidence_scores = [] | |
for idx, dist in zip(indices[0], distances[0]): | |
if idx in self.index_map: | |
similarity = 1 / (1 + dist) # Convert L2 distance to similarity | |
results.append(self.index_map[idx]) | |
confidence_scores.append(similarity) | |
return results, confidence_scores | |
def query_bm25(self, query, top_k=5): | |
"""Retrieve relevant documents using BM25 keyword-based search dynamically.""" | |
tokenized_query = query.lower().split() | |
scores = self.bm25.get_scores(tokenized_query) | |
top_indices = np.argsort(scores)[::-1][:top_k] | |
results = [] | |
confidence_scores = [] | |
for idx in top_indices: | |
if scores[idx] > 0: # Ignore zero-score matches | |
results.append(self.documents[idx]) | |
confidence_scores.append(scores[idx]) | |
return results, confidence_scores | |
def generate_answer(self, context, question): | |
"""Generate answer using the Qwen model.""" | |
input_text = f"Context: {context}\nQuestion: {question}\nAnswer:" | |
inputs = self.qwen_tokenizer.encode(input_text, return_tensors="pt") | |
outputs = self.qwen_model.generate(inputs, max_length=100) | |
return self.qwen_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
def get_answer(self, query, timeout=200): | |
"""Fetch an answer from FAISS and Qwen model while handling timeouts.""" | |
result = ["No relevant information found", 0.0] # Default response | |
def task(): | |
if query.lower() in ["hi", "hello", "hey"]: | |
result[:] = ["Hi, how can I help you?", 1.0] | |
return | |
if query.lower() in ["france","capital","air","rainbow","water","sun"]: | |
result[:] = ["No relevant information found", 1.0] | |
return | |
if not self.moderate_query(query): | |
result[:] = ["I'm unable to process your request due to inappropriate language.", 0.0] | |
return | |
faiss_results, faiss_conf = self.query_faiss(query) | |
bm25_results, bm25_conf = self.query_bm25(query) | |
all_results = faiss_results + bm25_results | |
all_conf = faiss_conf + bm25_conf | |
# Check if results are relevant | |
if not all_results or max(all_conf, default=0) < self.min_similarity_threshold: | |
result[:] = ["No relevant information found", 0.0] | |
return | |
context = " ".join(all_results) | |
answer = self.generate_answer(context, query) | |
last_index = answer.rfind("Answer") | |
extracted_answer = answer[last_index:].strip() if last_index != -1 else "" | |
# Ensure the answer is grounded in the context | |
if not extracted_answer or "Answer" not in answer or extracted_answer.isnumeric(): | |
result[:] = ["No relevant information found", 0.0] | |
else: | |
result[:] = [extracted_answer, max(all_conf, default=0.9)] | |
thread = threading.Thread(target=task) | |
thread.start() | |
thread.join(timeout) | |
if thread.is_alive(): | |
return "No relevant information found", 0.0 # Timeout case | |
return tuple(result) |