import os
import re
import json
import time
from tqdm import tqdm
from pathlib import Path
import spaces
import gradio as gr
# Helper functions that don't use GPU
def safe_tokenize(text):
"""Pure regex tokenizer with no NLTK dependency"""
if not text:
return []
# Replace punctuation with spaces around them
text = re.sub(r'([.,!?;:()\[\]{}"\'/\\])', r' \1 ', text)
# Split on whitespace and filter empty strings
return [token for token in re.split(r'\s+', text.lower()) if token]
def detect_language(text):
"""Detect if text is primarily Arabic or English"""
# Simple heuristic: count Arabic characters
arabic_chars = re.findall(r'[\u0600-\u06FF]', text)
is_arabic = len(arabic_chars) > len(text) * 0.5
return "arabic" if is_arabic else "english"
# Comprehensive evaluation dataset
comprehensive_evaluation_data = [
# === Overview ===
{
"query": "ما هي رؤية السعودية 2030؟",
"reference": "رؤية السعودية 2030 هي خطة استراتيجية تهدف إلى تنويع الاقتصاد السعودي وتقليل الاعتماد على النفط مع تطوير قطاعات مختلفة مثل الصحة والتعليم والسياحة.",
"category": "overview",
"language": "arabic"
},
{
"query": "What is Saudi Vision 2030?",
"reference": "Saudi Vision 2030 is a strategic framework aiming to diversify Saudi Arabia's economy and reduce dependence on oil, while developing sectors like health, education, and tourism.",
"category": "overview",
"language": "english"
},
# === Economic Goals ===
{
"query": "ما هي الأهداف الاقتصادية لرؤية 2030؟",
"reference": "تشمل الأهداف الاقتصادية زيادة مساهمة القطاع الخاص إلى 65%، وزيادة الصادرات غير النفطية إلى 50% من الناتج المحلي غير النفطي، وخفض البطالة إلى 7%.",
"category": "economic",
"language": "arabic"
},
{
"query": "What are the economic goals of Vision 2030?",
"reference": "The economic goals of Vision 2030 include increasing private sector contribution from 40% to 65% of GDP, raising non-oil exports from 16% to 50%, reducing unemployment from 11.6% to 7%.",
"category": "economic",
"language": "english"
},
# === Social Goals ===
{
"query": "كيف تعزز رؤية 2030 الإرث الثقافي السعودي؟",
"reference": "تتضمن رؤية 2030 الحفاظ على الهوية الوطنية، تسجيل مواقع أثرية في اليونسكو، وتعزيز الفعاليات الثقافية.",
"category": "social",
"language": "arabic"
},
{
"query": "How does Vision 2030 aim to improve quality of life?",
"reference": "Vision 2030 plans to enhance quality of life by expanding sports facilities, promoting cultural activities, and boosting tourism and entertainment sectors.",
"category": "social",
"language": "english"
}
]
# RAG Service class
class Vision2030Service:
def __init__(self):
self.initialized = False
self.model = None
self.tokenizer = None
self.vector_store = None
self.conversation_history = []
@spaces.GPU
def initialize(self):
"""Initialize the system - ALL GPU operations must happen here"""
if self.initialized:
return True
try:
# Import all GPU-dependent libraries only inside this function
import torch
import PyPDF2
from transformers import AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain.schema import Document
from langchain.embeddings import HuggingFaceEmbeddings
# Define paths for PDF files
pdf_files = ["saudi_vision203.pdf", "saudi_vision2030_ar.pdf"]
# Process PDFs and create vector store
vector_store_dir = "vector_stores"
os.makedirs(vector_store_dir, exist_ok=True)
if os.path.exists(os.path.join(vector_store_dir, "index.faiss")):
print("Loading existing vector store...")
embedding_function = HuggingFaceEmbeddings(
model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
)
# Important: Add allow_dangerous_deserialization=True to fix the pickle error
self.vector_store = FAISS.load_local(
vector_store_dir,
embedding_function,
allow_dangerous_deserialization=True # Add this parameter
)
else:
print("Creating new vector store...")
# Process PDFs
documents = []
for pdf_path in pdf_files:
if not os.path.exists(pdf_path):
print(f"Warning: {pdf_path} does not exist")
continue
print(f"Processing {pdf_path}...")
text = ""
with open(pdf_path, 'rb') as file:
reader = PyPDF2.PdfReader(file)
for page in reader.pages:
page_text = page.extract_text()
if page_text:
text += page_text + "\n\n"
if text.strip():
doc = Document(
page_content=text,
metadata={"source": pdf_path, "filename": os.path.basename(pdf_path)}
)
documents.append(doc)
if not documents:
raise ValueError("No documents were processed successfully.")
# Split into chunks
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=50,
separators=["\n\n", "\n", ".", "!", "?", ",", " ", ""]
)
chunks = []
for doc in documents:
doc_chunks = text_splitter.split_text(doc.page_content)
chunks.extend([
Document(page_content=chunk, metadata=doc.metadata)
for chunk in doc_chunks
])
# Create vector store
embedding_function = HuggingFaceEmbeddings(
model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
)
self.vector_store = FAISS.from_documents(chunks, embedding_function)
self.vector_store.save_local(vector_store_dir)
# Load model
model_name = "ALLaM-AI/ALLaM-7B-Instruct-preview"
self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True,
use_fast=False
)
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map="auto",
)
self.initialized = True
return True
except Exception as e:
import traceback
print(f"Initialization error: {e}")
print(traceback.format_exc())
return False
@spaces.GPU
def retrieve_context(self, query, top_k=5):
"""Retrieve contexts from vector store"""
# Import must be inside the function to avoid CUDA init in main process
if not self.initialized:
return []
try:
results = self.vector_store.similarity_search_with_score(query, k=top_k)
contexts = []
for doc, score in results:
contexts.append({
"content": doc.page_content,
"source": doc.metadata.get("source", "Unknown"),
"relevance_score": score
})
return contexts
except Exception as e:
print(f"Error retrieving context: {e}")
return []
@spaces.GPU
def generate_response(self, query, contexts, language="auto"):
"""Generate response using the model"""
# Import must be inside the function to avoid CUDA init in main process
import torch
if not self.initialized or self.model is None or self.tokenizer is None:
return "I'm still initializing. Please try again in a moment."
try:
# Auto-detect language if not specified
if language == "auto":
language = detect_language(query)
# Format the prompt based on language
if language == "arabic":
instruction = (
"أنت مساعد افتراضي يهتم برؤية السعودية 2030. استخدم المعلومات التالية للإجابة على السؤال. "
"إذا لم تعرف الإجابة، فقل بأمانة إنك لا تعرف."
)
else: # english
instruction = (
"You are a virtual assistant for Saudi Vision 2030. Use the following information to answer the question. "
"If you don't know the answer, honestly say you don't know."
)
# Combine retrieved contexts
context_text = "\n\n".join([f"Document: {ctx['content']}" for ctx in contexts])
# Format the prompt for ALLaM instruction format
prompt = f"""[INST] {instruction}
Context:
{context_text}
Question: {query} [/INST]"""
# Generate response
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
outputs = self.model.generate(
inputs.input_ids,
attention_mask=inputs.attention_mask,
max_new_tokens=512,
temperature=0.7,
top_p=0.9,
do_sample=True,
repetition_penalty=1.1
)
# Decode the response
full_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract just the answer part (after the instruction)
response = full_output.split("[/INST]")[-1].strip()
# If response is empty for some reason, return the full output
if not response:
response = full_output
return response
except Exception as e:
import traceback
print(f"Error generating response: {e}")
print(traceback.format_exc())
return f"Sorry, I encountered an error while generating a response."
@spaces.GPU
def answer_question(self, query):
"""Process a user query and return a response with sources"""
if not self.initialized:
if not self.initialize():
return "System initialization failed. Please check the logs.", []
try:
# Add user query to conversation history
self.conversation_history.append({"role": "user", "content": query})
# Get the full conversation context
conversation_context = "\n".join([
f"{'User' if msg['role'] == 'user' else 'Assistant'}: {msg['content']}"
for msg in self.conversation_history[-6:] # Keep last 3 turns
])
# Enhance query with conversation context
enhanced_query = f"{conversation_context}\n{query}"
# Retrieve relevant contexts
contexts = self.retrieve_context(enhanced_query, top_k=5)
# Generate response
response = self.generate_response(query, contexts)
# Add response to conversation history
self.conversation_history.append({"role": "assistant", "content": response})
# Get sources
sources = [ctx.get("source", "Unknown") for ctx in contexts]
unique_sources = list(set(sources))
return response, unique_sources
except Exception as e:
import traceback
print(f"Error answering question: {e}")
print(traceback.format_exc())
return f"Sorry, I encountered an error: {str(e)}", []
def reset_conversation(self):
"""Reset the conversation history"""
self.conversation_history = []
return "Conversation has been reset."
def main():
# Create the Vision 2030 service
service = Vision2030Service()
# Define theme and styling
theme = gr.themes.Soft(
primary_hue="emerald",
secondary_hue="teal",
).set(
body_background_fill="linear-gradient(to right, #f0f9ff, #e6f7ff)",
button_primary_background_fill="linear-gradient(90deg, #1e9e5a, #1d8753)",
button_primary_background_fill_hover="linear-gradient(90deg, #1d8753, #176f44)",
button_primary_text_color="white",
button_secondary_background_fill="#f0f0f0",
button_secondary_background_fill_hover="#e0e0e0",
block_title_text_weight="600",
block_border_width="2px",
block_shadow="0px 4px 6px rgba(0, 0, 0, 0.1)",
background_fill_primary="#ffffff",
)
# Build the Gradio interface with enhanced styling
with gr.Blocks(title="Vision 2030 Assistant", theme=theme, css="""
.language-toggle { margin-bottom: 20px; }
.container { border-radius: 10px; padding: 20px; box-shadow: 0 4px 6px rgba(0,0,0,0.1); }
.header-img { margin-bottom: 10px; border-radius: 10px; }
.highlight { background-color: rgba(46, 175, 125, 0.1); padding: 15px; border-radius: 8px; margin: 10px 0; }
.footer { text-align: center; margin-top: 30px; color: #666; font-size: 0.9em; }
.loading-spinner { display: inline-block; width: 20px; height: 20px; margin-right: 10px; }
.status-indicator { display: inline-flex; align-items: center; padding: 8px; border-radius: 4px; }
.status-indicator.success { background-color: rgba(46, 175, 125, 0.2); }
.status-indicator.warning { background-color: rgba(255, 190, 0, 0.2); }
.status-indicator.error { background-color: rgba(255, 76, 76, 0.2); }
.header { display: flex; justify-content: space-between; align-items: center; }
.lang-btn { min-width: 100px; }
.chat-input { background-color: white; border-radius: 8px; border: 1px solid #ddd; }
.info-box { background-color: #f8f9fa; padding: 10px; border-radius: 8px; margin-top: 10px; }
/* Style for sample question buttons */
.gradio-button.secondary {
margin-bottom: 8px;
text-align: left;
background-color: #f0f9ff;
transition: all 0.3s ease;
display: block;
width: 100%;
padding: 8px 12px;
}
.gradio-button.secondary:hover {
background-color: #e0f2fe;
transform: translateX(3px);
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
}
""") as demo:
# Header with stylized title (no external images)
with gr.Row():
with gr.Column():
with gr.Row():
# Add local logo image on the left
with gr.Column(scale=1, min_width=100):
gr.Image("logo.png", show_label=False, height=80)
# Title and tagline on the right
with gr.Column(scale=4):
gr.Markdown("""
# Vision 2030 Assistant
### Your interactive guide to Saudi Arabia's national transformation program
""")
# Language toggle in the header with better styling
with gr.Row(elem_classes=["language-toggle"]):
with gr.Column(scale=1):
language_toggle = gr.Radio(
choices=["English", "العربية (Arabic)", "Auto-detect"],
value="Auto-detect",
label="Interface Language",
info="Choose your preferred language",
elem_classes=["lang-btn"]
)
# Main interface with tabs
with gr.Tabs() as tabs:
# Chat Tab with enhanced design
with gr.TabItem("💬 Chat", id="chat"):
with gr.Row():
with gr.Column(scale=2):
chatbot = gr.Chatbot(
height=450,
bubble_full_width=False,
show_label=False
)
with gr.Row():
msg = gr.Textbox(
label="",
placeholder="Ask a question about Saudi Vision 2030...",
show_label=False,
elem_classes=["chat-input"],
scale=9
)
submit_btn = gr.Button("Send", variant="primary", scale=1)
with gr.Row():
clear = gr.Button("Clear History", variant="secondary")
thinking_indicator = gr.HTML(
value='
This section displays real-time progress on economic targets of Vision 2030.