File size: 30,763 Bytes
383520d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68b189e
383520d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74c62a2
 
 
9b31d36
74c62a2
 
383520d
23beeea
2acc39d
 
 
 
 
 
 
 
74c62a2
 
 
383520d
 
 
2acc39d
383520d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2acc39d
383520d
 
 
 
 
 
7e09504
383520d
 
7e09504
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
383520d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2acc39d
383520d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2acc39d
383520d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23beeea
 
383520d
 
23beeea
383520d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74c62a2
 
 
 
 
be2a132
 
 
 
 
383520d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23beeea
383520d
 
 
 
 
23beeea
383520d
 
 
 
 
 
 
68b189e
 
 
 
 
 
 
 
 
 
383520d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
"""
CSM-1B TTS API main application.
Provides an OpenAI-compatible API for the CSM-1B text-to-speech model.
"""
import os
import time
import tempfile
import logging
from logging.handlers import RotatingFileHandler
import traceback
import asyncio
import glob
import torch
import uvicorn
from contextlib import asynccontextmanager
from fastapi import FastAPI, Depends, HTTPException, Request, Response
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import RedirectResponse, FileResponse
from fastapi.staticfiles import StaticFiles
from app.api.routes import router as api_router
from app.db import connect_to_mongo, close_mongo_connection

# Setup logging
os.makedirs("logs", exist_ok=True)
log_format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'

# Console handler
console_handler = logging.StreamHandler()
console_handler.setFormatter(logging.Formatter(log_format))

# File handler
file_handler = RotatingFileHandler(
    "logs/csm_tts_api.log", 
    maxBytes=10*1024*1024,  # 10MB
    backupCount=5
)
file_handler.setFormatter(logging.Formatter(log_format))

# Configure root logger
logging.basicConfig(
    level=logging.INFO,
    format=log_format,
    handlers=[console_handler, file_handler]
)
logger = logging.getLogger(__name__)
logger.info("Starting CSM-1B TTS API")

@asynccontextmanager
async def lifespan(app: FastAPI):
    """Application lifespan manager for startup and shutdown events."""
    # STARTUP EVENT
    logger.info("Starting application initialization")
    app.state.startup_time = time.time()
    app.state.generator = None  # Will be populated later if model loads
    app.state.logger = logger  # Make logger available to routes
    
    # Initialize database
    from app.db import init_db
    logger.info("Initializing database...")
    await init_db()
    logger.info("Database initialized")
    
    # Create necessary directories - use persistent locations
    APP_DIR = "/app"
    os.makedirs(os.path.join(APP_DIR, "models"), exist_ok=True)
    os.makedirs(os.path.join(APP_DIR, "tokenizers"), exist_ok=True)
    os.makedirs(os.path.join(APP_DIR, "voice_memories"), exist_ok=True)
    os.makedirs(os.path.join(APP_DIR, "voice_references"), exist_ok=True)
    os.makedirs(os.path.join(APP_DIR, "voice_profiles"), exist_ok=True)
    os.makedirs(os.path.join(APP_DIR, "cloned_voices"), exist_ok=True)
    os.makedirs(os.path.join(APP_DIR, "audio_cache"), exist_ok=True)
    os.makedirs(os.path.join(APP_DIR, "static"), exist_ok=True)
    os.makedirs(os.path.join(APP_DIR, "storage/audio"), exist_ok=True)  # For audio files
    os.makedirs(os.path.join(APP_DIR, "storage/text"), exist_ok=True)   # For text files
    os.makedirs(os.path.join(APP_DIR, "audiobooks"), exist_ok=True)  # Add audiobooks directory
    
    # Set tokenizer cache
    try:
        os.environ["TRANSFORMERS_CACHE"] = os.path.join(APP_DIR, "tokenizers")
        logger.info(f"Set tokenizer cache to: {os.environ['TRANSFORMERS_CACHE']}")
    except Exception as e:
        logger.error(f"Error setting tokenizer cache: {e}")
    
    # Install additional dependencies if needed
    try:
        import scipy
        import soundfile
        logger.info("Audio processing dependencies available")
    except ImportError as e:
        logger.warning(f"Audio processing dependency missing: {e}. Some audio enhancements may not work.")
        logger.warning("Consider installing: pip install scipy soundfile")
    
    # Check CUDA availability
    cuda_available = torch.cuda.is_available()
    if cuda_available:
        device_count = torch.cuda.device_count()
        device_name = torch.cuda.get_device_name(0) if device_count > 0 else "unknown"
        logger.info(f"CUDA is available: {device_count} device(s). Using {device_name}")
        # Report CUDA memory
        if hasattr(torch.cuda, 'get_device_properties'):
            total_memory = torch.cuda.get_device_properties(0).total_memory
            logger.info(f"Total CUDA memory: {total_memory / (1024**3):.2f} GB")
    else:
        logger.warning("CUDA is not available. Using CPU (this will be slow)")
    
    # Determine device and device mapping
    device = "cuda" if cuda_available else "cpu"
    device_map = os.environ.get("CSM_DEVICE_MAP", None)  # Options: "auto", "balanced", "sequential"
    if device_map and cuda_available:
        if torch.cuda.device_count() > 1:
            logger.info(f"Using device mapping strategy: {device_map} across {torch.cuda.device_count()} GPUs")
        else:
            logger.info("Device mapping requested but only one GPU available, ignoring device_map")
            device_map = None
    else:
        device_map = None
    
    logger.info(f"Using device: {device}")
    app.state.device = device
    app.state.device_map = device_map
    
    # Check if model file exists
    model_path = os.path.join(APP_DIR, "models", "ckpt.pt")
    if not os.path.exists(model_path):
        # Try to download at runtime if not present
        logger.info("Model not found. Attempting to download...")
        try:
            from huggingface_hub import hf_hub_download, login
            # Check for token in environment
            hf_token = os.environ.get("HF_TOKEN", "").strip()
            if hf_token:
                logger.info("Logging in to Hugging Face using provided token")
                try:
                    login(token=hf_token)
                    logger.info("Successfully logged in to Hugging Face")
                except Exception as e:
                    logger.error(f"Error logging in to Hugging Face: {e}")
                    logger.error("Will attempt to download model without authentication")
            else:
                logger.warning("No Hugging Face token provided. Some models may not be accessible")

            # Attempt to download the model
            try:
                logger.info("Downloading CSM-1B model from Hugging Face...")
                download_start = time.time()
                model_path = hf_hub_download(
                    repo_id="sesame/csm-1b", 
                    filename="ckpt.pt",
                    local_dir=os.path.join(APP_DIR, "models"),
                    token=hf_token if hf_token else None
                )
                download_time = time.time() - download_start
                logger.info(f"Model downloaded to {model_path} in {download_time:.2f} seconds")
            except Exception as e:
                error_stack = traceback.format_exc()
                logger.error(f"Error downloading model: {str(e)}\n{error_stack}")
                logger.error("Please ensure you have a valid Hugging Face token with access to the model")
                logger.error("Starting without model - API will return 503 Service Unavailable")
        except Exception as e:
            error_stack = traceback.format_exc()
            logger.error(f"Error downloading model: {str(e)}\n{error_stack}")
            logger.error("Please build the image with HF_TOKEN to download the model")
            logger.error("Starting without model - API will return 503 Service Unavailable")
    else:
        logger.info(f"Found existing model at {model_path}")
        logger.info(f"Model size: {os.path.getsize(model_path) / (1024 * 1024):.2f} MB")
    
    # Load the model
    try:
        logger.info("Loading CSM-1B model...")
        load_start = time.time()
        from app.generator import load_csm_1b
        app.state.generator = load_csm_1b(model_path, device, device_map)
        load_time = time.time() - load_start
        logger.info(f"Model loaded successfully in {load_time:.2f} seconds")
        
        # Store sample rate in app state
        app.state.sample_rate = app.state.generator.sample_rate
        logger.info(f"Model sample rate: {app.state.sample_rate} Hz")
        
        # Initialize voice enhancement system (this will create proper voice profiles)
        logger.info("Initializing voice enhancement system...")
        try:
            from app.voice_enhancement import initialize_voice_profiles, save_voice_profiles
            initialize_voice_profiles()
            app.state.voice_enhancement_enabled = True
            logger.info("Voice profiles initialized successfully")
        except Exception as e:
            error_stack = traceback.format_exc()
            logger.error(f"Error initializing voice profiles: {str(e)}\n{error_stack}")
            logger.warning("Voice enhancement features will be limited")
            app.state.voice_enhancement_enabled = False
        
        # Initialize voice memory system for consistent generation
        logger.info("Initializing voice memory system...")
        try:
            from app.voice_memory import initialize_voices
            initialize_voices(app.state.sample_rate)
            app.state.voice_memory_enabled = True
            logger.info("Voice memory system initialized")
        except Exception as e:
            logger.warning(f"Error initializing voice memory: {e}")
            app.state.voice_memory_enabled = False
        
        # Initialize voice cloning system
        try:
            logger.info("Initializing voice cloning system...")
            from app.voice_cloning import VoiceCloner, CLONED_VOICES_DIR
            # Update the cloned voices directory to use the persistent volume
            app.state.cloned_voices_dir = os.path.join(APP_DIR, "cloned_voices")  # Store path in app state for access
            os.makedirs(app.state.cloned_voices_dir, exist_ok=True)
            CLONED_VOICES_DIR = app.state.cloned_voices_dir  # Update the module constant
            
            # Initialize the voice cloner with proper device
            app.state.voice_cloner = VoiceCloner(app.state.generator, device=device)
            
            # Make sure existing voices are loaded
            app.state.voice_cloner._load_existing_voices()
            
            # Log the available voices
            cloned_voices = app.state.voice_cloner.list_voices()
            logger.info(f"Voice cloning system initialized with {len(cloned_voices)} existing voices")
            for voice in cloned_voices:
                logger.info(f"  - {voice.name} (ID: {voice.id}, Speaker ID: {voice.speaker_id})")
            
            # Flag for voice cloning availability
            app.state.voice_cloning_enabled = True
        except Exception as e:
            error_stack = traceback.format_exc()
            logger.error(f"Error initializing voice cloning: {e}\n{error_stack}")
            logger.warning("Voice cloning features will not be available")
            app.state.voice_cloning_enabled = False
        
        # Create prompt templates for consistent generation
        logger.info("Setting up prompt engineering templates...")
        try:
            from app.prompt_engineering import initialize_templates
            app.state.prompt_templates = initialize_templates()
            logger.info("Prompt templates initialized")
        except Exception as e:
            error_stack = traceback.format_exc()
            logger.error(f"Error initializing prompt templates: {e}\n{error_stack}")
            logger.warning("Voice consistency features will be limited")
        
        # Generate voice reference samples (runs in background to avoid blocking startup)
        async def generate_samples_async():
            try:
                logger.info("Starting voice reference generation (background task)...")
                from app.voice_enhancement import create_voice_segments
                create_voice_segments(app.state)
                logger.info("Voice reference generation completed")
            except Exception as e:
                error_stack = traceback.format_exc()
                logger.error(f"Error in voice reference generation: {str(e)}\n{error_stack}")
        
        # Start as a background task
        asyncio.create_task(generate_samples_async())
        
        # Initialize voice cache for all voices (standard + cloned)
        app.state.voice_cache = {}
        
        # Add standard voices
        standard_voices = ["alloy", "echo", "fable", "onyx", "nova", "shimmer"]
        for voice in standard_voices:
            app.state.voice_cache[voice] = []
        
        # Add cloned voices to cache if they exist
        if app.state.voice_cloning_enabled and hasattr(app.state, "voice_cloner"):
            for voice in app.state.voice_cloner.list_voices():
                app.state.voice_cache[voice.id] = []
                # Also add by name for more flexible lookup
                app.state.voice_cache[voice.name] = []
        
        # Create mapping from voice name/id to speaker_id for easy lookup
        app.state.voice_speaker_map = {
            "alloy": 0, "echo": 1, "fable": 2, "onyx": 3, "nova": 4, "shimmer": 5
        }
        
        # Add cloned voices to the speaker map
        if app.state.voice_cloning_enabled and hasattr(app.state, "voice_cloner"):
            for voice in app.state.voice_cloner.list_voices():
                app.state.voice_speaker_map[voice.id] = voice.speaker_id
                app.state.voice_speaker_map[voice.name] = voice.speaker_id
                app.state.voice_speaker_map[str(voice.speaker_id)] = voice.speaker_id
        
        # Compile voice information for API
        app.state.available_voices = standard_voices.copy()
        if app.state.voice_cloning_enabled and hasattr(app.state, "voice_cloner"):
            for voice in app.state.voice_cloner.list_voices():
                app.state.available_voices.append(voice.id)
                app.state.available_voices.append(voice.name)
        
        # Store model information for API endpoints
        app.state.model_info = {
            "name": "CSM-1B",
            "device": device,
            "device_map": device_map,
            "sample_rate": app.state.sample_rate,
            "standard_voices": standard_voices,
            "cloned_voices": [v.id for v in app.state.voice_cloner.list_voices()] if app.state.voice_cloning_enabled else [],
            "voice_enhancement_enabled": app.state.voice_enhancement_enabled,
            "voice_memory_enabled": app.state.voice_memory_enabled,
            "voice_cloning_enabled": app.state.voice_cloning_enabled,
            "streaming_enabled": True
        }
        
        # Create a function to access all voices in a standardized format
        def get_all_available_voices():
            """Helper function to get all available voices for API endpoints"""
            # Standard voices with fixed descriptions
            all_voices = [
                {"voice_id": "alloy", "name": "Alloy", "description": "Balanced and natural"},
                {"voice_id": "echo", "name": "Echo", "description": "Resonant and deeper"},
                {"voice_id": "fable", "name": "Fable", "description": "Bright and higher-pitched"},
                {"voice_id": "onyx", "name": "Onyx", "description": "Deep and authoritative"},
                {"voice_id": "nova", "name": "Nova", "description": "Warm and smooth"},
                {"voice_id": "shimmer", "name": "Shimmer", "description": "Light and airy"}
            ]
            
            # Add cloned voices if available
            if app.state.voice_cloning_enabled and hasattr(app.state, "voice_cloner"):
                for voice in app.state.voice_cloner.list_voices():
                    all_voices.append({
                        "voice_id": voice.id,
                        "name": voice.name,
                        "description": voice.description or f"Cloned voice: {voice.name}"
                    })
            
            return all_voices
        
        app.state.get_all_voices = get_all_available_voices
        
        # Add helper function to lookup voice info
        def get_voice_info(voice_identifier):
            """Look up voice information based on name, ID, or speaker_id"""
            # Check standard voices
            if voice_identifier in standard_voices:
                return {
                    "type": "standard",
                    "voice_id": voice_identifier,
                    "name": voice_identifier,
                    "speaker_id": standard_voices.index(voice_identifier)
                }
            
            # Look for cloned voice
            if not app.state.voice_cloning_enabled or not hasattr(app.state, "voice_cloner"):
                return None
                
            # Check by ID
            if voice_identifier in app.state.voice_cloner.cloned_voices:
                voice = app.state.voice_cloner.cloned_voices[voice_identifier]
                return {
                    "type": "cloned",
                    "voice_id": voice.id,
                    "name": voice.name,
                    "speaker_id": voice.speaker_id
                }
                
            # Check by name
            for v_id, voice in app.state.voice_cloner.cloned_voices.items():
                if voice.name == voice_identifier:
                    return {
                        "type": "cloned",
                        "voice_id": voice.id,
                        "name": voice.name,
                        "speaker_id": voice.speaker_id
                    }
                    
            # Check by speaker_id (string representation)
            try:
                speaker_id = int(voice_identifier)
                # Check if any cloned voice has this speaker_id
                for v_id, voice in app.state.voice_cloner.cloned_voices.items():
                    if voice.speaker_id == speaker_id:
                        return {
                            "type": "cloned", 
                            "voice_id": voice.id,
                            "name": voice.name,
                            "speaker_id": speaker_id
                        }
            except (ValueError, TypeError):
                pass
                
            # No match found
            return None
            
        app.state.get_voice_info = get_voice_info
        
        # Set up audio cache
        app.state.audio_cache_enabled = os.environ.get("ENABLE_AUDIO_CACHE", "true").lower() == "true"
        if app.state.audio_cache_enabled:
            app.state.audio_cache_dir = os.path.join(APP_DIR, "audio_cache")
            logger.info(f"Audio cache enabled, cache dir: {app.state.audio_cache_dir}")
        
        # Log GPU utilization after model loading
        if cuda_available:
            memory_allocated = torch.cuda.memory_allocated() / (1024**3)
            memory_reserved = torch.cuda.memory_reserved() / (1024**3)
            logger.info(f"GPU memory: {memory_allocated:.2f} GB allocated, {memory_reserved:.2f} GB reserved")
            
            if torch.cuda.device_count() > 1 and device_map:
                logger.info("Multi-GPU setup active with the following memory usage:")
                for i in range(torch.cuda.device_count()):
                    memory_allocated = torch.cuda.memory_allocated(i) / (1024**3)
                    memory_reserved = torch.cuda.memory_reserved(i) / (1024**3)
                    logger.info(f"GPU {i}: {memory_allocated:.2f} GB allocated, {memory_reserved:.2f} GB reserved")
        
        # Set up scheduled tasks
        try:
            # Create a background task for periodic voice profile backup
            async def periodic_voice_profile_backup(interval_hours=6):
                """Periodically save voice profiles to persistent storage."""
                while True:
                    try:
                        # Wait for the specified interval
                        await asyncio.sleep(interval_hours * 3600)
                        
                        # Log the backup
                        timestamp = time.strftime("%Y-%m-%d %H:%M:%S")
                        logger.info(f"Scheduled voice profile backup started at {timestamp}")
                        
                        # Save voice profiles
                        if hasattr(app.state, "voice_enhancement_enabled") and app.state.voice_enhancement_enabled:
                            from app.voice_enhancement import save_voice_profiles
                            save_voice_profiles()
                            logger.info("Voice profiles saved successfully")
                            
                        # Save voice memories
                        if hasattr(app.state, "voice_memory_enabled") and app.state.voice_memory_enabled:
                            from app.voice_memory import VOICE_MEMORIES
                            for voice_name, memory in VOICE_MEMORIES.items():
                                memory.save()
                            logger.info("Voice memories saved successfully")
                            
                    except Exception as e:
                        logger.error(f"Error in periodic voice profile backup: {e}")
            
            # Start the scheduled task
            asyncio.create_task(periodic_voice_profile_backup(interval_hours=6))
            logger.info("Started scheduled voice profile backup task")
            
        except Exception as e:
            logger.warning(f"Failed to set up scheduled tasks: {e}")
        
        logger.info(f"CSM-1B TTS API is ready on {device} with sample rate {app.state.sample_rate}")
        logger.info(f"Standard voices: {standard_voices}")
        cloned_count = len(app.state.voice_cloner.list_voices()) if app.state.voice_cloning_enabled else 0
        logger.info(f"Cloned voices: {cloned_count}")
        
    except Exception as e:
        error_stack = traceback.format_exc()
        logger.error(f"Error loading model: {str(e)}\n{error_stack}")
        app.state.generator = None
    
    # Calculate total startup time
    startup_time = time.time() - app.state.startup_time
    logger.info(f"Application startup completed in {startup_time:.2f} seconds")
    
    yield  # This is where the application runs
    
    # SHUTDOWN EVENT
    logger.info("Application shutdown initiated")
    
    # Clean up model resources
    if hasattr(app.state, "generator") and app.state.generator is not None:
        try:
            # Clean up CUDA memory if available
            if torch.cuda.is_available():
                logger.info("Clearing CUDA cache")
                torch.cuda.empty_cache()
                torch.cuda.synchronize()
        except Exception as e:
            logger.error(f"Error during CUDA cleanup: {e}")
    
    # Save voice profiles if they've been updated
    try:
        if hasattr(app.state, "voice_enhancement_enabled") and app.state.voice_enhancement_enabled:
            from app.voice_enhancement import save_voice_profiles
            logger.info("Saving voice profiles...")
            save_voice_profiles()
            logger.info("Voice profiles saved successfully")
    except Exception as e:
        logger.error(f"Error saving voice profiles: {e}")
    
    # Save voice memories if they've been updated
    try:
        if hasattr(app.state, "voice_memory_enabled") and app.state.voice_memory_enabled:
            from app.voice_memory import VOICE_MEMORIES
            logger.info("Saving voice memories...")
            for voice_name, memory in VOICE_MEMORIES.items():
                memory.save()
            logger.info("Voice memories saved successfully")
    except Exception as e:
        logger.error(f"Error saving voice memories: {e}")
    
    # Clean up any temporary files
    try:
        for temp_file in glob.glob(os.path.join(tempfile.gettempdir(), "csm_tts_*")):
            try:
                os.remove(temp_file)
                logger.info(f"Removed temporary file: {temp_file}")
            except:
                pass
    except Exception as e:
        logger.warning(f"Error cleaning up temporary files: {e}")
    
    logger.info("Application shutdown complete")
    
# Initialize FastAPI app
app = FastAPI(
    title="CSM-1B TTS API",
    description="OpenAI-compatible TTS API using the CSM-1B model from Sesame",
    version="1.0.0",
    lifespan=lifespan
)

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Create static and other required directories
os.makedirs("/app/static", exist_ok=True)
os.makedirs("/app/cloned_voices", exist_ok=True)

# Mount the static files directory
app.mount("/static", StaticFiles(directory="/app/static"), name="static")

# Include routers
app.include_router(api_router, prefix="/api/v1")

# Add OpenAI compatible route
app.include_router(api_router, prefix="/v1")

# Add voice cloning routes
from app.api.voice_cloning_routes import router as voice_cloning_router
app.include_router(voice_cloning_router, prefix="/api/v1")
app.include_router(voice_cloning_router, prefix="/v1")

# Add streaming routes
from app.api.streaming import router as streaming_router
app.include_router(streaming_router, prefix="/api/v1")
app.include_router(streaming_router, prefix="/v1")

# Add audiobook routes
from app.api.audiobook_routes import router as audiobook_router
app.include_router(audiobook_router, prefix="/api/v1")
app.include_router(audiobook_router, prefix="/v1")

# Add realtime conversation routes
from app.api.realtime import router as realtime_router
app.include_router(realtime_router, prefix="/api/v1")
app.include_router(realtime_router, prefix="/v1")

# Middleware for request timing
@app.middleware("http")
async def add_process_time_header(request: Request, call_next):
    """Middleware to track request processing time."""
    start_time = time.time()
    response = await call_next(request)
    process_time = time.time() - start_time
    response.headers["X-Process-Time"] = str(process_time)
    logger.debug(f"Request to {request.url.path} processed in {process_time:.3f} seconds")
    return response

# Health check endpoint
@app.get("/health", include_in_schema=False)
async def health_check(request: Request):
    """Health check endpoint that returns the status of the API."""
    model_status = "healthy" if hasattr(request.app.state, "generator") and request.app.state.generator is not None else "unhealthy"
    uptime = time.time() - getattr(request.app.state, "startup_time", time.time())

    # Get voice information
    standard_voices = ["alloy", "echo", "fable", "onyx", "nova", "shimmer"]
    cloned_voices = []
    
    if hasattr(request.app.state, "voice_cloner") and request.app.state.voice_cloner is not None:
        cloned_voices = [
            {"id": v.id, "name": v.name, "speaker_id": v.speaker_id}
            for v in request.app.state.voice_cloner.list_voices()
        ]
    
    # Get CUDA memory stats if available
    cuda_stats = None
    if torch.cuda.is_available():
        cuda_stats = {
            "allocated_gb": torch.cuda.memory_allocated() / (1024**3),
            "reserved_gb": torch.cuda.memory_reserved() / (1024**3)
        }
    
    return {
        "status": model_status,
        "uptime": f"{uptime:.2f} seconds",
        "device": getattr(request.app.state, "device", "unknown"),
        "model": "CSM-1B",
        "standard_voices": standard_voices,
        "cloned_voices": cloned_voices,
        "cloned_voices_count": len(cloned_voices),
        "sample_rate": getattr(request.app.state, "sample_rate", 0),
        "enhancements": "enabled" if hasattr(request.app.state, "model_info") and 
                      request.app.state.model_info.get("voice_enhancement_enabled", False) else "disabled",
        "streaming": "enabled",
        "cuda": cuda_stats,
        "version": "1.0.0"
    }

# Version endpoint
@app.get("/version", include_in_schema=False)
async def version():
    """Version endpoint that returns API version information."""
    return {
        "api_version": "1.0.0",
        "model_version": "CSM-1B",
        "compatible_with": "OpenAI TTS v1",
        "enhancements": "voice consistency and audio quality v1.0",
        "voice_cloning": "enabled" if hasattr(app.state, "voice_cloner") else "disabled",
        "streaming": "enabled"
    }

# Voice cloning UI endpoint
@app.get("/voice-cloning", include_in_schema=False)
async def voice_cloning_ui():
    """Voice cloning UI endpoint."""
    return FileResponse("/app/static/voice-cloning.html")

# Streaming demo endpoint
@app.get("/streaming-demo", include_in_schema=False)
async def streaming_demo():
    """Streaming TTS demo endpoint."""
    return FileResponse("/app/static/streaming-demo.html")

@app.get("/", include_in_schema=False)
async def root():
    """Root endpoint that redirects to docs."""
    logger.debug("Root endpoint accessed, redirecting to docs")
    return RedirectResponse(url="/docs")

@app.on_event("startup")
async def startup_db_client():
    """Initialize MongoDB connection on startup."""
    await connect_to_mongo()

@app.on_event("shutdown")
async def shutdown_db_client():
    """Close MongoDB connection on shutdown."""
    await close_mongo_connection()

if __name__ == "__main__":
    # Get port from environment or use default
    port = int(os.environ.get("PORT", 7860))
    
    # Development mode flag
    dev_mode = os.environ.get("DEV_MODE", "false").lower() == "true"
    
    # Log level (default to INFO, but can be overridden)
    log_level = os.environ.get("LOG_LEVEL", "INFO").upper()
    logging.getLogger().setLevel(log_level)
    
    # Check for audio enhancement and voice cloning flags
    enable_enhancements = os.environ.get("ENABLE_ENHANCEMENTS", "true").lower() == "true"
    enable_voice_cloning = os.environ.get("ENABLE_VOICE_CLONING", "true").lower() == "true"
    
    if not enable_enhancements:
        logger.warning("Voice enhancements disabled by environment variable")
    if not enable_voice_cloning:
        logger.warning("Voice cloning disabled by environment variable")
    
    logger.info(f"Voice enhancements: {'enabled' if enable_enhancements else 'disabled'}")
    logger.info(f"Voice cloning: {'enabled' if enable_voice_cloning else 'disabled'}")
    logger.info(f"Streaming: enabled")
    logger.info(f"Log level: {log_level}")
    
    if dev_mode:
        logger.info(f"Running in development mode with auto-reload enabled on port {port}")
        uvicorn.run(
            "app.main:app", 
            host="0.0.0.0", 
            port=port, 
            reload=True, 
            log_level=log_level.lower()
        )
    else:
        logger.info(f"Running in production mode on port {port}")
        uvicorn.run(
            "app.main:app", 
            host="0.0.0.0", 
            port=port, 
            reload=False, 
            log_level=log_level.lower()
        )