Spaces:
Running
on
L4
Running
on
L4
# app.py | |
import torch | |
import gradio as gr | |
from transformers import ( | |
AutoProcessor, | |
VisionEncoderDecoderModel, | |
AutoModelForSpeechSeq2Seq, | |
pipeline | |
) | |
import numpy as np | |
from PIL import Image, UnidentifiedImageError | |
import logging | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Hardware configuration | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
logger.info(f"Using device: {DEVICE}") | |
class PoliceTranslator: | |
def __init__(self): | |
self._init_models() | |
def _init_models(self): | |
"""Initialize models with dtype consistency""" | |
try: | |
# Document processing (Nougat) | |
logger.info("Initializing document model...") | |
self.doc_processor = AutoProcessor.from_pretrained("facebook/nougat-base") | |
self.doc_model = VisionEncoderDecoderModel.from_pretrained( | |
"facebook/nougat-base", | |
torch_dtype=torch.float32 # Force float32 for stability | |
).to(DEVICE) | |
# Translation model (NLLB) | |
logger.info("Initializing translation model...") | |
self.translator = pipeline( | |
"translation", | |
model="facebook/nllb-200-distilled-600M", | |
device=0 if DEVICE == "cuda" else -1, | |
torch_dtype=torch.float32 # Match document model dtype | |
) | |
# Speech processing (Whisper) | |
logger.info("Initializing speech model...") | |
self.speech_processor = AutoProcessor.from_pretrained("openai/whisper-medium") | |
self.speech_model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
"openai/whisper-medium", | |
torch_dtype=torch.float32 # Consistent dtype | |
).to(DEVICE) | |
logger.info("All models initialized successfully") | |
except Exception as e: | |
logger.error(f"Initialization failed: {str(e)}") | |
raise | |
def process_documents(self, files, src_lang, tgt_lang): | |
"""Process documents with dtype validation""" | |
try: | |
# Validate input files | |
images = [] | |
for file in files: | |
try: | |
img = Image.open(file).convert("RGB") | |
images.append(img) | |
except (UnidentifiedImageError, OSError) as e: | |
logger.error(f"Invalid image file: {str(e)}") | |
return {"error": f"Invalid image format: {file.name}"} | |
# Process with dtype consistency | |
inputs = self.doc_processor( | |
images=images, | |
return_tensors="pt" | |
).to(DEVICE, dtype=torch.float32) # Explicit dtype | |
with torch.no_grad(): | |
outputs = self.doc_model.generate(**inputs, max_new_tokens=512) | |
texts = self.doc_processor.batch_decode(outputs, skip_special_tokens=True) | |
# Validate language codes | |
lang_map = {"or": "ory", "ar": "ara_Arab", "en": "eng_Latn"} | |
valid_tgt = lang_map.get(tgt_lang, tgt_lang) | |
translations = self.translator( | |
texts, | |
src_lang=lang_map.get(src_lang, src_lang), | |
tgt_lang=valid_tgt | |
) | |
return { | |
"results": [ | |
{"original": text, "translated": trans['translation_text']} | |
for text, trans in zip(texts, translations) | |
] | |
} | |
except Exception as e: | |
logger.error(f"Document error: {str(e)}") | |
return {"error": str(e)} | |
def process_speech(self, audio_files, src_lang, tgt_lang): | |
"""Process audio with dtype validation""" | |
try: | |
# Convert audio to numpy arrays | |
audio_data = [] | |
for audio in audio_files: | |
try: | |
audio_data.append(np.array(audio[1])) | |
except Exception as e: | |
logger.error(f"Audio processing error: {str(e)}") | |
return {"error": "Invalid audio format"} | |
# Process with dtype consistency | |
inputs = self.speech_processor( | |
audio_data, | |
sampling_rate=16000, | |
return_tensors="pt", | |
padding=True | |
).to(DEVICE, dtype=torch.float32) # Explicit dtype | |
with torch.no_grad(): | |
outputs = self.speech_model.generate(**inputs, max_new_tokens=256) | |
transcriptions = self.speech_processor.batch_decode(outputs, skip_special_tokens=True) | |
translations = self.translator( | |
transcriptions, | |
src_lang=src_lang, | |
tgt_lang=tgt_lang | |
) | |
return { | |
"results": [ | |
{"transcription": trans, "translation": tran['translation_text']} | |
for trans, tran in zip(transcriptions, translations) | |
] | |
} | |
except Exception as e: | |
logger.error(f"Audio error: {str(e)}") | |
return {"error": str(e)} | |
# Initialize translator | |
try: | |
translator = PoliceTranslator() | |
except Exception as e: | |
logger.error(f"Failed to initialize: {str(e)}") | |
raise | |
# Gradio interface | |
with gr.Blocks(title="Police Translation Assistant", theme=gr.themes.Soft()) as app: | |
gr.Markdown("# π¨ Police Translation Assistant") | |
with gr.Tab("π Document Translation"): | |
gr.Markdown("### Upload document images (PNG/JPG)") | |
with gr.Row(): | |
doc_input = gr.File(file_count="multiple", file_types=["image"], label="Documents") | |
doc_src = gr.Dropdown( | |
label="Source Language", | |
choices=["en", "ar", "ory"], | |
value="en" | |
) | |
doc_tgt = gr.Dropdown( | |
label="Target Language", | |
choices=["ory", "en", "ar"], | |
value="ory" | |
) | |
doc_btn = gr.Button("Translate Documents", variant="primary") | |
doc_output = gr.JSON(label="Results") | |
with gr.Tab("ποΈ Audio Translation"): | |
gr.Markdown("### Upload audio recordings (WAV/MP3)") | |
with gr.Row(): | |
audio_input = gr.Audio(sources=["upload"], type="filepath", label="Audio Files") | |
audio_src = gr.Dropdown( | |
label="Spoken Language", | |
choices=["en", "ar"], | |
value="en" | |
) | |
audio_tgt = gr.Dropdown( | |
label="Target Language", | |
choices["ory", "en", "ar"], | |
value="ory" | |
) | |
audio_btn = gr.Button("Translate Audio", variant="primary") | |
audio_output = gr.JSON(label="Results") | |
doc_btn.click( | |
translator.process_documents, | |
inputs=[doc_input, doc_src, doc_tgt], | |
outputs=doc_output | |
) | |
audio_btn.click( | |
translator.process_speech, | |
inputs=[audio_input, audio_src, audio_tgt], | |
outputs=audio_output | |
) | |
if __name__ == "__main__": | |
app.launch(server_name="0.0.0.0", server_port=7860) |