File size: 4,043 Bytes
6728136
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import spaces
import torch
import gradio as gr
import librosa
import numpy as np
import json
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from scipy.signal import butter, lfilter

# Charger la liste des modèles depuis un fichier JSON
def load_model_list(file_path="model_list.json"):
    try:
        with open(file_path, "r") as f:
            return json.load(f)
    except Exception as e:
        raise ValueError(f"Erreur lors du chargement de la liste des modèles : {str(e)}")

# Charger les modèles depuis le fichier JSON
MODEL_LIST = load_model_list()

# Fonction pour charger le modèle et le processeur
def load_model_and_processor(model_name):
    model_info = MODEL_LIST.get(model_name)
    if not model_info:
        raise ValueError("Modèle non trouvé dans la liste.")
    model_path = model_info["model_path"]
    processor = WhisperProcessor.from_pretrained(model_path)
    model = WhisperForConditionalGeneration.from_pretrained(model_path)
    model.eval()
    return processor, model

# Nettoyage et normalisation de l'audio
def preprocess_audio(audio, sr=16000):
    # Charger l'audio
    audio_data, _ = librosa.load(audio, sr=sr)
    # Filtrage passe-bas pour réduire les bruits aigus
    b, a = butter(6, 0.1, btype="low", analog=False)
    audio_data = lfilter(b, a, audio_data)
    # Normaliser l'audio
    audio_data = librosa.util.normalize(audio_data)
    return audio_data

# Fonction pour transcrire l'audio
@spaces.GPU(duration=120)
def transcribe_audio(audio, model_name):
    try:
        # Charger le modèle et le processeur en fonction du choix
        processor, model = load_model_and_processor(model_name)
        
        # Nettoyer et normaliser l'audio
        audio_input = preprocess_audio(audio)
        
        # Prétraiter l'audio avec le processeur
        inputs = processor(audio_input, sampling_rate=16000, return_tensors="pt")
        inputs["attention_mask"] = torch.ones_like(inputs["input_features"]).to(inputs["input_features"].dtype)

        # Faire la prédiction 
        with torch.no_grad():
            predicted_ids = model.generate(
                inputs['input_features'], 
                forced_decoder_ids=None,  # Suppression du conflit
                language="fr",  # Ajustez selon votre langue cible
                task="transcribe"
            )
        
        # Convertir les IDs de prédiction en texte
        transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
        
        return transcription[0]
    
    except Exception as e:
        return f"Erreur de transcription : {str(e)}"

# Charger une seule fois le tableau (statique)
MODEL_TABLE = [
    [name, details.get("dataset", "Non spécifié"), details.get("performance", {}).get("WER", "Non spécifié"), details.get("performance", {}).get("CER", "Non spécifié")]
    for name, details in MODEL_LIST.items()
]

# Interface Gradio
with gr.Blocks() as app:
    # Section principale
    with gr.Row():
        with gr.Column(scale=2):
            gr.Markdown("## Téléchargez ou enregistrez un fichier audio")
            audio_input = gr.Audio(type="filepath", label="Audio (télécharger ou enregistrer)")
            model_dropdown = gr.Dropdown(choices=list(MODEL_LIST.keys()), label="Sélectionnez un modèle", value="Wolof ASR - dofbi")
            submit_button = gr.Button("Transcrire")
        with gr.Column(scale=3):
            transcription_output = gr.Textbox(label="Transcription", lines=6)

    # Tableau statique en bas
    gr.Markdown("## Informations sur les modèles disponibles")
    gr.Dataframe(
        headers=["Nom du modèle", "Dataset utilisé", "WER", "CER"], 
        value=MODEL_TABLE,
        interactive=False,
        label="Informations sur les modèles"
    )
    
    # Action du bouton
    submit_button.click(
        fn=transcribe_audio,
        inputs=[audio_input, model_dropdown],
        outputs=transcription_output
    )

# Lancer l'application
if __name__ == "__main__":
    app.launch()