import os import logging from typing import List, Tuple import torch import gradio as gr from sentence_transformers import SentenceTransformer from langchain_community.vectorstores import FAISS from langchain.embeddings.base import Embeddings from gradio_client import Client from tqdm import tqdm # Configuration QWEN_API_URL = os.getenv("QWEN_API_URL", "https://huggingface.co./spaces/Qwen/Qwen2.5-Max-Demo") # Ensure this URL points to the correct Gradio Space API endpoint. CHUNK_SIZE = 800 TOP_K_RESULTS = 150 SIMILARITY_THRESHOLD = 0.4 PASSWORD_HASH = os.getenv("PASSWORD_HASH", "abc12345") # Use an environment variable for security BASE_SYSTEM_PROMPT = """ Répondez en français selon ces règles : 1. Utilisez EXCLUSIVEMENT le contexte fourni. 2. Structurez la réponse en : - Définition principale. - Caractéristiques clés (3 points maximum). - Relations avec d'autres concepts. 3. Si aucune information pertinente, indiquez-le clairement. Contexte : {context} """ # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler("mtc_chat.log"), logging.StreamHandler() ] ) class LocalEmbeddings(Embeddings): """Local sentence-transformers embeddings""" def __init__(self, model): self.model = model def embed_documents(self, texts: List[str]) -> List[List[float]]: embeddings = [] for text in tqdm(texts, desc="Creating embeddings"): embeddings.append(self.model.encode(text).tolist()) return embeddings def embed_query(self, text: str) -> List[float]: return self.model.encode(text).tolist() def split_text_into_chunks(text: str) -> List[str]: """Split text into chunks with overlap and sentence preservation""" chunks = [] start = 0 text_length = len(text) while start < text_length: end = min(start + CHUNK_SIZE, text_length) chunk = text[start:end] # Find last complete punctuation last_punct = max( chunk.rfind('.'), chunk.rfind('!'), chunk.rfind('?'), chunk.rfind('\n\n') ) if last_punct != -1 and (end - start) > CHUNK_SIZE // 2: end = start + last_punct + 1 chunks.append(text[start:end].strip()) start = end if end > start else start + CHUNK_SIZE return chunks def create_new_database(file_content: str, db_name: str, password: str, progress=gr.Progress()) -> Tuple[str, List[str]]: """Create a new FAISS database from uploaded file""" if password != PASSWORD_HASH: return "Incorrect password. Database creation failed.", [] if not file_content.strip(): return "Uploaded file is empty. Database creation failed.", [] if not db_name.isalnum(): return "Database name must be alphanumeric. Database creation failed.", [] try: faiss_file = f"{db_name}-index.faiss" pkl_file = f"{db_name}-index.pkl" # Check if the database already exists if os.path.exists(faiss_file) or os.path.exists(pkl_file): return f"Database '{db_name}' already exists.", [] # Initialize embeddings and split text into chunks chunks = split_text_into_chunks(file_content) if not chunks: return "No valid chunks generated. Database creation failed.", [] logging.info(f"Creating {len(chunks)} chunks...") # Create embeddings with progress tracking embeddings_list = [embeddings.embed_query(chunk) for chunk in tqdm(chunks)] # Create FAISS database vector_store = FAISS.from_embeddings( text_embeddings=list(zip(chunks, embeddings_list)), embedding=embeddings ) # Save FAISS database locally vector_store.save_local(".") db_list = [os.path.splitext(f)[0].replace("-index", "") for f in os.listdir(".") if f.endswith(".faiss")] return f"Database '{db_name}' created successfully.", db_list except Exception as e: logging.error(f"Database creation failed: {str(e)}") return f"Error creating database: {str(e)}", [] def generate_response(user_input: str, db_name: str) -> str: """Generate response using Qwen2.5-Max Demo API""" try: if not db_name: return "Please select a database to chat with." faiss_file = f"{db_name}-index.faiss" pkl_file = f"{db_name}-index.pkl" if not os.path.exists(faiss_file) or not os.path.exists(pkl_file): return f"Database '{db_name}' does not exist." vector_store = FAISS.load_local(".", embeddings) docs_scores = vector_store.similarity_search_with_score(user_input, k=TOP_K_RESULTS * 3) filtered_docs = [(doc, score) for doc, score in docs_scores if score < SIMILARITY_THRESHOLD] filtered_docs.sort(key=lambda x: x[1]) if not filtered_docs: return "Aucune correspondance trouvée. Essayez des termes plus spécifiques." best_docs = [doc for doc, _ in filtered_docs[:TOP_K_RESULTS]] context = "\n".join(f"=== Source {i+1} ===\n{doc.page_content}\n" for i, doc in enumerate(best_docs)) client = Client(QWEN_API_URL) response = client.predict( query=user_input, history=[], system=BASE_SYSTEM_PROMPT.format(context=context), api_name="/model_chat" ) if isinstance(response, tuple) and len(response) >= 2: chat_history = response[1] if chat_history and len(chat_history[-1]) >= 2: return chat_history[-1][1] return "Réponse indisponible - Veuillez reformuler votre question." except Exception as e: logging.error(f"Error generating response: {str(e)}") return "Erreur de génération." # Initialize models and Gradio app device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = SentenceTransformer("cnmoro/snowflake-arctic-embed-m-v2.0-cpu", device=device) embeddings = LocalEmbeddings(model) with gr.Blocks() as app: gr.Markdown("# Knowledge Assistant") with gr.Tab("Create Database"): # Database creation UI setup if __name__ == "__main__": app.launch(server_name="0.0.0.0", server_port=7860)