Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import re | |
import json | |
import time | |
from tqdm import tqdm | |
from pathlib import Path | |
import spaces | |
import gradio as gr | |
# Helper functions that don't use GPU | |
def safe_tokenize(text): | |
"""Pure regex tokenizer with no NLTK dependency""" | |
if not text: | |
return [] | |
# Replace punctuation with spaces around them | |
text = re.sub(r'([.,!?;:()\[\]{}"\'/\\])', r' \1 ', text) | |
# Split on whitespace and filter empty strings | |
return [token for token in re.split(r'\s+', text.lower()) if token] | |
def detect_language(text): | |
"""Detect if text is primarily Arabic or English""" | |
# Simple heuristic: count Arabic characters | |
arabic_chars = re.findall(r'[\u0600-\u06FF]', text) | |
is_arabic = len(arabic_chars) > len(text) * 0.5 | |
return "arabic" if is_arabic else "english" | |
# Comprehensive evaluation dataset | |
comprehensive_evaluation_data = [ | |
# === Overview === | |
{ | |
"query": "ما هي رؤية السعودية 2030؟", | |
"reference": "رؤية السعودية 2030 هي خطة استراتيجية تهدف إلى تنويع الاقتصاد السعودي وتقليل الاعتماد على النفط مع تطوير قطاعات مختلفة مثل الصحة والتعليم والسياحة.", | |
"category": "overview", | |
"language": "arabic" | |
}, | |
{ | |
"query": "What is Saudi Vision 2030?", | |
"reference": "Saudi Vision 2030 is a strategic framework aiming to diversify Saudi Arabia's economy and reduce dependence on oil, while developing sectors like health, education, and tourism.", | |
"category": "overview", | |
"language": "english" | |
}, | |
# === Economic Goals === | |
{ | |
"query": "ما هي الأهداف الاقتصادية لرؤية 2030؟", | |
"reference": "تشمل الأهداف الاقتصادية زيادة مساهمة القطاع الخاص إلى 65%، وزيادة الصادرات غير النفطية إلى 50% من الناتج المحلي غير النفطي، وخفض البطالة إلى 7%.", | |
"category": "economic", | |
"language": "arabic" | |
}, | |
{ | |
"query": "What are the economic goals of Vision 2030?", | |
"reference": "The economic goals of Vision 2030 include increasing private sector contribution from 40% to 65% of GDP, raising non-oil exports from 16% to 50%, reducing unemployment from 11.6% to 7%.", | |
"category": "economic", | |
"language": "english" | |
}, | |
# === Social Goals === | |
{ | |
"query": "كيف تعزز رؤية 2030 الإرث الثقافي السعودي؟", | |
"reference": "تتضمن رؤية 2030 الحفاظ على الهوية الوطنية، تسجيل مواقع أثرية في اليونسكو، وتعزيز الفعاليات الثقافية.", | |
"category": "social", | |
"language": "arabic" | |
}, | |
{ | |
"query": "How does Vision 2030 aim to improve quality of life?", | |
"reference": "Vision 2030 plans to enhance quality of life by expanding sports facilities, promoting cultural activities, and boosting tourism and entertainment sectors.", | |
"category": "social", | |
"language": "english" | |
} | |
] | |
# RAG Service class | |
class Vision2030Service: | |
def __init__(self): | |
self.initialized = False | |
self.model = None | |
self.tokenizer = None | |
self.vector_store = None | |
self.conversation_history = [] | |
def initialize(self): | |
"""Initialize the system - ALL GPU operations must happen here""" | |
if self.initialized: | |
return True | |
try: | |
# Import all GPU-dependent libraries only inside this function | |
import torch | |
import PyPDF2 | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from sentence_transformers import SentenceTransformer | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.vectorstores import FAISS | |
from langchain.schema import Document | |
from langchain.embeddings import HuggingFaceEmbeddings | |
# Define paths for PDF files | |
pdf_files = ["saudi_vision203.pdf", "saudi_vision2030_ar.pdf"] | |
# Process PDFs and create vector store | |
vector_store_dir = "vector_stores" | |
os.makedirs(vector_store_dir, exist_ok=True) | |
if os.path.exists(os.path.join(vector_store_dir, "index.faiss")): | |
print("Loading existing vector store...") | |
embedding_function = HuggingFaceEmbeddings( | |
model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2" | |
) | |
# Important: Add allow_dangerous_deserialization=True to fix the pickle error | |
self.vector_store = FAISS.load_local( | |
vector_store_dir, | |
embedding_function, | |
allow_dangerous_deserialization=True # Add this parameter | |
) | |
else: | |
print("Creating new vector store...") | |
# Process PDFs | |
documents = [] | |
for pdf_path in pdf_files: | |
if not os.path.exists(pdf_path): | |
print(f"Warning: {pdf_path} does not exist") | |
continue | |
print(f"Processing {pdf_path}...") | |
text = "" | |
with open(pdf_path, 'rb') as file: | |
reader = PyPDF2.PdfReader(file) | |
for page in reader.pages: | |
page_text = page.extract_text() | |
if page_text: | |
text += page_text + "\n\n" | |
if text.strip(): | |
doc = Document( | |
page_content=text, | |
metadata={"source": pdf_path, "filename": os.path.basename(pdf_path)} | |
) | |
documents.append(doc) | |
if not documents: | |
raise ValueError("No documents were processed successfully.") | |
# Split into chunks | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=500, | |
chunk_overlap=50, | |
separators=["\n\n", "\n", ".", "!", "?", ",", " ", ""] | |
) | |
chunks = [] | |
for doc in documents: | |
doc_chunks = text_splitter.split_text(doc.page_content) | |
chunks.extend([ | |
Document(page_content=chunk, metadata=doc.metadata) | |
for chunk in doc_chunks | |
]) | |
# Create vector store | |
embedding_function = HuggingFaceEmbeddings( | |
model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2" | |
) | |
self.vector_store = FAISS.from_documents(chunks, embedding_function) | |
self.vector_store.save_local(vector_store_dir) | |
# Load model | |
model_name = "ALLaM-AI/ALLaM-7B-Instruct-preview" | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
model_name, | |
trust_remote_code=True, | |
use_fast=False | |
) | |
self.model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.bfloat16, | |
trust_remote_code=True, | |
device_map="auto", | |
) | |
self.initialized = True | |
return True | |
except Exception as e: | |
import traceback | |
print(f"Initialization error: {e}") | |
print(traceback.format_exc()) | |
return False | |
def retrieve_context(self, query, top_k=5): | |
"""Retrieve contexts from vector store""" | |
# Import must be inside the function to avoid CUDA init in main process | |
if not self.initialized: | |
return [] | |
try: | |
results = self.vector_store.similarity_search_with_score(query, k=top_k) | |
contexts = [] | |
for doc, score in results: | |
contexts.append({ | |
"content": doc.page_content, | |
"source": doc.metadata.get("source", "Unknown"), | |
"relevance_score": score | |
}) | |
return contexts | |
except Exception as e: | |
print(f"Error retrieving context: {e}") | |
return [] | |
def generate_response(self, query, contexts, language="auto"): | |
"""Generate response using the model""" | |
# Import must be inside the function to avoid CUDA init in main process | |
import torch | |
if not self.initialized or self.model is None or self.tokenizer is None: | |
return "I'm still initializing. Please try again in a moment." | |
try: | |
# Auto-detect language if not specified | |
if language == "auto": | |
language = detect_language(query) | |
# Format the prompt based on language | |
if language == "arabic": | |
instruction = ( | |
"أنت مساعد افتراضي يهتم برؤية السعودية 2030. استخدم المعلومات التالية للإجابة على السؤال. " | |
"إذا لم تعرف الإجابة، فقل بأمانة إنك لا تعرف." | |
) | |
else: # english | |
instruction = ( | |
"You are a virtual assistant for Saudi Vision 2030. Use the following information to answer the question. " | |
"If you don't know the answer, honestly say you don't know." | |
) | |
# Combine retrieved contexts | |
context_text = "\n\n".join([f"Document: {ctx['content']}" for ctx in contexts]) | |
# Format the prompt for ALLaM instruction format | |
prompt = f"""<s>[INST] {instruction} | |
Context: | |
{context_text} | |
Question: {query} [/INST]</s>""" | |
# Generate response | |
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) | |
outputs = self.model.generate( | |
inputs.input_ids, | |
attention_mask=inputs.attention_mask, | |
max_new_tokens=512, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True, | |
repetition_penalty=1.1 | |
) | |
# Decode the response | |
full_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Extract just the answer part (after the instruction) | |
response = full_output.split("[/INST]")[-1].strip() | |
# If response is empty for some reason, return the full output | |
if not response: | |
response = full_output | |
return response | |
except Exception as e: | |
import traceback | |
print(f"Error generating response: {e}") | |
print(traceback.format_exc()) | |
return f"Sorry, I encountered an error while generating a response." | |
def answer_question(self, query): | |
"""Process a user query and return a response with sources""" | |
if not self.initialized: | |
if not self.initialize(): | |
return "System initialization failed. Please check the logs.", [] | |
try: | |
# Add user query to conversation history | |
self.conversation_history.append({"role": "user", "content": query}) | |
# Get the full conversation context | |
conversation_context = "\n".join([ | |
f"{'User' if msg['role'] == 'user' else 'Assistant'}: {msg['content']}" | |
for msg in self.conversation_history[-6:] # Keep last 3 turns | |
]) | |
# Enhance query with conversation context | |
enhanced_query = f"{conversation_context}\n{query}" | |
# Retrieve relevant contexts | |
contexts = self.retrieve_context(enhanced_query, top_k=5) | |
# Generate response | |
response = self.generate_response(query, contexts) | |
# Add response to conversation history | |
self.conversation_history.append({"role": "assistant", "content": response}) | |
# Get sources | |
sources = [ctx.get("source", "Unknown") for ctx in contexts] | |
unique_sources = list(set(sources)) | |
return response, unique_sources | |
except Exception as e: | |
import traceback | |
print(f"Error answering question: {e}") | |
print(traceback.format_exc()) | |
return f"Sorry, I encountered an error: {str(e)}", [] | |
def reset_conversation(self): | |
"""Reset the conversation history""" | |
self.conversation_history = [] | |
return "Conversation has been reset." | |
def main(): | |
# Create the Vision 2030 service | |
service = Vision2030Service() | |
# Define theme and styling | |
theme = gr.themes.Soft( | |
primary_hue="emerald", | |
secondary_hue="teal", | |
).set( | |
body_background_fill="linear-gradient(to right, #f0f9ff, #e6f7ff)", | |
button_primary_background_fill="linear-gradient(90deg, #1e9e5a, #1d8753)", | |
button_primary_background_fill_hover="linear-gradient(90deg, #1d8753, #176f44)", | |
button_primary_text_color="white", | |
button_secondary_background_fill="#f0f0f0", | |
button_secondary_background_fill_hover="#e0e0e0", | |
block_title_text_weight="600", | |
block_border_width="2px", | |
block_shadow="0px 4px 6px rgba(0, 0, 0, 0.1)", | |
background_fill_primary="#ffffff", | |
) | |
# Build the Gradio interface with enhanced styling | |
with gr.Blocks(title="Vision 2030 Assistant", theme=theme, css=""" | |
.language-toggle { margin-bottom: 20px; } | |
.container { border-radius: 10px; padding: 20px; box-shadow: 0 4px 6px rgba(0,0,0,0.1); } | |
.header-img { margin-bottom: 10px; border-radius: 10px; } | |
.highlight { background-color: rgba(46, 175, 125, 0.1); padding: 15px; border-radius: 8px; margin: 10px 0; } | |
.footer { text-align: center; margin-top: 30px; color: #666; font-size: 0.9em; } | |
.loading-spinner { display: inline-block; width: 20px; height: 20px; margin-right: 10px; } | |
.status-indicator { display: inline-flex; align-items: center; padding: 8px; border-radius: 4px; } | |
.status-indicator.success { background-color: rgba(46, 175, 125, 0.2); } | |
.status-indicator.warning { background-color: rgba(255, 190, 0, 0.2); } | |
.status-indicator.error { background-color: rgba(255, 76, 76, 0.2); } | |
.header { display: flex; justify-content: space-between; align-items: center; } | |
.lang-btn { min-width: 100px; } | |
.chat-input { background-color: white; border-radius: 8px; border: 1px solid #ddd; } | |
.info-box { background-color: #f8f9fa; padding: 10px; border-radius: 8px; margin-top: 10px; } | |
/* Style for sample question buttons */ | |
.gradio-button.secondary { | |
margin-bottom: 8px; | |
text-align: left; | |
background-color: #f0f9ff; | |
transition: all 0.3s ease; | |
display: block; | |
width: 100%; | |
padding: 8px 12px; | |
} | |
.gradio-button.secondary:hover { | |
background-color: #e0f2fe; | |
transform: translateX(3px); | |
box-shadow: 0 2px 4px rgba(0,0,0,0.1); | |
} | |
""") as demo: | |
# Header with stylized title (no external images) | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
# Add local logo image on the left | |
with gr.Column(scale=1, min_width=100): | |
gr.Image("logo.png", show_label=False, height=80) | |
# Title and tagline on the right | |
with gr.Column(scale=4): | |
gr.Markdown(""" | |
# Vision 2030 Assistant | |
### Your interactive guide to Saudi Arabia's national transformation program | |
""") | |
# Language toggle in the header with better styling | |
with gr.Row(elem_classes=["language-toggle"]): | |
with gr.Column(scale=1): | |
language_toggle = gr.Radio( | |
choices=["English", "العربية (Arabic)", "Auto-detect"], | |
value="Auto-detect", | |
label="Interface Language", | |
info="Choose your preferred language", | |
elem_classes=["lang-btn"] | |
) | |
# Main interface with tabs | |
with gr.Tabs() as tabs: | |
# Chat Tab with enhanced design | |
with gr.TabItem("💬 Chat", id="chat"): | |
with gr.Row(): | |
with gr.Column(scale=2): | |
chatbot = gr.Chatbot( | |
height=450, | |
bubble_full_width=False, | |
show_label=False | |
) | |
with gr.Row(): | |
msg = gr.Textbox( | |
label="", | |
placeholder="Ask a question about Saudi Vision 2030...", | |
show_label=False, | |
elem_classes=["chat-input"], | |
scale=9 | |
) | |
submit_btn = gr.Button("Send", variant="primary", scale=1) | |
with gr.Row(): | |
clear = gr.Button("Clear History", variant="secondary") | |
thinking_indicator = gr.HTML( | |
value='<div id="thinking" style="display:none;">The assistant is thinking...</div>', | |
visible=True | |
) | |
# Sidebar with features | |
with gr.Column(scale=1): | |
gr.Markdown("### Quick Information") | |
with gr.Accordion("Vision 2030 Pillars", open=False): | |
gr.Markdown(""" | |
* **Vibrant Society** - Cultural and social development | |
* **Thriving Economy** - Economic diversification | |
* **Ambitious Nation** - Effective governance | |
""") | |
with gr.Accordion("About this Assistant", open=False): | |
gr.Markdown(""" | |
This assistant uses advanced NLP models to answer questions about Saudi Vision 2030 in both English and Arabic. The system retrieves information from official documents and provides relevant answers. | |
""") | |
system_status = gr.HTML( | |
value='<div class="status-indicator warning">⚠️ System initializing</div>', | |
visible=True | |
) | |
init_btn = gr.Button("Initialize System", variant="primary") | |
# Replace dropdown with clickable buttons | |
gr.Markdown("### Sample Questions") | |
with gr.Group(): | |
# English questions | |
q1_btn = gr.Button("What is Saudi Vision 2030?", variant="secondary") | |
q2_btn = gr.Button("What are the economic goals of Vision 2030?", variant="secondary") | |
q3_btn = gr.Button("How does Vision 2030 aim to improve quality of life?", variant="secondary") | |
# Arabic questions | |
q4_btn = gr.Button("ما هي رؤية السعودية 2030؟", variant="secondary") | |
q5_btn = gr.Button("ما هي الأهداف الاقتصادية لرؤية 2030؟", variant="secondary") | |
q6_btn = gr.Button("كيف تعزز رؤية 2030 الإرث الثقافي السعودي؟", variant="secondary") | |
# Analytics and insights tab | |
with gr.TabItem("📊 Analytics", id="analytics"): | |
gr.Markdown("### Vision 2030 Progress Tracking") | |
with gr.Tabs(): | |
with gr.TabItem("Economic Metrics"): | |
gr.Markdown(""" | |
<div class="highlight"> | |
<h3>Key Economic Indicators</h3> | |
<p>This section displays real-time progress on economic targets of Vision 2030.</p> | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
gr.HTML(""" | |
<div style="background: white; padding: 15px; border-radius: 10px; box-shadow: 0 2px 5px rgba(0,0,0,0.1);"> | |
<h4>GDP Non-oil Growth</h4> | |
<div style="height: 20px; background-color: #e0e0e0; border-radius: 10px; margin: 10px 0;"> | |
<div style="height: 100%; width: 68%; background: linear-gradient(to right, #1e9e5a, #63e6be); border-radius: 10px;"> | |
</div> | |
</div> | |
<div style="display: flex; justify-content: space-between;"> | |
<span>Target: 65%</span> | |
<span>Current: 44%</span> | |
</div> | |
</div> | |
""") | |
with gr.Column(): | |
gr.HTML(""" | |
<div style="background: white; padding: 15px; border-radius: 10px; box-shadow: 0 2px 5px rgba(0,0,0,0.1);"> | |
<h4>Unemployment Rate</h4> | |
<div style="height: 20px; background-color: #e0e0e0; border-radius: 10px; margin: 10px 0;"> | |
<div style="height: 100%; width: 55%; background: linear-gradient(to right, #1e9e5a, #63e6be); border-radius: 10px;"> | |
</div> | |
</div> | |
<div style="display: flex; justify-content: space-between;"> | |
<span>Target: 7%</span> | |
<span>Current: 9.9%</span> | |
</div> | |
</div> | |
""") | |
with gr.Column(): | |
gr.HTML(""" | |
<div style="background: white; padding: 15px; border-radius: 10px; box-shadow: 0 2px 5px rgba(0,0,0,0.1);"> | |
<h4>SME Contribution to GDP</h4> | |
<div style="height: 20px; background-color: #e0e0e0; border-radius: 10px; margin: 10px 0;"> | |
<div style="height: 100%; width: 32%; background: linear-gradient(to right, #1e9e5a, #63e6be); border-radius: 10px;"> | |
</div> | |
</div> | |
<div style="display: flex; justify-content: space-between;"> | |
<span>Target: 35%</span> | |
<span>Current: 22%</span> | |
</div> | |
</div> | |
""") | |
with gr.TabItem("Social Development"): | |
gr.Markdown("#### Social Initiative Progress") | |
social_chart = gr.HTML(""" | |
<div style="background: white; padding: 20px; border-radius: 10px; margin-top: 15px;"> | |
<h3>Quality of Life Improvement Programs</h3> | |
<div style="display: flex; height: 200px; align-items: flex-end; justify-content: space-around; margin: 30px 0;"> | |
<div style="display: flex; flex-direction: column; align-items: center;"> | |
<div style="width: 50px; height: 150px; background: linear-gradient(to top, #1e9e5a, #63e6be); border-radius: 5px 5px 0 0;"></div> | |
<span style="margin-top: 10px;">Tourism</span> | |
</div> | |
<div style="display: flex; flex-direction: column; align-items: center;"> | |
<div style="width: 50px; height: 120px; background: linear-gradient(to top, #1e9e5a, #63e6be); border-radius: 5px 5px 0 0;"></div> | |
<span style="margin-top: 10px;">Entertainment</span> | |
</div> | |
<div style="display: flex; flex-direction: column; align-items: center;"> | |
<div style="width: 50px; height: 180px; background: linear-gradient(to top, #1e9e5a, #63e6be); border-radius: 5px 5px 0 0;"></div> | |
<span style="margin-top: 10px;">Healthcare</span> | |
</div> | |
<div style="display: flex; flex-direction: column; align-items: center;"> | |
<div style="width: 50px; height: 100px; background: linear-gradient(to top, #1e9e5a, #63e6be); border-radius: 5px 5px 0 0;"></div> | |
<span style="margin-top: 10px;">Housing</span> | |
</div> | |
<div style="display: flex; flex-direction: column; align-items: center;"> | |
<div style="width: 50px; height: 160px; background: linear-gradient(to top, #1e9e5a, #63e6be); border-radius: 5px 5px 0 0;"></div> | |
<span style="margin-top: 10px;">Education</span> | |
</div> | |
</div> | |
</div> | |
""") | |
with gr.TabItem("Giga-Projects"): | |
gr.Markdown("#### Major Development Projects") | |
with gr.Row(): | |
for project, desc in [ | |
("NEOM", "A $500 billion mega-city with advanced technologies"), | |
("Red Sea Project", "Luxury tourism destination across 28,000 km²"), | |
("Qiddiya", "Entertainment, sports and arts destination") | |
]: | |
with gr.Column(): | |
gr.HTML(f""" | |
<div style="background: white; padding: 15px; border-radius: 10px; box-shadow: 0 2px 5px rgba(0,0,0,0.1); height: 200px; position: relative; overflow: hidden;"> | |
<div style="position: absolute; top: 0; left: 0; width: 100%; height: 70px; background: linear-gradient(90deg, #1e9e5a, #45b08c); border-radius: 10px 10px 0 0;"></div> | |
<div style="position: relative; padding-top: 80px; text-align: center;"> | |
<h3>{project}</h3> | |
<p>{desc}</p> | |
<button style="background: #1e9e5a; color: white; border: none; padding: 8px 15px; border-radius: 5px; cursor: pointer; margin-top: 15px;">Learn More</button> | |
</div> | |
</div> | |
""") | |
# Technical System Status with improved visualization | |
with gr.TabItem("⚙️ System", id="system"): | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("### System Diagnostics") | |
status_box = gr.Textbox( | |
label="Status", | |
value="System not initialized", | |
lines=1 | |
) | |
with gr.Group(): | |
gr.Markdown("### PDF Documents") | |
pdf_status = gr.Dataframe( | |
headers=["File", "Status", "Size"], | |
datatype=["str", "str", "str"], | |
col_count=(3, "fixed"), | |
value=[["saudi_vision203.pdf", "Checking...", ""], | |
["saudi_vision2030_ar.pdf", "Checking...", ""]] | |
) | |
pdf_btn = gr.Button("Check PDF Files", variant="secondary") | |
gr.Markdown("### System Dependencies") | |
sys_status = gr.Dataframe( | |
headers=["Component", "Status"], | |
datatype=["str", "str"], | |
col_count=(2, "fixed"), | |
value=[["PyTorch", "Not checked"], | |
["Transformers", "Not checked"], | |
["LangChain", "Not checked"], | |
["FAISS", "Not checked"]] | |
) | |
sys_btn = gr.Button("Check Dependencies", variant="secondary") | |
# Visualization column | |
with gr.Column(): | |
gr.Markdown("### System Architecture") | |
gr.HTML(""" | |
<div style="background: white; padding: 20px; border-radius: 10px; margin-top: 15px;"> | |
<svg viewBox="0 0 800 500" xmlns="http://www.w3.org/2000/svg"> | |
<!-- User Input --> | |
<rect x="50" y="50" width="150" height="60" rx="10" fill="#e6f7ff" stroke="#1e9e5a" stroke-width="2"/> | |
<text x="125" y="85" text-anchor="middle" font-size="16">User Query</text> | |
<!-- Arrow down --> | |
<path d="M125 110 L125 160" stroke="#1e9e5a" stroke-width="3" stroke-dasharray="5,5"/> | |
<polygon points="125,170 120,160 130,160" fill="#1e9e5a"/> | |
<!-- RAG System --> | |
<rect x="50" y="170" width="150" height="60" rx="10" fill="#e6f7ff" stroke="#1e9e5a" stroke-width="2"/> | |
<text x="125" y="205" text-anchor="middle" font-size="16">RAG System</text> | |
<!-- Arrow right --> | |
<path d="M200 200 L300 200" stroke="#1e9e5a" stroke-width="3" stroke-dasharray="5,5"/> | |
<polygon points="310,200 300,195 300,205" fill="#1e9e5a"/> | |
<!-- Document Store --> | |
<rect x="310" y="170" width="150" height="60" rx="10" fill="#e6f7ff" stroke="#1e9e5a" stroke-width="2"/> | |
<text x="385" y="195" text-anchor="middle" font-size="16">Vector Store</text> | |
<text x="385" y="215" text-anchor="middle" font-size="14">(FAISS)</text> | |
<!-- Document icons --> | |
<rect x="350" y="270" width="30" height="40" fill="#e6f7ff" stroke="#1e9e5a" stroke-width="1"/> | |
<rect x="355" y="265" width="30" height="40" fill="#e6f7ff" stroke="#1e9e5a" stroke-width="1"/> | |
<rect x="360" y="260" width="30" height="40" fill="#e6f7ff" stroke="#1e9e5a" stroke-width="1"/> | |
<text x="375" y="330" text-anchor="middle" font-size="14">PDF Docs</text> | |
<!-- Arrow up --> | |
<path d="M375 260 L375 230" stroke="#1e9e5a" stroke-width="2"/> | |
<polygon points="375,230 370,240 380,240" fill="#1e9e5a"/> | |
<!-- Arrow back to RAG --> | |
<path d="M310 220 L200 220" stroke="#1e9e5a" stroke-width="3" stroke-dasharray="5,5"/> | |
<polygon points="190,220 200,215 200,225" fill="#1e9e5a"/> | |
<!-- Arrow down from RAG --> | |
<path d="M125 230 L125 280" stroke="#1e9e5a" stroke-width="3"/> | |
<polygon points="125,290 120,280 130,280" fill="#1e9e5a"/> | |
<!-- LLM --> | |
<rect x="50" y="290" width="150" height="60" rx="10" fill="#e6f7ff" stroke="#1e9e5a" stroke-width="2"/> | |
<text x="125" y="315" text-anchor="middle" font-size="16">ALLaM Model</text> | |
<text x="125" y="335" text-anchor="middle" font-size="14">(7B Params)</text> | |
<!-- Arrow down --> | |
<path d="M125 350 L125 400" stroke="#1e9e5a" stroke-width="3"/> | |
<polygon points="125,410 120,400 130,400" fill="#1e9e5a"/> | |
<!-- User Response --> | |
<rect x="50" y="410" width="150" height="60" rx="10" fill="#e6f7ff" stroke="#1e9e5a" stroke-width="2"/> | |
<text x="125" y="445" text-anchor="middle" font-size="16">Response</text> | |
</svg> | |
</div> | |
""") | |
# Memory usage visualization | |
gr.Markdown("### System Resources") | |
gr.HTML(""" | |
<div style="background: white; padding: 15px; border-radius: 10px; box-shadow: 0 2px 5px rgba(0,0,0,0.1); margin-top: 15px;"> | |
<h4>GPU Memory Usage</h4> | |
<div style="height: 20px; background-color: #e0e0e0; border-radius: 10px; margin: 10px 0;"> | |
<div style="height: 100%; width: 72%; background: linear-gradient(to right, #1e9e5a, #ffc107); border-radius: 10px;"> | |
</div> | |
</div> | |
<div style="display: flex; justify-content: space-between;"> | |
<span>Total: 16GB</span> | |
<span>Used: 11.5GB</span> | |
</div> | |
<h4 style="margin-top: 20px;">CPU Usage</h4> | |
<div style="height: 20px; background-color: #e0e0e0; border-radius: 10px; margin: 10px 0;"> | |
<div style="height: 100%; width: 45%; background: linear-gradient(to right, #1e9e5a, #63e6be); border-radius: 10px;"> | |
</div> | |
</div> | |
<div style="display: flex; justify-content: space-between;"> | |
<span>0%</span> | |
<span>45%</span> | |
<span>100%</span> | |
</div> | |
</div> | |
""") | |
# Footer | |
gr.HTML(""" | |
<div class="footer"> | |
<p>Vision 2030 Assistant • Powered by ALLaM-7B-Instruct • © 2025</p> | |
</div> | |
""") | |
# JavaScript for animations and enhanced UI effects | |
demo.load(js=""" | |
function setupThinking() { | |
const thinking = document.getElementById('thinking'); | |
function animateThinking() { | |
if (thinking) { | |
thinking.style.display = 'block'; | |
let dots = '.'; | |
setInterval(() => { | |
dots = dots.length < 3 ? dots + '.' : '.'; | |
thinking.innerHTML = `<div class="status-indicator">🤔 The assistant is thinking${dots}</div>`; | |
}, 500); | |
} | |
} | |
// Demo code to show the thinking animation | |
document.querySelectorAll('button').forEach(btn => { | |
if (btn.textContent.includes('Send')) { | |
btn.addEventListener('click', () => { | |
setTimeout(() => { | |
animateThinking(); | |
}, 100); | |
}); | |
} | |
}); | |
} | |
// Run setup when page loads | |
if (document.readyState === 'complete') { | |
setupThinking(); | |
} else { | |
window.addEventListener('load', setupThinking); | |
} | |
""") | |
# Event handlers | |
def respond(message, history): | |
if not message: | |
return history, "" | |
# Set thinking indicator | |
time.sleep(0.5) # Simulate thinking time | |
response, sources = service.answer_question(message) | |
sources_text = ", ".join(sources) if sources else "No specific sources" | |
# Format the response to include sources | |
full_response = f"{response}\n\nSources: {sources_text}" | |
return history + [[message, full_response]], "" | |
def reset_chat(): | |
service.reset_conversation() | |
return [], "Conversation history has been reset." | |
def initialize_system(): | |
success = service.initialize() | |
# Update system status indicator with styled HTML | |
if success: | |
status_html = '<div class="status-indicator success">✅ System initialized and ready</div>' | |
return "System initialized successfully!", status_html | |
else: | |
status_html = '<div class="status-indicator error">❌ System initialization failed</div>' | |
return "System initialization failed. Check logs for details.", status_html | |
def use_sample_question(question): | |
return question | |
def check_pdfs(): | |
result = [] | |
for pdf_file in ["saudi_vision203.pdf", "saudi_vision2030_ar.pdf"]: | |
if os.path.exists(pdf_file): | |
size = os.path.getsize(pdf_file) / (1024 * 1024) # Size in MB | |
result.append([pdf_file, "Found ✅", f"{size:.2f} MB"]) | |
else: | |
result.append([pdf_file, "Not found ❌", "0 MB"]) | |
return result | |
def check_dependencies(): | |
result = [] | |
# Safe imports inside GPU-decorated function | |
try: | |
import torch | |
result.append(["PyTorch", f"✅ {torch.__version__}"]) | |
except ImportError: | |
result.append(["PyTorch", "❌ Not installed"]) | |
try: | |
import transformers | |
result.append(["Transformers", f"✅ {transformers.__version__}"]) | |
except ImportError: | |
result.append(["Transformers", "❌ Not installed"]) | |
try: | |
import langchain | |
result.append(["LangChain", f"✅ {langchain.__version__}"]) | |
except ImportError: | |
result.append(["LangChain", "❌ Not installed"]) | |
try: | |
import faiss | |
result.append(["FAISS", "✅ Installed"]) | |
except ImportError: | |
result.append(["FAISS", "❌ Not installed"]) | |
return result | |
# Connect event handlers | |
msg.submit(respond, [msg, chatbot], [chatbot, msg]) | |
submit_btn.click(respond, [msg, chatbot], [chatbot, msg]) | |
clear.click(reset_chat, None, [chatbot, msg]) | |
init_btn.click(initialize_system, None, [status_box, system_status]) | |
# Connect all sample question buttons to the message input | |
q1_btn.click(lambda: "What is Saudi Vision 2030?", None, msg) | |
q2_btn.click(lambda: "What are the economic goals of Vision 2030?", None, msg) | |
q3_btn.click(lambda: "How does Vision 2030 aim to improve quality of life?", None, msg) | |
q4_btn.click(lambda: "ما هي رؤية السعودية 2030؟", None, msg) | |
q5_btn.click(lambda: "ما هي الأهداف الاقتصادية لرؤية 2030؟", None, msg) | |
q6_btn.click(lambda: "كيف تعزز رؤية 2030 الإرث الثقافي السعودي؟", None, msg) | |
pdf_btn.click(check_pdfs, None, pdf_status) | |
sys_btn.click(check_dependencies, None, sys_status) | |
# Initialize system on page load | |
demo.load(initialize_system, None, [status_box, system_status]) | |
return demo | |
if __name__ == "__main__": | |
demo = main() | |
demo.queue() | |
demo.launch() |