File size: 4,341 Bytes
1b97239
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# Standard library imports
import os
from typing import Annotated, List, Dict

# Related third-party imports
import torch
from faster_whisper import decode_audio
from ctc_forced_aligner import (
    generate_emissions,
    get_alignments,
    get_spans,
    load_alignment_model,
    postprocess_results,
    preprocess_text,
)


class ForcedAligner:
    """
    ForcedAligner is a class for aligning audio to a provided transcript using a pre-trained alignment model.

    Attributes
    ----------
    device : str
        Device to run the model on ('cuda' for GPU or 'cpu').
    alignment_model : torch.nn.Module
        The pre-trained alignment model.
    alignment_tokenizer : Any
        Tokenizer for processing text in alignment.

    Methods
    -------
    align(audio_path, transcript, language, batch_size)
        Aligns audio with a transcript and returns word-level timing information.
    """

    def __init__(self, device: Annotated[str, "Device for model ('cuda' or 'cpu')"] = None):
        """
        Initialize the ForcedAligner with the specified device.

        Parameters
        ----------
        device : str, optional
            Device for running the model, by default 'cuda' if available, otherwise 'cpu'.
        """
        self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')

        self.alignment_model, self.alignment_tokenizer = load_alignment_model(
            self.device,
            dtype=torch.float16 if self.device == 'cuda' else torch.float32,
        )

    def align(
            self,
            audio_path: Annotated[str, "Path to the audio file"],
            transcript: Annotated[str, "Transcript of the audio content"],
            language: Annotated[str, "Language of the transcript"] = 'en',
            batch_size: Annotated[int, "Batch size for emission generation"] = 8,
    ) -> Annotated[List[Dict[str, float]], "List of word alignment data with timestamps"]:
        """
        Aligns audio with a transcript and returns word-level timing information.

        Parameters
        ----------
        audio_path : str
            Path to the audio file.
        transcript : str
            Transcript text corresponding to the audio.
        language : str, optional
            Language code for the transcript, default is 'en' (English).
        batch_size : int, optional
            Batch size for generating emissions, by default 8.

        Returns
        -------
        List[Dict[str, float]]
            A list of dictionaries containing word timing information.

        Raises
        ------
        FileNotFoundError
            If the specified audio file does not exist.

        Examples
        --------
        >>> aligner = ForcedAligner()
        >>> aligner.align("path/to/audio.wav", "hello world")
        [{'word': 'hello', 'start': 0.0, 'end': 0.5}, {'word': 'world', 'start': 0.6, 'end': 1.0}]
        """
        if not os.path.exists(audio_path):
            raise FileNotFoundError(
                f"The audio file at path '{audio_path}' was not found."
            )

        speech_array = torch.from_numpy(decode_audio(audio_path))

        emissions, stride = generate_emissions(
            self.alignment_model,
            speech_array.to(self.alignment_model.dtype).to(self.alignment_model.device),
            batch_size=batch_size,
        )

        tokens_starred, text_starred = preprocess_text(
            transcript,
            romanize=True,
            language=language,
        )

        segments, scores, blank_token = get_alignments(
            emissions,
            tokens_starred,
            self.alignment_tokenizer,
        )

        spans = get_spans(tokens_starred, segments, blank_token)

        word_timestamps = postprocess_results(text_starred, spans, stride, scores)

        if self.device == 'cuda':
            del self.alignment_model
            torch.cuda.empty_cache()

        print(f"Word_Timestamps: {word_timestamps}")

        return word_timestamps


if __name__ == "__main__":

    forced_aligner = ForcedAligner()
    try:
        path = "example_audio.wav"
        audio_transcript = "This is a test transcript."
        word_timestamp = forced_aligner.align(path, audio_transcript)
        print(word_timestamp)
    except FileNotFoundError as e:
        print(e)