File size: 4,322 Bytes
00613da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
from typing import Any, Dict

import gradio as gr
import librosa
import numpy as np
import torch
from transformers import WavLMForSequenceClassification


def feature_extract_simple(
    wav,
    sr=16_000,
    win_len=15.0,
    win_stride=15.0,
    do_normalize=False,
) -> np.ndarray:
    """Simple feature extraction for WavLM.
    Parameters
    ----------
    wav : str or array-like
        path to the wav file, or array-like
    sr : int, optional
        sample rate, by default 16_000
    win_len : float, optional
        window length, by default 15.0
    win_stride : float, optional
        window stride, by default 15.0
    do_normalize: bool, optional
        whether to normalize the input, by default False.
    Returns
    -------
    np.ndarray
        batched input to WavLM
    """
    if type(wav) == str:
        signal, _ = librosa.core.load(wav, sr=sr)
    else:
        try:
            signal = np.array(wav).squeeze()
        except Exception as e:
            print(e)
            raise RuntimeError
    batched_input = []
    stride = int(win_stride * sr)
    l = int(win_len * sr)
    if len(signal) / sr > win_len:
        for i in range(0, len(signal), stride):
            if i + int(win_len * sr) > len(signal):
                # padding the last chunk to make it the same length as others
                chunked = np.pad(signal[i:], (0, l - len(signal[i:])))
            else:
                chunked = signal[i : i + l]
            if do_normalize:
                chunked = (chunked - np.mean(chunked)) / (np.std(chunked) + 1e-7)
            batched_input.append(chunked)
            if i + int(win_len * sr) > len(signal):
                break
    else:
        if do_normalize:
            signal = (signal - np.mean(signal)) / (np.std(signal) + 1e-7)
        batched_input.append(signal)
    return np.stack(batched_input)  # [N, T]


def infer(model, inputs) -> torch.Tensor:
    output = model(inputs)
    probs = torch.sigmoid(torch.Tensor(output.logits))
    return probs


def predict(audio_file) -> Dict[str, Any]:
    if audio_file is None:
        return {"No prediction available": 0.0}

    try:
        input_np = feature_extract_simple(audio_file, sr=16000, do_normalize=True)
        input_pt = torch.Tensor(input_np)

        probs = infer(model, input_pt)
        probs_list = probs.reshape(-1, len(labels)).detach().tolist()

        # Create a results dictionary
        if len(probs_list) > 0:
            first_segment_probs = probs_list[0]
            results = {
                label: float(prob) for label, prob in zip(labels, first_segment_probs)
            }

            # If there are multiple segments, include that information in the results
            if len(probs_list) > 1:
                results["Note"] = (
                    f"Audio contains {len(probs_list)} segments. Showing first segment only."
                )
        else:
            results = {"Error": "No segments detected in audio"}

        # Sort by confidence score
        sorted_results = dict(sorted(results.items(), key=lambda x: x[1], reverse=True))

        return sorted_results
    except Exception as e:
        return {"Error": str(e)}


if __name__ == "__main__":
    model_path = "Roblox/voice-safety-classifier-v2"
    labels = [
        "Discrimination",
        "Harassment",
        "Sexual",
        "IllegalAndRegulated",
        "DatingAndRomantic",
        "Profanity",
    ]

    model = WavLMForSequenceClassification.from_pretrained(
        model_path, num_labels=len(labels)
    )
    model.eval()

    demo = gr.Interface(
        fn=predict,
        inputs=gr.Audio(type="filepath", label="Upload or record audio"),
        outputs=gr.Label(num_top_classes=6, label="Classification Results"),
        title="Voice Safety Classifier",
        description="""This app uses the Roblox Voice Safety Classifier v2 model to identify potentially unsafe content in audio.
    Upload or record an audio file to get started. The model classifies audio into categories including Discrimination,
    Harassment, Sexual, IllegalAndRegulated, DatingAndRomantic, and Profanity.

    The model processes audio in 15-second chunks and returns probability scores for each category.""",
        examples=[],
        flagging_mode="never",
    )

    demo.launch()