chapter-llama / tools /extract /asr_whisperx.py
lucas-ventura's picture
Rename tools/extract/asr.py to tools/extract/asr_whisperx.py
35adc06 verified
from pathlib import Path
import torch
import whisperx
from whisperx.audio import SAMPLE_RATE
from src.data.chapters import sec_to_hms
# Set device and disable TF32 for consistent results
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
class ASRProcessor:
"""
Automatic Speech Recognition processor using WhisperX.
Transcribes audio files and returns time-aligned transcription segments.
"""
def __init__(self, model_name="large-v2", compute_type="float16"):
self.model_name = model_name
self.model = whisperx.load_model(model_name, device, compute_type=compute_type)
def get_asr(self, audio_file, return_duration=True):
assert Path(audio_file).exists(), f"File {audio_file} does not exist"
audio = whisperx.load_audio(audio_file)
result = self.model.transcribe(audio, batch_size=1)
language = result["language"]
duration = audio.shape[0] / SAMPLE_RATE
# Align the transcription
model_a, metadata = whisperx.load_align_model(
language_code=language, device=device
)
aligned_result = whisperx.align(
result["segments"],
model_a,
metadata,
audio,
device,
return_char_alignments=False,
)
# Format the output
segments = [
{field: segment[field] for field in ["start", "end", "text"]}
for segment in aligned_result["segments"]
]
asr_clean = []
for segment in segments:
t = segment["text"].strip()
s = sec_to_hms(segment["start"])
asr_clean.append(f"{s}: {t}")
if return_duration:
return "\n".join(asr_clean) + "\n", duration
else:
return "\n".join(asr_clean) + "\n"