# app.py import torch import gradio as gr from transformers import ( AutoProcessor, VisionEncoderDecoderModel, AutoModelForSpeechSeq2Seq, pipeline ) import numpy as np from PIL import Image, UnidentifiedImageError import logging # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Hardware configuration DEVICE = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Using device: {DEVICE}") class PoliceTranslator: def __init__(self): self._init_models() def _init_models(self): """Initialize models with dtype consistency""" try: # Document processing (Nougat) logger.info("Initializing document model...") self.doc_processor = AutoProcessor.from_pretrained("facebook/nougat-base") self.doc_model = VisionEncoderDecoderModel.from_pretrained( "facebook/nougat-base", torch_dtype=torch.float32 # Force float32 for stability ).to(DEVICE) # Translation model (NLLB) logger.info("Initializing translation model...") self.translator = pipeline( "translation", model="facebook/nllb-200-distilled-600M", device=0 if DEVICE == "cuda" else -1, torch_dtype=torch.float32 # Match document model dtype ) # Speech processing (Whisper) logger.info("Initializing speech model...") self.speech_processor = AutoProcessor.from_pretrained("openai/whisper-medium") self.speech_model = AutoModelForSpeechSeq2Seq.from_pretrained( "openai/whisper-medium", torch_dtype=torch.float32 # Consistent dtype ).to(DEVICE) logger.info("All models initialized successfully") except Exception as e: logger.error(f"Initialization failed: {str(e)}") raise def process_documents(self, files, src_lang, tgt_lang): """Process documents with dtype validation""" try: # Validate input files images = [] for file in files: try: img = Image.open(file).convert("RGB") images.append(img) except (UnidentifiedImageError, OSError) as e: logger.error(f"Invalid image file: {str(e)}") return {"error": f"Invalid image format: {file.name}"} # Process with dtype consistency inputs = self.doc_processor( images=images, return_tensors="pt" ).to(DEVICE, dtype=torch.float32) # Explicit dtype with torch.no_grad(): outputs = self.doc_model.generate(**inputs, max_new_tokens=512) texts = self.doc_processor.batch_decode(outputs, skip_special_tokens=True) # Validate language codes lang_map = {"or": "ory", "ar": "ara_Arab", "en": "eng_Latn"} valid_tgt = lang_map.get(tgt_lang, tgt_lang) translations = self.translator( texts, src_lang=lang_map.get(src_lang, src_lang), tgt_lang=valid_tgt ) return { "results": [ {"original": text, "translated": trans['translation_text']} for text, trans in zip(texts, translations) ] } except Exception as e: logger.error(f"Document error: {str(e)}") return {"error": str(e)} def process_speech(self, audio_files, src_lang, tgt_lang): """Process audio with dtype validation""" try: # Convert audio to numpy arrays audio_data = [] for audio in audio_files: try: audio_data.append(np.array(audio[1])) except Exception as e: logger.error(f"Audio processing error: {str(e)}") return {"error": "Invalid audio format"} # Process with dtype consistency inputs = self.speech_processor( audio_data, sampling_rate=16000, return_tensors="pt", padding=True ).to(DEVICE, dtype=torch.float32) # Explicit dtype with torch.no_grad(): outputs = self.speech_model.generate(**inputs, max_new_tokens=256) transcriptions = self.speech_processor.batch_decode(outputs, skip_special_tokens=True) translations = self.translator( transcriptions, src_lang=src_lang, tgt_lang=tgt_lang ) return { "results": [ {"transcription": trans, "translation": tran['translation_text']} for trans, tran in zip(transcriptions, translations) ] } except Exception as e: logger.error(f"Audio error: {str(e)}") return {"error": str(e)} # Initialize translator try: translator = PoliceTranslator() except Exception as e: logger.error(f"Failed to initialize: {str(e)}") raise # Gradio interface with gr.Blocks(title="Police Translation Assistant", theme=gr.themes.Soft()) as app: gr.Markdown("# 🚨 Police Translation Assistant") with gr.Tab("📄 Document Translation"): gr.Markdown("### Upload document images (PNG/JPG)") with gr.Row(): doc_input = gr.File(file_count="multiple", file_types=["image"], label="Documents") doc_src = gr.Dropdown( label="Source Language", choices=["en", "ar", "ory"], value="en" ) doc_tgt = gr.Dropdown( label="Target Language", choices=["ory", "en", "ar"], value="ory" ) doc_btn = gr.Button("Translate Documents", variant="primary") doc_output = gr.JSON(label="Results") with gr.Tab("🎙️ Audio Translation"): gr.Markdown("### Upload audio recordings (WAV/MP3)") with gr.Row(): audio_input = gr.Audio(sources=["upload"], type="filepath", label="Audio Files") audio_src = gr.Dropdown( label="Spoken Language", choices=["en", "ar"], value="en" ) audio_tgt = gr.Dropdown( label="Target Language", choices["ory", "en", "ar"], value="ory" ) audio_btn = gr.Button("Translate Audio", variant="primary") audio_output = gr.JSON(label="Results") doc_btn.click( translator.process_documents, inputs=[doc_input, doc_src, doc_tgt], outputs=doc_output ) audio_btn.click( translator.process_speech, inputs=[audio_input, audio_src, audio_tgt], outputs=audio_output ) if __name__ == "__main__": app.launch(server_name="0.0.0.0", server_port=7860)