Spaces:
Running
Running
import base64 | |
import io | |
import logging | |
from typing import List, Optional | |
import torch | |
import torchaudio | |
import uvicorn | |
from fastapi import FastAPI, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
from generator import load_csm_1b, Segment | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
app = FastAPI( | |
title="CSM 1B API", | |
description="API for Sesame's Conversational Speech Model", | |
version="1.0.0", | |
) | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
generator = None | |
class SegmentRequest(BaseModel): | |
speaker: int | |
text: str | |
audio_base64: Optional[str] = None | |
class GenerateAudioRequest(BaseModel): | |
text: str | |
speaker: int | |
context: List[SegmentRequest] = [] | |
max_audio_length_ms: float = 10000 | |
temperature: float = 0.9 | |
topk: int = 50 | |
class AudioResponse(BaseModel): | |
audio_base64: str | |
sample_rate: int | |
async def startup_event(): | |
global generator | |
logger.info("Loading CSM 1B model...") | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
if device == "cpu": | |
logger.warning("GPU not available. Using CPU, performance may be slow!") | |
try: | |
generator = load_csm_1b(device=device) | |
logger.info(f"Model loaded successfully on device: {device}") | |
except Exception as e: | |
logger.error(f"Could not load model: {str(e)}") | |
raise e | |
async def generate_audio(request: GenerateAudioRequest): | |
global generator | |
if generator is None: | |
raise HTTPException(status_code=503, detail="Model not loaded. Please try again later.") | |
try: | |
context_segments = [] | |
for segment in request.context: | |
if segment.audio_base64: | |
audio_bytes = base64.b64decode(segment.audio_base64) | |
audio_buffer = io.BytesIO(audio_bytes) | |
audio_tensor, sample_rate = torchaudio.load(audio_buffer) | |
audio_tensor = torchaudio.functional.resample( | |
audio_tensor.squeeze(0), | |
orig_freq=sample_rate, | |
new_freq=generator.sample_rate | |
) | |
else: | |
audio_tensor = torch.zeros(0, dtype=torch.float32) | |
context_segments.append( | |
Segment(text=segment.text, speaker=segment.speaker, audio=audio_tensor) | |
) | |
audio = generator.generate( | |
text=request.text, | |
speaker=request.speaker, | |
context=context_segments, | |
max_audio_length_ms=request.max_audio_length_ms, | |
temperature=request.temperature, | |
topk=request.topk, | |
) | |
buffer = io.BytesIO() | |
torchaudio.save(buffer, audio.unsqueeze(0).cpu(), generator.sample_rate, format="wav") | |
# torchaudio.save("audio.wav", audio.unsqueeze(0).cpu(), generator.sample_rate) | |
buffer.seek(0) | |
# audio_base64 = base64.b64encode(buffer.read()).decode("utf-8") | |
return AudioResponse( | |
content=buffer.read(), | |
media_type="audio/wav", | |
headers={"Content-Disposition": "attachment; filename=audio.wav"} | |
) | |
except Exception as e: | |
logger.error(f"error when building audio: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"error when building audio: {str(e)}") | |
async def health_check(): | |
if generator is None: | |
return {"status": "not_ready", "message": "Model is loading"} | |
return {"status": "ready", "message": "API is ready to serve"} | |