Spaces:
Runtime error
Runtime error
import os | |
import logging | |
from typing import List, Tuple | |
import torch | |
import gradio as gr | |
from sentence_transformers import SentenceTransformer | |
from langchain_community.vectorstores import FAISS | |
from langchain.embeddings.base import Embeddings | |
from gradio_client import Client | |
from tqdm import tqdm | |
# Configuration | |
QWEN_API_URL = os.getenv("QWEN_API_URL", "https://huggingface.co./spaces/Qwen/Qwen2.5-Max-Demo") # Ensure this URL points to the correct Gradio Space API endpoint. | |
CHUNK_SIZE = 800 | |
TOP_K_RESULTS = 150 | |
SIMILARITY_THRESHOLD = 0.4 | |
PASSWORD_HASH = os.getenv("PASSWORD_HASH", "abc12345") # Use an environment variable for security | |
BASE_SYSTEM_PROMPT = """ | |
Répondez en français selon ces règles : | |
1. Utilisez EXCLUSIVEMENT le contexte fourni. | |
2. Structurez la réponse en : | |
- Définition principale. | |
- Caractéristiques clés (3 points maximum). | |
- Relations avec d'autres concepts. | |
3. Si aucune information pertinente, indiquez-le clairement. | |
Contexte : | |
{context} | |
""" | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s', | |
handlers=[ | |
logging.FileHandler("mtc_chat.log"), | |
logging.StreamHandler() | |
] | |
) | |
class LocalEmbeddings(Embeddings): | |
"""Local sentence-transformers embeddings""" | |
def __init__(self, model): | |
self.model = model | |
def embed_documents(self, texts: List[str]) -> List[List[float]]: | |
embeddings = [] | |
for text in tqdm(texts, desc="Creating embeddings"): | |
embeddings.append(self.model.encode(text).tolist()) | |
return embeddings | |
def embed_query(self, text: str) -> List[float]: | |
return self.model.encode(text).tolist() | |
def split_text_into_chunks(text: str) -> List[str]: | |
"""Split text into chunks with overlap and sentence preservation""" | |
chunks = [] | |
start = 0 | |
text_length = len(text) | |
while start < text_length: | |
end = min(start + CHUNK_SIZE, text_length) | |
chunk = text[start:end] | |
# Find last complete punctuation | |
last_punct = max( | |
chunk.rfind('.'), | |
chunk.rfind('!'), | |
chunk.rfind('?'), | |
chunk.rfind('\n\n') | |
) | |
if last_punct != -1 and (end - start) > CHUNK_SIZE // 2: | |
end = start + last_punct + 1 | |
chunks.append(text[start:end].strip()) | |
start = end if end > start else start + CHUNK_SIZE | |
return chunks | |
def create_new_database(file_content: str, db_name: str, password: str, progress=gr.Progress()) -> Tuple[str, List[str]]: | |
"""Create a new FAISS database from uploaded file""" | |
if password != PASSWORD_HASH: | |
return "Incorrect password. Database creation failed.", [] | |
if not file_content.strip(): | |
return "Uploaded file is empty. Database creation failed.", [] | |
if not db_name.isalnum(): | |
return "Database name must be alphanumeric. Database creation failed.", [] | |
try: | |
faiss_file = f"{db_name}-index.faiss" | |
pkl_file = f"{db_name}-index.pkl" | |
# Check if the database already exists | |
if os.path.exists(faiss_file) or os.path.exists(pkl_file): | |
return f"Database '{db_name}' already exists.", [] | |
# Initialize embeddings and split text into chunks | |
chunks = split_text_into_chunks(file_content) | |
if not chunks: | |
return "No valid chunks generated. Database creation failed.", [] | |
logging.info(f"Creating {len(chunks)} chunks...") | |
# Create embeddings with progress tracking | |
embeddings_list = [embeddings.embed_query(chunk) for chunk in tqdm(chunks)] | |
# Create FAISS database | |
vector_store = FAISS.from_embeddings( | |
text_embeddings=list(zip(chunks, embeddings_list)), | |
embedding=embeddings | |
) | |
# Save FAISS database locally | |
vector_store.save_local(".") | |
db_list = [os.path.splitext(f)[0].replace("-index", "") for f in os.listdir(".") if f.endswith(".faiss")] | |
return f"Database '{db_name}' created successfully.", db_list | |
except Exception as e: | |
logging.error(f"Database creation failed: {str(e)}") | |
return f"Error creating database: {str(e)}", [] | |
def generate_response(user_input: str, db_name: str) -> str: | |
"""Generate response using Qwen2.5-Max Demo API""" | |
try: | |
if not db_name: | |
return "Please select a database to chat with." | |
faiss_file = f"{db_name}-index.faiss" | |
pkl_file = f"{db_name}-index.pkl" | |
if not os.path.exists(faiss_file) or not os.path.exists(pkl_file): | |
return f"Database '{db_name}' does not exist." | |
vector_store = FAISS.load_local(".", embeddings) | |
docs_scores = vector_store.similarity_search_with_score(user_input, k=TOP_K_RESULTS * 3) | |
filtered_docs = [(doc, score) for doc, score in docs_scores if score < SIMILARITY_THRESHOLD] | |
filtered_docs.sort(key=lambda x: x[1]) | |
if not filtered_docs: | |
return "Aucune correspondance trouvée. Essayez des termes plus spécifiques." | |
best_docs = [doc for doc, _ in filtered_docs[:TOP_K_RESULTS]] | |
context = "\n".join(f"=== Source {i+1} ===\n{doc.page_content}\n" for i, doc in enumerate(best_docs)) | |
client = Client(QWEN_API_URL) | |
response = client.predict( | |
query=user_input, | |
history=[], | |
system=BASE_SYSTEM_PROMPT.format(context=context), | |
api_name="/model_chat" | |
) | |
if isinstance(response, tuple) and len(response) >= 2: | |
chat_history = response[1] | |
if chat_history and len(chat_history[-1]) >= 2: | |
return chat_history[-1][1] | |
return "Réponse indisponible - Veuillez reformuler votre question." | |
except Exception as e: | |
logging.error(f"Error generating response: {str(e)}") | |
return "Erreur de génération." | |
# Initialize models and Gradio app | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = SentenceTransformer("cnmoro/snowflake-arctic-embed-m-v2.0-cpu", device=device) | |
embeddings = LocalEmbeddings(model) | |
with gr.Blocks() as app: | |
gr.Markdown("# Knowledge Assistant") | |
with gr.Tab("Create Database"): | |
# Database creation UI setup | |
if __name__ == "__main__": | |
app.launch(server_name="0.0.0.0", server_port=7860) | |