"""Streaming support for the TTS API.""" import asyncio import io import logging import time from typing import AsyncGenerator, Optional, List import torch import torchaudio from fastapi import APIRouter, Request, HTTPException from fastapi.responses import StreamingResponse from app.api.schemas import SpeechRequest, ResponseFormat from app.prompt_engineering import split_into_segments from app.models import Segment logger = logging.getLogger(__name__) router = APIRouter() class AudioChunker: """Handle audio chunking for streaming responses.""" def __init__(self, sample_rate: int, format: str = "mp3", chunk_size_ms: int = 200): # Smaller chunks for better streaming """ Initialize audio chunker. Args: sample_rate: Audio sample rate in Hz format: Output audio format (mp3, opus, etc.) chunk_size_ms: Size of each chunk in milliseconds """ self.sample_rate = sample_rate self.format = format.lower() self.chunk_size_samples = int(sample_rate * (chunk_size_ms / 1000)) logger.info(f"Audio chunker initialized with {chunk_size_ms}ms chunks ({self.chunk_size_samples} samples)") async def chunk_audio(self, audio: torch.Tensor, delay_ms: int = 0) -> AsyncGenerator[bytes, None]: """ Convert audio tensor to streaming chunks. Args: audio: Audio tensor to stream delay_ms: Artificial delay between chunks (for testing) Yields: Audio chunks as bytes """ # Ensure audio is on CPU if audio.is_cuda: audio = audio.cpu() # Calculate number of chunks num_samples = audio.shape[0] num_chunks = (num_samples + self.chunk_size_samples - 1) // self.chunk_size_samples logger.info(f"Streaming {num_samples} samples as {num_chunks} chunks") for i in range(num_chunks): start_idx = i * self.chunk_size_samples end_idx = min(start_idx + self.chunk_size_samples, num_samples) # Extract chunk chunk = audio[start_idx:end_idx] # Convert to bytes in requested format chunk_bytes = await self._format_chunk(chunk) # Add artificial delay if requested (for testing) if delay_ms > 0: await asyncio.sleep(delay_ms / 1000) yield chunk_bytes async def _format_chunk(self, chunk: torch.Tensor) -> bytes: """Convert audio chunk to bytes in the specified format.""" buf = io.BytesIO() # Ensure chunk is 1D and on CPU if len(chunk.shape) == 1: chunk = chunk.unsqueeze(0) # Add channel dimension # Ensure chunk is on CPU if chunk.is_cuda: chunk = chunk.cpu() # Save to buffer in specified format if self.format == "mp3": torchaudio.save(buf, chunk, self.sample_rate, format="mp3") elif self.format == "opus": torchaudio.save(buf, chunk, self.sample_rate, format="opus") elif self.format == "aac": torchaudio.save(buf, chunk, self.sample_rate, format="aac") elif self.format == "flac": torchaudio.save(buf, chunk, self.sample_rate, format="flac") elif self.format == "wav": torchaudio.save(buf, chunk, self.sample_rate, format="wav") else: # Default to mp3 torchaudio.save(buf, chunk, self.sample_rate, format="mp3") # Get bytes from buffer buf.seek(0) return buf.read() # Helper function to get speaker ID for a voice def get_speaker_id(app_state, voice): """Helper function to get speaker ID from voice name or ID""" if hasattr(app_state, "voice_speaker_map") and voice in app_state.voice_speaker_map: return app_state.voice_speaker_map[voice] # Standard voices mapping voice_to_speaker = {"alloy": 0, "echo": 1, "fable": 2, "onyx": 3, "nova": 4, "shimmer": 5} if voice in voice_to_speaker: return voice_to_speaker[voice] # Try parsing as integer try: speaker_id = int(voice) if 0 <= speaker_id < 6: return speaker_id except (ValueError, TypeError): pass # Check cloned voices if the voice cloner exists if hasattr(app_state, "voice_cloner") and app_state.voice_cloner is not None: # Check by ID if voice in app_state.voice_cloner.cloned_voices: return app_state.voice_cloner.cloned_voices[voice].speaker_id # Check by name for v_id, v_info in app_state.voice_cloner.cloned_voices.items(): if v_info.name.lower() == voice.lower(): return v_info.speaker_id # Default to alloy return 0 @router.post("/audio/speech/stream", tags=["Audio"]) async def stream_speech( request: Request, speech_request: SpeechRequest, ): """ Stream audio of text being spoken by a realistic voice. This endpoint provides an OpenAI-compatible streaming interface for TTS. """ # Check if model is loaded if not hasattr(request.app.state, "generator") or request.app.state.generator is None: raise HTTPException( status_code=503, detail="Model not loaded. Please try again later." ) # Get request parameters model = speech_request.model input_text = speech_request.input voice = speech_request.voice response_format = speech_request.response_format speed = speech_request.speed temperature = speech_request.temperature max_audio_length_ms = speech_request.max_audio_length_ms # Log the request logger.info(f"Real-time streaming speech from text ({len(input_text)} chars) with voice '{voice}'") # Check if text is empty if not input_text or len(input_text.strip()) == 0: raise HTTPException( status_code=400, detail="Input text cannot be empty" ) # Get speaker ID for the voice speaker_id = get_speaker_id(request.app.state, voice) if speaker_id is None: raise HTTPException( status_code=400, detail=f"Voice '{voice}' not found. Available voices: {request.app.state.available_voices}" ) try: # Create media type based on format media_type = { "mp3": "audio/mpeg", "opus": "audio/opus", "aac": "audio/aac", "flac": "audio/flac", "wav": "audio/wav", }.get(response_format, "audio/mpeg") # Create the chunker for streaming sample_rate = request.app.state.sample_rate chunker = AudioChunker(sample_rate, response_format) # Split text into segments using the imported function from app.prompt_engineering import split_into_segments text_segments = split_into_segments(input_text, max_chars=50) # Smaller segments for faster first response logger.info(f"Split text into {len(text_segments)} segments for incremental streaming") async def generate_streaming_audio(): # Check for cloned voice voice_info = None from_cloned_voice = False if hasattr(request.app.state, "voice_cloning_enabled") and request.app.state.voice_cloning_enabled: voice_info = request.app.state.get_voice_info(voice) from_cloned_voice = voice_info and voice_info["type"] == "cloned" if from_cloned_voice: # Use cloned voice context for first segment voice_cloner = request.app.state.voice_cloner context = voice_cloner.get_voice_context(voice_info["voice_id"]) else: # Use standard voice context from app.voice_enhancement import get_voice_segments context = get_voice_segments(voice, request.app.state.device) else: # Use standard voice context from app.voice_enhancement import get_voice_segments context = get_voice_segments(voice, request.app.state.device) # Send an empty chunk to initialize the connection yield b'' # Process each text segment incrementally and stream in real time for i, segment_text in enumerate(text_segments): try: logger.info(f"Generating segment {i+1}/{len(text_segments)}") # Generate audio for this segment - use async to avoid blocking if from_cloned_voice: # Generate with cloned voice voice_cloner = request.app.state.voice_cloner # Convert to asynchronous with asyncio.to_thread segment_audio = await asyncio.to_thread( voice_cloner.generate_speech, segment_text, voice_info["voice_id"], temperature=temperature, topk=30, max_audio_length_ms=2000 # Keep segments short for streaming ) else: # Use standard voice with generator segment_audio = await asyncio.to_thread( request.app.state.generator.generate, segment_text, speaker_id, context, max_audio_length_ms=2000, # Short for quicker generation temperature=temperature ) # Process audio quality for this segment if hasattr(request.app.state, "voice_enhancement_enabled") and request.app.state.voice_enhancement_enabled: from app.voice_enhancement import process_generated_audio segment_audio = process_generated_audio( audio=segment_audio, voice_name=voice, sample_rate=sample_rate, text=segment_text ) # Handle speed adjustment if speed != 1.0 and speed > 0: try: # Adjust speed using torchaudio effects = [["tempo", str(speed)]] audio_cpu = segment_audio.cpu() adjusted_audio, _= torchaudio.sox_effects.apply_effects_tensor( audio_cpu.unsqueeze(0), sample_rate, effects ) segment_audio = adjusted_audio.squeeze(0) except Exception as e: logger.warning(f"Failed to adjust speech speed: {e}") # Convert this segment to bytes and stream immediately buf = io.BytesIO() audio_to_save = segment_audio.unsqueeze(0) if len(segment_audio.shape) == 1 else segment_audio torchaudio.save(buf, audio_to_save.cpu(), sample_rate, format=response_format) buf.seek(0) segment_bytes = buf.read() # Stream this segment immediately yield segment_bytes # Update context with this segment for next generation context = [ Segment( text=segment_text, speaker=speaker_id, audio=segment_audio ) ] except Exception as e: logger.error(f"Error generating segment {i+1}: {e}") # Try to continue with next segment # Return streaming response return StreamingResponse( generate_streaming_audio(), media_type=media_type, headers={ "Content-Disposition": f'attachment; filename="speech.{response_format}"', "X-Accel-Buffering": "no", # Prevent buffering in nginx "Cache-Control": "no-cache, no-store, must-revalidate", # Prevent caching "Pragma": "no-cache", "Expires": "0", "Connection": "keep-alive", "Transfer-Encoding": "chunked" } ) except Exception as e: logger.error(f"Error in stream_speech: {e}") raise HTTPException(status_code=500, detail=f"Error generating speech: {str(e)}") @router.post("/audio/speech/streaming", tags=["Audio"]) async def openai_stream_speech( request: Request, speech_request: SpeechRequest, ): """ Stream audio in OpenAI-compatible streaming format. This endpoint is compatible with the OpenAI streaming TTS API. """ # Use the same logic as the stream_speech endpoint but with a different name # to maintain the OpenAI API naming convention return await stream_speech(request, speech_request)