csm-1b / app.py
alethanhson
fix
bd02d7a
import base64
import io
import logging
from typing import List, Optional
import torch
import torchaudio
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from generator import load_csm_1b, Segment
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(
title="CSM 1B API",
description="API for Sesame's Conversational Speech Model",
version="1.0.0",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
generator = None
class SegmentRequest(BaseModel):
speaker: int
text: str
audio_base64: Optional[str] = None
class GenerateAudioRequest(BaseModel):
text: str
speaker: int
context: List[SegmentRequest] = []
max_audio_length_ms: float = 10000
temperature: float = 0.9
topk: int = 50
class AudioResponse(BaseModel):
audio_base64: str
sample_rate: int
@app.on_event("startup")
async def startup_event():
global generator
logger.info("Loading CSM 1B model...")
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cpu":
logger.warning("GPU not available. Using CPU, performance may be slow!")
try:
generator = load_csm_1b(device=device)
logger.info(f"Model loaded successfully on device: {device}")
except Exception as e:
logger.error(f"Could not load model: {str(e)}")
raise e
@app.post("/generate-audio", response_model=AudioResponse)
async def generate_audio(request: GenerateAudioRequest):
global generator
if generator is None:
raise HTTPException(status_code=503, detail="Model not loaded. Please try again later.")
try:
context_segments = []
for segment in request.context:
if segment.audio_base64:
audio_bytes = base64.b64decode(segment.audio_base64)
audio_buffer = io.BytesIO(audio_bytes)
audio_tensor, sample_rate = torchaudio.load(audio_buffer)
audio_tensor = torchaudio.functional.resample(
audio_tensor.squeeze(0),
orig_freq=sample_rate,
new_freq=generator.sample_rate
)
else:
audio_tensor = torch.zeros(0, dtype=torch.float32)
context_segments.append(
Segment(text=segment.text, speaker=segment.speaker, audio=audio_tensor)
)
audio = generator.generate(
text=request.text,
speaker=request.speaker,
context=context_segments,
max_audio_length_ms=request.max_audio_length_ms,
temperature=request.temperature,
topk=request.topk,
)
buffer = io.BytesIO()
torchaudio.save(buffer, audio.unsqueeze(0).cpu(), generator.sample_rate, format="wav")
# torchaudio.save("audio.wav", audio.unsqueeze(0).cpu(), generator.sample_rate)
buffer.seek(0)
# audio_base64 = base64.b64encode(buffer.read()).decode("utf-8")
return AudioResponse(
content=buffer.read(),
media_type="audio/wav",
headers={"Content-Disposition": "attachment; filename=audio.wav"}
)
except Exception as e:
logger.error(f"error when building audio: {str(e)}")
raise HTTPException(status_code=500, detail=f"error when building audio: {str(e)}")
@app.get("/health")
async def health_check():
if generator is None:
return {"status": "not_ready", "message": "Model is loading"}
return {"status": "ready", "message": "API is ready to serve"}