Spaces:
Running
Running
File size: 13,609 Bytes
383520d c27f115 383520d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 |
"""Streaming support for the TTS API."""
import asyncio
import io
import logging
import time
from typing import AsyncGenerator, Optional, List
import torch
import torchaudio
from fastapi import APIRouter, Request, HTTPException
from fastapi.responses import StreamingResponse
from app.api.schemas import SpeechRequest, ResponseFormat
from app.prompt_engineering import split_into_segments
from app.model import Segment
logger = logging.getLogger(__name__)
router = APIRouter()
class AudioChunker:
"""Handle audio chunking for streaming responses."""
def __init__(self,
sample_rate: int,
format: str = "mp3",
chunk_size_ms: int = 200): # Smaller chunks for better streaming
"""
Initialize audio chunker.
Args:
sample_rate: Audio sample rate in Hz
format: Output audio format (mp3, opus, etc.)
chunk_size_ms: Size of each chunk in milliseconds
"""
self.sample_rate = sample_rate
self.format = format.lower()
self.chunk_size_samples = int(sample_rate * (chunk_size_ms / 1000))
logger.info(f"Audio chunker initialized with {chunk_size_ms}ms chunks ({self.chunk_size_samples} samples)")
async def chunk_audio(self,
audio: torch.Tensor,
delay_ms: int = 0) -> AsyncGenerator[bytes, None]:
"""
Convert audio tensor to streaming chunks.
Args:
audio: Audio tensor to stream
delay_ms: Artificial delay between chunks (for testing)
Yields:
Audio chunks as bytes
"""
# Ensure audio is on CPU
if audio.is_cuda:
audio = audio.cpu()
# Calculate number of chunks
num_samples = audio.shape[0]
num_chunks = (num_samples + self.chunk_size_samples - 1) // self.chunk_size_samples
logger.info(f"Streaming {num_samples} samples as {num_chunks} chunks")
for i in range(num_chunks):
start_idx = i * self.chunk_size_samples
end_idx = min(start_idx + self.chunk_size_samples, num_samples)
# Extract chunk
chunk = audio[start_idx:end_idx]
# Convert to bytes in requested format
chunk_bytes = await self._format_chunk(chunk)
# Add artificial delay if requested (for testing)
if delay_ms > 0:
await asyncio.sleep(delay_ms / 1000)
yield chunk_bytes
async def _format_chunk(self, chunk: torch.Tensor) -> bytes:
"""Convert audio chunk to bytes in the specified format."""
buf = io.BytesIO()
# Ensure chunk is 1D and on CPU
if len(chunk.shape) == 1:
chunk = chunk.unsqueeze(0) # Add channel dimension
# Ensure chunk is on CPU
if chunk.is_cuda:
chunk = chunk.cpu()
# Save to buffer in specified format
if self.format == "mp3":
torchaudio.save(buf, chunk, self.sample_rate, format="mp3")
elif self.format == "opus":
torchaudio.save(buf, chunk, self.sample_rate, format="opus")
elif self.format == "aac":
torchaudio.save(buf, chunk, self.sample_rate, format="aac")
elif self.format == "flac":
torchaudio.save(buf, chunk, self.sample_rate, format="flac")
elif self.format == "wav":
torchaudio.save(buf, chunk, self.sample_rate, format="wav")
else:
# Default to mp3
torchaudio.save(buf, chunk, self.sample_rate, format="mp3")
# Get bytes from buffer
buf.seek(0)
return buf.read()
# Helper function to get speaker ID for a voice
def get_speaker_id(app_state, voice):
"""Helper function to get speaker ID from voice name or ID"""
if hasattr(app_state, "voice_speaker_map") and voice in app_state.voice_speaker_map:
return app_state.voice_speaker_map[voice]
# Standard voices mapping
voice_to_speaker = {"alloy": 0, "echo": 1, "fable": 2, "onyx": 3, "nova": 4, "shimmer": 5}
if voice in voice_to_speaker:
return voice_to_speaker[voice]
# Try parsing as integer
try:
speaker_id = int(voice)
if 0 <= speaker_id < 6:
return speaker_id
except (ValueError, TypeError):
pass
# Check cloned voices if the voice cloner exists
if hasattr(app_state, "voice_cloner") and app_state.voice_cloner is not None:
# Check by ID
if voice in app_state.voice_cloner.cloned_voices:
return app_state.voice_cloner.cloned_voices[voice].speaker_id
# Check by name
for v_id, v_info in app_state.voice_cloner.cloned_voices.items():
if v_info.name.lower() == voice.lower():
return v_info.speaker_id
# Default to alloy
return 0
@router.post("/audio/speech/stream", tags=["Audio"])
async def stream_speech(
request: Request,
speech_request: SpeechRequest,
):
"""
Stream audio of text being spoken by a realistic voice.
This endpoint provides an OpenAI-compatible streaming interface for TTS.
"""
# Check if model is loaded
if not hasattr(request.app.state, "generator") or request.app.state.generator is None:
raise HTTPException(
status_code=503,
detail="Model not loaded. Please try again later."
)
# Get request parameters
model = speech_request.model
input_text = speech_request.input
voice = speech_request.voice
response_format = speech_request.response_format
speed = speech_request.speed
temperature = speech_request.temperature
max_audio_length_ms = speech_request.max_audio_length_ms
# Log the request
logger.info(f"Real-time streaming speech from text ({len(input_text)} chars) with voice '{voice}'")
# Check if text is empty
if not input_text or len(input_text.strip()) == 0:
raise HTTPException(
status_code=400,
detail="Input text cannot be empty"
)
# Get speaker ID for the voice
speaker_id = get_speaker_id(request.app.state, voice)
if speaker_id is None:
raise HTTPException(
status_code=400,
detail=f"Voice '{voice}' not found. Available voices: {request.app.state.available_voices}"
)
try:
# Create media type based on format
media_type = {
"mp3": "audio/mpeg",
"opus": "audio/opus",
"aac": "audio/aac",
"flac": "audio/flac",
"wav": "audio/wav",
}.get(response_format, "audio/mpeg")
# Create the chunker for streaming
sample_rate = request.app.state.sample_rate
chunker = AudioChunker(sample_rate, response_format)
# Split text into segments using the imported function
from app.prompt_engineering import split_into_segments
text_segments = split_into_segments(input_text, max_chars=50) # Smaller segments for faster first response
logger.info(f"Split text into {len(text_segments)} segments for incremental streaming")
async def generate_streaming_audio():
# Check for cloned voice
voice_info = None
from_cloned_voice = False
if hasattr(request.app.state, "voice_cloning_enabled") and request.app.state.voice_cloning_enabled:
voice_info = request.app.state.get_voice_info(voice)
from_cloned_voice = voice_info and voice_info["type"] == "cloned"
if from_cloned_voice:
# Use cloned voice context for first segment
voice_cloner = request.app.state.voice_cloner
context = voice_cloner.get_voice_context(voice_info["voice_id"])
else:
# Use standard voice context
from app.voice_enhancement import get_voice_segments
context = get_voice_segments(voice, request.app.state.device)
else:
# Use standard voice context
from app.voice_enhancement import get_voice_segments
context = get_voice_segments(voice, request.app.state.device)
# Send an empty chunk to initialize the connection
yield b''
# Process each text segment incrementally and stream in real time
for i, segment_text in enumerate(text_segments):
try:
logger.info(f"Generating segment {i+1}/{len(text_segments)}")
# Generate audio for this segment - use async to avoid blocking
if from_cloned_voice:
# Generate with cloned voice
voice_cloner = request.app.state.voice_cloner
# Convert to asynchronous with asyncio.to_thread
segment_audio = await asyncio.to_thread(
voice_cloner.generate_speech,
segment_text,
voice_info["voice_id"],
temperature=temperature,
topk=30,
max_audio_length_ms=2000 # Keep segments short for streaming
)
else:
# Use standard voice with generator
segment_audio = await asyncio.to_thread(
request.app.state.generator.generate,
segment_text,
speaker_id,
context,
max_audio_length_ms=2000, # Short for quicker generation
temperature=temperature
)
# Process audio quality for this segment
if hasattr(request.app.state, "voice_enhancement_enabled") and request.app.state.voice_enhancement_enabled:
from app.voice_enhancement import process_generated_audio
segment_audio = process_generated_audio(
audio=segment_audio,
voice_name=voice,
sample_rate=sample_rate,
text=segment_text
)
# Handle speed adjustment
if speed != 1.0 and speed > 0:
try:
# Adjust speed using torchaudio
effects = [["tempo", str(speed)]]
audio_cpu = segment_audio.cpu()
adjusted_audio, _= torchaudio.sox_effects.apply_effects_tensor(
audio_cpu.unsqueeze(0),
sample_rate,
effects
)
segment_audio = adjusted_audio.squeeze(0)
except Exception as e:
logger.warning(f"Failed to adjust speech speed: {e}")
# Convert this segment to bytes and stream immediately
buf = io.BytesIO()
audio_to_save = segment_audio.unsqueeze(0) if len(segment_audio.shape) == 1 else segment_audio
torchaudio.save(buf, audio_to_save.cpu(), sample_rate, format=response_format)
buf.seek(0)
segment_bytes = buf.read()
# Stream this segment immediately
yield segment_bytes
# Update context with this segment for next generation
context = [
Segment(
text=segment_text,
speaker=speaker_id,
audio=segment_audio
)
]
except Exception as e:
logger.error(f"Error generating segment {i+1}: {e}")
# Try to continue with next segment
# Return streaming response
return StreamingResponse(
generate_streaming_audio(),
media_type=media_type,
headers={
"Content-Disposition": f'attachment; filename="speech.{response_format}"',
"X-Accel-Buffering": "no", # Prevent buffering in nginx
"Cache-Control": "no-cache, no-store, must-revalidate", # Prevent caching
"Pragma": "no-cache",
"Expires": "0",
"Connection": "keep-alive",
"Transfer-Encoding": "chunked"
}
)
except Exception as e:
logger.error(f"Error in stream_speech: {e}")
raise HTTPException(status_code=500, detail=f"Error generating speech: {str(e)}")
@router.post("/audio/speech/streaming", tags=["Audio"])
async def openai_stream_speech(
request: Request,
speech_request: SpeechRequest,
):
"""
Stream audio in OpenAI-compatible streaming format.
This endpoint is compatible with the OpenAI streaming TTS API.
"""
# Use the same logic as the stream_speech endpoint but with a different name
# to maintain the OpenAI API naming convention
return await stream_speech(request, speech_request) |