Jerich commited on
Commit
352aa0e
·
verified ·
1 Parent(s): 6a1bf6c

Implement lazy loading of ML models to fix startup timeout on HF Spaces

Browse files
Files changed (1) hide show
  1. app.py +185 -50
app.py CHANGED
@@ -3,34 +3,36 @@ import os
3
  os.environ["HOME"] = "/root"
4
  os.environ["HF_HOME"] = "/tmp/hf_cache"
5
 
6
- # Debug: Print environment variables to confirm
7
  print("HOME environment variable:", os.environ.get("HOME"))
8
  print("HF_HOME environment variable:", os.environ.get("HF_HOME"))
9
 
10
- # Now import other libraries
11
  import torch
12
  import numpy as np
13
  import soundfile as sf
14
- from transformers import (
15
- AutoModelForSeq2SeqLM,
16
- AutoTokenizer,
17
- VitsModel,
18
- AutoProcessor,
19
- AutoModelForCTC,
20
- WhisperProcessor,
21
- WhisperForConditionalGeneration
22
- )
23
- from typing import Optional, Tuple, Dict
24
- from fastapi import FastAPI, UploadFile, File, Form, HTTPException
25
  from fastapi.responses import JSONResponse
26
  import tempfile
27
  import logging
 
 
 
 
 
 
28
 
29
  # Configure transformers logging to reduce verbosity
30
  logging.getLogger("transformers").setLevel(logging.ERROR)
31
 
32
  app = FastAPI(title="Talklas API")
33
 
 
 
 
 
 
34
  class TalklasTranslator:
35
  LANGUAGE_MAPPING = {
36
  "English": "eng",
@@ -52,72 +54,113 @@ class TalklasTranslator:
52
 
53
  def __init__(self, source_lang: str = "eng", target_lang: str = "tgl"):
54
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
 
55
  self.source_lang = source_lang
56
  self.target_lang = target_lang
57
  self.sample_rate = 16000
58
- self.mt_model = None # Initialize as None
59
- self.mt_tokenizer = None # Initialize as None
60
- self._initialize_stt_model()
61
- self._initialize_mt_model()
62
- self._initialize_tts_model()
63
- print("All models loaded successfully, starting FastAPI app")
 
 
 
 
 
 
 
64
 
65
  def _initialize_stt_model(self):
 
 
 
66
  try:
67
- print("Trying to load openai/whisper-tiny...")
 
68
  self.stt_processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
69
  self.stt_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
70
  self.stt_model.to(self.device)
71
- print("Loaded openai/whisper-tiny successfully")
 
 
72
  except Exception as e:
73
- raise RuntimeError(f"STT model initialization failed: {e}")
 
74
 
75
  def _initialize_mt_model(self):
 
 
 
76
  try:
77
- print("Trying to load facebook/nllb-200-distilled-600M...")
 
78
  self.mt_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
79
  self.mt_tokenizer = AutoTokenizer.from_pretrained(
80
  "facebook/nllb-200-distilled-600M",
81
  clean_up_tokenization_spaces=True
82
  )
83
  self.mt_model.to(self.device)
84
- print("Loaded NLLB translation model successfully")
 
 
85
  except Exception as e:
86
- print(f"Failed to load facebook/nllb-200-distilled-600M: {e}")
87
- print("Translation model not loaded, translation will return source text as a fallback")
88
- self.mt_model = None
89
- self.mt_tokenizer = None
90
 
91
  def _initialize_tts_model(self):
 
 
 
 
 
92
  try:
93
- print(f"Trying to load facebook/mms-tts-{self.target_lang}...")
 
94
  self.tts_model = VitsModel.from_pretrained(f"facebook/mms-tts-{self.target_lang}")
95
  self.tts_tokenizer = AutoTokenizer.from_pretrained(
96
  f"facebook/mms-tts-{self.target_lang}",
97
  clean_up_tokenization_spaces=True
98
  )
99
  self.tts_model.to(self.device)
100
- print(f"Loaded TTS model facebook/mms-tts-{self.target_lang} successfully")
101
- except Exception:
102
- print(f"Failed to load facebook/mms-tts-{self.target_lang}, falling back to English TTS")
103
- self.tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
104
- self.tts_tokenizer = AutoTokenizer.from_pretrained(
105
- "facebook/mms-tts-eng",
106
- clean_up_tokenization_spaces=True
107
- )
108
- self.tts_model.to(self.device)
109
- print("Loaded fallback TTS model facebook/mms-tts-eng successfully")
 
 
 
 
 
 
 
 
 
 
 
110
 
111
  def update_languages(self, source_lang: str, target_lang: str):
112
- print(f"Updating languages: source_lang={source_lang}, target_lang={target_lang}")
113
  self.source_lang = source_lang
114
  self.target_lang = target_lang
115
- print("Calling _initialize_tts_model...")
116
- self._initialize_tts_model()
117
- print("Languages updated successfully")
 
 
118
  return f"Languages updated to {source_lang} → {target_lang}"
119
 
120
  def speech_to_text(self, audio_path: str) -> str:
 
 
 
121
  waveform, sample_rate = sf.read(audio_path)
122
  if sample_rate != 16000:
123
  import librosa
@@ -129,9 +172,10 @@ class TalklasTranslator:
129
  return transcription
130
 
131
  def translate_text(self, text: str) -> str:
132
- if self.mt_model is None or self.mt_tokenizer is None:
133
- print("Translation model not loaded, returning source text as fallback")
134
  return text
 
135
  source_code = self.NLLB_LANGUAGE_CODES[self.source_lang]
136
  target_code = self.NLLB_LANGUAGE_CODES[self.target_lang]
137
  self.mt_tokenizer.src_lang = source_code
@@ -145,6 +189,9 @@ class TalklasTranslator:
145
  return self.mt_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
146
 
147
  def text_to_speech(self, text: str) -> Tuple[int, np.ndarray]:
 
 
 
148
  inputs = self.tts_tokenizer(text, return_tensors="pt", clean_up_tokenization_spaces=True).to(self.device)
149
  with torch.no_grad():
150
  output = self.tts_model(**inputs)
@@ -173,11 +220,75 @@ class TalklasTranslator:
173
  "performance": "Translation successful"
174
  }
175
 
 
176
  translator = TalklasTranslator()
177
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  @app.get("/health")
179
  async def health_check():
180
- return {"status": "healthy"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
  @app.post("/update-languages")
183
  async def update_languages(source_lang: str = Form(...), target_lang: str = Form(...)):
@@ -196,6 +307,17 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
196
  if source_lang not in TalklasTranslator.LANGUAGE_MAPPING or target_lang not in TalklasTranslator.LANGUAGE_MAPPING:
197
  raise HTTPException(status_code=400, detail="Invalid language selected")
198
 
 
 
 
 
 
 
 
 
 
 
 
199
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
200
  temp_file.write(await audio.read())
201
  temp_path = temp_file.name
@@ -207,6 +329,8 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
207
  )
208
  result = translator.translate_speech(temp_path)
209
  return JSONResponse(content=result)
 
 
210
  finally:
211
  os.unlink(temp_path)
212
 
@@ -217,15 +341,26 @@ async def translate_text(text: str = Form(...), source_lang: str = Form(...), ta
217
  if source_lang not in TalklasTranslator.LANGUAGE_MAPPING or target_lang not in TalklasTranslator.LANGUAGE_MAPPING:
218
  raise HTTPException(status_code=400, detail="Invalid language selected")
219
 
 
 
 
 
 
 
 
220
  translator.update_languages(
221
  TalklasTranslator.LANGUAGE_MAPPING[source_lang],
222
  TalklasTranslator.LANGUAGE_MAPPING[target_lang]
223
  )
224
- result = translator.translate_text_only(text)
225
- return JSONResponse(content=result)
 
 
 
 
226
 
227
  if __name__ == "__main__":
228
  import uvicorn
229
- print("Starting Uvicorn server...")
230
  uvicorn.run(app, host="0.0.0.0", port=8000)
231
- print("Uvicorn server started successfully")
 
3
  os.environ["HOME"] = "/root"
4
  os.environ["HF_HOME"] = "/tmp/hf_cache"
5
 
6
+ # Print environment variables to confirm
7
  print("HOME environment variable:", os.environ.get("HOME"))
8
  print("HF_HOME environment variable:", os.environ.get("HF_HOME"))
9
 
10
+ # Import libraries
11
  import torch
12
  import numpy as np
13
  import soundfile as sf
14
+ from typing import Optional, Tuple, Dict, Any
15
+ from fastapi import FastAPI, UploadFile, File, Form, HTTPException, BackgroundTasks
 
 
 
 
 
 
 
 
 
16
  from fastapi.responses import JSONResponse
17
  import tempfile
18
  import logging
19
+ from threading import Thread
20
+ import time
21
+
22
+ # Configure logging
23
+ logging.basicConfig(level=logging.INFO)
24
+ logger = logging.getLogger("talklas-api")
25
 
26
  # Configure transformers logging to reduce verbosity
27
  logging.getLogger("transformers").setLevel(logging.ERROR)
28
 
29
  app = FastAPI(title="Talklas API")
30
 
31
+ # Global variables to track model loading status
32
+ is_loading = False
33
+ loading_complete = False
34
+ loading_error = None
35
+
36
  class TalklasTranslator:
37
  LANGUAGE_MAPPING = {
38
  "English": "eng",
 
54
 
55
  def __init__(self, source_lang: str = "eng", target_lang: str = "tgl"):
56
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
57
+ logger.info(f"Using device: {self.device}")
58
  self.source_lang = source_lang
59
  self.target_lang = target_lang
60
  self.sample_rate = 16000
61
+
62
+ # Initialize all models as None - will be lazy loaded
63
+ self.stt_processor = None
64
+ self.stt_model = None
65
+ self.mt_model = None
66
+ self.mt_tokenizer = None
67
+ self.tts_model = None
68
+ self.tts_tokenizer = None
69
+
70
+ # Flags to track which models are loaded
71
+ self.stt_loaded = False
72
+ self.mt_loaded = False
73
+ self.tts_loaded = False
74
 
75
  def _initialize_stt_model(self):
76
+ if self.stt_loaded:
77
+ return True
78
+
79
  try:
80
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration
81
+ logger.info("Loading STT model: openai/whisper-tiny...")
82
  self.stt_processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
83
  self.stt_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
84
  self.stt_model.to(self.device)
85
+ self.stt_loaded = True
86
+ logger.info("STT model loaded successfully")
87
+ return True
88
  except Exception as e:
89
+ logger.error(f"STT model initialization failed: {e}")
90
+ return False
91
 
92
  def _initialize_mt_model(self):
93
+ if self.mt_loaded:
94
+ return True
95
+
96
  try:
97
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
98
+ logger.info("Loading MT model: facebook/nllb-200-distilled-600M...")
99
  self.mt_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
100
  self.mt_tokenizer = AutoTokenizer.from_pretrained(
101
  "facebook/nllb-200-distilled-600M",
102
  clean_up_tokenization_spaces=True
103
  )
104
  self.mt_model.to(self.device)
105
+ self.mt_loaded = True
106
+ logger.info("MT model loaded successfully")
107
+ return True
108
  except Exception as e:
109
+ logger.error(f"MT model initialization failed: {e}")
110
+ return False
 
 
111
 
112
  def _initialize_tts_model(self):
113
+ if self.tts_loaded:
114
+ # Check if we need to reload for a different language
115
+ if hasattr(self, 'current_tts_lang') and self.current_tts_lang == self.target_lang:
116
+ return True
117
+
118
  try:
119
+ from transformers import VitsModel, AutoTokenizer
120
+ logger.info(f"Loading TTS model: facebook/mms-tts-{self.target_lang}...")
121
  self.tts_model = VitsModel.from_pretrained(f"facebook/mms-tts-{self.target_lang}")
122
  self.tts_tokenizer = AutoTokenizer.from_pretrained(
123
  f"facebook/mms-tts-{self.target_lang}",
124
  clean_up_tokenization_spaces=True
125
  )
126
  self.tts_model.to(self.device)
127
+ self.tts_loaded = True
128
+ self.current_tts_lang = self.target_lang
129
+ logger.info(f"TTS model loaded successfully for {self.target_lang}")
130
+ return True
131
+ except Exception as e:
132
+ logger.error(f"Failed to load TTS model for {self.target_lang}: {e}")
133
+ try:
134
+ logger.info("Falling back to English TTS model...")
135
+ self.tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
136
+ self.tts_tokenizer = AutoTokenizer.from_pretrained(
137
+ "facebook/mms-tts-eng",
138
+ clean_up_tokenization_spaces=True
139
+ )
140
+ self.tts_model.to(self.device)
141
+ self.tts_loaded = True
142
+ self.current_tts_lang = "eng"
143
+ logger.info("Loaded fallback TTS model successfully")
144
+ return True
145
+ except Exception as fallback_error:
146
+ logger.error(f"Fallback TTS model initialization failed: {fallback_error}")
147
+ return False
148
 
149
  def update_languages(self, source_lang: str, target_lang: str):
150
+ logger.info(f"Updating languages: source_lang={source_lang}, target_lang={target_lang}")
151
  self.source_lang = source_lang
152
  self.target_lang = target_lang
153
+
154
+ # Only reload TTS model if target language changed
155
+ if hasattr(self, 'current_tts_lang') and self.current_tts_lang != target_lang:
156
+ self._initialize_tts_model()
157
+
158
  return f"Languages updated to {source_lang} → {target_lang}"
159
 
160
  def speech_to_text(self, audio_path: str) -> str:
161
+ if not self._initialize_stt_model():
162
+ raise Exception("STT model failed to initialize")
163
+
164
  waveform, sample_rate = sf.read(audio_path)
165
  if sample_rate != 16000:
166
  import librosa
 
172
  return transcription
173
 
174
  def translate_text(self, text: str) -> str:
175
+ if not self._initialize_mt_model():
176
+ logger.warning("Translation model not loaded, returning source text as fallback")
177
  return text
178
+
179
  source_code = self.NLLB_LANGUAGE_CODES[self.source_lang]
180
  target_code = self.NLLB_LANGUAGE_CODES[self.target_lang]
181
  self.mt_tokenizer.src_lang = source_code
 
189
  return self.mt_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
190
 
191
  def text_to_speech(self, text: str) -> Tuple[int, np.ndarray]:
192
+ if not self._initialize_tts_model():
193
+ raise Exception("TTS model failed to initialize")
194
+
195
  inputs = self.tts_tokenizer(text, return_tensors="pt", clean_up_tokenization_spaces=True).to(self.device)
196
  with torch.no_grad():
197
  output = self.tts_model(**inputs)
 
220
  "performance": "Translation successful"
221
  }
222
 
223
+ # Create translator instance but don't load models yet
224
  translator = TalklasTranslator()
225
 
226
+ def background_load_model():
227
+ """Background task to load models"""
228
+ global is_loading, loading_complete, loading_error
229
+
230
+ try:
231
+ is_loading = True
232
+ # Load STT model first to make health check pass quickly
233
+ success = translator._initialize_stt_model()
234
+ if not success:
235
+ loading_error = "Failed to load STT model"
236
+ return
237
+
238
+ # Then load MT model
239
+ success = translator._initialize_mt_model()
240
+ if not success:
241
+ logger.warning("MT model failed to load, will use fallback")
242
+
243
+ # Finally load TTS model
244
+ success = translator._initialize_tts_model()
245
+ if not success:
246
+ loading_error = "Failed to load TTS model"
247
+ return
248
+
249
+ loading_complete = True
250
+ logger.info("All models loaded successfully in background")
251
+
252
+ except Exception as e:
253
+ loading_error = str(e)
254
+ logger.error(f"Error loading models in background: {e}")
255
+ finally:
256
+ is_loading = False
257
+
258
+ # Start background loading of models
259
+ Thread(target=background_load_model, daemon=True).start()
260
+
261
  @app.get("/health")
262
  async def health_check():
263
+ """Health check endpoint that returns detailed loading status"""
264
+ global is_loading, loading_complete, loading_error
265
+
266
+ # Check if at least the STT model is loaded (minimum requirement)
267
+ if translator.stt_loaded:
268
+ status = "healthy"
269
+ elif loading_error:
270
+ status = "error"
271
+ elif is_loading:
272
+ status = "loading"
273
+ else:
274
+ status = "not_initialized"
275
+
276
+ response = {
277
+ "status": status,
278
+ "models": {
279
+ "stt": "loaded" if translator.stt_loaded else "not_loaded",
280
+ "mt": "loaded" if translator.mt_loaded else "not_loaded",
281
+ "tts": "loaded" if translator.tts_loaded else "not_loaded",
282
+ },
283
+ "loading": is_loading,
284
+ "complete": loading_complete
285
+ }
286
+
287
+ if loading_error:
288
+ response["error"] = loading_error
289
+
290
+ # Hugging Face Spaces considers a service healthy if the health endpoint returns a 200 status
291
+ return response
292
 
293
  @app.post("/update-languages")
294
  async def update_languages(source_lang: str = Form(...), target_lang: str = Form(...)):
 
307
  if source_lang not in TalklasTranslator.LANGUAGE_MAPPING or target_lang not in TalklasTranslator.LANGUAGE_MAPPING:
308
  raise HTTPException(status_code=400, detail="Invalid language selected")
309
 
310
+ # Check if models are loaded
311
+ if not translator.stt_loaded:
312
+ if loading_error:
313
+ raise HTTPException(status_code=500, detail=f"Model loading failed: {loading_error}")
314
+ elif is_loading:
315
+ raise HTTPException(status_code=503, detail="Models are still loading, please try again later")
316
+ else:
317
+ # Try to load models now
318
+ if not translator._initialize_stt_model():
319
+ raise HTTPException(status_code=500, detail="Failed to initialize STT model")
320
+
321
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
322
  temp_file.write(await audio.read())
323
  temp_path = temp_file.name
 
329
  )
330
  result = translator.translate_speech(temp_path)
331
  return JSONResponse(content=result)
332
+ except Exception as e:
333
+ raise HTTPException(status_code=500, detail=f"Translation failed: {str(e)}")
334
  finally:
335
  os.unlink(temp_path)
336
 
 
341
  if source_lang not in TalklasTranslator.LANGUAGE_MAPPING or target_lang not in TalklasTranslator.LANGUAGE_MAPPING:
342
  raise HTTPException(status_code=400, detail="Invalid language selected")
343
 
344
+ # Check if models are loaded
345
+ if not translator.mt_loaded or not translator.tts_loaded:
346
+ if loading_error:
347
+ raise HTTPException(status_code=500, detail=f"Model loading failed: {loading_error}")
348
+ elif is_loading:
349
+ raise HTTPException(status_code=503, detail="Models are still loading, please try again later")
350
+
351
  translator.update_languages(
352
  TalklasTranslator.LANGUAGE_MAPPING[source_lang],
353
  TalklasTranslator.LANGUAGE_MAPPING[target_lang]
354
  )
355
+
356
+ try:
357
+ result = translator.translate_text_only(text)
358
+ return JSONResponse(content=result)
359
+ except Exception as e:
360
+ raise HTTPException(status_code=500, detail=f"Translation failed: {str(e)}")
361
 
362
  if __name__ == "__main__":
363
  import uvicorn
364
+ logger.info("Starting Uvicorn server...")
365
  uvicorn.run(app, host="0.0.0.0", port=8000)
366
+ logger.info("Uvicorn server started successfully")