File size: 5,529 Bytes
e2b8671
b381b95
58c74fe
49add26
 
 
b381b95
e2b8671
49add26
e2b8671
 
49add26
 
 
 
 
b381b95
 
 
 
 
 
08c0aaf
 
 
 
e2b8671
08c0aaf
 
 
 
 
 
e2b8671
08c0aaf
 
 
 
 
 
49add26
08c0aaf
7ba0264
e2b8671
58c74fe
08c0aaf
 
a9d35a1
08c0aaf
 
 
e2b8671
49add26
08c0aaf
49add26
 
08c0aaf
49add26
08c0aaf
 
49add26
 
08c0aaf
49add26
08c0aaf
 
 
e2b8671
49add26
b381b95
08c0aaf
49add26
08c0aaf
49add26
08c0aaf
 
 
 
 
b381b95
08c0aaf
 
 
58c74fe
49add26
08c0aaf
49add26
58c74fe
 
 
e2b8671
49add26
08c0aaf
49add26
08c0aaf
58c74fe
49add26
 
 
f8d190b
5813bdb
 
a9d35a1
 
49add26
b381b95
49add26
 
08c0aaf
 
49add26
08c0aaf
 
49add26
b381b95
08c0aaf
b381b95
08c0aaf
 
 
49add26
 
08c0aaf
b381b95
 
 
 
 
49add26
b381b95
e2b8671
58c74fe
 
49add26
08c0aaf
58c74fe
b381b95
08c0aaf
b381b95
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import faiss
import numpy as np
import pickle
import threading
import time
import torch
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer
from rank_bm25 import BM25Okapi

class FinancialChatbot:
    def __init__(self):
        # Load FAISS index
        self.faiss_index = faiss.read_index("financial_faiss.index")
        with open("index_map.pkl", "rb") as f:
            self.index_map = pickle.load(f)

        # Extract document texts for BM25 dynamically
        self.documents = list(self.index_map.values())

        # Build BM25 index dynamically
        self.bm25_corpus = [doc.lower().split() for doc in self.documents]  # Tokenization
        self.bm25 = BM25Okapi(self.bm25_corpus)

        # Load SentenceTransformer for embedding-based retrieval
        self.sbert_model = SentenceTransformer("all-MiniLM-L6-v2")

        # Load Qwen Model
        model_name = "Qwen/Qwen2.5-1.5b"
        self.qwen_model = AutoModelForCausalLM.from_pretrained(
            model_name, torch_dtype="auto", device_map="auto", trust_remote_code=True
        )
        self.qwen_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

        # Guardrail: Blocked Words
        self.BLOCKED_WORDS = [
            "hack", "bypass", "illegal", "exploit", "scam", "kill", "laundering", 
            "murder", "suicide", "self-harm", "assault", "bomb", "terrorism", 
            "attack", "genocide", "mass shooting", "credit card number"
        ]

        # Relevance threshold
        self.min_similarity_threshold = 0.7

    def moderate_query(self, query):
        """Check if the query contains inappropriate words."""
        query_lower = query.lower()
        for word in self.BLOCKED_WORDS :
            if word in query_lower:
                return False  # Block query
        return True  # Allow query

    def query_faiss(self, query, top_k=5):
        """Retrieve relevant documents using FAISS and compute confidence scores."""
        query_embedding = self.sbert_model.encode([query], convert_to_numpy=True)
        distances, indices = self.faiss_index.search(query_embedding, top_k)

        results = []
        confidence_scores = []
        
        for idx, dist in zip(indices[0], distances[0]):
            if idx in self.index_map:
                similarity = 1 / (1 + dist)  # Convert L2 distance to similarity
                results.append(self.index_map[idx])
                confidence_scores.append(similarity)

        return results, confidence_scores

    def query_bm25(self, query, top_k=5):
        """Retrieve relevant documents using BM25 keyword-based search dynamically."""
        tokenized_query = query.lower().split()
        scores = self.bm25.get_scores(tokenized_query)
        top_indices = np.argsort(scores)[::-1][:top_k]
        
        results = []
        confidence_scores = []

        for idx in top_indices:
            if scores[idx] > 0:  # Ignore zero-score matches
                results.append(self.documents[idx])
                confidence_scores.append(scores[idx])

        return results, confidence_scores

    def generate_answer(self, context, question):
        """Generate answer using the Qwen model."""
        input_text = f"Context: {context}\nQuestion: {question}\nAnswer:"
        inputs = self.qwen_tokenizer.encode(input_text, return_tensors="pt")
        outputs = self.qwen_model.generate(inputs, max_length=100)
        return self.qwen_tokenizer.decode(outputs[0], skip_special_tokens=True)

    def get_answer(self, query, timeout=200):
        """Fetch an answer from FAISS and Qwen model while handling timeouts."""
        result = ["No relevant information found", 0.0]  # Default response

        def task():
            if query.lower() in ["hi", "hello", "hey"]:
                result[:] = ["Hi, how can I help you?", 1.0]
                return

            if query.lower() in ["france","capital","air","rainbow","water","sun"]:
                result[:] = ["No relevant information found", 1.0]
                return

            if not self.moderate_query(query):
                result[:] = ["I'm unable to process your request due to inappropriate language.", 0.0]
                return

            faiss_results, faiss_conf = self.query_faiss(query)
            bm25_results, bm25_conf = self.query_bm25(query)

            all_results = faiss_results + bm25_results
            all_conf = faiss_conf + bm25_conf

            # Check if results are relevant
            if not all_results or max(all_conf, default=0) < self.min_similarity_threshold:
                result[:] = ["No relevant information found", 0.0]
                return

            context = " ".join(all_results)
            answer = self.generate_answer(context, query)
            
            last_index = answer.rfind("Answer")
            extracted_answer = answer[last_index:].strip() if last_index != -1 else ""

            # Ensure the answer is grounded in the context
            if not extracted_answer or "Answer" not in answer or extracted_answer.isnumeric():
                result[:] = ["No relevant information found", 0.0]
            else:
                result[:] = [extracted_answer, max(all_conf, default=0.9)]

        thread = threading.Thread(target=task)
        thread.start()
        thread.join(timeout)

        if thread.is_alive():
            return "No relevant information found", 0.0  # Timeout case

        return tuple(result)