File size: 1,910 Bytes
e218cb8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35adc06
e218cb8
 
 
 
35adc06
e218cb8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from pathlib import Path

import torch
import whisperx
from whisperx.audio import SAMPLE_RATE

from src.data.chapters import sec_to_hms

# Set device and disable TF32 for consistent results
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False


class ASRProcessor:
    """
    Automatic Speech Recognition processor using WhisperX.

    Transcribes audio files and returns time-aligned transcription segments.
    """

    def __init__(self, model_name="large-v2", compute_type="float16"):
        self.model_name = model_name
        self.model = whisperx.load_model(model_name, device, compute_type=compute_type)

    def get_asr(self, audio_file, return_duration=True):
        assert Path(audio_file).exists(), f"File {audio_file} does not exist"
        audio = whisperx.load_audio(audio_file)
        result = self.model.transcribe(audio, batch_size=1)
        language = result["language"]
        duration = audio.shape[0] / SAMPLE_RATE

        # Align the transcription
        model_a, metadata = whisperx.load_align_model(
            language_code=language, device=device
        )
        aligned_result = whisperx.align(
            result["segments"],
            model_a,
            metadata,
            audio,
            device,
            return_char_alignments=False,
        )

        # Format the output
        segments = [
            {field: segment[field] for field in ["start", "end", "text"]}
            for segment in aligned_result["segments"]
        ]

        asr_clean = []
        for segment in segments:
            t = segment["text"].strip()
            s = sec_to_hms(segment["start"])
            asr_clean.append(f"{s}: {t}")

        if return_duration:
            return "\n".join(asr_clean) + "\n", duration
        else:
            return "\n".join(asr_clean) + "\n"