Spaces:
Running
Running

Refactor model structure: update import paths from 'app.modelz' to 'app.models' across multiple files for consistency, remove obsolete 'modelz' directory, and adjust Dockerfile and migration script to reflect these changes, enhancing clarity and organization in the codebase.
c27f115
"""Streaming support for the TTS API.""" | |
import asyncio | |
import io | |
import logging | |
import time | |
from typing import AsyncGenerator, Optional, List | |
import torch | |
import torchaudio | |
from fastapi import APIRouter, Request, HTTPException | |
from fastapi.responses import StreamingResponse | |
from app.api.schemas import SpeechRequest, ResponseFormat | |
from app.prompt_engineering import split_into_segments | |
from app.model import Segment | |
logger = logging.getLogger(__name__) | |
router = APIRouter() | |
class AudioChunker: | |
"""Handle audio chunking for streaming responses.""" | |
def __init__(self, | |
sample_rate: int, | |
format: str = "mp3", | |
chunk_size_ms: int = 200): # Smaller chunks for better streaming | |
""" | |
Initialize audio chunker. | |
Args: | |
sample_rate: Audio sample rate in Hz | |
format: Output audio format (mp3, opus, etc.) | |
chunk_size_ms: Size of each chunk in milliseconds | |
""" | |
self.sample_rate = sample_rate | |
self.format = format.lower() | |
self.chunk_size_samples = int(sample_rate * (chunk_size_ms / 1000)) | |
logger.info(f"Audio chunker initialized with {chunk_size_ms}ms chunks ({self.chunk_size_samples} samples)") | |
async def chunk_audio(self, | |
audio: torch.Tensor, | |
delay_ms: int = 0) -> AsyncGenerator[bytes, None]: | |
""" | |
Convert audio tensor to streaming chunks. | |
Args: | |
audio: Audio tensor to stream | |
delay_ms: Artificial delay between chunks (for testing) | |
Yields: | |
Audio chunks as bytes | |
""" | |
# Ensure audio is on CPU | |
if audio.is_cuda: | |
audio = audio.cpu() | |
# Calculate number of chunks | |
num_samples = audio.shape[0] | |
num_chunks = (num_samples + self.chunk_size_samples - 1) // self.chunk_size_samples | |
logger.info(f"Streaming {num_samples} samples as {num_chunks} chunks") | |
for i in range(num_chunks): | |
start_idx = i * self.chunk_size_samples | |
end_idx = min(start_idx + self.chunk_size_samples, num_samples) | |
# Extract chunk | |
chunk = audio[start_idx:end_idx] | |
# Convert to bytes in requested format | |
chunk_bytes = await self._format_chunk(chunk) | |
# Add artificial delay if requested (for testing) | |
if delay_ms > 0: | |
await asyncio.sleep(delay_ms / 1000) | |
yield chunk_bytes | |
async def _format_chunk(self, chunk: torch.Tensor) -> bytes: | |
"""Convert audio chunk to bytes in the specified format.""" | |
buf = io.BytesIO() | |
# Ensure chunk is 1D and on CPU | |
if len(chunk.shape) == 1: | |
chunk = chunk.unsqueeze(0) # Add channel dimension | |
# Ensure chunk is on CPU | |
if chunk.is_cuda: | |
chunk = chunk.cpu() | |
# Save to buffer in specified format | |
if self.format == "mp3": | |
torchaudio.save(buf, chunk, self.sample_rate, format="mp3") | |
elif self.format == "opus": | |
torchaudio.save(buf, chunk, self.sample_rate, format="opus") | |
elif self.format == "aac": | |
torchaudio.save(buf, chunk, self.sample_rate, format="aac") | |
elif self.format == "flac": | |
torchaudio.save(buf, chunk, self.sample_rate, format="flac") | |
elif self.format == "wav": | |
torchaudio.save(buf, chunk, self.sample_rate, format="wav") | |
else: | |
# Default to mp3 | |
torchaudio.save(buf, chunk, self.sample_rate, format="mp3") | |
# Get bytes from buffer | |
buf.seek(0) | |
return buf.read() | |
# Helper function to get speaker ID for a voice | |
def get_speaker_id(app_state, voice): | |
"""Helper function to get speaker ID from voice name or ID""" | |
if hasattr(app_state, "voice_speaker_map") and voice in app_state.voice_speaker_map: | |
return app_state.voice_speaker_map[voice] | |
# Standard voices mapping | |
voice_to_speaker = {"alloy": 0, "echo": 1, "fable": 2, "onyx": 3, "nova": 4, "shimmer": 5} | |
if voice in voice_to_speaker: | |
return voice_to_speaker[voice] | |
# Try parsing as integer | |
try: | |
speaker_id = int(voice) | |
if 0 <= speaker_id < 6: | |
return speaker_id | |
except (ValueError, TypeError): | |
pass | |
# Check cloned voices if the voice cloner exists | |
if hasattr(app_state, "voice_cloner") and app_state.voice_cloner is not None: | |
# Check by ID | |
if voice in app_state.voice_cloner.cloned_voices: | |
return app_state.voice_cloner.cloned_voices[voice].speaker_id | |
# Check by name | |
for v_id, v_info in app_state.voice_cloner.cloned_voices.items(): | |
if v_info.name.lower() == voice.lower(): | |
return v_info.speaker_id | |
# Default to alloy | |
return 0 | |
async def stream_speech( | |
request: Request, | |
speech_request: SpeechRequest, | |
): | |
""" | |
Stream audio of text being spoken by a realistic voice. | |
This endpoint provides an OpenAI-compatible streaming interface for TTS. | |
""" | |
# Check if model is loaded | |
if not hasattr(request.app.state, "generator") or request.app.state.generator is None: | |
raise HTTPException( | |
status_code=503, | |
detail="Model not loaded. Please try again later." | |
) | |
# Get request parameters | |
model = speech_request.model | |
input_text = speech_request.input | |
voice = speech_request.voice | |
response_format = speech_request.response_format | |
speed = speech_request.speed | |
temperature = speech_request.temperature | |
max_audio_length_ms = speech_request.max_audio_length_ms | |
# Log the request | |
logger.info(f"Real-time streaming speech from text ({len(input_text)} chars) with voice '{voice}'") | |
# Check if text is empty | |
if not input_text or len(input_text.strip()) == 0: | |
raise HTTPException( | |
status_code=400, | |
detail="Input text cannot be empty" | |
) | |
# Get speaker ID for the voice | |
speaker_id = get_speaker_id(request.app.state, voice) | |
if speaker_id is None: | |
raise HTTPException( | |
status_code=400, | |
detail=f"Voice '{voice}' not found. Available voices: {request.app.state.available_voices}" | |
) | |
try: | |
# Create media type based on format | |
media_type = { | |
"mp3": "audio/mpeg", | |
"opus": "audio/opus", | |
"aac": "audio/aac", | |
"flac": "audio/flac", | |
"wav": "audio/wav", | |
}.get(response_format, "audio/mpeg") | |
# Create the chunker for streaming | |
sample_rate = request.app.state.sample_rate | |
chunker = AudioChunker(sample_rate, response_format) | |
# Split text into segments using the imported function | |
from app.prompt_engineering import split_into_segments | |
text_segments = split_into_segments(input_text, max_chars=50) # Smaller segments for faster first response | |
logger.info(f"Split text into {len(text_segments)} segments for incremental streaming") | |
async def generate_streaming_audio(): | |
# Check for cloned voice | |
voice_info = None | |
from_cloned_voice = False | |
if hasattr(request.app.state, "voice_cloning_enabled") and request.app.state.voice_cloning_enabled: | |
voice_info = request.app.state.get_voice_info(voice) | |
from_cloned_voice = voice_info and voice_info["type"] == "cloned" | |
if from_cloned_voice: | |
# Use cloned voice context for first segment | |
voice_cloner = request.app.state.voice_cloner | |
context = voice_cloner.get_voice_context(voice_info["voice_id"]) | |
else: | |
# Use standard voice context | |
from app.voice_enhancement import get_voice_segments | |
context = get_voice_segments(voice, request.app.state.device) | |
else: | |
# Use standard voice context | |
from app.voice_enhancement import get_voice_segments | |
context = get_voice_segments(voice, request.app.state.device) | |
# Send an empty chunk to initialize the connection | |
yield b'' | |
# Process each text segment incrementally and stream in real time | |
for i, segment_text in enumerate(text_segments): | |
try: | |
logger.info(f"Generating segment {i+1}/{len(text_segments)}") | |
# Generate audio for this segment - use async to avoid blocking | |
if from_cloned_voice: | |
# Generate with cloned voice | |
voice_cloner = request.app.state.voice_cloner | |
# Convert to asynchronous with asyncio.to_thread | |
segment_audio = await asyncio.to_thread( | |
voice_cloner.generate_speech, | |
segment_text, | |
voice_info["voice_id"], | |
temperature=temperature, | |
topk=30, | |
max_audio_length_ms=2000 # Keep segments short for streaming | |
) | |
else: | |
# Use standard voice with generator | |
segment_audio = await asyncio.to_thread( | |
request.app.state.generator.generate, | |
segment_text, | |
speaker_id, | |
context, | |
max_audio_length_ms=2000, # Short for quicker generation | |
temperature=temperature | |
) | |
# Process audio quality for this segment | |
if hasattr(request.app.state, "voice_enhancement_enabled") and request.app.state.voice_enhancement_enabled: | |
from app.voice_enhancement import process_generated_audio | |
segment_audio = process_generated_audio( | |
audio=segment_audio, | |
voice_name=voice, | |
sample_rate=sample_rate, | |
text=segment_text | |
) | |
# Handle speed adjustment | |
if speed != 1.0 and speed > 0: | |
try: | |
# Adjust speed using torchaudio | |
effects = [["tempo", str(speed)]] | |
audio_cpu = segment_audio.cpu() | |
adjusted_audio, _= torchaudio.sox_effects.apply_effects_tensor( | |
audio_cpu.unsqueeze(0), | |
sample_rate, | |
effects | |
) | |
segment_audio = adjusted_audio.squeeze(0) | |
except Exception as e: | |
logger.warning(f"Failed to adjust speech speed: {e}") | |
# Convert this segment to bytes and stream immediately | |
buf = io.BytesIO() | |
audio_to_save = segment_audio.unsqueeze(0) if len(segment_audio.shape) == 1 else segment_audio | |
torchaudio.save(buf, audio_to_save.cpu(), sample_rate, format=response_format) | |
buf.seek(0) | |
segment_bytes = buf.read() | |
# Stream this segment immediately | |
yield segment_bytes | |
# Update context with this segment for next generation | |
context = [ | |
Segment( | |
text=segment_text, | |
speaker=speaker_id, | |
audio=segment_audio | |
) | |
] | |
except Exception as e: | |
logger.error(f"Error generating segment {i+1}: {e}") | |
# Try to continue with next segment | |
# Return streaming response | |
return StreamingResponse( | |
generate_streaming_audio(), | |
media_type=media_type, | |
headers={ | |
"Content-Disposition": f'attachment; filename="speech.{response_format}"', | |
"X-Accel-Buffering": "no", # Prevent buffering in nginx | |
"Cache-Control": "no-cache, no-store, must-revalidate", # Prevent caching | |
"Pragma": "no-cache", | |
"Expires": "0", | |
"Connection": "keep-alive", | |
"Transfer-Encoding": "chunked" | |
} | |
) | |
except Exception as e: | |
logger.error(f"Error in stream_speech: {e}") | |
raise HTTPException(status_code=500, detail=f"Error generating speech: {str(e)}") | |
async def openai_stream_speech( | |
request: Request, | |
speech_request: SpeechRequest, | |
): | |
""" | |
Stream audio in OpenAI-compatible streaming format. | |
This endpoint is compatible with the OpenAI streaming TTS API. | |
""" | |
# Use the same logic as the stream_speech endpoint but with a different name | |
# to maintain the OpenAI API naming convention | |
return await stream_speech(request, speech_request) |