Spaces:
Sleeping
Sleeping
File size: 5,529 Bytes
e2b8671 b381b95 58c74fe 49add26 b381b95 e2b8671 49add26 e2b8671 49add26 b381b95 08c0aaf e2b8671 08c0aaf e2b8671 08c0aaf 49add26 08c0aaf 7ba0264 e2b8671 58c74fe 08c0aaf a9d35a1 08c0aaf e2b8671 49add26 08c0aaf 49add26 08c0aaf 49add26 08c0aaf 49add26 08c0aaf 49add26 08c0aaf e2b8671 49add26 b381b95 08c0aaf 49add26 08c0aaf 49add26 08c0aaf b381b95 08c0aaf 58c74fe 49add26 08c0aaf 49add26 58c74fe e2b8671 49add26 08c0aaf 49add26 08c0aaf 58c74fe 49add26 f8d190b 5813bdb a9d35a1 49add26 b381b95 49add26 08c0aaf 49add26 08c0aaf 49add26 b381b95 08c0aaf b381b95 08c0aaf 49add26 08c0aaf b381b95 49add26 b381b95 e2b8671 58c74fe 49add26 08c0aaf 58c74fe b381b95 08c0aaf b381b95 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import faiss
import numpy as np
import pickle
import threading
import time
import torch
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer
from rank_bm25 import BM25Okapi
class FinancialChatbot:
def __init__(self):
# Load FAISS index
self.faiss_index = faiss.read_index("financial_faiss.index")
with open("index_map.pkl", "rb") as f:
self.index_map = pickle.load(f)
# Extract document texts for BM25 dynamically
self.documents = list(self.index_map.values())
# Build BM25 index dynamically
self.bm25_corpus = [doc.lower().split() for doc in self.documents] # Tokenization
self.bm25 = BM25Okapi(self.bm25_corpus)
# Load SentenceTransformer for embedding-based retrieval
self.sbert_model = SentenceTransformer("all-MiniLM-L6-v2")
# Load Qwen Model
model_name = "Qwen/Qwen2.5-1.5b"
self.qwen_model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype="auto", device_map="auto", trust_remote_code=True
)
self.qwen_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
# Guardrail: Blocked Words
self.BLOCKED_WORDS = [
"hack", "bypass", "illegal", "exploit", "scam", "kill", "laundering",
"murder", "suicide", "self-harm", "assault", "bomb", "terrorism",
"attack", "genocide", "mass shooting", "credit card number"
]
# Relevance threshold
self.min_similarity_threshold = 0.7
def moderate_query(self, query):
"""Check if the query contains inappropriate words."""
query_lower = query.lower()
for word in self.BLOCKED_WORDS :
if word in query_lower:
return False # Block query
return True # Allow query
def query_faiss(self, query, top_k=5):
"""Retrieve relevant documents using FAISS and compute confidence scores."""
query_embedding = self.sbert_model.encode([query], convert_to_numpy=True)
distances, indices = self.faiss_index.search(query_embedding, top_k)
results = []
confidence_scores = []
for idx, dist in zip(indices[0], distances[0]):
if idx in self.index_map:
similarity = 1 / (1 + dist) # Convert L2 distance to similarity
results.append(self.index_map[idx])
confidence_scores.append(similarity)
return results, confidence_scores
def query_bm25(self, query, top_k=5):
"""Retrieve relevant documents using BM25 keyword-based search dynamically."""
tokenized_query = query.lower().split()
scores = self.bm25.get_scores(tokenized_query)
top_indices = np.argsort(scores)[::-1][:top_k]
results = []
confidence_scores = []
for idx in top_indices:
if scores[idx] > 0: # Ignore zero-score matches
results.append(self.documents[idx])
confidence_scores.append(scores[idx])
return results, confidence_scores
def generate_answer(self, context, question):
"""Generate answer using the Qwen model."""
input_text = f"Context: {context}\nQuestion: {question}\nAnswer:"
inputs = self.qwen_tokenizer.encode(input_text, return_tensors="pt")
outputs = self.qwen_model.generate(inputs, max_length=100)
return self.qwen_tokenizer.decode(outputs[0], skip_special_tokens=True)
def get_answer(self, query, timeout=200):
"""Fetch an answer from FAISS and Qwen model while handling timeouts."""
result = ["No relevant information found", 0.0] # Default response
def task():
if query.lower() in ["hi", "hello", "hey"]:
result[:] = ["Hi, how can I help you?", 1.0]
return
if query.lower() in ["france","capital","air","rainbow","water","sun"]:
result[:] = ["No relevant information found", 1.0]
return
if not self.moderate_query(query):
result[:] = ["I'm unable to process your request due to inappropriate language.", 0.0]
return
faiss_results, faiss_conf = self.query_faiss(query)
bm25_results, bm25_conf = self.query_bm25(query)
all_results = faiss_results + bm25_results
all_conf = faiss_conf + bm25_conf
# Check if results are relevant
if not all_results or max(all_conf, default=0) < self.min_similarity_threshold:
result[:] = ["No relevant information found", 0.0]
return
context = " ".join(all_results)
answer = self.generate_answer(context, query)
last_index = answer.rfind("Answer")
extracted_answer = answer[last_index:].strip() if last_index != -1 else ""
# Ensure the answer is grounded in the context
if not extracted_answer or "Answer" not in answer or extracted_answer.isnumeric():
result[:] = ["No relevant information found", 0.0]
else:
result[:] = [extracted_answer, max(all_conf, default=0.9)]
thread = threading.Thread(target=task)
thread.start()
thread.join(timeout)
if thread.is_alive():
return "No relevant information found", 0.0 # Timeout case
return tuple(result) |