Local-Solution / app.py
localsavageai's picture
Upload app.py
bd1b05d verified
import os
import logging
from typing import List, Tuple
import torch
import gradio as gr
from sentence_transformers import SentenceTransformer
from langchain_community.vectorstores import FAISS
from langchain.embeddings.base import Embeddings
from gradio_client import Client
from tqdm import tqdm
# Configuration
QWEN_API_URL = os.getenv("QWEN_API_URL", "https://huggingface.co./spaces/Qwen/Qwen2.5-Max-Demo") # Ensure this URL points to the correct Gradio Space API endpoint.
CHUNK_SIZE = 800
TOP_K_RESULTS = 150
SIMILARITY_THRESHOLD = 0.4
PASSWORD_HASH = os.getenv("PASSWORD_HASH", "abc12345") # Use an environment variable for security
BASE_SYSTEM_PROMPT = """
Répondez en français selon ces règles :
1. Utilisez EXCLUSIVEMENT le contexte fourni.
2. Structurez la réponse en :
- Définition principale.
- Caractéristiques clés (3 points maximum).
- Relations avec d'autres concepts.
3. Si aucune information pertinente, indiquez-le clairement.
Contexte :
{context}
"""
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler("mtc_chat.log"),
logging.StreamHandler()
]
)
class LocalEmbeddings(Embeddings):
"""Local sentence-transformers embeddings"""
def __init__(self, model):
self.model = model
def embed_documents(self, texts: List[str]) -> List[List[float]]:
embeddings = []
for text in tqdm(texts, desc="Creating embeddings"):
embeddings.append(self.model.encode(text).tolist())
return embeddings
def embed_query(self, text: str) -> List[float]:
return self.model.encode(text).tolist()
def split_text_into_chunks(text: str) -> List[str]:
"""Split text into chunks with overlap and sentence preservation"""
chunks = []
start = 0
text_length = len(text)
while start < text_length:
end = min(start + CHUNK_SIZE, text_length)
chunk = text[start:end]
# Find last complete punctuation
last_punct = max(
chunk.rfind('.'),
chunk.rfind('!'),
chunk.rfind('?'),
chunk.rfind('\n\n')
)
if last_punct != -1 and (end - start) > CHUNK_SIZE // 2:
end = start + last_punct + 1
chunks.append(text[start:end].strip())
start = end if end > start else start + CHUNK_SIZE
return chunks
def create_new_database(file_content: str, db_name: str, password: str, progress=gr.Progress()) -> Tuple[str, List[str]]:
"""Create a new FAISS database from uploaded file"""
if password != PASSWORD_HASH:
return "Incorrect password. Database creation failed.", []
if not file_content.strip():
return "Uploaded file is empty. Database creation failed.", []
if not db_name.isalnum():
return "Database name must be alphanumeric. Database creation failed.", []
try:
faiss_file = f"{db_name}-index.faiss"
pkl_file = f"{db_name}-index.pkl"
# Check if the database already exists
if os.path.exists(faiss_file) or os.path.exists(pkl_file):
return f"Database '{db_name}' already exists.", []
# Initialize embeddings and split text into chunks
chunks = split_text_into_chunks(file_content)
if not chunks:
return "No valid chunks generated. Database creation failed.", []
logging.info(f"Creating {len(chunks)} chunks...")
# Create embeddings with progress tracking
embeddings_list = [embeddings.embed_query(chunk) for chunk in tqdm(chunks)]
# Create FAISS database
vector_store = FAISS.from_embeddings(
text_embeddings=list(zip(chunks, embeddings_list)),
embedding=embeddings
)
# Save FAISS database locally
vector_store.save_local(".")
db_list = [os.path.splitext(f)[0].replace("-index", "") for f in os.listdir(".") if f.endswith(".faiss")]
return f"Database '{db_name}' created successfully.", db_list
except Exception as e:
logging.error(f"Database creation failed: {str(e)}")
return f"Error creating database: {str(e)}", []
def generate_response(user_input: str, db_name: str) -> str:
"""Generate response using Qwen2.5-Max Demo API"""
try:
if not db_name:
return "Please select a database to chat with."
faiss_file = f"{db_name}-index.faiss"
pkl_file = f"{db_name}-index.pkl"
if not os.path.exists(faiss_file) or not os.path.exists(pkl_file):
return f"Database '{db_name}' does not exist."
vector_store = FAISS.load_local(".", embeddings)
docs_scores = vector_store.similarity_search_with_score(user_input, k=TOP_K_RESULTS * 3)
filtered_docs = [(doc, score) for doc, score in docs_scores if score < SIMILARITY_THRESHOLD]
filtered_docs.sort(key=lambda x: x[1])
if not filtered_docs:
return "Aucune correspondance trouvée. Essayez des termes plus spécifiques."
best_docs = [doc for doc, _ in filtered_docs[:TOP_K_RESULTS]]
context = "\n".join(f"=== Source {i+1} ===\n{doc.page_content}\n" for i, doc in enumerate(best_docs))
client = Client(QWEN_API_URL)
response = client.predict(
query=user_input,
history=[],
system=BASE_SYSTEM_PROMPT.format(context=context),
api_name="/model_chat"
)
if isinstance(response, tuple) and len(response) >= 2:
chat_history = response[1]
if chat_history and len(chat_history[-1]) >= 2:
return chat_history[-1][1]
return "Réponse indisponible - Veuillez reformuler votre question."
except Exception as e:
logging.error(f"Error generating response: {str(e)}")
return "Erreur de génération."
# Initialize models and Gradio app
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SentenceTransformer("cnmoro/snowflake-arctic-embed-m-v2.0-cpu", device=device)
embeddings = LocalEmbeddings(model)
with gr.Blocks() as app:
gr.Markdown("# Knowledge Assistant")
with gr.Tab("Create Database"):
# Database creation UI setup
if __name__ == "__main__":
app.launch(server_name="0.0.0.0", server_port=7860)