Spaces:
Running
Running

Refactor model structure: update import paths from 'app.modelz' to 'app.models' across multiple files for consistency, remove obsolete 'modelz' directory, and adjust Dockerfile and migration script to reflect these changes, enhancing clarity and organization in the codebase.
c27f115
# Updated generator.py with proper function order | |
from dataclasses import dataclass | |
from typing import List, Tuple | |
import torch | |
import torchaudio | |
import logging | |
import os | |
from huggingface_hub import hf_hub_download | |
from transformers import AutoTokenizer | |
from tokenizers.processors import TemplateProcessing | |
from app.model import Segment | |
from app.text_normalizer import clean_text_for_tts | |
from app.text_normalizer import TextNormalizer | |
# Set up logging | |
logger = logging.getLogger(__name__) | |
# Import the CSM watermarking code | |
try: | |
from app.watermarking import CSM_1B_GH_WATERMARK, load_watermarker, watermark | |
except ImportError: | |
# Define stubs for watermarking if the module is not available | |
CSM_1B_GH_WATERMARK = "CSM1B" | |
def load_watermarker(device="cpu"): | |
return None | |
def watermark(watermarker, audio, sample_rate, key): | |
return audio, sample_rate | |
def load_llama3_tokenizer(): | |
""" | |
Load tokenizer for Llama 3.2, using unsloth's open version | |
instead of the gated meta-llama version. | |
""" | |
try: | |
# Use the unsloth version which is not gated | |
tokenizer_name = "unsloth/Llama-3.2-1B" | |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) | |
bos = tokenizer.bos_token | |
eos = tokenizer.eos_token | |
tokenizer._tokenizer.post_processor = TemplateProcessing( | |
single=f"{bos}:0 $A:0 {eos}:0", | |
pair=f"{bos}:0 $A:0 {eos}:0 {bos}:1 $B:1 {eos}:1", | |
special_tokens=[(f"{bos}", tokenizer.bos_token_id), (f"{eos}", tokenizer.eos_token_id)], | |
) | |
logger.info("Successfully loaded tokenizer from unsloth/Llama-3.2-1B") | |
return tokenizer | |
except Exception as e: | |
logger.error(f"Error loading tokenizer from unsloth: {e}") | |
# Fallback to a simpler tokenizer if needed | |
try: | |
from transformers import GPT2Tokenizer | |
logger.warning("Falling back to GPT2Tokenizer") | |
tokenizer = GPT2Tokenizer.from_pretrained("gpt2") | |
tokenizer.pad_token = tokenizer.eos_token | |
return tokenizer | |
except Exception as fallback_e: | |
logger.error(f"Fallback tokenizer also failed: {fallback_e}") | |
raise RuntimeError("Could not load any suitable tokenizer") | |
class Generator: | |
"""Generator class for CSM-1B model.""" | |
def __init__(self, model): | |
"""Initialize generator with model.""" | |
self._model = model | |
self._model.setup_caches(1) | |
self._text_tokenizer = load_llama3_tokenizer() | |
device = next(model.parameters()).device | |
# Load Mimi codec for audio tokenization | |
try: | |
logger.info("Loading Mimi audio codec...") | |
from huggingface_hub import hf_hub_download | |
# First try to import from moshi | |
try: | |
from moshi.models import loaders | |
DEFAULT_REPO = loaders.DEFAULT_REPO | |
MIMI_NAME = loaders.MIMI_NAME | |
get_mimi = loaders.get_mimi | |
except ImportError: | |
logger.warning("moshi.models.loaders not found, using fallback") | |
# Fallback values if moshi.models.loaders is not available | |
DEFAULT_REPO = "kyutai/mimi" | |
MIMI_NAME = "mimi-december.pt" | |
# Fallback function to load mimi | |
def get_mimi(checkpoint_path, device): | |
from moshi.models.vqvae_model import MiMiModule | |
checkpoint = torch.load(checkpoint_path, map_location=device) | |
model = MiMiModule.init_from_checkpoint(checkpoint, device=device) | |
return model | |
mimi_weight = hf_hub_download(DEFAULT_REPO, MIMI_NAME) | |
mimi = get_mimi(mimi_weight, device=device) | |
mimi.set_num_codebooks(32) | |
self._audio_tokenizer = mimi | |
self.sample_rate = mimi.sample_rate | |
logger.info(f"Mimi codec loaded successfully with sample rate {self.sample_rate}") | |
except Exception as e: | |
logger.error(f"Error loading Mimi codec: {e}") | |
self._audio_tokenizer = None | |
self.sample_rate = 24000 # Default sample rate | |
logger.warning(f"Using fallback sample rate: {self.sample_rate}") | |
raise RuntimeError(f"Failed to load Mimi codec: {e}") | |
try: | |
self._watermarker = load_watermarker(device=device) | |
logger.info("Watermarker loaded successfully") | |
except Exception as e: | |
logger.warning(f"Error loading watermarker: {e}. Watermarking will be disabled.") | |
self._watermarker = None | |
self.device = device | |
# Optimize for CUDA throughput | |
if torch.cuda.is_available(): | |
torch.backends.cudnn.benchmark = True | |
torch.cuda.empty_cache() | |
logger.info("CUDA optimizations enabled") | |
def _tokenize_text_segment(self, text: str, speaker: int) -> Tuple[torch.Tensor, torch.Tensor]: | |
"""Tokenize a text segment.""" | |
frame_tokens = [] | |
frame_masks = [] | |
# Strip any voice instructions in square brackets to avoid them being read out | |
text = self._clean_text_input(text) | |
text_tokens = self._text_tokenizer.encode(f"[{speaker}]{text}") | |
text_frame = torch.zeros(len(text_tokens), 33).long() | |
text_frame_mask = torch.zeros(len(text_tokens), 33).bool() | |
text_frame[:, -1] = torch.tensor(text_tokens) | |
text_frame_mask[:, -1] = True | |
frame_tokens.append(text_frame.to(self.device)) | |
frame_masks.append(text_frame_mask.to(self.device)) | |
return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0) | |
def _clean_text_input(self, text: str) -> str: | |
"""Clean and normalize text for TTS.""" | |
return clean_text_for_tts(text) | |
def _tokenize_audio(self, audio: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
"""Tokenize audio.""" | |
if self._audio_tokenizer is None: | |
raise RuntimeError("Audio tokenizer not initialized") | |
frame_tokens = [] | |
frame_masks = [] | |
# (K, T) | |
audio = audio.to(self.device) | |
audio_tokens = self._audio_tokenizer.encode(audio.unsqueeze(0).unsqueeze(0))[0] | |
# add EOS frame | |
eos_frame = torch.zeros(audio_tokens.size(0), 1).to(self.device) | |
audio_tokens = torch.cat([audio_tokens, eos_frame], dim=1) | |
audio_frame = torch.zeros(audio_tokens.size(1), 33).long().to(self.device) | |
audio_frame_mask = torch.zeros(audio_tokens.size(1), 33).bool().to(self.device) | |
audio_frame[:, :-1] = audio_tokens.transpose(0, 1) | |
audio_frame_mask[:, :-1] = True | |
frame_tokens.append(audio_frame) | |
frame_masks.append(audio_frame_mask) | |
return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0) | |
def _tokenize_segment(self, segment: Segment) -> Tuple[torch.Tensor, torch.Tensor]: | |
"""Tokenize a segment of text and audio.""" | |
text_tokens, text_masks = self._tokenize_text_segment(segment.text, segment.speaker) | |
audio_tokens, audio_masks = self._tokenize_audio(segment.audio) | |
return torch.cat([text_tokens, audio_tokens], dim=0), torch.cat([text_masks, audio_masks], dim=0) | |
def generate_quick( | |
self, | |
text: str, | |
speaker: int, | |
context: List[Segment], | |
max_audio_length_ms: float = 2000, # Short for quick generation | |
temperature: float = 0.7, # Lower for more predictable output | |
topk: int = 20, # Lower for faster beam selection | |
) -> torch.Tensor: | |
"""Generate audio quickly for real-time streaming.""" | |
# Similar to generate() but optimized for speed | |
self._model.reset_caches() | |
# Convert max_audio_length_ms to frames - limit for faster generation | |
max_audio_frames = min(int(max_audio_length_ms / 80), 128) # Smaller limit | |
# Process text | |
cleaned_text = clean_text_for_tts(text) | |
# Prepare tokens | |
tokens, tokens_mask = [], [] | |
# Add context segments (limited to 1 for speed) | |
if context: | |
segment_tokens, segment_tokens_mask = self._tokenize_segment(context[0]) | |
tokens.append(segment_tokens) | |
tokens_mask.append(segment_tokens_mask) | |
# Add text tokens | |
gen_segment_tokens, gen_segment_tokens_mask = self._tokenize_text_segment(cleaned_text, speaker) | |
tokens.append(gen_segment_tokens) | |
tokens_mask.append(gen_segment_tokens_mask) | |
prompt_tokens = torch.cat(tokens, dim=0).long().to(self.device) | |
prompt_tokens_mask = torch.cat(tokens_mask, dim=0).bool().to(self.device) | |
# Generate with larger batch size for fewer iterations | |
curr_tokens = prompt_tokens.unsqueeze(0) | |
curr_tokens_mask = prompt_tokens_mask.unsqueeze(0) | |
curr_pos = torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device) | |
# Use larger batch size | |
batch_size = 64 # Generate more frames at once | |
all_samples = [] | |
for start_idx in range(0, max_audio_frames, batch_size): | |
end_idx = min(start_idx + batch_size, max_audio_frames) | |
batch_frames = end_idx - start_idx | |
samples_batch = [] | |
for i in range(batch_frames): | |
sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk) | |
samples_batch.append(sample) | |
if torch.all(sample == 0): | |
break | |
curr_tokens = torch.cat([sample, torch.zeros(1, 1).long().to(self.device)], dim=1).unsqueeze(1) | |
curr_tokens_mask = torch.cat( | |
[torch.ones_like(sample).bool(), torch.zeros(1, 1).bool().to(self.device)], dim=1 | |
).unsqueeze(1) | |
curr_pos = curr_pos[:, -1:] + 1 | |
all_samples.extend(samples_batch) | |
if len(samples_batch) < batch_frames: | |
break | |
if not all_samples: | |
return torch.zeros(10, device=self.device) # Return short empty audio | |
# Decode audio | |
audio = self._audio_tokenizer.decode(torch.stack(all_samples).permute(1, 2, 0)).squeeze(0).squeeze(0) | |
return audio | |
def generate( | |
self, | |
text: str, | |
speaker: int, | |
context: List[Segment], | |
max_audio_length_ms: float = 90_000, | |
temperature: float = 0.9, | |
topk: int = 50, | |
) -> torch.Tensor: | |
"""Generate audio from text.""" | |
if self._audio_tokenizer is None: | |
raise RuntimeError("Audio tokenizer not initialized") | |
# Start timing | |
start_time = torch.cuda.Event(enable_timing=True) | |
end_time = torch.cuda.Event(enable_timing=True) | |
start_time.record() | |
self._model.reset_caches() | |
# Convert max_audio_length_ms to frames - this controls the maximum generation length | |
max_audio_frames = min(int(max_audio_length_ms / 80), 1024) # Limit to reasonable size | |
max_seq_len = 2048 - max_audio_frames | |
# Check if text is long and should be split | |
if len(text) > 200: | |
logger.info(f"Long text detected ({len(text)} chars), processing in segments") | |
sentences = TextNormalizer.split_into_sentences(text) | |
logger.info(f"Split into {len(sentences)} segments") | |
# Process sentences individually and concatenate the results | |
all_audio_segments = [] | |
# Use the first sentence to establish voice | |
first_sentence = sentences[0] | |
cleaned_text = clean_text_for_tts(first_sentence) | |
# Generate the first segment | |
tokens, tokens_mask = [], [] | |
# Add context segments for the first sentence | |
for segment in context: | |
segment_tokens, segment_tokens_mask = self._tokenize_segment(segment) | |
tokens.append(segment_tokens) | |
tokens_mask.append(segment_tokens_mask) | |
# Add first sentence tokens | |
gen_segment_tokens, gen_segment_tokens_mask = self._tokenize_text_segment(cleaned_text, speaker) | |
tokens.append(gen_segment_tokens) | |
tokens_mask.append(gen_segment_tokens_mask) | |
prompt_tokens = torch.cat(tokens, dim=0).long().to(self.device) | |
prompt_tokens_mask = torch.cat(tokens_mask, dim=0).bool().to(self.device) | |
# Check context size and truncate if needed | |
if prompt_tokens.size(0) >= max_seq_len: | |
logger.warning(f"Inputs too long ({prompt_tokens.size(0)} tokens), truncating to {max_seq_len - 50}") | |
prompt_tokens = prompt_tokens[-max_seq_len+50:] | |
prompt_tokens_mask = prompt_tokens_mask[-max_seq_len+50:] | |
# Generate first sentence audio | |
curr_tokens = prompt_tokens.unsqueeze(0) | |
curr_tokens_mask = prompt_tokens_mask.unsqueeze(0) | |
curr_pos = torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device) | |
# Generate first segment | |
first_segment_samples = [] | |
for start_idx in range(0, max_audio_frames, 32): | |
end_idx = min(start_idx + 32, max_audio_frames) | |
batch_frames = end_idx - start_idx | |
samples_batch = [] | |
for i in range(batch_frames): | |
sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk) | |
samples_batch.append(sample) | |
if torch.all(sample == 0): | |
break | |
curr_tokens = torch.cat([sample, torch.zeros(1, 1).long().to(self.device)], dim=1).unsqueeze(1) | |
curr_tokens_mask = torch.cat( | |
[torch.ones_like(sample).bool(), torch.zeros(1, 1).bool().to(self.device)], dim=1 | |
).unsqueeze(1) | |
curr_pos = curr_pos[:, -1:] + 1 | |
first_segment_samples.extend(samples_batch) | |
if len(samples_batch) < batch_frames: | |
break | |
if not first_segment_samples: | |
raise RuntimeError("No audio generated for first segment") | |
# Decode first segment | |
first_segment_audio = self._audio_tokenizer.decode( | |
torch.stack(first_segment_samples).permute(1, 2, 0) | |
).squeeze(0).squeeze(0) | |
all_audio_segments.append(first_segment_audio) | |
# Now process remaining sentences using the first as context | |
for i, sentence in enumerate(sentences[1:], 1): | |
logger.info(f"Generating segment {i+1}/{len(sentences)}") | |
cleaned_text = clean_text_for_tts(sentence) | |
# Create a context segment from the previous generation | |
prev_segment = Segment( | |
speaker=speaker, | |
text=sentences[i-1], | |
audio=all_audio_segments[-1] | |
) | |
# Generate with this segment as context | |
segment_tokens, segment_tokens_mask = [], [] | |
segment_tokens.append(self._tokenize_segment(prev_segment)[0]) | |
segment_tokens_mask.append(self._tokenize_segment(prev_segment)[1]) | |
# Add current segment tokens | |
current_tokens, current_tokens_mask = self._tokenize_text_segment(cleaned_text, speaker) | |
segment_tokens.append(current_tokens) | |
segment_tokens_mask.append(current_tokens_mask) | |
segment_prompt_tokens = torch.cat(segment_tokens, dim=0).long().to(self.device) | |
segment_prompt_tokens_mask = torch.cat(segment_tokens_mask, dim=0).bool().to(self.device) | |
# Check length and truncate if needed | |
if segment_prompt_tokens.size(0) >= max_seq_len: | |
segment_prompt_tokens = segment_prompt_tokens[-max_seq_len+50:] | |
segment_prompt_tokens_mask = segment_prompt_tokens_mask[-max_seq_len+50:] | |
# Generate audio for this segment | |
curr_tokens = segment_prompt_tokens.unsqueeze(0) | |
curr_tokens_mask = segment_prompt_tokens_mask.unsqueeze(0) | |
curr_pos = torch.arange(0, segment_prompt_tokens.size(0)).unsqueeze(0).long().to(self.device) | |
# Generate segment | |
segment_samples = [] | |
for start_idx in range(0, max_audio_frames, 32): | |
end_idx = min(start_idx + 32, max_audio_frames) | |
batch_frames = end_idx - start_idx | |
samples_batch = [] | |
for i in range(batch_frames): | |
sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk) | |
samples_batch.append(sample) | |
if torch.all(sample == 0): | |
break | |
curr_tokens = torch.cat([sample, torch.zeros(1, 1).long().to(self.device)], dim=1).unsqueeze(1) | |
curr_tokens_mask = torch.cat( | |
[torch.ones_like(sample).bool(), torch.zeros(1, 1).bool().to(self.device)], dim=1 | |
).unsqueeze(1) | |
curr_pos = curr_pos[:, -1:] + 1 | |
segment_samples.extend(samples_batch) | |
if len(samples_batch) < batch_frames: | |
break | |
if not segment_samples: | |
logger.warning(f"No audio generated for segment {i+1}") | |
continue | |
# Decode segment | |
segment_audio = self._audio_tokenizer.decode( | |
torch.stack(segment_samples).permute(1, 2, 0) | |
).squeeze(0).squeeze(0) | |
all_audio_segments.append(segment_audio) | |
# Combine all segments with small pauses | |
pause_samples = int(0.3 * self.sample_rate) # 300ms pause | |
pause = torch.zeros(pause_samples, device=self.device) | |
audio_parts = [] | |
for i, segment_audio in enumerate(all_audio_segments): | |
audio_parts.append(segment_audio) | |
if i < len(all_audio_segments) - 1: | |
audio_parts.append(pause) | |
audio = torch.cat(audio_parts) | |
logger.info(f"Combined {len(all_audio_segments)} segments into final audio") | |
else: | |
# For shorter text, standard processing | |
tokens, tokens_mask = [], [] | |
# Add context segments | |
for segment in context: | |
segment_tokens, segment_tokens_mask = self._tokenize_segment(segment) | |
tokens.append(segment_tokens) | |
tokens_mask.append(segment_tokens_mask) | |
# Process text | |
cleaned_text = clean_text_for_tts(text) | |
gen_segment_tokens, gen_segment_tokens_mask = self._tokenize_text_segment(cleaned_text, speaker) | |
tokens.append(gen_segment_tokens) | |
tokens_mask.append(gen_segment_tokens_mask) | |
prompt_tokens = torch.cat(tokens, dim=0).long().to(self.device) | |
prompt_tokens_mask = torch.cat(tokens_mask, dim=0).bool().to(self.device) | |
# Check context size | |
if prompt_tokens.size(0) >= max_seq_len: | |
logger.warning(f"Inputs too long ({prompt_tokens.size(0)} tokens), truncating to {max_seq_len - 50}") | |
prompt_tokens = prompt_tokens[-max_seq_len+50:] | |
prompt_tokens_mask = prompt_tokens_mask[-max_seq_len+50:] | |
# Generate audio - optimized batch generation | |
curr_tokens = prompt_tokens.unsqueeze(0) | |
curr_tokens_mask = prompt_tokens_mask.unsqueeze(0) | |
curr_pos = torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device) | |
# Using optimized batch generation | |
batch_size = 32 # Generate this many frames at once | |
all_samples = [] | |
for start_idx in range(0, max_audio_frames, batch_size): | |
end_idx = min(start_idx + batch_size, max_audio_frames) | |
batch_frames = end_idx - start_idx | |
samples_batch = [] | |
for i in range(batch_frames): | |
sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk) | |
samples_batch.append(sample) | |
if torch.all(sample == 0): | |
break | |
curr_tokens = torch.cat([sample, torch.zeros(1, 1).long().to(self.device)], dim=1).unsqueeze(1) | |
curr_tokens_mask = torch.cat( | |
[torch.ones_like(sample).bool(), torch.zeros(1, 1).bool().to(self.device)], dim=1 | |
).unsqueeze(1) | |
curr_pos = curr_pos[:, -1:] + 1 | |
all_samples.extend(samples_batch) | |
if len(samples_batch) < batch_frames: | |
logger.info(f"Early stopping at frame {start_idx + len(samples_batch)}/{max_audio_frames}") | |
break | |
if not all_samples: | |
raise RuntimeError("No audio generated - model produced empty output") | |
# Decode audio | |
audio = self._audio_tokenizer.decode(torch.stack(all_samples).permute(1, 2, 0)).squeeze(0).squeeze(0) | |
# Apply watermark | |
if self._watermarker is not None: | |
try: | |
audio, wm_sample_rate = watermark(self._watermarker, audio, self.sample_rate, CSM_1B_GH_WATERMARK) | |
audio = torchaudio.functional.resample(audio, orig_freq=wm_sample_rate, new_freq=self.sample_rate) | |
except Exception as e: | |
logger.warning(f"Error applying watermark: {e}. Continuing without watermark.") | |
# Record execution time | |
end_time.record() | |
torch.cuda.synchronize() | |
execution_ms = start_time.elapsed_time(end_time) | |
audio_length_ms = (audio.shape[0] / self.sample_rate) * 1000 | |
# Calculate real-time factor (RTF) | |
rtf = execution_ms / audio_length_ms | |
logger.info(f"Audio generated in {execution_ms:.2f}ms, length: {audio_length_ms:.2f}ms, RTF: {rtf:.2f}x") | |
return audio | |
# Define helper functions for multi-GPU support | |
def _manual_device_map(model, state_dict, strategy="balanced"): | |
"""Apply manual device mapping for multi-GPU setups. | |
Args: | |
model: The model to map | |
state_dict: Model state dict | |
strategy: Mapping strategy ('balanced', 'sequential') | |
Returns: | |
Model with weights distributed across GPUs | |
""" | |
num_gpus = torch.cuda.device_count() | |
if num_gpus <= 1: | |
# No need for mapping with single GPU | |
model.load_state_dict(state_dict) | |
model = model.to("cuda") | |
return model | |
logger.info(f"Applying manual {strategy} device mapping across {num_gpus} GPUs") | |
# Get all layer names from state dict | |
layer_names = [name for name in state_dict.keys() if "layers" in name] | |
backbone_layers = [name for name in layer_names if "backbone.layers" in name] | |
decoder_layers = [name for name in layer_names if "decoder.layers" in name] | |
# Count number of backbone and decoder layers | |
backbone_layer_indices = set() | |
for name in backbone_layers: | |
parts = name.split('.') | |
if len(parts) > 2: | |
try: | |
backbone_layer_indices.add(int(parts[2])) | |
except ValueError: | |
pass | |
decoder_layer_indices = set() | |
for name in decoder_layers: | |
parts = name.split('.') | |
if len(parts) > 2: | |
try: | |
decoder_layer_indices.add(int(parts[2])) | |
except ValueError: | |
pass | |
num_backbone_layers = len(backbone_layer_indices) | |
num_decoder_layers = len(decoder_layer_indices) | |
# Create device map | |
device_map = {} | |
if strategy == "balanced": | |
# Distribute layers evenly across GPUs | |
layers_per_gpu = (num_backbone_layers + num_decoder_layers) // num_gpus | |
remainder = (num_backbone_layers + num_decoder_layers) % num_gpus | |
# Assign backbone layers | |
for i in backbone_layer_indices: | |
gpu_idx = min(i // layers_per_gpu, num_gpus - 1) | |
device_map[f"backbone.layers.{i}"] = f"cuda:{gpu_idx}" | |
# Assign decoder layers | |
for i in decoder_layer_indices: | |
gpu_idx = min((i + num_backbone_layers) // layers_per_gpu, num_gpus - 1) | |
device_map[f"decoder.layers.{i}"] = f"cuda:{gpu_idx}" | |
elif strategy == "sequential": | |
# Fill each GPU sequentially | |
# Backbone layers on first GPU(s) | |
backbone_per_gpu = max(1, num_backbone_layers // ((num_gpus + 1) // 2)) | |
for i in backbone_layer_indices: | |
gpu_idx = min(i // backbone_per_gpu, (num_gpus + 1) // 2 - 1) | |
device_map[f"backbone.layers.{i}"] = f"cuda:{gpu_idx}" | |
# Decoder layers on remaining GPU(s) | |
decoder_per_gpu = max(1, num_decoder_layers // (num_gpus - (num_gpus + 1) // 2 + 1)) | |
for i in decoder_layer_indices: | |
gpu_idx = min(i // decoder_per_gpu + (num_gpus + 1) // 2 - 1, num_gpus - 1) | |
device_map[f"decoder.layers.{i}"] = f"cuda:{gpu_idx}" | |
# Assign embeddings and other components | |
device_map["text_embeddings"] = "cuda:0" | |
device_map["audio_embeddings"] = "cuda:0" | |
device_map["projection"] = "cuda:0" | |
device_map["codebook0_head"] = "cuda:0" | |
device_map["audio_head"] = "cuda:0" | |
# Load state dict with device mapping | |
model.load_state_dict(state_dict) | |
# Move model parts to assigned devices | |
for name, device in device_map.items(): | |
if "backbone.layers" in name: | |
layer_idx = int(name.split('.')[-1]) | |
if hasattr(model.backbone, 'layers') and layer_idx < len(model.backbone.layers): | |
model.backbone.layers[layer_idx] = model.backbone.layers[layer_idx].to(device) | |
elif "decoder.layers" in name: | |
layer_idx = int(name.split('.')[-1]) | |
if hasattr(model.decoder, 'layers') and layer_idx < len(model.decoder.layers): | |
model.decoder.layers[layer_idx] = model.decoder.layers[layer_idx].to(device) | |
elif hasattr(model, name): | |
setattr(model, name, getattr(model, name).to(device)) | |
logger.info(f"Model distributed across GPUs with {strategy} strategy") | |
return model | |
def load_csm_1b(ckpt_path: str = "ckpt.pt", device: str = "cuda", device_map: str = None) -> Generator: | |
"""Load CSM-1B model and create generator with performance optimizations. | |
Args: | |
ckpt_path: Path to model checkpoint | |
device: Device to load model on ('cuda', 'cpu', or specific CUDA device) | |
device_map: Optional device mapping strategy ('auto', 'balanced', 'sequential', or None) | |
Returns: | |
Generator instance with optimized settings | |
""" | |
try: | |
# Import models module for CSM | |
from app.torchtune_models import Model, ModelArgs | |
# Create model | |
model_args = ModelArgs( | |
backbone_flavor="llama-1B", | |
decoder_flavor="llama-100M", | |
text_vocab_size=128256, | |
audio_vocab_size=2051, | |
audio_num_codebooks=32, | |
) | |
# Load model | |
logger.info(f"Loading CSM-1B model from {ckpt_path} with device={device}, device_map={device_map}") | |
# Check for CUDA availability | |
cuda_available = device == "cuda" and torch.cuda.is_available() | |
# Set up torch for optimized inference | |
if cuda_available: | |
# Check if we should enable TF32 (faster but slightly less precise) | |
enable_tf32 = os.environ.get("ENABLE_TF32", "true").lower() == "true" | |
if enable_tf32: | |
logger.info("Enabling TF32 for faster matrix multiplications") | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
# Check for available precision modes | |
use_bfloat16 = torch.cuda.is_bf16_supported() | |
use_float16 = not use_bfloat16 and torch.cuda.is_available() # Fallback to float16 | |
if use_bfloat16: | |
dtype = torch.bfloat16 | |
logger.info("Using bfloat16 precision for faster inference") | |
elif use_float16: | |
dtype = torch.float16 | |
logger.info("Using float16 precision for faster inference") | |
else: | |
dtype = torch.float32 | |
logger.info("Using float32 precision (mixed precision not available)") | |
# Enable Flash Attention if available | |
try: | |
import flash_attn | |
if os.environ.get("ENABLE_FLASH_ATTN", "true").lower() == "true": | |
logger.info("Flash Attention detected - enabling for faster attention") | |
os.environ["PYTORCH_FLASH_ATTENTION_ENABLED"] = "1" | |
except ImportError: | |
logger.info("Flash Attention not available (install flash-attn for faster inference)") | |
else: | |
# CPU-only mode | |
dtype = torch.float32 | |
logger.info("Using CPU mode with float32 precision") | |
# Check for quantization | |
enable_quantization = os.environ.get("ENABLE_QUANTIZATION", "false").lower() == "true" | |
is_quantized = False | |
# Check for multi-GPU setup | |
if device_map and torch.cuda.device_count() > 1: | |
logger.info(f"Using device_map={device_map} across {torch.cuda.device_count()} GPUs") | |
# Create model with device map | |
model = Model(model_args) | |
# Load state dict | |
state_dict = torch.load(ckpt_path, map_location='cpu') | |
# Try quantization before device mapping if enabled | |
if enable_quantization and cuda_available: | |
try: | |
from bitsandbytes.nn import Linear8bitLt | |
def replace_with_8bit(model): | |
"""Replace linear layers with 8-bit quantized versions""" | |
for name, module in model.named_modules(): | |
if isinstance(module, torch.nn.Linear) and module.out_features > 256: | |
parent_name = name.rsplit('.', 1)[0] if '.' in name else '' | |
parent = model | |
if parent_name: | |
for attr in parent_name.split('.'): | |
parent = getattr(parent, attr) | |
child_name = name.rsplit('.', 1)[1] if '.' in name else name | |
setattr(parent, child_name, Linear8bitLt.from_float(module)) | |
return model | |
logger.info("Applying 8-bit quantization to linear layers") | |
model = replace_with_8bit(model) | |
is_quantized = True | |
except ImportError: | |
logger.warning("bitsandbytes not available, skipping quantization") | |
# Apply device mapping | |
if device_map == "auto": | |
# Use accelerate for automatic device mapping | |
try: | |
from accelerate import init_empty_weights, load_checkpoint_and_dispatch | |
# Initialize empty model | |
with init_empty_weights(): | |
empty_model = Model(model_args) | |
# Load and dispatch model across GPUs | |
model = load_checkpoint_and_dispatch( | |
empty_model, | |
ckpt_path, | |
device_map="auto", | |
no_split_module_classes=["TransformerLayer"], | |
# Offload CPU if very large model | |
offload_folder="offload" if os.environ.get("OFFLOAD_TO_CPU", "false").lower() == "true" else None | |
) | |
logger.info("Model loaded with automatic device mapping") | |
except ImportError: | |
logger.warning("accelerate package not found, falling back to manual device mapping") | |
model = _manual_device_map(model, state_dict, "balanced") | |
except Exception as mapping_error: | |
logger.error(f"Auto device mapping failed: {mapping_error}, falling back to manual") | |
model = _manual_device_map(model, state_dict, "balanced") | |
else: | |
# Manual device mapping | |
model = _manual_device_map(model, state_dict, device_map or "balanced") | |
else: | |
# Single GPU or CPU setup | |
# Try quantization before loading if enabled (GPU only) | |
if enable_quantization and cuda_available and not is_quantized: | |
try: | |
# First load to CPU for quantization | |
model = Model(model_args).to("cpu") | |
state_dict = torch.load(ckpt_path, map_location="cpu") | |
model.load_state_dict(state_dict) | |
from bitsandbytes.nn import Linear8bitLt | |
def replace_with_8bit(model): | |
"""Replace linear layers with 8-bit quantized versions""" | |
for name, module in model.named_modules(): | |
if isinstance(module, torch.nn.Linear) and module.out_features > 256: | |
parent_name = name.rsplit('.', 1)[0] if '.' in name else '' | |
parent = model | |
if parent_name: | |
for attr in parent_name.split('.'): | |
parent = getattr(parent, attr) | |
child_name = name.rsplit('.', 1)[1] if '.' in name else name | |
setattr(parent, child_name, Linear8bitLt.from_float(module)) | |
return model | |
logger.info("Applying 8-bit quantization to linear layers") | |
model = replace_with_8bit(model) | |
model = model.to(device=device) | |
is_quantized = True | |
except ImportError: | |
logger.warning("bitsandbytes not available, loading without quantization") | |
# Load the standard way | |
model = Model(model_args).to(device=device, dtype=dtype) | |
state_dict = torch.load(ckpt_path, map_location=device) | |
model.load_state_dict(state_dict) | |
except Exception as quant_error: | |
logger.error(f"Quantization failed: {quant_error}, loading without quantization") | |
# Load the standard way | |
model = Model(model_args).to(device=device, dtype=dtype) | |
state_dict = torch.load(ckpt_path, map_location=device) | |
model.load_state_dict(state_dict) | |
else: | |
# Standard load without quantization | |
model = Model(model_args).to(device=device, dtype=dtype) | |
state_dict = torch.load(ckpt_path, map_location=device) | |
model.load_state_dict(state_dict) | |
# Apply torch.compile if available (PyTorch 2.0+) | |
compile_mode = os.environ.get("TORCH_COMPILE_MODE", "none") | |
if hasattr(torch, 'compile') and compile_mode != "none" and cuda_available: | |
try: | |
logger.info(f"Using torch.compile with mode '{compile_mode}' for faster inference") | |
if compile_mode == "default": | |
model = torch.compile(model) | |
else: | |
model = torch.compile(model, mode=compile_mode) | |
except Exception as compile_error: | |
logger.warning(f"Torch compile failed (requires PyTorch 2.0+): {compile_error}") | |
# Try to optimize CUDA graphs for faster inference (advanced) | |
use_cuda_graphs = os.environ.get("USE_CUDA_GRAPHS", "false").lower() == "true" | |
if use_cuda_graphs and cuda_available and hasattr(torch.cuda, 'CUDAGraph'): | |
try: | |
logger.info("Setting up CUDA graphs for repeated inference patterns") | |
# This requires custom integration inside the model's forward method | |
# Just flagging that CUDA graphs should be used | |
model.use_cuda_graphs = True | |
except Exception as cuda_graph_error: | |
logger.warning(f"CUDA graphs setup failed: {cuda_graph_error}") | |
model.use_cuda_graphs = False | |
# Set optimal settings for CUDA context | |
if cuda_available: | |
# Set benchmark mode for hardware-specific optimizations | |
torch.backends.cudnn.benchmark = True | |
# Clean up CUDA cache before creating generator | |
torch.cuda.empty_cache() | |
# Ensure all CUDA work is completed to avoid launch delays | |
torch.cuda.synchronize() | |
# Create generator | |
logger.info("Creating generator with optimized settings") | |
generator = Generator(model) | |
# Log memory usage if on CUDA | |
if cuda_available: | |
memory_allocated = torch.cuda.memory_allocated() / (1024**3) | |
memory_reserved = torch.cuda.memory_reserved() / (1024**3) | |
logger.info(f"Model loaded, CUDA memory: {memory_allocated:.2f}GB allocated, {memory_reserved:.2f}GB reserved") | |
logger.info(f"Generator created successfully: precision={dtype}, quantized={is_quantized}") | |
return generator | |
except Exception as e: | |
logger.error(f"Failed to load CSM-1B model: {e}") | |
import traceback | |
logger.error(traceback.format_exc()) | |
raise |