Spaces:
Running
Running
import functools | |
import uuid | |
import numpy as np | |
from fastapi import ( | |
File, | |
UploadFile, | |
) | |
import gradio as gr | |
from fastapi import APIRouter, BackgroundTasks, Depends, Response, status | |
from typing import List, Dict | |
from sqlalchemy.orm import Session | |
from datetime import datetime | |
from modules.whisper.data_classes import * | |
from modules.utils.paths import BACKEND_CACHE_DIR | |
from modules.whisper.faster_whisper_inference import FasterWhisperInference | |
from backend.common.audio import read_audio | |
from backend.common.models import QueueResponse | |
from backend.common.config_loader import load_server_config | |
from backend.db.task.dao import ( | |
add_task_to_db, | |
get_db_session, | |
update_task_status_in_db | |
) | |
from backend.db.task.models import TaskStatus, TaskType | |
transcription_router = APIRouter(prefix="/transcription", tags=["Transcription"]) | |
def get_pipeline() -> 'FasterWhisperInference': | |
config = load_server_config()["whisper"] | |
inferencer = FasterWhisperInference( | |
output_dir=BACKEND_CACHE_DIR | |
) | |
inferencer.update_model( | |
model_size=config["model_size"], | |
compute_type=config["compute_type"] | |
) | |
return inferencer | |
def run_transcription( | |
audio: np.ndarray, | |
params: TranscriptionPipelineParams, | |
identifier: str, | |
) -> List[Segment]: | |
update_task_status_in_db( | |
identifier=identifier, | |
update_data={ | |
"uuid": identifier, | |
"status": TaskStatus.IN_PROGRESS, | |
"updated_at": datetime.utcnow() | |
}, | |
) | |
segments, elapsed_time = get_pipeline().run( | |
audio, | |
gr.Progress(), | |
"SRT", | |
False, | |
*params.to_list() | |
) | |
segments = [seg.model_dump() for seg in segments] | |
update_task_status_in_db( | |
identifier=identifier, | |
update_data={ | |
"uuid": identifier, | |
"status": TaskStatus.COMPLETED, | |
"result": segments, | |
"updated_at": datetime.utcnow(), | |
"duration": elapsed_time | |
}, | |
) | |
return segments | |
async def transcription( | |
background_tasks: BackgroundTasks, | |
file: UploadFile = File(..., description="Audio or video file to transcribe."), | |
whisper_params: WhisperParams = Depends(), | |
vad_params: VadParams = Depends(), | |
bgm_separation_params: BGMSeparationParams = Depends(), | |
diarization_params: DiarizationParams = Depends(), | |
) -> QueueResponse: | |
if not isinstance(file, np.ndarray): | |
audio, info = await read_audio(file=file) | |
else: | |
audio, info = file, None | |
params = TranscriptionPipelineParams( | |
whisper=whisper_params, | |
vad=vad_params, | |
bgm_separation=bgm_separation_params, | |
diarization=diarization_params | |
) | |
identifier = add_task_to_db( | |
status=TaskStatus.QUEUED, | |
file_name=file.filename, | |
audio_duration=info.duration if info else None, | |
language=params.whisper.lang, | |
task_type=TaskType.TRANSCRIPTION, | |
task_params=params.to_dict(), | |
) | |
background_tasks.add_task( | |
run_transcription, | |
audio=audio, | |
params=params, | |
identifier=identifier, | |
) | |
return QueueResponse(identifier=identifier, status=TaskStatus.QUEUED, message="Transcription task has queued") | |