Nayera-2025's picture
Rename app.py to DS_app.py
4adc004 verified
# app.py
import torch
import gradio as gr
from transformers import (
AutoProcessor,
VisionEncoderDecoderModel,
AutoModelForSpeechSeq2Seq,
pipeline
)
import numpy as np
from PIL import Image, UnidentifiedImageError
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Hardware configuration
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {DEVICE}")
class PoliceTranslator:
def __init__(self):
self._init_models()
def _init_models(self):
"""Initialize models with dtype consistency"""
try:
# Document processing (Nougat)
logger.info("Initializing document model...")
self.doc_processor = AutoProcessor.from_pretrained("facebook/nougat-base")
self.doc_model = VisionEncoderDecoderModel.from_pretrained(
"facebook/nougat-base",
torch_dtype=torch.float32 # Force float32 for stability
).to(DEVICE)
# Translation model (NLLB)
logger.info("Initializing translation model...")
self.translator = pipeline(
"translation",
model="facebook/nllb-200-distilled-600M",
device=0 if DEVICE == "cuda" else -1,
torch_dtype=torch.float32 # Match document model dtype
)
# Speech processing (Whisper)
logger.info("Initializing speech model...")
self.speech_processor = AutoProcessor.from_pretrained("openai/whisper-medium")
self.speech_model = AutoModelForSpeechSeq2Seq.from_pretrained(
"openai/whisper-medium",
torch_dtype=torch.float32 # Consistent dtype
).to(DEVICE)
logger.info("All models initialized successfully")
except Exception as e:
logger.error(f"Initialization failed: {str(e)}")
raise
def process_documents(self, files, src_lang, tgt_lang):
"""Process documents with dtype validation"""
try:
# Validate input files
images = []
for file in files:
try:
img = Image.open(file).convert("RGB")
images.append(img)
except (UnidentifiedImageError, OSError) as e:
logger.error(f"Invalid image file: {str(e)}")
return {"error": f"Invalid image format: {file.name}"}
# Process with dtype consistency
inputs = self.doc_processor(
images=images,
return_tensors="pt"
).to(DEVICE, dtype=torch.float32) # Explicit dtype
with torch.no_grad():
outputs = self.doc_model.generate(**inputs, max_new_tokens=512)
texts = self.doc_processor.batch_decode(outputs, skip_special_tokens=True)
# Validate language codes
lang_map = {"or": "ory", "ar": "ara_Arab", "en": "eng_Latn"}
valid_tgt = lang_map.get(tgt_lang, tgt_lang)
translations = self.translator(
texts,
src_lang=lang_map.get(src_lang, src_lang),
tgt_lang=valid_tgt
)
return {
"results": [
{"original": text, "translated": trans['translation_text']}
for text, trans in zip(texts, translations)
]
}
except Exception as e:
logger.error(f"Document error: {str(e)}")
return {"error": str(e)}
def process_speech(self, audio_files, src_lang, tgt_lang):
"""Process audio with dtype validation"""
try:
# Convert audio to numpy arrays
audio_data = []
for audio in audio_files:
try:
audio_data.append(np.array(audio[1]))
except Exception as e:
logger.error(f"Audio processing error: {str(e)}")
return {"error": "Invalid audio format"}
# Process with dtype consistency
inputs = self.speech_processor(
audio_data,
sampling_rate=16000,
return_tensors="pt",
padding=True
).to(DEVICE, dtype=torch.float32) # Explicit dtype
with torch.no_grad():
outputs = self.speech_model.generate(**inputs, max_new_tokens=256)
transcriptions = self.speech_processor.batch_decode(outputs, skip_special_tokens=True)
translations = self.translator(
transcriptions,
src_lang=src_lang,
tgt_lang=tgt_lang
)
return {
"results": [
{"transcription": trans, "translation": tran['translation_text']}
for trans, tran in zip(transcriptions, translations)
]
}
except Exception as e:
logger.error(f"Audio error: {str(e)}")
return {"error": str(e)}
# Initialize translator
try:
translator = PoliceTranslator()
except Exception as e:
logger.error(f"Failed to initialize: {str(e)}")
raise
# Gradio interface
with gr.Blocks(title="Police Translation Assistant", theme=gr.themes.Soft()) as app:
gr.Markdown("# 🚨 Police Translation Assistant")
with gr.Tab("πŸ“„ Document Translation"):
gr.Markdown("### Upload document images (PNG/JPG)")
with gr.Row():
doc_input = gr.File(file_count="multiple", file_types=["image"], label="Documents")
doc_src = gr.Dropdown(
label="Source Language",
choices=["en", "ar", "ory"],
value="en"
)
doc_tgt = gr.Dropdown(
label="Target Language",
choices=["ory", "en", "ar"],
value="ory"
)
doc_btn = gr.Button("Translate Documents", variant="primary")
doc_output = gr.JSON(label="Results")
with gr.Tab("πŸŽ™οΈ Audio Translation"):
gr.Markdown("### Upload audio recordings (WAV/MP3)")
with gr.Row():
audio_input = gr.Audio(sources=["upload"], type="filepath", label="Audio Files")
audio_src = gr.Dropdown(
label="Spoken Language",
choices=["en", "ar"],
value="en"
)
audio_tgt = gr.Dropdown(
label="Target Language",
choices["ory", "en", "ar"],
value="ory"
)
audio_btn = gr.Button("Translate Audio", variant="primary")
audio_output = gr.JSON(label="Results")
doc_btn.click(
translator.process_documents,
inputs=[doc_input, doc_src, doc_tgt],
outputs=doc_output
)
audio_btn.click(
translator.process_speech,
inputs=[audio_input, audio_src, audio_tgt],
outputs=audio_output
)
if __name__ == "__main__":
app.launch(server_name="0.0.0.0", server_port=7860)