Jerich commited on
Commit
de2c35e
·
verified ·
1 Parent(s): f24e508

API integration for the Talklas pipeline

Browse files
Files changed (1) hide show
  1. app.py +312 -64
app.py CHANGED
@@ -1,10 +1,8 @@
1
  import os
2
  import torch
 
3
  import numpy as np
4
  import soundfile as sf
5
- from fastapi import FastAPI, File, UploadFile, HTTPException
6
- from fastapi.responses import JSONResponse
7
- from pydantic import BaseModel
8
  from transformers import (
9
  AutoModelForSeq2SeqLM,
10
  AutoTokenizer,
@@ -15,11 +13,18 @@ from transformers import (
15
  WhisperForConditionalGeneration
16
  )
17
  from typing import Optional, Tuple, Dict, List
 
 
18
  import base64
19
  import io
 
20
 
21
- # Your existing TalklasTranslator class (unchanged)
22
  class TalklasTranslator:
 
 
 
 
 
23
  LANGUAGE_MAPPING = {
24
  "English": "eng",
25
  "Tagalog": "tgl",
@@ -50,34 +55,45 @@ class TalklasTranslator:
50
  self.sample_rate = 16000
51
 
52
  print(f"Initializing Talklas Translator on {self.device}")
 
 
53
  self._initialize_stt_model()
54
  self._initialize_mt_model()
55
  self._initialize_tts_model()
56
 
57
  def _initialize_stt_model(self):
 
58
  try:
59
  print("Loading STT model...")
60
  try:
 
61
  self.stt_processor = AutoProcessor.from_pretrained("facebook/mms-1b-all")
62
  self.stt_model = AutoModelForCTC.from_pretrained("facebook/mms-1b-all")
 
 
63
  if self.source_lang in self.stt_processor.tokenizer.vocab.keys():
64
  self.stt_processor.tokenizer.set_target_lang(self.source_lang)
65
  self.stt_model.load_adapter(self.source_lang)
66
  print(f"Loaded MMS STT model for {self.source_lang}")
67
  else:
68
  print(f"Language {self.source_lang} not in MMS, using default")
 
69
  except Exception as mms_error:
70
  print(f"MMS loading failed: {mms_error}")
 
71
  print("Loading Whisper as fallback...")
72
  self.stt_processor = WhisperProcessor.from_pretrained("openai/whisper-small")
73
  self.stt_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
74
  print("Loaded Whisper STT model")
 
75
  self.stt_model.to(self.device)
 
76
  except Exception as e:
77
  print(f"STT model initialization failed: {e}")
78
  raise RuntimeError("Could not initialize STT model")
79
 
80
  def _initialize_mt_model(self):
 
81
  try:
82
  print("Loading NLLB Translation model...")
83
  self.mt_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
@@ -89,6 +105,7 @@ class TalklasTranslator:
89
  raise
90
 
91
  def _initialize_tts_model(self):
 
92
  try:
93
  print("Loading TTS model...")
94
  try:
@@ -100,78 +117,102 @@ class TalklasTranslator:
100
  print("Falling back to English TTS")
101
  self.tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
102
  self.tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
 
103
  self.tts_model.to(self.device)
104
  except Exception as e:
105
  print(f"TTS model initialization failed: {e}")
106
  raise
107
 
108
  def update_languages(self, source_lang: str, target_lang: str) -> str:
 
109
  if source_lang == self.source_lang and target_lang == self.target_lang:
110
  return "Languages already set"
 
111
  self.source_lang = source_lang
112
  self.target_lang = target_lang
 
 
113
  self._initialize_stt_model()
114
  self._initialize_tts_model()
 
115
  return f"Languages updated to {source_lang} → {target_lang}"
116
 
117
  def speech_to_text(self, audio_path: str) -> str:
 
118
  try:
119
  waveform, sample_rate = sf.read(audio_path)
 
120
  if sample_rate != 16000:
121
  import librosa
122
  waveform = librosa.resample(waveform, orig_sr=sample_rate, target_sr=16000)
 
123
  inputs = self.stt_processor(
124
  waveform,
125
  sampling_rate=16000,
126
  return_tensors="pt"
127
  ).to(self.device)
 
128
  with torch.no_grad():
129
- if isinstance(self.stt_model, WhisperForConditionalGeneration):
130
  generated_ids = self.stt_model.generate(**inputs)
131
  transcription = self.stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
132
- else:
133
  logits = self.stt_model(**inputs).logits
134
  predicted_ids = torch.argmax(logits, dim=-1)
135
  transcription = self.stt_processor.batch_decode(predicted_ids)[0]
 
136
  return transcription
 
137
  except Exception as e:
138
  print(f"Speech recognition failed: {e}")
139
  raise RuntimeError("Speech recognition failed")
140
 
141
  def translate_text(self, text: str) -> str:
 
142
  try:
143
  source_code = self.NLLB_LANGUAGE_CODES[self.source_lang]
144
  target_code = self.NLLB_LANGUAGE_CODES[self.target_lang]
 
145
  self.mt_tokenizer.src_lang = source_code
146
  inputs = self.mt_tokenizer(text, return_tensors="pt").to(self.device)
 
147
  with torch.no_grad():
148
  generated_tokens = self.mt_model.generate(
149
  **inputs,
150
  forced_bos_token_id=self.mt_tokenizer.convert_tokens_to_ids(target_code),
151
  max_length=448
152
  )
 
153
  return self.mt_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
 
154
  except Exception as e:
155
  print(f"Translation failed: {e}")
156
  raise RuntimeError("Text translation failed")
157
 
158
  def text_to_speech(self, text: str) -> Tuple[int, np.ndarray]:
 
159
  try:
160
  inputs = self.tts_tokenizer(text, return_tensors="pt").to(self.device)
 
161
  with torch.no_grad():
162
  output = self.tts_model(**inputs)
 
163
  speech = output.waveform.cpu().numpy().squeeze()
164
  speech = (speech * 32767).astype(np.int16)
 
165
  return self.tts_model.config.sampling_rate, speech
 
166
  except Exception as e:
167
  print(f"Speech synthesis failed: {e}")
168
  raise RuntimeError("Speech synthesis failed")
169
 
170
  def translate_speech(self, audio_path: str) -> Dict:
 
171
  try:
172
  source_text = self.speech_to_text(audio_path)
173
  translated_text = self.translate_text(source_text)
174
  sample_rate, audio = self.text_to_speech(translated_text)
 
175
  return {
176
  "source_text": source_text,
177
  "translated_text": translated_text,
@@ -187,9 +228,11 @@ class TalklasTranslator:
187
  }
188
 
189
  def translate_text_only(self, text: str) -> Dict:
 
190
  try:
191
  translated_text = self.translate_text(text)
192
  sample_rate, audio = self.text_to_speech(translated_text)
 
193
  return {
194
  "source_text": text,
195
  "translated_text": translated_text,
@@ -213,88 +256,293 @@ class TranslatorSingleton:
213
  cls._instance = TalklasTranslator()
214
  return cls._instance
215
 
216
- # FastAPI application
217
- app = FastAPI(title="Talklas API", description="Speech-to-Speech Translation API")
 
 
 
218
 
219
- class TranslationRequest(BaseModel):
220
- source_lang: str
221
- target_lang: str
222
- text: Optional[str] = None
223
 
224
- @app.post("/translate/audio")
225
- async def translate_audio(file: UploadFile = File(...), source_lang: str = "English", target_lang: str = "Tagalog"):
226
- try:
227
- # Validate languages
228
- if source_lang not in TalklasTranslator.LANGUAGE_MAPPING or target_lang not in TalklasTranslator.LANGUAGE_MAPPING:
229
- raise HTTPException(status_code=400, detail="Invalid language selection")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
 
231
- # Save uploaded audio file temporarily
232
- audio_path = f"temp_{file.filename}"
233
- with open(audio_path, "wb") as f:
234
- f.write(await file.read())
235
 
236
- # Update languages
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  source_code = TalklasTranslator.LANGUAGE_MAPPING[source_lang]
238
  target_code = TalklasTranslator.LANGUAGE_MAPPING[target_lang]
239
- translator = TranslatorSingleton.get_instance()
240
  translator.update_languages(source_code, target_code)
241
-
242
  # Process the audio
243
- results = translator.translate_speech(audio_path)
244
-
 
 
 
 
 
 
245
  # Clean up temporary file
246
- os.remove(audio_path)
247
-
248
- # Convert audio to base64 for response
249
- sample_rate, audio = results["output_audio"]
250
- buffer = io.BytesIO()
251
- sf.write(buffer, audio, sample_rate, format="wav")
252
- audio_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
253
-
254
- return JSONResponse(content={
255
  "source_text": results["source_text"],
256
  "translated_text": results["translated_text"],
257
  "audio_base64": audio_base64,
258
  "sample_rate": sample_rate,
259
- "performance": results["performance"]
260
  })
 
261
  except Exception as e:
262
- raise HTTPException(status_code=500, detail=f"Translation failed: {str(e)}")
 
 
 
263
 
264
- @app.post("/translate/text")
265
- async def translate_text(request: TranslationRequest):
 
266
  try:
267
- # Validate input
268
- if not request.text:
269
- raise HTTPException(status_code=400, detail="Text input is required")
270
- if request.source_lang not in TalklasTranslator.LANGUAGE_MAPPING or request.target_lang not in TalklasTranslator.LANGUAGE_MAPPING:
271
- raise HTTPException(status_code=400, detail="Invalid language selection")
272
-
273
- # Update languages
274
- source_code = TalklasTranslator.LANGUAGE_MAPPING[request.source_lang]
275
- target_code = TalklasTranslator.LANGUAGE_MAPPING[request.target_lang]
276
- translator = TranslatorSingleton.get_instance()
 
 
 
 
277
  translator.update_languages(source_code, target_code)
278
-
279
  # Process the text
280
- results = translator.translate_text_only(request.text)
281
-
282
- # Convert audio to base64 for response
283
- sample_rate, audio = results["output_audio"]
284
- buffer = io.BytesIO()
285
- sf.write(buffer, audio, sample_rate, format="wav")
286
- audio_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
287
-
288
- return JSONResponse(content={
289
  "source_text": results["source_text"],
290
  "translated_text": results["translated_text"],
291
  "audio_base64": audio_base64,
292
  "sample_rate": sample_rate,
293
- "performance": results["performance"]
294
  })
 
295
  except Exception as e:
296
- raise HTTPException(status_code=500, detail=f"Translation failed: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
 
 
298
  if __name__ == "__main__":
299
- import uvicorn
300
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
 
 
 
 
 
1
  import os
2
  import torch
3
+ import gradio as gr
4
  import numpy as np
5
  import soundfile as sf
 
 
 
6
  from transformers import (
7
  AutoModelForSeq2SeqLM,
8
  AutoTokenizer,
 
13
  WhisperForConditionalGeneration
14
  )
15
  from typing import Optional, Tuple, Dict, List
16
+ from flask import Flask, request, jsonify
17
+ from flask_cors import CORS
18
  import base64
19
  import io
20
+ import tempfile
21
 
 
22
  class TalklasTranslator:
23
+ """
24
+ Speech-to-Speech translation pipeline for Philippine languages.
25
+ Uses MMS/Whisper for STT, NLLB for MT, and MMS for TTS.
26
+ """
27
+
28
  LANGUAGE_MAPPING = {
29
  "English": "eng",
30
  "Tagalog": "tgl",
 
55
  self.sample_rate = 16000
56
 
57
  print(f"Initializing Talklas Translator on {self.device}")
58
+
59
+ # Initialize models
60
  self._initialize_stt_model()
61
  self._initialize_mt_model()
62
  self._initialize_tts_model()
63
 
64
  def _initialize_stt_model(self):
65
+ """Initialize speech-to-text model with fallback to Whisper"""
66
  try:
67
  print("Loading STT model...")
68
  try:
69
+ # Try loading MMS model first
70
  self.stt_processor = AutoProcessor.from_pretrained("facebook/mms-1b-all")
71
  self.stt_model = AutoModelForCTC.from_pretrained("facebook/mms-1b-all")
72
+
73
+ # Set language if available
74
  if self.source_lang in self.stt_processor.tokenizer.vocab.keys():
75
  self.stt_processor.tokenizer.set_target_lang(self.source_lang)
76
  self.stt_model.load_adapter(self.source_lang)
77
  print(f"Loaded MMS STT model for {self.source_lang}")
78
  else:
79
  print(f"Language {self.source_lang} not in MMS, using default")
80
+
81
  except Exception as mms_error:
82
  print(f"MMS loading failed: {mms_error}")
83
+ # Fallback to Whisper
84
  print("Loading Whisper as fallback...")
85
  self.stt_processor = WhisperProcessor.from_pretrained("openai/whisper-small")
86
  self.stt_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
87
  print("Loaded Whisper STT model")
88
+
89
  self.stt_model.to(self.device)
90
+
91
  except Exception as e:
92
  print(f"STT model initialization failed: {e}")
93
  raise RuntimeError("Could not initialize STT model")
94
 
95
  def _initialize_mt_model(self):
96
+ """Initialize machine translation model"""
97
  try:
98
  print("Loading NLLB Translation model...")
99
  self.mt_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
 
105
  raise
106
 
107
  def _initialize_tts_model(self):
108
+ """Initialize text-to-speech model"""
109
  try:
110
  print("Loading TTS model...")
111
  try:
 
117
  print("Falling back to English TTS")
118
  self.tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
119
  self.tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
120
+
121
  self.tts_model.to(self.device)
122
  except Exception as e:
123
  print(f"TTS model initialization failed: {e}")
124
  raise
125
 
126
  def update_languages(self, source_lang: str, target_lang: str) -> str:
127
+ """Update languages and reinitialize models if needed"""
128
  if source_lang == self.source_lang and target_lang == self.target_lang:
129
  return "Languages already set"
130
+
131
  self.source_lang = source_lang
132
  self.target_lang = target_lang
133
+
134
+ # Only reinitialize models that depend on language
135
  self._initialize_stt_model()
136
  self._initialize_tts_model()
137
+
138
  return f"Languages updated to {source_lang} → {target_lang}"
139
 
140
  def speech_to_text(self, audio_path: str) -> str:
141
+ """Convert speech to text using loaded STT model"""
142
  try:
143
  waveform, sample_rate = sf.read(audio_path)
144
+
145
  if sample_rate != 16000:
146
  import librosa
147
  waveform = librosa.resample(waveform, orig_sr=sample_rate, target_sr=16000)
148
+
149
  inputs = self.stt_processor(
150
  waveform,
151
  sampling_rate=16000,
152
  return_tensors="pt"
153
  ).to(self.device)
154
+
155
  with torch.no_grad():
156
+ if isinstance(self.stt_model, WhisperForConditionalGeneration): # Whisper model
157
  generated_ids = self.stt_model.generate(**inputs)
158
  transcription = self.stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
159
+ else: # MMS model (Wav2Vec2ForCTC)
160
  logits = self.stt_model(**inputs).logits
161
  predicted_ids = torch.argmax(logits, dim=-1)
162
  transcription = self.stt_processor.batch_decode(predicted_ids)[0]
163
+
164
  return transcription
165
+
166
  except Exception as e:
167
  print(f"Speech recognition failed: {e}")
168
  raise RuntimeError("Speech recognition failed")
169
 
170
  def translate_text(self, text: str) -> str:
171
+ """Translate text using NLLB model"""
172
  try:
173
  source_code = self.NLLB_LANGUAGE_CODES[self.source_lang]
174
  target_code = self.NLLB_LANGUAGE_CODES[self.target_lang]
175
+
176
  self.mt_tokenizer.src_lang = source_code
177
  inputs = self.mt_tokenizer(text, return_tensors="pt").to(self.device)
178
+
179
  with torch.no_grad():
180
  generated_tokens = self.mt_model.generate(
181
  **inputs,
182
  forced_bos_token_id=self.mt_tokenizer.convert_tokens_to_ids(target_code),
183
  max_length=448
184
  )
185
+
186
  return self.mt_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
187
+
188
  except Exception as e:
189
  print(f"Translation failed: {e}")
190
  raise RuntimeError("Text translation failed")
191
 
192
  def text_to_speech(self, text: str) -> Tuple[int, np.ndarray]:
193
+ """Convert text to speech"""
194
  try:
195
  inputs = self.tts_tokenizer(text, return_tensors="pt").to(self.device)
196
+
197
  with torch.no_grad():
198
  output = self.tts_model(**inputs)
199
+
200
  speech = output.waveform.cpu().numpy().squeeze()
201
  speech = (speech * 32767).astype(np.int16)
202
+
203
  return self.tts_model.config.sampling_rate, speech
204
+
205
  except Exception as e:
206
  print(f"Speech synthesis failed: {e}")
207
  raise RuntimeError("Speech synthesis failed")
208
 
209
  def translate_speech(self, audio_path: str) -> Dict:
210
+ """Full speech-to-speech translation"""
211
  try:
212
  source_text = self.speech_to_text(audio_path)
213
  translated_text = self.translate_text(source_text)
214
  sample_rate, audio = self.text_to_speech(translated_text)
215
+
216
  return {
217
  "source_text": source_text,
218
  "translated_text": translated_text,
 
228
  }
229
 
230
  def translate_text_only(self, text: str) -> Dict:
231
+ """Text-to-speech translation"""
232
  try:
233
  translated_text = self.translate_text(text)
234
  sample_rate, audio = self.text_to_speech(translated_text)
235
+
236
  return {
237
  "source_text": text,
238
  "translated_text": translated_text,
 
256
  cls._instance = TalklasTranslator()
257
  return cls._instance
258
 
259
+ def process_audio(audio_path, source_lang, target_lang):
260
+ """Process audio through the full translation pipeline"""
261
+ # Validate input
262
+ if not audio_path:
263
+ return None, "No audio provided", "No translation available", "Please provide audio input"
264
 
265
+ # Update languages
266
+ source_code = TalklasTranslator.LANGUAGE_MAPPING[source_lang]
267
+ target_code = TalklasTranslator.LANGUAGE_MAPPING[target_lang]
 
268
 
269
+ translator = TranslatorSingleton.get_instance()
270
+ status = translator.update_languages(source_code, target_code)
271
+
272
+ # Process the audio
273
+ results = translator.translate_speech(audio_path)
274
+
275
+ return results["output_audio"], results["source_text"], results["translated_text"], results["performance"]
276
+
277
+ def process_text(text, source_lang, target_lang):
278
+ """Process text through the translation pipeline"""
279
+ # Validate input
280
+ if not text:
281
+ return None, "No text provided", "No translation available", "Please provide text input"
282
+
283
+ # Update languages
284
+ source_code = TalklasTranslator.LANGUAGE_MAPPING[source_lang]
285
+ target_code = TalklasTranslator.LANGUAGE_MAPPING[target_lang]
286
+
287
+ translator = TranslatorSingleton.get_instance()
288
+ status = translator.update_languages(source_code, target_code)
289
+
290
+ # Process the text
291
+ results = translator.translate_text_only(text)
292
 
293
+ return results["output_audio"], results["source_text"], results["translated_text"], results["performance"]
 
 
 
294
 
295
+ def create_gradio_interface():
296
+ """Create and launch Gradio interface"""
297
+ # Define language options
298
+ languages = list(TalklasTranslator.LANGUAGE_MAPPING.keys())
299
+
300
+ # Define the interface
301
+ demo = gr.Blocks(title="Talklas - Speech & Text Translation")
302
+
303
+ with demo:
304
+ gr.Markdown("# Talklas: Speech-to-Speech Translation System")
305
+ gr.Markdown("### Translate between Philippine Languages and English")
306
+
307
+ with gr.Row():
308
+ with gr.Column():
309
+ source_lang = gr.Dropdown(
310
+ choices=languages,
311
+ value="English",
312
+ label="Source Language"
313
+ )
314
+
315
+ target_lang = gr.Dropdown(
316
+ choices=languages,
317
+ value="Tagalog",
318
+ label="Target Language"
319
+ )
320
+
321
+ language_status = gr.Textbox(label="Language Status")
322
+ update_btn = gr.Button("Update Languages")
323
+
324
+ with gr.Tabs():
325
+ with gr.TabItem("Audio Input"):
326
+ with gr.Row():
327
+ with gr.Column():
328
+ gr.Markdown("### Audio Input")
329
+ audio_input = gr.Audio(
330
+ type="filepath",
331
+ label="Upload Audio File"
332
+ )
333
+ audio_translate_btn = gr.Button("Translate Audio", variant="primary")
334
+
335
+ with gr.Column():
336
+ gr.Markdown("### Output")
337
+ audio_output = gr.Audio(
338
+ label="Translated Speech",
339
+ type="numpy",
340
+ autoplay=True
341
+ )
342
+
343
+ with gr.TabItem("Text Input"):
344
+ with gr.Row():
345
+ with gr.Column():
346
+ gr.Markdown("### Text Input")
347
+ text_input = gr.Textbox(
348
+ label="Enter text to translate",
349
+ lines=3
350
+ )
351
+ text_translate_btn = gr.Button("Translate Text", variant="primary")
352
+
353
+ with gr.Column():
354
+ gr.Markdown("### Output")
355
+ text_output = gr.Audio(
356
+ label="Translated Speech",
357
+ type="numpy",
358
+ autoplay=True
359
+ )
360
+
361
+ with gr.Row():
362
+ with gr.Column():
363
+ source_text = gr.Textbox(label="Source Text")
364
+ translated_text = gr.Textbox(label="Translated Text")
365
+ performance_info = gr.Textbox(label="Performance Metrics")
366
+
367
+ # Set up events
368
+ update_btn.click(
369
+ lambda source_lang, target_lang: TranslatorSingleton.get_instance().update_languages(
370
+ TalklasTranslator.LANGUAGE_MAPPING[source_lang],
371
+ TalklasTranslator.LANGUAGE_MAPPING[target_lang]
372
+ ),
373
+ inputs=[source_lang, target_lang],
374
+ outputs=[language_status]
375
+ )
376
+
377
+ # Audio translate button click
378
+ audio_translate_btn.click(
379
+ process_audio,
380
+ inputs=[audio_input, source_lang, target_lang],
381
+ outputs=[audio_output, source_text, translated_text, performance_info]
382
+ ).then(
383
+ None,
384
+ None,
385
+ None,
386
+ js="""() => {
387
+ const audioElements = document.querySelectorAll('audio');
388
+ if (audioElements.length > 0) {
389
+ const lastAudio = audioElements[audioElements.length - 1];
390
+ lastAudio.play().catch(error => {
391
+ console.warn('Autoplay failed:', error);
392
+ alert('Audio may require user interaction to play');
393
+ });
394
+ }
395
+ }"""
396
+ )
397
+
398
+ # Text translate button click
399
+ text_translate_btn.click(
400
+ process_text,
401
+ inputs=[text_input, source_lang, target_lang],
402
+ outputs=[text_output, source_text, translated_text, performance_info]
403
+ ).then(
404
+ None,
405
+ None,
406
+ None,
407
+ js="""() => {
408
+ const audioElements = document.querySelectorAll('audio');
409
+ if (audioElements.length > 0) {
410
+ const lastAudio = audioElements[audioElements.length - 1];
411
+ lastAudio.play().catch(error => {
412
+ console.warn('Autoplay failed:', error);
413
+ alert('Audio may require user interaction to play');
414
+ });
415
+ }
416
+ }"""
417
+ )
418
+
419
+ return demo
420
+
421
+ # Create Flask app
422
+ app = Flask(__name__)
423
+ CORS(app) # This allows cross-origin requests
424
+
425
+ # Initialize the translator singleton
426
+ translator_instance = None
427
+
428
+ def get_translator():
429
+ global translator_instance
430
+ if translator_instance is None:
431
+ translator_instance = TalklasTranslator()
432
+ return translator_instance
433
+
434
+ @app.route('/api/translate-speech', methods=['POST'])
435
+ def api_translate_speech():
436
+ """API endpoint for speech-to-speech translation"""
437
+ try:
438
+ # Check if required data is in the request
439
+ if 'audio' not in request.files:
440
+ return jsonify({
441
+ "error": "No audio file provided"
442
+ }), 400
443
+
444
+ audio_file = request.files['audio']
445
+ source_lang = request.form.get('source_lang', 'English')
446
+ target_lang = request.form.get('target_lang', 'Tagalog')
447
+
448
+ # Save temporary audio file
449
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as temp_audio:
450
+ audio_file.save(temp_audio.name)
451
+ temp_audio_path = temp_audio.name
452
+
453
+ # Get translator and update languages
454
+ translator = get_translator()
455
  source_code = TalklasTranslator.LANGUAGE_MAPPING[source_lang]
456
  target_code = TalklasTranslator.LANGUAGE_MAPPING[target_lang]
 
457
  translator.update_languages(source_code, target_code)
458
+
459
  # Process the audio
460
+ results = translator.translate_speech(temp_audio_path)
461
+
462
+ # Convert audio to base64 for transmission
463
+ sample_rate, audio_data = results["output_audio"]
464
+ audio_bytes = io.BytesIO()
465
+ sf.write(audio_bytes, audio_data, sample_rate, format='WAV')
466
+ audio_base64 = base64.b64encode(audio_bytes.getvalue()).decode('utf-8')
467
+
468
  # Clean up temporary file
469
+ os.unlink(temp_audio_path)
470
+
471
+ return jsonify({
 
 
 
 
 
 
472
  "source_text": results["source_text"],
473
  "translated_text": results["translated_text"],
474
  "audio_base64": audio_base64,
475
  "sample_rate": sample_rate,
476
+ "status": "success"
477
  })
478
+
479
  except Exception as e:
480
+ return jsonify({
481
+ "error": str(e),
482
+ "status": "error"
483
+ }), 500
484
 
485
+ @app.route('/api/translate-text', methods=['POST'])
486
+ def api_translate_text():
487
+ """API endpoint for text-to-speech translation"""
488
  try:
489
+ data = request.json
490
+ if not data or 'text' not in data:
491
+ return jsonify({
492
+ "error": "No text provided"
493
+ }), 400
494
+
495
+ text = data['text']
496
+ source_lang = data.get('source_lang', 'English')
497
+ target_lang = data.get('target_lang', 'Tagalog')
498
+
499
+ # Get translator and update languages
500
+ translator = get_translator()
501
+ source_code = TalklasTranslator.LANGUAGE_MAPPING[source_lang]
502
+ target_code = TalklasTranslator.LANGUAGE_MAPPING[target_lang]
503
  translator.update_languages(source_code, target_code)
504
+
505
  # Process the text
506
+ results = translator.translate_text_only(text)
507
+
508
+ # Convert audio to base64 for transmission
509
+ sample_rate, audio_data = results["output_audio"]
510
+ audio_bytes = io.BytesIO()
511
+ sf.write(audio_bytes, audio_data, sample_rate, format='WAV')
512
+ audio_base64 = base64.b64encode(audio_bytes.getvalue()).decode('utf-8')
513
+
514
+ return jsonify({
515
  "source_text": results["source_text"],
516
  "translated_text": results["translated_text"],
517
  "audio_base64": audio_base64,
518
  "sample_rate": sample_rate,
519
+ "status": "success"
520
  })
521
+
522
  except Exception as e:
523
+ return jsonify({
524
+ "error": str(e),
525
+ "status": "error"
526
+ }), 500
527
+
528
+ @app.route('/api/languages', methods=['GET'])
529
+ def get_languages():
530
+ """Return available languages"""
531
+ return jsonify({
532
+ "languages": list(TalklasTranslator.LANGUAGE_MAPPING.keys())
533
+ })
534
+
535
+ # Keep the Gradio interface for users who directly access the Hugging Face space
536
+ def create_gradio_interface():
537
+ # Your existing Gradio interface code
538
+ # ...
539
 
540
+ # Run both the API server and Gradio
541
  if __name__ == "__main__":
542
+ # Launch Gradio in a separate thread
543
+ import threading
544
+ demo = create_gradio_interface()
545
+ threading.Thread(target=demo.launch, kwargs={"share": True, "debug": False}).start()
546
+
547
+ # Run the Flask server
548
+ app.run(host='0.0.0.0', port=7860)