jamiya / app /api /streaming.py
jameszokah's picture
Refactor model structure: update import paths from 'app.modelz' to 'app.models' across multiple files for consistency, remove obsolete 'modelz' directory, and adjust Dockerfile and migration script to reflect these changes, enhancing clarity and organization in the codebase.
c27f115
"""Streaming support for the TTS API."""
import asyncio
import io
import logging
import time
from typing import AsyncGenerator, Optional, List
import torch
import torchaudio
from fastapi import APIRouter, Request, HTTPException
from fastapi.responses import StreamingResponse
from app.api.schemas import SpeechRequest, ResponseFormat
from app.prompt_engineering import split_into_segments
from app.model import Segment
logger = logging.getLogger(__name__)
router = APIRouter()
class AudioChunker:
"""Handle audio chunking for streaming responses."""
def __init__(self,
sample_rate: int,
format: str = "mp3",
chunk_size_ms: int = 200): # Smaller chunks for better streaming
"""
Initialize audio chunker.
Args:
sample_rate: Audio sample rate in Hz
format: Output audio format (mp3, opus, etc.)
chunk_size_ms: Size of each chunk in milliseconds
"""
self.sample_rate = sample_rate
self.format = format.lower()
self.chunk_size_samples = int(sample_rate * (chunk_size_ms / 1000))
logger.info(f"Audio chunker initialized with {chunk_size_ms}ms chunks ({self.chunk_size_samples} samples)")
async def chunk_audio(self,
audio: torch.Tensor,
delay_ms: int = 0) -> AsyncGenerator[bytes, None]:
"""
Convert audio tensor to streaming chunks.
Args:
audio: Audio tensor to stream
delay_ms: Artificial delay between chunks (for testing)
Yields:
Audio chunks as bytes
"""
# Ensure audio is on CPU
if audio.is_cuda:
audio = audio.cpu()
# Calculate number of chunks
num_samples = audio.shape[0]
num_chunks = (num_samples + self.chunk_size_samples - 1) // self.chunk_size_samples
logger.info(f"Streaming {num_samples} samples as {num_chunks} chunks")
for i in range(num_chunks):
start_idx = i * self.chunk_size_samples
end_idx = min(start_idx + self.chunk_size_samples, num_samples)
# Extract chunk
chunk = audio[start_idx:end_idx]
# Convert to bytes in requested format
chunk_bytes = await self._format_chunk(chunk)
# Add artificial delay if requested (for testing)
if delay_ms > 0:
await asyncio.sleep(delay_ms / 1000)
yield chunk_bytes
async def _format_chunk(self, chunk: torch.Tensor) -> bytes:
"""Convert audio chunk to bytes in the specified format."""
buf = io.BytesIO()
# Ensure chunk is 1D and on CPU
if len(chunk.shape) == 1:
chunk = chunk.unsqueeze(0) # Add channel dimension
# Ensure chunk is on CPU
if chunk.is_cuda:
chunk = chunk.cpu()
# Save to buffer in specified format
if self.format == "mp3":
torchaudio.save(buf, chunk, self.sample_rate, format="mp3")
elif self.format == "opus":
torchaudio.save(buf, chunk, self.sample_rate, format="opus")
elif self.format == "aac":
torchaudio.save(buf, chunk, self.sample_rate, format="aac")
elif self.format == "flac":
torchaudio.save(buf, chunk, self.sample_rate, format="flac")
elif self.format == "wav":
torchaudio.save(buf, chunk, self.sample_rate, format="wav")
else:
# Default to mp3
torchaudio.save(buf, chunk, self.sample_rate, format="mp3")
# Get bytes from buffer
buf.seek(0)
return buf.read()
# Helper function to get speaker ID for a voice
def get_speaker_id(app_state, voice):
"""Helper function to get speaker ID from voice name or ID"""
if hasattr(app_state, "voice_speaker_map") and voice in app_state.voice_speaker_map:
return app_state.voice_speaker_map[voice]
# Standard voices mapping
voice_to_speaker = {"alloy": 0, "echo": 1, "fable": 2, "onyx": 3, "nova": 4, "shimmer": 5}
if voice in voice_to_speaker:
return voice_to_speaker[voice]
# Try parsing as integer
try:
speaker_id = int(voice)
if 0 <= speaker_id < 6:
return speaker_id
except (ValueError, TypeError):
pass
# Check cloned voices if the voice cloner exists
if hasattr(app_state, "voice_cloner") and app_state.voice_cloner is not None:
# Check by ID
if voice in app_state.voice_cloner.cloned_voices:
return app_state.voice_cloner.cloned_voices[voice].speaker_id
# Check by name
for v_id, v_info in app_state.voice_cloner.cloned_voices.items():
if v_info.name.lower() == voice.lower():
return v_info.speaker_id
# Default to alloy
return 0
@router.post("/audio/speech/stream", tags=["Audio"])
async def stream_speech(
request: Request,
speech_request: SpeechRequest,
):
"""
Stream audio of text being spoken by a realistic voice.
This endpoint provides an OpenAI-compatible streaming interface for TTS.
"""
# Check if model is loaded
if not hasattr(request.app.state, "generator") or request.app.state.generator is None:
raise HTTPException(
status_code=503,
detail="Model not loaded. Please try again later."
)
# Get request parameters
model = speech_request.model
input_text = speech_request.input
voice = speech_request.voice
response_format = speech_request.response_format
speed = speech_request.speed
temperature = speech_request.temperature
max_audio_length_ms = speech_request.max_audio_length_ms
# Log the request
logger.info(f"Real-time streaming speech from text ({len(input_text)} chars) with voice '{voice}'")
# Check if text is empty
if not input_text or len(input_text.strip()) == 0:
raise HTTPException(
status_code=400,
detail="Input text cannot be empty"
)
# Get speaker ID for the voice
speaker_id = get_speaker_id(request.app.state, voice)
if speaker_id is None:
raise HTTPException(
status_code=400,
detail=f"Voice '{voice}' not found. Available voices: {request.app.state.available_voices}"
)
try:
# Create media type based on format
media_type = {
"mp3": "audio/mpeg",
"opus": "audio/opus",
"aac": "audio/aac",
"flac": "audio/flac",
"wav": "audio/wav",
}.get(response_format, "audio/mpeg")
# Create the chunker for streaming
sample_rate = request.app.state.sample_rate
chunker = AudioChunker(sample_rate, response_format)
# Split text into segments using the imported function
from app.prompt_engineering import split_into_segments
text_segments = split_into_segments(input_text, max_chars=50) # Smaller segments for faster first response
logger.info(f"Split text into {len(text_segments)} segments for incremental streaming")
async def generate_streaming_audio():
# Check for cloned voice
voice_info = None
from_cloned_voice = False
if hasattr(request.app.state, "voice_cloning_enabled") and request.app.state.voice_cloning_enabled:
voice_info = request.app.state.get_voice_info(voice)
from_cloned_voice = voice_info and voice_info["type"] == "cloned"
if from_cloned_voice:
# Use cloned voice context for first segment
voice_cloner = request.app.state.voice_cloner
context = voice_cloner.get_voice_context(voice_info["voice_id"])
else:
# Use standard voice context
from app.voice_enhancement import get_voice_segments
context = get_voice_segments(voice, request.app.state.device)
else:
# Use standard voice context
from app.voice_enhancement import get_voice_segments
context = get_voice_segments(voice, request.app.state.device)
# Send an empty chunk to initialize the connection
yield b''
# Process each text segment incrementally and stream in real time
for i, segment_text in enumerate(text_segments):
try:
logger.info(f"Generating segment {i+1}/{len(text_segments)}")
# Generate audio for this segment - use async to avoid blocking
if from_cloned_voice:
# Generate with cloned voice
voice_cloner = request.app.state.voice_cloner
# Convert to asynchronous with asyncio.to_thread
segment_audio = await asyncio.to_thread(
voice_cloner.generate_speech,
segment_text,
voice_info["voice_id"],
temperature=temperature,
topk=30,
max_audio_length_ms=2000 # Keep segments short for streaming
)
else:
# Use standard voice with generator
segment_audio = await asyncio.to_thread(
request.app.state.generator.generate,
segment_text,
speaker_id,
context,
max_audio_length_ms=2000, # Short for quicker generation
temperature=temperature
)
# Process audio quality for this segment
if hasattr(request.app.state, "voice_enhancement_enabled") and request.app.state.voice_enhancement_enabled:
from app.voice_enhancement import process_generated_audio
segment_audio = process_generated_audio(
audio=segment_audio,
voice_name=voice,
sample_rate=sample_rate,
text=segment_text
)
# Handle speed adjustment
if speed != 1.0 and speed > 0:
try:
# Adjust speed using torchaudio
effects = [["tempo", str(speed)]]
audio_cpu = segment_audio.cpu()
adjusted_audio, _= torchaudio.sox_effects.apply_effects_tensor(
audio_cpu.unsqueeze(0),
sample_rate,
effects
)
segment_audio = adjusted_audio.squeeze(0)
except Exception as e:
logger.warning(f"Failed to adjust speech speed: {e}")
# Convert this segment to bytes and stream immediately
buf = io.BytesIO()
audio_to_save = segment_audio.unsqueeze(0) if len(segment_audio.shape) == 1 else segment_audio
torchaudio.save(buf, audio_to_save.cpu(), sample_rate, format=response_format)
buf.seek(0)
segment_bytes = buf.read()
# Stream this segment immediately
yield segment_bytes
# Update context with this segment for next generation
context = [
Segment(
text=segment_text,
speaker=speaker_id,
audio=segment_audio
)
]
except Exception as e:
logger.error(f"Error generating segment {i+1}: {e}")
# Try to continue with next segment
# Return streaming response
return StreamingResponse(
generate_streaming_audio(),
media_type=media_type,
headers={
"Content-Disposition": f'attachment; filename="speech.{response_format}"',
"X-Accel-Buffering": "no", # Prevent buffering in nginx
"Cache-Control": "no-cache, no-store, must-revalidate", # Prevent caching
"Pragma": "no-cache",
"Expires": "0",
"Connection": "keep-alive",
"Transfer-Encoding": "chunked"
}
)
except Exception as e:
logger.error(f"Error in stream_speech: {e}")
raise HTTPException(status_code=500, detail=f"Error generating speech: {str(e)}")
@router.post("/audio/speech/streaming", tags=["Audio"])
async def openai_stream_speech(
request: Request,
speech_request: SpeechRequest,
):
"""
Stream audio in OpenAI-compatible streaming format.
This endpoint is compatible with the OpenAI streaming TTS API.
"""
# Use the same logic as the stream_speech endpoint but with a different name
# to maintain the OpenAI API naming convention
return await stream_speech(request, speech_request)