Spaces:
Paused
Paused
"""Advanced voice enhancement and consistency system for CSM-1B.""" | |
import os | |
import torch | |
import torchaudio | |
import numpy as np | |
import soundfile as sf | |
from typing import Dict, List, Optional, Tuple | |
import logging | |
from dataclasses import dataclass | |
from scipy import signal | |
# Setup logging | |
logger = logging.getLogger(__name__) | |
# Define persistent paths | |
VOICE_REFERENCES_DIR = "/app/voice_references" | |
VOICE_PROFILES_DIR = "/app/voice_profiles" | |
# Ensure directories exist | |
os.makedirs(VOICE_REFERENCES_DIR, exist_ok=True) | |
os.makedirs(VOICE_PROFILES_DIR, exist_ok=True) | |
class VoiceProfile: | |
"""Detailed voice profile with acoustic characteristics.""" | |
name: str | |
speaker_id: int | |
# Acoustic parameters | |
pitch_range: Tuple[float, float] # Min/max pitch in Hz | |
intensity_range: Tuple[float, float] # Min/max intensity (volume) | |
spectral_tilt: float # Brightness vs. darkness | |
prosody_pattern: str # Pattern of intonation and rhythm | |
speech_rate: float # Relative speech rate (1.0 = normal) | |
formant_shift: float # Formant frequency shift (1.0 = no shift) | |
# Reference audio | |
reference_segments: List[torch.Tensor] | |
# Normalization parameters | |
target_rms: float = 0.2 | |
target_peak: float = 0.95 | |
def get_enhancement_params(self) -> Dict: | |
"""Get parameters for enhancing generated audio.""" | |
return { | |
"target_rms": self.target_rms, | |
"target_peak": self.target_peak, | |
"pitch_range": self.pitch_range, | |
"formant_shift": self.formant_shift, | |
"speech_rate": self.speech_rate, | |
"spectral_tilt": self.spectral_tilt | |
} | |
# Voice profiles with carefully tuned parameters | |
VOICE_PROFILES = { | |
"alloy": VoiceProfile( | |
name="alloy", | |
speaker_id=0, | |
pitch_range=(85, 180), # Hz - balanced range | |
intensity_range=(0.15, 0.3), # moderate intensity | |
spectral_tilt=0.0, # neutral tilt | |
prosody_pattern="balanced", | |
speech_rate=1.0, # normal rate | |
formant_shift=1.0, # no shift | |
reference_segments=[], | |
target_rms=0.2, | |
target_peak=0.95 | |
), | |
"echo": VoiceProfile( | |
name="echo", | |
speaker_id=1, | |
pitch_range=(75, 165), # Hz - lower, resonant | |
intensity_range=(0.2, 0.35), # slightly stronger | |
spectral_tilt=-0.2, # more low frequencies | |
prosody_pattern="deliberate", | |
speech_rate=0.95, # slightly slower | |
formant_shift=0.95, # slightly lower formants | |
reference_segments=[], | |
target_rms=0.22, # slightly louder | |
target_peak=0.95 | |
), | |
"fable": VoiceProfile( | |
name="fable", | |
speaker_id=2, | |
pitch_range=(120, 250), # Hz - higher range | |
intensity_range=(0.15, 0.28), # moderate intensity | |
spectral_tilt=0.2, # more high frequencies | |
prosody_pattern="animated", | |
speech_rate=1.05, # slightly faster | |
formant_shift=1.05, # slightly higher formants | |
reference_segments=[], | |
target_rms=0.19, | |
target_peak=0.95 | |
), | |
"onyx": VoiceProfile( | |
name="onyx", | |
speaker_id=3, | |
pitch_range=(65, 150), # Hz - deeper range | |
intensity_range=(0.18, 0.32), # moderate-strong | |
spectral_tilt=-0.3, # more low frequencies | |
prosody_pattern="authoritative", | |
speech_rate=0.93, # slightly slower | |
formant_shift=0.9, # lower formants | |
reference_segments=[], | |
target_rms=0.23, # stronger | |
target_peak=0.95 | |
), | |
"nova": VoiceProfile( | |
name="nova", | |
speaker_id=4, | |
pitch_range=(90, 200), # Hz - warm midrange | |
intensity_range=(0.15, 0.27), # moderate | |
spectral_tilt=-0.1, # slightly warm | |
prosody_pattern="flowing", | |
speech_rate=1.0, # normal rate | |
formant_shift=1.0, # no shift | |
reference_segments=[], | |
target_rms=0.2, | |
target_peak=0.95 | |
), | |
"shimmer": VoiceProfile( | |
name="shimmer", | |
speaker_id=5, | |
pitch_range=(140, 280), # Hz - brighter, higher | |
intensity_range=(0.15, 0.25), # moderate-light | |
spectral_tilt=0.3, # more high frequencies | |
prosody_pattern="light", | |
speech_rate=1.07, # slightly faster | |
formant_shift=1.1, # higher formants | |
reference_segments=[], | |
target_rms=0.18, # slightly softer | |
target_peak=0.95 | |
) | |
} | |
# Voice-specific prompt templates - crafted to establish voice identity clearly | |
VOICE_PROMPTS = { | |
"alloy": [ | |
"Hello, I'm Alloy. I speak with a balanced, natural tone that's easy to understand.", | |
"This is Alloy speaking. My voice is designed to be clear and conversational.", | |
"Alloy here - I have a neutral, friendly voice with balanced tone qualities." | |
], | |
"echo": [ | |
"Hello, I'm Echo. I speak with a resonant, deeper voice that carries well.", | |
"This is Echo speaking. My voice has a rich, resonant quality with depth.", | |
"Echo here - My voice is characterized by its warm, resonant tones." | |
], | |
"fable": [ | |
"Hello, I'm Fable. I speak with a bright, higher-pitched voice that's full of energy.", | |
"This is Fable speaking. My voice is characterized by its clear, bright quality.", | |
"Fable here - My voice is light, articulate, and slightly higher-pitched." | |
], | |
"onyx": [ | |
"Hello, I'm Onyx. I speak with a deep, authoritative voice that commands attention.", | |
"This is Onyx speaking. My voice has a powerful, deep quality with gravitas.", | |
"Onyx here - My voice is characterized by its depth and commanding presence." | |
], | |
"nova": [ | |
"Hello, I'm Nova. I speak with a warm, pleasant mid-range voice that's easy to listen to.", | |
"This is Nova speaking. My voice has a smooth, harmonious quality.", | |
"Nova here - My voice is characterized by its warm, friendly mid-tones." | |
], | |
"shimmer": [ | |
"Hello, I'm Shimmer. I speak with a light, bright voice that's expressive and clear.", | |
"This is Shimmer speaking. My voice has an airy, higher-pitched quality.", | |
"Shimmer here - My voice is characterized by its bright, crystalline tones." | |
] | |
} | |
def initialize_voice_profiles(): | |
"""Initialize voice profiles with default settings. | |
This function loads existing voice profiles from disk if available, | |
or initializes them with default settings. | |
""" | |
global VOICE_PROFILES | |
# Try to load existing profiles from persistent storage | |
profile_path = os.path.join(VOICE_PROFILES_DIR, "voice_profiles.pt") | |
if os.path.exists(profile_path): | |
try: | |
logger.info(f"Loading voice profiles from {profile_path}") | |
saved_profiles = torch.load(profile_path) | |
# Update existing profiles with saved data | |
for name, data in saved_profiles.items(): | |
if name in VOICE_PROFILES: | |
VOICE_PROFILES[name].reference_segments = [ | |
seg.to(torch.device("cpu")) for seg in data.get('reference_segments', []) | |
] | |
logger.info(f"Loaded voice profiles for {len(saved_profiles)} voices") | |
except Exception as e: | |
logger.error(f"Error loading voice profiles: {e}") | |
logger.info("Using default voice profiles") | |
else: | |
logger.info("No saved voice profiles found, using defaults") | |
# Ensure all voices have at least empty reference segments | |
for name, profile in VOICE_PROFILES.items(): | |
if not hasattr(profile, 'reference_segments'): | |
profile.reference_segments = [] | |
logger.info(f"Voice profiles initialized for {len(VOICE_PROFILES)} voices") | |
return VOICE_PROFILES | |
def normalize_audio(audio: torch.Tensor, target_rms: float = 0.2, target_peak: float = 0.95) -> torch.Tensor: | |
"""Apply professional-grade normalization to audio. | |
Args: | |
audio: Audio tensor | |
target_rms: Target RMS level for normalization | |
target_peak: Target peak level for limiting | |
Returns: | |
Normalized audio tensor | |
""" | |
# Ensure audio is on CPU for processing | |
audio_cpu = audio.detach().cpu() | |
# Handle silent audio | |
if audio_cpu.abs().max() < 1e-6: | |
logger.warning("Audio is nearly silent, returning original") | |
return audio | |
# Calculate current RMS | |
current_rms = torch.sqrt(torch.mean(audio_cpu ** 2)) | |
# Apply RMS normalization | |
if current_rms > 0: | |
gain = target_rms / current_rms | |
normalized = audio_cpu * gain | |
else: | |
normalized = audio_cpu | |
# Apply peak limiting | |
current_peak = normalized.abs().max() | |
if current_peak > target_peak: | |
normalized = normalized * (target_peak / current_peak) | |
# Return to original device | |
return normalized.to(audio.device) | |
def apply_anti_muffling(audio: torch.Tensor, sample_rate: int, clarity_boost: float = 1.2) -> torch.Tensor: | |
"""Apply anti-muffling to improve clarity. | |
Args: | |
audio: Audio tensor | |
sample_rate: Audio sample rate | |
clarity_boost: Amount of high frequency boost (1.0 = no boost) | |
Returns: | |
Processed audio tensor | |
""" | |
# Convert to numpy for filtering | |
audio_np = audio.detach().cpu().numpy() | |
try: | |
# Design a high shelf filter to boost high frequencies | |
# Use a standard high-shelf filter that's supported by scipy.signal | |
# We'll use a second-order Butterworth high-pass filter as an alternative | |
cutoff = 2000 # Hz | |
b, a = signal.butter(2, cutoff/(sample_rate/2), btype='high', analog=False) | |
# Apply the filter with the clarity boost gain | |
boosted = signal.filtfilt(b, a, audio_np, axis=0) * clarity_boost | |
# Mix with original to maintain some warmth | |
mix_ratio = 0.7 # 70% processed, 30% original | |
processed = mix_ratio * boosted + (1-mix_ratio) * audio_np | |
except Exception as e: | |
logger.warning(f"Audio enhancement failed, using original: {e}") | |
# Return original audio if enhancement fails | |
return audio | |
# Convert back to tensor on original device | |
return torch.tensor(processed, dtype=audio.dtype, device=audio.device) | |
def enhance_audio(audio: torch.Tensor, sample_rate: int, voice_profile: VoiceProfile) -> torch.Tensor: | |
"""Apply comprehensive audio enhancement based on voice profile. | |
Args: | |
audio: Audio tensor | |
sample_rate: Audio sample rate | |
voice_profile: Voice profile containing enhancement parameters | |
Returns: | |
Enhanced audio tensor | |
""" | |
if audio is None or audio.numel() == 0: | |
logger.error("Cannot enhance empty audio") | |
return audio | |
try: | |
# Step 1: Normalize audio levels | |
params = voice_profile.get_enhancement_params() | |
normalized = normalize_audio( | |
audio, | |
target_rms=params["target_rms"], | |
target_peak=params["target_peak"] | |
) | |
# Step 2: Apply anti-muffling based on spectral tilt | |
# Positive tilt means brighter voice so less clarity boost needed | |
clarity_boost = 1.0 + max(0, -params["spectral_tilt"]) * 0.5 | |
clarified = apply_anti_muffling( | |
normalized, | |
sample_rate, | |
clarity_boost=clarity_boost | |
) | |
# Log the enhancement | |
logger.debug( | |
f"Enhanced audio for {voice_profile.name}: " | |
f"RMS: {audio.pow(2).mean().sqrt().item():.3f}->{clarified.pow(2).mean().sqrt().item():.3f}, " | |
f"Peak: {audio.abs().max().item():.3f}->{clarified.abs().max().item():.3f}" | |
) | |
return clarified | |
except Exception as e: | |
logger.error(f"Error in audio enhancement: {e}") | |
return audio # Return original audio if enhancement fails | |
def validate_generated_audio( | |
audio: torch.Tensor, | |
voice_name: str, | |
sample_rate: int, | |
min_expected_duration: float = 0.5 | |
) -> Tuple[bool, torch.Tensor, str]: | |
"""Validate and fix generated audio. | |
Args: | |
audio: Audio tensor to validate | |
voice_name: Name of the voice used | |
sample_rate: Audio sample rate | |
min_expected_duration: Minimum expected duration in seconds | |
Returns: | |
Tuple of (is_valid, fixed_audio, message) | |
""" | |
if audio is None: | |
return False, torch.zeros(1), "Audio is None" | |
# Check for NaN values | |
if torch.isnan(audio).any(): | |
logger.warning(f"Audio for {voice_name} contains NaN values, replacing with zeros") | |
audio = torch.where(torch.isnan(audio), torch.zeros_like(audio), audio) | |
# Check audio duration | |
duration = audio.shape[0] / sample_rate | |
if duration < min_expected_duration: | |
logger.warning(f"Audio for {voice_name} is too short ({duration:.2f}s < {min_expected_duration}s)") | |
return False, audio, f"Audio too short: {duration:.2f}s" | |
# Check for silent sections - this can indicate generation problems | |
rms = torch.sqrt(torch.mean(audio ** 2)) | |
if rms < 0.01: # Very low RMS indicates near silence | |
logger.warning(f"Audio for {voice_name} is nearly silent (RMS: {rms:.6f})") | |
return False, audio, f"Audio nearly silent: RMS = {rms:.6f}" | |
# Check if audio suddenly cuts off - this detects premature stopping | |
# Calculate RMS in the last 100ms | |
last_samples = int(0.1 * sample_rate) | |
if audio.shape[0] > last_samples: | |
end_rms = torch.sqrt(torch.mean(audio[-last_samples:] ** 2)) | |
if end_rms > 0.1: # High RMS at the end suggests an abrupt cutoff | |
logger.warning(f"Audio for {voice_name} may have cut off prematurely (end RMS: {end_rms:.3f})") | |
return True, audio, "Audio may have cut off prematurely" | |
return True, audio, "Audio validation passed" | |
def create_voice_segments(app_state, regenerate: bool = False): | |
"""Create high-quality voice reference segments. | |
Args: | |
app_state: Application state containing generator | |
regenerate: Whether to regenerate existing references | |
""" | |
generator = app_state.generator | |
if not generator: | |
logger.error("Cannot create voice segments: generator not available") | |
return | |
# Use persistent directory for voice reference segments | |
os.makedirs(VOICE_REFERENCES_DIR, exist_ok=True) | |
for voice_name, profile in VOICE_PROFILES.items(): | |
voice_dir = os.path.join(VOICE_REFERENCES_DIR, voice_name) | |
os.makedirs(voice_dir, exist_ok=True) | |
# Check if we already have references | |
if not regenerate and profile.reference_segments: | |
logger.info(f"Voice {voice_name} already has {len(profile.reference_segments)} reference segments") | |
continue | |
# Get prompts for this voice | |
prompts = VOICE_PROMPTS[voice_name] | |
# Generate reference segments | |
logger.info(f"Generating reference segments for voice: {voice_name}") | |
reference_segments = [] | |
for i, prompt in enumerate(prompts): | |
ref_path = os.path.join(voice_dir, f"{voice_name}_ref_{i}.wav") | |
# Skip if file exists and we're not regenerating | |
if not regenerate and os.path.exists(ref_path): | |
try: | |
# Load existing reference | |
audio_tensor, sr = torchaudio.load(ref_path) | |
if sr != generator.sample_rate: | |
audio_tensor = torchaudio.functional.resample( | |
audio_tensor.squeeze(0), orig_freq=sr, new_freq=generator.sample_rate | |
) | |
else: | |
audio_tensor = audio_tensor.squeeze(0) | |
reference_segments.append(audio_tensor.to(generator.device)) | |
logger.info(f"Loaded existing reference {i+1}/{len(prompts)} for {voice_name}") | |
continue | |
except Exception as e: | |
logger.warning(f"Failed to load existing reference {i+1} for {voice_name}: {e}") | |
try: | |
# Use a lower temperature for more stability in reference samples | |
logger.info(f"Generating reference {i+1}/{len(prompts)} for {voice_name}: '{prompt}'") | |
# We want references to be as clean as possible | |
audio = generator.generate( | |
text=prompt, | |
speaker=profile.speaker_id, | |
context=[], # No context for initial samples to prevent voice bleed | |
max_audio_length_ms=6000, # Shorter for more control | |
temperature=0.7, # Lower temperature for more stability | |
topk=30, # More focused sampling | |
) | |
# Validate and enhance the audio | |
is_valid, audio, message = validate_generated_audio( | |
audio, voice_name, generator.sample_rate | |
) | |
if is_valid: | |
# Enhance the audio | |
audio = enhance_audio(audio, generator.sample_rate, profile) | |
# Save the reference to persistent storage | |
torchaudio.save(ref_path, audio.unsqueeze(0).cpu(), generator.sample_rate) | |
reference_segments.append(audio) | |
logger.info(f"Generated reference {i+1} for {voice_name}: {message}") | |
else: | |
logger.warning(f"Invalid reference for {voice_name}: {message}") | |
# Try again with different settings if invalid | |
if i < len(prompts) - 1: | |
logger.info(f"Trying again with next prompt") | |
continue | |
except Exception as e: | |
logger.error(f"Error generating reference for {voice_name}: {e}") | |
# Update the voice profile with references | |
if reference_segments: | |
VOICE_PROFILES[voice_name].reference_segments = reference_segments | |
logger.info(f"Updated {voice_name} with {len(reference_segments)} reference segments") | |
# Save the updated profiles to persistent storage | |
save_voice_profiles() | |
def get_voice_segments(voice_name: str, device: torch.device) -> List: | |
"""Get context segments for a given voice. | |
Args: | |
voice_name: Name of the voice to use | |
device: Device to place tensors on | |
Returns: | |
List of context segments | |
""" | |
from app.models import Segment | |
if voice_name not in VOICE_PROFILES: | |
logger.warning(f"Voice {voice_name} not found, defaulting to alloy") | |
voice_name = "alloy" | |
profile = VOICE_PROFILES[voice_name] | |
# If we don't have reference segments yet, create them | |
if not profile.reference_segments: | |
try: | |
# Try to load from disk - use persistent storage | |
voice_dir = os.path.join(VOICE_REFERENCES_DIR, voice_name) | |
if os.path.exists(voice_dir): | |
reference_segments = [] | |
prompts = VOICE_PROMPTS[voice_name] | |
for i, prompt in enumerate(prompts): | |
ref_path = os.path.join(voice_dir, f"{voice_name}_ref_{i}.wav") | |
if os.path.exists(ref_path): | |
audio_tensor, sr = torchaudio.load(ref_path) | |
audio_tensor = audio_tensor.squeeze(0) | |
reference_segments.append(audio_tensor) | |
if reference_segments: | |
profile.reference_segments = reference_segments | |
logger.info(f"Loaded {len(reference_segments)} reference segments for {voice_name}") | |
except Exception as e: | |
logger.error(f"Error loading reference segments for {voice_name}: {e}") | |
# Create context segments from references | |
context = [] | |
if profile.reference_segments: | |
for i, ref_audio in enumerate(profile.reference_segments): | |
# Use corresponding prompt if available, otherwise use a generic one | |
text = VOICE_PROMPTS[voice_name][i] if i < len(VOICE_PROMPTS[voice_name]) else f"Voice reference for {voice_name}" | |
context.append( | |
Segment( | |
speaker=profile.speaker_id, | |
text=text, | |
audio=ref_audio.to(device) | |
) | |
) | |
logger.info(f"Returning {len(context)} context segments for {voice_name}") | |
return context | |
def save_voice_profiles(): | |
"""Save voice profiles to persistent storage.""" | |
os.makedirs(VOICE_PROFILES_DIR, exist_ok=True) | |
profile_path = os.path.join(VOICE_PROFILES_DIR, "voice_profiles.pt") | |
# Create a serializable version of the profiles | |
serializable_profiles = {} | |
for name, profile in VOICE_PROFILES.items(): | |
serializable_profiles[name] = { | |
'reference_segments': [seg.cpu() for seg in profile.reference_segments] | |
} | |
# Save to persistent storage | |
torch.save(serializable_profiles, profile_path) | |
logger.info(f"Saved voice profiles to {profile_path}") | |
def process_generated_audio( | |
audio: torch.Tensor, | |
voice_name: str, | |
sample_rate: int, | |
text: str | |
) -> torch.Tensor: | |
"""Process generated audio for consistency and quality. | |
Args: | |
audio: Audio tensor | |
voice_name: Name of voice used | |
sample_rate: Audio sample rate | |
text: Text that was spoken | |
Returns: | |
Processed audio tensor | |
""" | |
# Validate the audio | |
is_valid, audio, message = validate_generated_audio(audio, voice_name, sample_rate) | |
if not is_valid: | |
logger.warning(f"Generated audio validation issue: {message}") | |
# Get voice profile for enhancement | |
profile = VOICE_PROFILES.get(voice_name, VOICE_PROFILES["alloy"]) | |
# Enhance the audio based on voice profile | |
enhanced = enhance_audio(audio, sample_rate, profile) | |
# Log the enhancement | |
original_duration = audio.shape[0] / sample_rate | |
enhanced_duration = enhanced.shape[0] / sample_rate | |
logger.info( | |
f"Processed audio for '{voice_name}': " | |
f"Duration: {original_duration:.2f}s->{enhanced_duration:.2f}s, " | |
f"RMS: {audio.pow(2).mean().sqrt().item():.3f}->{enhanced.pow(2).mean().sqrt().item():.3f}" | |
) | |
return enhanced |