File size: 6,501 Bytes
1301284
 
c26c573
1301284
 
 
 
 
bd1b05d
1301284
 
 
bd1b05d
1301284
 
 
bd1b05d
1301284
 
 
 
30c0b2f
1301284
30c0b2f
 
 
 
1301284
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30c0b2f
1301284
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9896ffc
1301284
 
 
 
 
 
 
c26c573
1301284
9896ffc
c26c573
1301284
 
c26c573
1301284
 
c26c573
1301284
 
79b023d
 
c26c573
79b023d
 
c26c573
b965bb0
30c0b2f
1301284
 
c26c573
1301284
 
d59a8ff
 
bd1b05d
d59a8ff
 
 
 
 
 
d8f5f8c
30c0b2f
 
 
c26c573
30c0b2f
c26c573
5c262f5
1301284
 
c26c573
1301284
c26c573
bd1b05d
1301284
439d179
 
 
79b023d
 
 
 
1301284
 
bd1b05d
1301284
bd1b05d
1301284
bd1b05d
30c0b2f
1301284
 
 
 
 
 
 
bd1b05d
1301284
bd1b05d
30c0b2f
1301284
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd1b05d
 
1301284
bd1b05d
1301284
bd1b05d
1301284
 
 
bd1b05d
 
1301284
bd1b05d
 
ac98d2b
c4e43d3
bd1b05d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
import os
import logging
from typing import List, Tuple
import torch
import gradio as gr
from sentence_transformers import SentenceTransformer
from langchain_community.vectorstores import FAISS
from langchain.embeddings.base import Embeddings
from gradio_client import Client
from tqdm import tqdm

# Configuration
QWEN_API_URL = os.getenv("QWEN_API_URL", "https://huggingface.co./spaces/Qwen/Qwen2.5-Max-Demo")  # Ensure this URL points to the correct Gradio Space API endpoint.
CHUNK_SIZE = 800
TOP_K_RESULTS = 150
SIMILARITY_THRESHOLD = 0.4
PASSWORD_HASH = os.getenv("PASSWORD_HASH", "abc12345")  # Use an environment variable for security

BASE_SYSTEM_PROMPT = """
Répondez en français selon ces règles :

1. Utilisez EXCLUSIVEMENT le contexte fourni.
2. Structurez la réponse en :
   - Définition principale.
   - Caractéristiques clés (3 points maximum).
   - Relations avec d'autres concepts.
3. Si aucune information pertinente, indiquez-le clairement.

Contexte :
{context}
"""

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler("mtc_chat.log"),
        logging.StreamHandler()
    ]
)

class LocalEmbeddings(Embeddings):
    """Local sentence-transformers embeddings"""
    def __init__(self, model):
        self.model = model

    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        embeddings = []
        for text in tqdm(texts, desc="Creating embeddings"):
            embeddings.append(self.model.encode(text).tolist())
        return embeddings

    def embed_query(self, text: str) -> List[float]:
        return self.model.encode(text).tolist()

def split_text_into_chunks(text: str) -> List[str]:
    """Split text into chunks with overlap and sentence preservation"""
    chunks = []
    start = 0
    text_length = len(text)
    
    while start < text_length:
        end = min(start + CHUNK_SIZE, text_length)
        chunk = text[start:end]
        
        # Find last complete punctuation
        last_punct = max(
            chunk.rfind('.'),
            chunk.rfind('!'),
            chunk.rfind('?'),
            chunk.rfind('\n\n')
        )
        
        if last_punct != -1 and (end - start) > CHUNK_SIZE // 2:
            end = start + last_punct + 1
        
        chunks.append(text[start:end].strip())
        start = end if end > start else start + CHUNK_SIZE
    
    return chunks

def create_new_database(file_content: str, db_name: str, password: str, progress=gr.Progress()) -> Tuple[str, List[str]]:
    """Create a new FAISS database from uploaded file"""
    if password != PASSWORD_HASH:
        return "Incorrect password. Database creation failed.", []

    if not file_content.strip():
        return "Uploaded file is empty. Database creation failed.", []

    if not db_name.isalnum():
        return "Database name must be alphanumeric. Database creation failed.", []

    try:
        faiss_file = f"{db_name}-index.faiss"
        pkl_file = f"{db_name}-index.pkl"

        # Check if the database already exists
        if os.path.exists(faiss_file) or os.path.exists(pkl_file):
            return f"Database '{db_name}' already exists.", []

        # Initialize embeddings and split text into chunks
        chunks = split_text_into_chunks(file_content)
        if not chunks:
            return "No valid chunks generated. Database creation failed.", []

        logging.info(f"Creating {len(chunks)} chunks...")
        
        # Create embeddings with progress tracking
        embeddings_list = [embeddings.embed_query(chunk) for chunk in tqdm(chunks)]
        
        # Create FAISS database
        vector_store = FAISS.from_embeddings(
            text_embeddings=list(zip(chunks, embeddings_list)),
            embedding=embeddings
        )

        # Save FAISS database locally
        vector_store.save_local(".")
        
        db_list = [os.path.splitext(f)[0].replace("-index", "") for f in os.listdir(".") if f.endswith(".faiss")]
        
        return f"Database '{db_name}' created successfully.", db_list

    except Exception as e:
        logging.error(f"Database creation failed: {str(e)}")
        return f"Error creating database: {str(e)}", []

def generate_response(user_input: str, db_name: str) -> str:
    """Generate response using Qwen2.5-Max Demo API"""
    try:
        if not db_name:
            return "Please select a database to chat with."

        faiss_file = f"{db_name}-index.faiss"
        pkl_file = f"{db_name}-index.pkl"

        if not os.path.exists(faiss_file) or not os.path.exists(pkl_file):
            return f"Database '{db_name}' does not exist."

        vector_store = FAISS.load_local(".", embeddings)

        docs_scores = vector_store.similarity_search_with_score(user_input, k=TOP_K_RESULTS * 3)
        
        filtered_docs = [(doc, score) for doc, score in docs_scores if score < SIMILARITY_THRESHOLD]
        
        filtered_docs.sort(key=lambda x: x[1])
        
        if not filtered_docs:
            return "Aucune correspondance trouvée. Essayez des termes plus spécifiques."
        
        best_docs = [doc for doc, _ in filtered_docs[:TOP_K_RESULTS]]
        
        context = "\n".join(f"=== Source {i+1} ===\n{doc.page_content}\n" for i, doc in enumerate(best_docs))
        
        client = Client(QWEN_API_URL)
        
        response = client.predict(
            query=user_input,
            history=[],
            system=BASE_SYSTEM_PROMPT.format(context=context),
            api_name="/model_chat"
        )

        if isinstance(response, tuple) and len(response) >= 2:
            chat_history = response[1]
            if chat_history and len(chat_history[-1]) >= 2:
                return chat_history[-1][1]
        
        return "Réponse indisponible - Veuillez reformuler votre question."

    except Exception as e:
        logging.error(f"Error generating response: {str(e)}")
        return "Erreur de génération."

# Initialize models and Gradio app
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SentenceTransformer("cnmoro/snowflake-arctic-embed-m-v2.0-cpu", device=device)
embeddings = LocalEmbeddings(model)

with gr.Blocks() as app:
    gr.Markdown("# Knowledge Assistant")
    
    with gr.Tab("Create Database"):
      # Database creation UI setup
    
if __name__ == "__main__":
    app.launch(server_name="0.0.0.0", server_port=7860)