Spaces:
Paused
Paused
# Updated generator.py with proper function order | |
from dataclasses import dataclass | |
from typing import List, Tuple | |
import torch | |
import torchaudio | |
import logging | |
import os | |
from huggingface_hub import hf_hub_download | |
from transformers import AutoTokenizer | |
from tokenizers.processors import TemplateProcessing | |
from app.models import Segment | |
from app.text_normalizer import clean_text_for_tts | |
from app.text_normalizer import TextNormalizer | |
# Set up logging | |
logger = logging.getLogger(__name__) | |
# Import the CSM watermarking code | |
try: | |
from app.watermarking import CSM_1B_GH_WATERMARK, load_watermarker, watermark | |
except ImportError: | |
# Define stubs for watermarking if the module is not available | |
CSM_1B_GH_WATERMARK = "CSM1B" | |
def load_watermarker(device="cpu"): | |
return None | |
def watermark(watermarker, audio, sample_rate, key): | |
return audio, sample_rate | |
def load_llama3_tokenizer(): | |
""" | |
Load tokenizer for Llama 3.2, using unsloth's open version | |
instead of the gated meta-llama version. | |
""" | |
try: | |
# Use the unsloth version which is not gated | |
tokenizer_name = "unsloth/Llama-3.2-1B" | |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) | |
bos = tokenizer.bos_token | |
eos = tokenizer.eos_token | |
tokenizer._tokenizer.post_processor = TemplateProcessing( | |
single=f"{bos}:0 $A:0 {eos}:0", | |
pair=f"{bos}:0 $A:0 {eos}:0 {bos}:1 $B:1 {eos}:1", | |
special_tokens=[(f"{bos}", tokenizer.bos_token_id), (f"{eos}", tokenizer.eos_token_id)], | |
) | |
logger.info("Successfully loaded tokenizer from unsloth/Llama-3.2-1B") | |
return tokenizer | |
except Exception as e: | |
logger.error(f"Error loading tokenizer from unsloth: {e}") | |
# Fallback to a simpler tokenizer if needed | |
try: | |
from transformers import GPT2Tokenizer | |
logger.warning("Falling back to GPT2Tokenizer") | |
tokenizer = GPT2Tokenizer.from_pretrained("gpt2") | |
tokenizer.pad_token = tokenizer.eos_token | |
return tokenizer | |
except Exception as fallback_e: | |
logger.error(f"Fallback tokenizer also failed: {fallback_e}") | |
raise RuntimeError("Could not load any suitable tokenizer") | |
class Generator: | |
"""Generator class for CSM-1B model.""" | |
def __init__(self, model): | |
"""Initialize generator with model.""" | |
self._model = model | |
self._model.setup_caches(1) | |
self._text_tokenizer = load_llama3_tokenizer() | |
device = next(model.parameters()).device | |
# Load Mimi codec for audio tokenization | |
try: | |
logger.info("Loading Mimi audio codec...") | |
from huggingface_hub import hf_hub_download | |
# First try to import from moshi | |
try: | |
from moshi.models import loaders | |
DEFAULT_REPO = loaders.DEFAULT_REPO | |
MIMI_NAME = loaders.MIMI_NAME | |
get_mimi = loaders.get_mimi | |
except ImportError: | |
logger.warning("moshi.models.loaders not found, using fallback") | |
# Fallback values if moshi.models.loaders is not available | |
DEFAULT_REPO = "kyutai/mimi" | |
MIMI_NAME = "mimi-december.pt" | |
# Fallback function to load mimi | |
def get_mimi(checkpoint_path, device): | |
from moshi.models.vqvae_model import MiMiModule | |
checkpoint = torch.load(checkpoint_path, map_location=device) | |
model = MiMiModule.init_from_checkpoint(checkpoint, device=device) | |
return model | |
mimi_weight = hf_hub_download(DEFAULT_REPO, MIMI_NAME) | |
mimi = get_mimi(mimi_weight, device=device) | |
mimi.set_num_codebooks(32) | |
self._audio_tokenizer = mimi | |
self.sample_rate = mimi.sample_rate | |
logger.info(f"Mimi codec loaded successfully with sample rate {self.sample_rate}") | |
except Exception as e: | |
logger.error(f"Error loading Mimi codec: {e}") | |
self._audio_tokenizer = None | |
self.sample_rate = 24000 # Default sample rate | |
logger.warning(f"Using fallback sample rate: {self.sample_rate}") | |
raise RuntimeError(f"Failed to load Mimi codec: {e}") | |
try: | |
self._watermarker = load_watermarker(device=device) | |
logger.info("Watermarker loaded successfully") | |
except Exception as e: | |
logger.warning(f"Error loading watermarker: {e}. Watermarking will be disabled.") | |
self._watermarker = None | |
self.device = device | |
# Optimize for CUDA throughput | |
if torch.cuda.is_available(): | |
torch.backends.cudnn.benchmark = True | |
torch.cuda.empty_cache() | |
logger.info("CUDA optimizations enabled") | |
def _tokenize_text_segment(self, text: str, speaker: int) -> Tuple[torch.Tensor, torch.Tensor]: | |
"""Tokenize a text segment.""" | |
frame_tokens = [] | |
frame_masks = [] | |
# Strip any voice instructions in square brackets to avoid them being read out | |
text = self._clean_text_input(text) | |
text_tokens = self._text_tokenizer.encode(f"[{speaker}]{text}") | |
text_frame = torch.zeros(len(text_tokens), 33).long() | |
text_frame_mask = torch.zeros(len(text_tokens), 33).bool() | |
text_frame[:, -1] = torch.tensor(text_tokens) | |
text_frame_mask[:, -1] = True | |
frame_tokens.append(text_frame.to(self.device)) | |
frame_masks.append(text_frame_mask.to(self.device)) | |
return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0) | |
def _clean_text_input(self, text: str) -> str: | |
"""Clean and normalize text for TTS.""" | |
return clean_text_for_tts(text) | |
def _tokenize_audio(self, audio: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
"""Tokenize audio.""" | |
if self._audio_tokenizer is None: | |
raise RuntimeError("Audio tokenizer not initialized") | |
frame_tokens = [] | |
frame_masks = [] | |
# (K, T) | |
audio = audio.to(self.device) | |
audio_tokens = self._audio_tokenizer.encode(audio.unsqueeze(0).unsqueeze(0))[0] | |
# add EOS frame | |
eos_frame = torch.zeros(audio_tokens.size(0), 1).to(self.device) | |
audio_tokens = torch.cat([audio_tokens, eos_frame], dim=1) | |
audio_frame = torch.zeros(audio_tokens.size(1), 33).long().to(self.device) | |
audio_frame_mask = torch.zeros(audio_tokens.size(1), 33).bool().to(self.device) | |
audio_frame[:, :-1] = audio_tokens.transpose(0, 1) | |
audio_frame_mask[:, :-1] = True | |
frame_tokens.append(audio_frame) | |
frame_masks.append(audio_frame_mask) | |
return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0) | |
def _tokenize_segment(self, segment: Segment) -> Tuple[torch.Tensor, torch.Tensor]: | |
"""Tokenize a segment of text and audio.""" | |
text_tokens, text_masks = self._tokenize_text_segment(segment.text, segment.speaker) | |
audio_tokens, audio_masks = self._tokenize_audio(segment.audio) | |
return torch.cat([text_tokens, audio_tokens], dim=0), torch.cat([text_masks, audio_masks], dim=0) | |
def generate_quick( | |
self, | |
text: str, | |
speaker: int, | |
context: List[Segment], | |
max_audio_length_ms: float = 2000, # Short for quick generation | |
temperature: float = 0.7, # Lower for more predictable output | |
topk: int = 20, # Lower for faster beam selection | |
) -> torch.Tensor: | |
"""Generate audio quickly for real-time streaming.""" | |
# Similar to generate() but optimized for speed | |
self._model.reset_caches() | |
# Convert max_audio_length_ms to frames - limit for faster generation | |
max_audio_frames = min(int(max_audio_length_ms / 80), 128) # Smaller limit | |
# Process text | |
cleaned_text = clean_text_for_tts(text) | |
# Prepare tokens | |
tokens, tokens_mask = [], [] | |
# Add context segments (limited to 1 for speed) | |
if context: | |
segment_tokens, segment_tokens_mask = self._tokenize_segment(context[0]) | |
tokens.append(segment_tokens) | |
tokens_mask.append(segment_tokens_mask) | |
# Add text tokens | |
gen_segment_tokens, gen_segment_tokens_mask = self._tokenize_text_segment(cleaned_text, speaker) | |
tokens.append(gen_segment_tokens) | |
tokens_mask.append(gen_segment_tokens_mask) | |
prompt_tokens = torch.cat(tokens, dim=0).long().to(self.device) | |
prompt_tokens_mask = torch.cat(tokens_mask, dim=0).bool().to(self.device) | |
# Generate with larger batch size for fewer iterations | |
curr_tokens = prompt_tokens.unsqueeze(0) | |
curr_tokens_mask = prompt_tokens_mask.unsqueeze(0) | |
curr_pos = torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device) | |
# Use larger batch size | |
batch_size = 64 # Generate more frames at once | |
all_samples = [] | |
for start_idx in range(0, max_audio_frames, batch_size): | |
end_idx = min(start_idx + batch_size, max_audio_frames) | |
batch_frames = end_idx - start_idx | |
samples_batch = [] | |
for i in range(batch_frames): | |
sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk) | |
samples_batch.append(sample) | |
if torch.all(sample == 0): | |
break | |
curr_tokens = torch.cat([sample, torch.zeros(1, 1).long().to(self.device)], dim=1).unsqueeze(1) | |
curr_tokens_mask = torch.cat( | |
[torch.ones_like(sample).bool(), torch.zeros(1, 1).bool().to(self.device)], dim=1 | |
).unsqueeze(1) | |
curr_pos = curr_pos[:, -1:] + 1 | |
all_samples.extend(samples_batch) | |
if len(samples_batch) < batch_frames: | |
break | |
if not all_samples: | |
return torch.zeros(10, device=self.device) # Return short empty audio | |
# Decode audio | |
audio = self._audio_tokenizer.decode(torch.stack(all_samples).permute(1, 2, 0)).squeeze(0).squeeze(0) | |
return audio | |
def generate( | |
self, | |
text: str, | |
speaker: int, | |
context: List[Segment], | |
max_audio_length_ms: float = 90_000, | |
temperature: float = 0.9, | |
topk: int = 50, | |
) -> torch.Tensor: | |
"""Generate audio from text.""" | |
if self._audio_tokenizer is None: | |
raise RuntimeError("Audio tokenizer not initialized") | |
# Start timing | |
start_time = torch.cuda.Event(enable_timing=True) | |
end_time = torch.cuda.Event(enable_timing=True) | |
start_time.record() | |
self._model.reset_caches() | |
# Convert max_audio_length_ms to frames - this controls the maximum generation length | |
max_audio_frames = min(int(max_audio_length_ms / 80), 1024) # Limit to reasonable size | |
max_seq_len = 2048 - max_audio_frames | |
# Check if text is long and should be split | |
if len(text) > 200: | |
logger.info(f"Long text detected ({len(text)} chars), processing in segments") | |
sentences = TextNormalizer.split_into_sentences(text) | |
logger.info(f"Split into {len(sentences)} segments") | |
# Process sentences individually and concatenate the results | |
all_audio_segments = [] | |
# Use the first sentence to establish voice | |
first_sentence = sentences[0] | |
cleaned_text = clean_text_for_tts(first_sentence) | |
# Generate the first segment | |
tokens, tokens_mask = [], [] | |
# Add context segments for the first sentence | |
for segment in context: | |
segment_tokens, segment_tokens_mask = self._tokenize_segment(segment) | |
tokens.append(segment_tokens) | |
tokens_mask.append(segment_tokens_mask) | |
# Add first sentence tokens | |
gen_segment_tokens, gen_segment_tokens_mask = self._tokenize_text_segment(cleaned_text, speaker) | |
tokens.append(gen_segment_tokens) | |
tokens_mask.append(gen_segment_tokens_mask) | |
prompt_tokens = torch.cat(tokens, dim=0).long().to(self.device) | |
prompt_tokens_mask = torch.cat(tokens_mask, dim=0).bool().to(self.device) | |
# Check context size and truncate if needed | |
if prompt_tokens.size(0) >= max_seq_len: | |
logger.warning(f"Inputs too long ({prompt_tokens.size(0)} tokens), truncating to {max_seq_len - 50}") | |
prompt_tokens = prompt_tokens[-max_seq_len+50:] | |
prompt_tokens_mask = prompt_tokens_mask[-max_seq_len+50:] | |
# Generate first sentence audio | |
curr_tokens = prompt_tokens.unsqueeze(0) | |
curr_tokens_mask = prompt_tokens_mask.unsqueeze(0) | |
curr_pos = torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device) | |
# Generate first segment | |
first_segment_samples = [] | |
for start_idx in range(0, max_audio_frames, 32): | |
end_idx = min(start_idx + 32, max_audio_frames) | |
batch_frames = end_idx - start_idx | |
samples_batch = [] | |
for i in range(batch_frames): | |
sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk) | |
samples_batch.append(sample) | |
if torch.all(sample == 0): | |
break | |
curr_tokens = torch.cat([sample, torch.zeros(1, 1).long().to(self.device)], dim=1).unsqueeze(1) | |
curr_tokens_mask = torch.cat( | |
[torch.ones_like(sample).bool(), torch.zeros(1, 1).bool().to(self.device)], dim=1 | |
).unsqueeze(1) | |
curr_pos = curr_pos[:, -1:] + 1 | |
first_segment_samples.extend(samples_batch) | |
if len(samples_batch) < batch_frames: | |
break | |
if not first_segment_samples: | |
raise RuntimeError("No audio generated for first segment") | |
# Decode first segment | |
first_segment_audio = self._audio_tokenizer.decode( | |
torch.stack(first_segment_samples).permute(1, 2, 0) | |
).squeeze(0).squeeze(0) | |
all_audio_segments.append(first_segment_audio) | |
# Now process remaining sentences using the first as context | |
for i, sentence in enumerate(sentences[1:], 1): | |
logger.info(f"Generating segment {i+1}/{len(sentences)}") | |
cleaned_text = clean_text_for_tts(sentence) | |
# Create a context segment from the previous generation | |
prev_segment = Segment( | |
speaker=speaker, | |
text=sentences[i-1], | |
audio=all_audio_segments[-1] | |
) | |
# Generate with this segment as context | |
segment_tokens, segment_tokens_mask = [], [] | |
segment_tokens.append(self._tokenize_segment(prev_segment)[0]) | |
segment_tokens_mask.append(self._tokenize_segment(prev_segment)[1]) | |
# Add current segment tokens | |
current_tokens, current_tokens_mask = self._tokenize_text_segment(cleaned_text, speaker) | |
segment_tokens.append(current_tokens) | |
segment_tokens_mask.append(current_tokens_mask) | |
segment_prompt_tokens = torch.cat(segment_tokens, dim=0).long().to(self.device) | |
segment_prompt_tokens_mask = torch.cat(segment_tokens_mask, dim=0).bool().to(self.device) | |
# Check length and truncate if needed | |
if segment_prompt_tokens.size(0) >= max_seq_len: | |
segment_prompt_tokens = segment_prompt_tokens[-max_seq_len+50:] | |
segment_prompt_tokens_mask = segment_prompt_tokens_mask[-max_seq_len+50:] | |
# Generate audio for this segment | |
curr_tokens = segment_prompt_tokens.unsqueeze(0) | |
curr_tokens_mask = segment_prompt_tokens_mask.unsqueeze(0) | |
curr_pos = torch.arange(0, segment_prompt_tokens.size(0)).unsqueeze(0).long().to(self.device) | |
# Generate segment | |
segment_samples = [] | |
for start_idx in range(0, max_audio_frames, 32): | |
end_idx = min(start_idx + 32, max_audio_frames) | |
batch_frames = end_idx - start_idx | |
samples_batch = [] | |
for i in range(batch_frames): | |
sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk) | |
samples_batch.append(sample) | |
if torch.all(sample == 0): | |
break | |
curr_tokens = torch.cat([sample, torch.zeros(1, 1).long().to(self.device)], dim=1).unsqueeze(1) | |
curr_tokens_mask = torch.cat( | |
[torch.ones_like(sample).bool(), torch.zeros(1, 1).bool().to(self.device)], dim=1 | |
).unsqueeze(1) | |
curr_pos = curr_pos[:, -1:] + 1 | |
segment_samples.extend(samples_batch) | |
if len(samples_batch) < batch_frames: | |
break | |
if not segment_samples: | |
logger.warning(f"No audio generated for segment {i+1}") | |
continue | |
# Decode segment | |
segment_audio = self._audio_tokenizer.decode( | |
torch.stack(segment_samples).permute(1, 2, 0) | |
).squeeze(0).squeeze(0) | |
all_audio_segments.append(segment_audio) | |
# Combine all segments with small pauses | |
pause_samples = int(0.3 * self.sample_rate) # 300ms pause | |
pause = torch.zeros(pause_samples, device=self.device) | |
audio_parts = [] | |
for i, segment_audio in enumerate(all_audio_segments): | |
audio_parts.append(segment_audio) | |
if i < len(all_audio_segments) - 1: | |
audio_parts.append(pause) | |
audio = torch.cat(audio_parts) | |
logger.info(f"Combined {len(all_audio_segments)} segments into final audio") | |
else: | |
# For shorter text, standard processing | |
tokens, tokens_mask = [], [] | |
# Add context segments | |
for segment in context: | |
segment_tokens, segment_tokens_mask = self._tokenize_segment(segment) | |
tokens.append(segment_tokens) | |
tokens_mask.append(segment_tokens_mask) | |
# Process text | |
cleaned_text = clean_text_for_tts(text) | |
gen_segment_tokens, gen_segment_tokens_mask = self._tokenize_text_segment(cleaned_text, speaker) | |
tokens.append(gen_segment_tokens) | |
tokens_mask.append(gen_segment_tokens_mask) | |
prompt_tokens = torch.cat(tokens, dim=0).long().to(self.device) | |
prompt_tokens_mask = torch.cat(tokens_mask, dim=0).bool().to(self.device) | |
# Check context size | |
if prompt_tokens.size(0) >= max_seq_len: | |
logger.warning(f"Inputs too long ({prompt_tokens.size(0)} tokens), truncating to {max_seq_len - 50}") | |
prompt_tokens = prompt_tokens[-max_seq_len+50:] | |
prompt_tokens_mask = prompt_tokens_mask[-max_seq_len+50:] | |
# Generate audio - optimized batch generation | |
curr_tokens = prompt_tokens.unsqueeze(0) | |
curr_tokens_mask = prompt_tokens_mask.unsqueeze(0) | |
curr_pos = torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device) | |
# Using optimized batch generation | |
batch_size = 32 # Generate this many frames at once | |
all_samples = [] | |
for start_idx in range(0, max_audio_frames, batch_size): | |
end_idx = min(start_idx + batch_size, max_audio_frames) | |
batch_frames = end_idx - start_idx | |
samples_batch = [] | |
for i in range(batch_frames): | |
sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk) | |
samples_batch.append(sample) | |
if torch.all(sample == 0): | |
break | |
curr_tokens = torch.cat([sample, torch.zeros(1, 1).long().to(self.device)], dim=1).unsqueeze(1) | |
curr_tokens_mask = torch.cat( | |
[torch.ones_like(sample).bool(), torch.zeros(1, 1).bool().to(self.device)], dim=1 | |
).unsqueeze(1) | |
curr_pos = curr_pos[:, -1:] + 1 | |
all_samples.extend(samples_batch) | |
if len(samples_batch) < batch_frames: | |
logger.info(f"Early stopping at frame {start_idx + len(samples_batch)}/{max_audio_frames}") | |
break | |
if not all_samples: | |
raise RuntimeError("No audio generated - model produced empty output") | |
# Decode audio | |
audio = self._audio_tokenizer.decode(torch.stack(all_samples).permute(1, 2, 0)).squeeze(0).squeeze(0) | |
# Apply watermark | |
if self._watermarker is not None: | |
try: | |
audio, wm_sample_rate = watermark(self._watermarker, audio, self.sample_rate, CSM_1B_GH_WATERMARK) | |
audio = torchaudio.functional.resample(audio, orig_freq=wm_sample_rate, new_freq=self.sample_rate) | |
except Exception as e: | |
logger.warning(f"Error applying watermark: {e}. Continuing without watermark.") | |
# Record execution time | |
end_time.record() | |
torch.cuda.synchronize() | |
execution_ms = start_time.elapsed_time(end_time) | |
audio_length_ms = (audio.shape[0] / self.sample_rate) * 1000 | |
# Calculate real-time factor (RTF) | |
rtf = execution_ms / audio_length_ms | |
logger.info(f"Audio generated in {execution_ms:.2f}ms, length: {audio_length_ms:.2f}ms, RTF: {rtf:.2f}x") | |
return audio | |
# Define helper functions for multi-GPU support | |
def _manual_device_map(model, state_dict, strategy="balanced"): | |
"""Apply manual device mapping for multi-GPU setups. | |
Args: | |
model: The model to map | |
state_dict: Model state dict | |
strategy: Mapping strategy ('balanced', 'sequential') | |
Returns: | |
Model with weights distributed across GPUs | |
""" | |
num_gpus = torch.cuda.device_count() | |
if num_gpus <= 1: | |
# No need for mapping with single GPU | |
model.load_state_dict(state_dict) | |
model = model.to("cuda") | |
return model | |
logger.info(f"Applying manual {strategy} device mapping across {num_gpus} GPUs") | |
# Get all layer names from state dict | |
layer_names = [name for name in state_dict.keys() if "layers" in name] | |
backbone_layers = [name for name in layer_names if "backbone.layers" in name] | |
decoder_layers = [name for name in layer_names if "decoder.layers" in name] | |
# Count number of backbone and decoder layers | |
backbone_layer_indices = set() | |
for name in backbone_layers: | |
parts = name.split('.') | |
if len(parts) > 2: | |
try: | |
backbone_layer_indices.add(int(parts[2])) | |
except ValueError: | |
pass | |
decoder_layer_indices = set() | |
for name in decoder_layers: | |
parts = name.split('.') | |
if len(parts) > 2: | |
try: | |
decoder_layer_indices.add(int(parts[2])) | |
except ValueError: | |
pass | |
num_backbone_layers = len(backbone_layer_indices) | |
num_decoder_layers = len(decoder_layer_indices) | |
# Create device map | |
device_map = {} | |
if strategy == "balanced": | |
# Distribute layers evenly across GPUs | |
layers_per_gpu = (num_backbone_layers + num_decoder_layers) // num_gpus | |
remainder = (num_backbone_layers + num_decoder_layers) % num_gpus | |
# Assign backbone layers | |
for i in backbone_layer_indices: | |
gpu_idx = min(i // layers_per_gpu, num_gpus - 1) | |
device_map[f"backbone.layers.{i}"] = f"cuda:{gpu_idx}" | |
# Assign decoder layers | |
for i in decoder_layer_indices: | |
gpu_idx = min((i + num_backbone_layers) // layers_per_gpu, num_gpus - 1) | |
device_map[f"decoder.layers.{i}"] = f"cuda:{gpu_idx}" | |
elif strategy == "sequential": | |
# Fill each GPU sequentially | |
# Backbone layers on first GPU(s) | |
backbone_per_gpu = max(1, num_backbone_layers // ((num_gpus + 1) // 2)) | |
for i in backbone_layer_indices: | |
gpu_idx = min(i // backbone_per_gpu, (num_gpus + 1) // 2 - 1) | |
device_map[f"backbone.layers.{i}"] = f"cuda:{gpu_idx}" | |
# Decoder layers on remaining GPU(s) | |
decoder_per_gpu = max(1, num_decoder_layers // (num_gpus - (num_gpus + 1) // 2 + 1)) | |
for i in decoder_layer_indices: | |
gpu_idx = min(i // decoder_per_gpu + (num_gpus + 1) // 2 - 1, num_gpus - 1) | |
device_map[f"decoder.layers.{i}"] = f"cuda:{gpu_idx}" | |
# Assign embeddings and other components | |
device_map["text_embeddings"] = "cuda:0" | |
device_map["audio_embeddings"] = "cuda:0" | |
device_map["projection"] = "cuda:0" | |
device_map["codebook0_head"] = "cuda:0" | |
device_map["audio_head"] = "cuda:0" | |
# Load state dict with device mapping | |
model.load_state_dict(state_dict) | |
# Move model parts to assigned devices | |
for name, device in device_map.items(): | |
if "backbone.layers" in name: | |
layer_idx = int(name.split('.')[-1]) | |
if hasattr(model.backbone, 'layers') and layer_idx < len(model.backbone.layers): | |
model.backbone.layers[layer_idx] = model.backbone.layers[layer_idx].to(device) | |
elif "decoder.layers" in name: | |
layer_idx = int(name.split('.')[-1]) | |
if hasattr(model.decoder, 'layers') and layer_idx < len(model.decoder.layers): | |
model.decoder.layers[layer_idx] = model.decoder.layers[layer_idx].to(device) | |
elif hasattr(model, name): | |
setattr(model, name, getattr(model, name).to(device)) | |
logger.info(f"Model distributed across GPUs with {strategy} strategy") | |
return model | |
def load_csm_1b(ckpt_path: str = "ckpt.pt", device: str = "cuda", device_map: str = None) -> Generator: | |
"""Load CSM-1B model and create generator with performance optimizations. | |
Args: | |
ckpt_path: Path to model checkpoint | |
device: Device to load model on ('cuda', 'cpu', or specific CUDA device) | |
device_map: Optional device mapping strategy ('auto', 'balanced', 'sequential', or None) | |
Returns: | |
Generator instance with optimized settings | |
""" | |
try: | |
# Import models module for CSM | |
from app.torchtune_models import Model, ModelArgs | |
# Create model | |
model_args = ModelArgs( | |
backbone_flavor="llama-1B", | |
decoder_flavor="llama-100M", | |
text_vocab_size=128256, | |
audio_vocab_size=2051, | |
audio_num_codebooks=32, | |
) | |
# Load model | |
logger.info(f"Loading CSM-1B model from {ckpt_path} with device={device}, device_map={device_map}") | |
# Check for CUDA availability | |
cuda_available = device == "cuda" and torch.cuda.is_available() | |
# Set up torch for optimized inference | |
if cuda_available: | |
# Check if we should enable TF32 (faster but slightly less precise) | |
enable_tf32 = os.environ.get("ENABLE_TF32", "true").lower() == "true" | |
if enable_tf32: | |
logger.info("Enabling TF32 for faster matrix multiplications") | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
# Check for available precision modes | |
use_bfloat16 = torch.cuda.is_bf16_supported() | |
use_float16 = not use_bfloat16 and torch.cuda.is_available() # Fallback to float16 | |
if use_bfloat16: | |
dtype = torch.bfloat16 | |
logger.info("Using bfloat16 precision for faster inference") | |
elif use_float16: | |
dtype = torch.float16 | |
logger.info("Using float16 precision for faster inference") | |
else: | |
dtype = torch.float32 | |
logger.info("Using float32 precision (mixed precision not available)") | |
# Enable Flash Attention if available | |
try: | |
import flash_attn | |
if os.environ.get("ENABLE_FLASH_ATTN", "true").lower() == "true": | |
logger.info("Flash Attention detected - enabling for faster attention") | |
os.environ["PYTORCH_FLASH_ATTENTION_ENABLED"] = "1" | |
except ImportError: | |
logger.info("Flash Attention not available (install flash-attn for faster inference)") | |
else: | |
# CPU-only mode | |
dtype = torch.float32 | |
logger.info("Using CPU mode with float32 precision") | |
# Check for quantization | |
enable_quantization = os.environ.get("ENABLE_QUANTIZATION", "false").lower() == "true" | |
is_quantized = False | |
# Check for multi-GPU setup | |
if device_map and torch.cuda.device_count() > 1: | |
logger.info(f"Using device_map={device_map} across {torch.cuda.device_count()} GPUs") | |
# Create model with device map | |
model = Model(model_args) | |
# Load state dict | |
state_dict = torch.load(ckpt_path, map_location='cpu') | |
# Try quantization before device mapping if enabled | |
if enable_quantization and cuda_available: | |
try: | |
from bitsandbytes.nn import Linear8bitLt | |
def replace_with_8bit(model): | |
"""Replace linear layers with 8-bit quantized versions""" | |
for name, module in model.named_modules(): | |
if isinstance(module, torch.nn.Linear) and module.out_features > 256: | |
parent_name = name.rsplit('.', 1)[0] if '.' in name else '' | |
parent = model | |
if parent_name: | |
for attr in parent_name.split('.'): | |
parent = getattr(parent, attr) | |
child_name = name.rsplit('.', 1)[1] if '.' in name else name | |
setattr(parent, child_name, Linear8bitLt.from_float(module)) | |
return model | |
logger.info("Applying 8-bit quantization to linear layers") | |
model = replace_with_8bit(model) | |
is_quantized = True | |
except ImportError: | |
logger.warning("bitsandbytes not available, skipping quantization") | |
# Apply device mapping | |
if device_map == "auto": | |
# Use accelerate for automatic device mapping | |
try: | |
from accelerate import init_empty_weights, load_checkpoint_and_dispatch | |
# Initialize empty model | |
with init_empty_weights(): | |
empty_model = Model(model_args) | |
# Load and dispatch model across GPUs | |
model = load_checkpoint_and_dispatch( | |
empty_model, | |
ckpt_path, | |
device_map="auto", | |
no_split_module_classes=["TransformerLayer"], | |
# Offload CPU if very large model | |
offload_folder="offload" if os.environ.get("OFFLOAD_TO_CPU", "false").lower() == "true" else None | |
) | |
logger.info("Model loaded with automatic device mapping") | |
except ImportError: | |
logger.warning("accelerate package not found, falling back to manual device mapping") | |
model = _manual_device_map(model, state_dict, "balanced") | |
except Exception as mapping_error: | |
logger.error(f"Auto device mapping failed: {mapping_error}, falling back to manual") | |
model = _manual_device_map(model, state_dict, "balanced") | |
else: | |
# Manual device mapping | |
model = _manual_device_map(model, state_dict, device_map or "balanced") | |
else: | |
# Single GPU or CPU setup | |
# Try quantization before loading if enabled (GPU only) | |
if enable_quantization and cuda_available and not is_quantized: | |
try: | |
# First load to CPU for quantization | |
model = Model(model_args).to("cpu") | |
state_dict = torch.load(ckpt_path, map_location="cpu") | |
model.load_state_dict(state_dict) | |
from bitsandbytes.nn import Linear8bitLt | |
def replace_with_8bit(model): | |
"""Replace linear layers with 8-bit quantized versions""" | |
for name, module in model.named_modules(): | |
if isinstance(module, torch.nn.Linear) and module.out_features > 256: | |
parent_name = name.rsplit('.', 1)[0] if '.' in name else '' | |
parent = model | |
if parent_name: | |
for attr in parent_name.split('.'): | |
parent = getattr(parent, attr) | |
child_name = name.rsplit('.', 1)[1] if '.' in name else name | |
setattr(parent, child_name, Linear8bitLt.from_float(module)) | |
return model | |
logger.info("Applying 8-bit quantization to linear layers") | |
model = replace_with_8bit(model) | |
model = model.to(device=device) | |
is_quantized = True | |
except ImportError: | |
logger.warning("bitsandbytes not available, loading without quantization") | |
# Load the standard way | |
model = Model(model_args).to(device=device, dtype=dtype) | |
state_dict = torch.load(ckpt_path, map_location=device) | |
model.load_state_dict(state_dict) | |
except Exception as quant_error: | |
logger.error(f"Quantization failed: {quant_error}, loading without quantization") | |
# Load the standard way | |
model = Model(model_args).to(device=device, dtype=dtype) | |
state_dict = torch.load(ckpt_path, map_location=device) | |
model.load_state_dict(state_dict) | |
else: | |
# Standard load without quantization | |
model = Model(model_args).to(device=device, dtype=dtype) | |
state_dict = torch.load(ckpt_path, map_location=device) | |
model.load_state_dict(state_dict) | |
# Apply torch.compile if available (PyTorch 2.0+) | |
compile_mode = os.environ.get("TORCH_COMPILE_MODE", "none") | |
if hasattr(torch, 'compile') and compile_mode != "none" and cuda_available: | |
try: | |
logger.info(f"Using torch.compile with mode '{compile_mode}' for faster inference") | |
if compile_mode == "default": | |
model = torch.compile(model) | |
else: | |
model = torch.compile(model, mode=compile_mode) | |
except Exception as compile_error: | |
logger.warning(f"Torch compile failed (requires PyTorch 2.0+): {compile_error}") | |
# Try to optimize CUDA graphs for faster inference (advanced) | |
use_cuda_graphs = os.environ.get("USE_CUDA_GRAPHS", "false").lower() == "true" | |
if use_cuda_graphs and cuda_available and hasattr(torch.cuda, 'CUDAGraph'): | |
try: | |
logger.info("Setting up CUDA graphs for repeated inference patterns") | |
# This requires custom integration inside the model's forward method | |
# Just flagging that CUDA graphs should be used | |
model.use_cuda_graphs = True | |
except Exception as cuda_graph_error: | |
logger.warning(f"CUDA graphs setup failed: {cuda_graph_error}") | |
model.use_cuda_graphs = False | |
# Set optimal settings for CUDA context | |
if cuda_available: | |
# Set benchmark mode for hardware-specific optimizations | |
torch.backends.cudnn.benchmark = True | |
# Clean up CUDA cache before creating generator | |
torch.cuda.empty_cache() | |
# Ensure all CUDA work is completed to avoid launch delays | |
torch.cuda.synchronize() | |
# Create generator | |
logger.info("Creating generator with optimized settings") | |
generator = Generator(model) | |
# Log memory usage if on CUDA | |
if cuda_available: | |
memory_allocated = torch.cuda.memory_allocated() / (1024**3) | |
memory_reserved = torch.cuda.memory_reserved() / (1024**3) | |
logger.info(f"Model loaded, CUDA memory: {memory_allocated:.2f}GB allocated, {memory_reserved:.2f}GB reserved") | |
logger.info(f"Generator created successfully: precision={dtype}, quantized={is_quantized}") | |
return generator | |
except Exception as e: | |
logger.error(f"Failed to load CSM-1B model: {e}") | |
import traceback | |
logger.error(traceback.format_exc()) | |
raise |