File size: 7,220 Bytes
ef9531a
 
 
 
 
 
 
 
 
 
720b3b3
ef9531a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
720b3b3
ef9531a
 
720b3b3
ef9531a
 
 
720b3b3
ef9531a
 
720b3b3
 
ef9531a
 
 
 
720b3b3
ef9531a
 
 
720b3b3
ef9531a
 
 
720b3b3
ef9531a
 
720b3b3
ef9531a
 
 
 
 
 
720b3b3
ef9531a
720b3b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef9531a
 
720b3b3
ef9531a
 
720b3b3
 
 
 
 
 
 
 
 
 
ef9531a
 
 
 
 
 
 
 
720b3b3
ef9531a
 
 
720b3b3
ef9531a
720b3b3
 
 
 
 
 
 
 
 
 
ef9531a
720b3b3
ef9531a
 
 
720b3b3
 
ef9531a
 
720b3b3
ef9531a
 
720b3b3
 
 
 
 
 
ef9531a
 
 
 
 
 
 
 
720b3b3
ef9531a
 
 
 
 
 
 
 
 
 
 
 
 
720b3b3
 
ef9531a
720b3b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef9531a
720b3b3
 
 
 
 
 
 
 
 
 
 
 
 
ef9531a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
# app.py
import torch
import gradio as gr
from transformers import (
    AutoProcessor,
    VisionEncoderDecoderModel,
    AutoModelForSpeechSeq2Seq,
    pipeline
)
import numpy as np
from PIL import Image, UnidentifiedImageError
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Hardware configuration
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {DEVICE}")

class PoliceTranslator:
    def __init__(self):
        self._init_models()
        
    def _init_models(self):
        """Initialize models with dtype consistency"""
        try:
            # Document processing (Nougat)
            logger.info("Initializing document model...")
            self.doc_processor = AutoProcessor.from_pretrained("facebook/nougat-base")
            self.doc_model = VisionEncoderDecoderModel.from_pretrained(
                "facebook/nougat-base",
                torch_dtype=torch.float32  # Force float32 for stability
            ).to(DEVICE)

            # Translation model (NLLB)
            logger.info("Initializing translation model...")
            self.translator = pipeline(
                "translation",
                model="facebook/nllb-200-distilled-600M",
                device=0 if DEVICE == "cuda" else -1,
                torch_dtype=torch.float32  # Match document model dtype
            )

            # Speech processing (Whisper)
            logger.info("Initializing speech model...")
            self.speech_processor = AutoProcessor.from_pretrained("openai/whisper-medium")
            self.speech_model = AutoModelForSpeechSeq2Seq.from_pretrained(
                "openai/whisper-medium",
                torch_dtype=torch.float32  # Consistent dtype
            ).to(DEVICE)

            logger.info("All models initialized successfully")

        except Exception as e:
            logger.error(f"Initialization failed: {str(e)}")
            raise

    def process_documents(self, files, src_lang, tgt_lang):
        """Process documents with dtype validation"""
        try:
            # Validate input files
            images = []
            for file in files:
                try:
                    img = Image.open(file).convert("RGB")
                    images.append(img)
                except (UnidentifiedImageError, OSError) as e:
                    logger.error(f"Invalid image file: {str(e)}")
                    return {"error": f"Invalid image format: {file.name}"}

            # Process with dtype consistency
            inputs = self.doc_processor(
                images=images,
                return_tensors="pt"
            ).to(DEVICE, dtype=torch.float32)  # Explicit dtype

            with torch.no_grad():
                outputs = self.doc_model.generate(**inputs, max_new_tokens=512)

            texts = self.doc_processor.batch_decode(outputs, skip_special_tokens=True)
            
            # Validate language codes
            lang_map = {"or": "ory", "ar": "ara_Arab", "en": "eng_Latn"}
            valid_tgt = lang_map.get(tgt_lang, tgt_lang)
            
            translations = self.translator(
                texts,
                src_lang=lang_map.get(src_lang, src_lang),
                tgt_lang=valid_tgt
            )

            return {
                "results": [
                    {"original": text, "translated": trans['translation_text']}
                    for text, trans in zip(texts, translations)
                ]
            }
            
        except Exception as e:
            logger.error(f"Document error: {str(e)}")
            return {"error": str(e)}

    def process_speech(self, audio_files, src_lang, tgt_lang):
        """Process audio with dtype validation"""
        try:
            # Convert audio to numpy arrays
            audio_data = []
            for audio in audio_files:
                try:
                    audio_data.append(np.array(audio[1]))
                except Exception as e:
                    logger.error(f"Audio processing error: {str(e)}")
                    return {"error": "Invalid audio format"}

            # Process with dtype consistency
            inputs = self.speech_processor(
                audio_data,
                sampling_rate=16000,
                return_tensors="pt",
                padding=True
            ).to(DEVICE, dtype=torch.float32)  # Explicit dtype

            with torch.no_grad():
                outputs = self.speech_model.generate(**inputs, max_new_tokens=256)

            transcriptions = self.speech_processor.batch_decode(outputs, skip_special_tokens=True)
            
            translations = self.translator(
                transcriptions,
                src_lang=src_lang,
                tgt_lang=tgt_lang
            )

            return {
                "results": [
                    {"transcription": trans, "translation": tran['translation_text']}
                    for trans, tran in zip(transcriptions, translations)
                ]
            }
            
        except Exception as e:
            logger.error(f"Audio error: {str(e)}")
            return {"error": str(e)}

# Initialize translator
try:
    translator = PoliceTranslator()
except Exception as e:
    logger.error(f"Failed to initialize: {str(e)}")
    raise

# Gradio interface
with gr.Blocks(title="Police Translation Assistant", theme=gr.themes.Soft()) as app:
    gr.Markdown("# 🚨 Police Translation Assistant")
    
    with gr.Tab("πŸ“„ Document Translation"):
        gr.Markdown("### Upload document images (PNG/JPG)")
        with gr.Row():
            doc_input = gr.File(file_count="multiple", file_types=["image"], label="Documents")
            doc_src = gr.Dropdown(
                label="Source Language",
                choices=["en", "ar", "ory"],
                value="en"
            )
            doc_tgt = gr.Dropdown(
                label="Target Language",
                choices=["ory", "en", "ar"],
                value="ory"
            )
        doc_btn = gr.Button("Translate Documents", variant="primary")
        doc_output = gr.JSON(label="Results")

    with gr.Tab("πŸŽ™οΈ Audio Translation"):
        gr.Markdown("### Upload audio recordings (WAV/MP3)")
        with gr.Row():
            audio_input = gr.Audio(sources=["upload"], type="filepath", label="Audio Files")
            audio_src = gr.Dropdown(
                label="Spoken Language",
                choices=["en", "ar"],
                value="en"
            )
            audio_tgt = gr.Dropdown(
                label="Target Language",
                choices["ory", "en", "ar"],
                value="ory"
            )
        audio_btn = gr.Button("Translate Audio", variant="primary")
        audio_output = gr.JSON(label="Results")

    doc_btn.click(
        translator.process_documents,
        inputs=[doc_input, doc_src, doc_tgt],
        outputs=doc_output
    )
    
    audio_btn.click(
        translator.process_speech,
        inputs=[audio_input, audio_src, audio_tgt],
        outputs=audio_output
    )

if __name__ == "__main__":
    app.launch(server_name="0.0.0.0", server_port=7860)