nareauow commited on
Commit
d698901
·
verified ·
1 Parent(s): 2fdcc5e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -287
app.py CHANGED
@@ -1,293 +1,37 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- import numpy as np
5
- import scipy.io.wavfile as wav
6
- from scipy.fftpack import idct
7
- import gradio as gr
8
- import os
9
- import matplotlib.pyplot as plt
10
- from huggingface_hub import hf_hub_download
11
- from transformers import Speech2TextForConditionalGeneration, Speech2TextProcessor
12
- from transformers import pipeline, SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
13
- from datasets import load_dataset
14
- import soundfile as sf
15
-
16
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
17
- print(f"Using device: {device}")
18
-
19
- # Load speech-to-text model
20
- try:
21
- speech_recognizer = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-small-librispeech-asr").to(device)
22
- speech_processor = Speech2TextProcessor.from_pretrained("facebook/s2t-small-librispeech-asr")
23
- print("Speech recognition model loaded successfully!")
24
- except Exception as e:
25
- print(f"Error loading speech recognition model: {e}")
26
- speech_recognizer = None
27
- speech_processor = None
28
-
29
- # Load text-to-speech models
30
- try:
31
- # Load processor and model
32
- tts_processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
33
- tts_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts").to(device)
34
- tts_vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").to(device)
35
-
36
- # Load speaker embeddings
37
- speaker_embeddings = torch.load("./speaker_embedding.pt").to(device)
38
- except Exception as e:
39
- print(f"Error loading text-to-speech models: {e}")
40
- tts_processor = None
41
- tts_model = None
42
- tts_vocoder = None
43
- speaker_embeddings = None
44
-
45
- # Modele CNN
46
- class modele_CNN(nn.Module):
47
- def __init__(self, num_classes=7, dropout=0.3):
48
- super(modele_CNN, self).__init__()
49
- self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
50
- self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
51
- self.conv3 = nn.Conv2d(32, 64, 3, padding=1)
52
- self.pool = nn.MaxPool2d(2, 2)
53
- self.fc1 = nn.Linear(64 * 1 * 62, 128)
54
- self.fc2 = nn.Linear(128, num_classes)
55
- self.dropout = nn.Dropout(dropout)
56
-
57
- def forward(self, x):
58
- x = self.pool(F.relu(self.conv1(x)))
59
- x = self.pool(F.relu(self.conv2(x)))
60
- x = self.pool(F.relu(self.conv3(x)))
61
- x = x.view(x.size(0), -1)
62
- x = self.dropout(F.relu(self.fc1(x)))
63
- x = self.fc2(x)
64
- return x
65
-
66
- # Audio processor
67
- class AudioProcessor:
68
- def Mel2Hz(self, mel): return 700 * (np.power(10, mel/2595)-1)
69
- def Hz2Mel(self, freq): return 2595 * np.log10(1+freq/700)
70
- def Hz2Ind(self, freq, fs, Tfft): return (freq*Tfft/fs).astype(int)
71
-
72
- def hamming(self, T):
73
- if T <= 1:
74
- return np.ones(T)
75
- return 0.54-0.46*np.cos(2*np.pi*np.arange(T)/(T-1))
76
-
77
- def FiltresMel(self, fs, nf=36, Tfft=512, fmin=100, fmax=8000):
78
- Indices = self.Hz2Ind(self.Mel2Hz(np.linspace(self.Hz2Mel(fmin), self.Hz2Mel(min(fmax, fs/2)), nf+2)), fs, Tfft)
79
- filtres = np.zeros((int(Tfft/2), nf))
80
- for i in range(nf): filtres[Indices[i]:Indices[i+2], i] = self.hamming(Indices[i+2]-Indices[i])
81
- return filtres
82
-
83
- def spectrogram(self, x, T, p, Tfft):
84
- S = []
85
- for i in range(0, len(x)-T, p): S.append(x[i:i+T]*self.hamming(T))
86
- S = np.fft.fft(S, Tfft)
87
- return np.abs(S), np.angle(S)
88
-
89
- def mfcc(self, data, filtres, nc=13, T=256, p=64, Tfft=512):
90
- data = (data[1]-np.mean(data[1]))/np.std(data[1])
91
- amp, ph = self.spectrogram(data, T, p, Tfft)
92
- amp_f = np.log10(np.dot(amp[:, :int(Tfft/2)], filtres)+1)
93
- return idct(amp_f, n=nc, norm='ortho')
94
-
95
- def process_audio(self, audio_data, sr, audio_length=32000):
96
- if sr != 16000:
97
- audio_resampled = np.interp(
98
- np.linspace(0, len(audio_data), int(16000 * len(audio_data) / sr)),
99
- np.arange(len(audio_data)),
100
- audio_data
101
  )
102
- sgn = audio_resampled
103
- fs = 16000
104
- else:
105
- sgn = audio_data
106
- fs = sr
107
-
108
- sgn = np.array(sgn, dtype=np.float32)
109
-
110
- if len(sgn) > audio_length:
111
- sgn = sgn[:audio_length]
112
- else:
113
- sgn = np.pad(sgn, (0, audio_length - len(sgn)), mode='constant')
114
-
115
- filtres = self.FiltresMel(fs)
116
- sgn_features = self.mfcc([fs, sgn], filtres)
117
-
118
- mfcc_tensor = torch.tensor(sgn_features.T, dtype=torch.float32)
119
- mfcc_tensor = mfcc_tensor.unsqueeze(0).unsqueeze(0)
120
-
121
- return mfcc_tensor
122
-
123
- # Speech recognition function
124
- def recognize_speech(audio_path):
125
- if speech_recognizer is None or speech_processor is None:
126
- return "Speech recognition model not available"
127
-
128
- try:
129
- # Read audio file
130
- audio_data, sr = sf.read(audio_path)
131
-
132
- # Resample to 16kHz if needed
133
- if sr != 16000:
134
- audio_data = np.interp(
135
- np.linspace(0, len(audio_data), int(16000 * len(audio_data) / sr)),
136
- np.arange(len(audio_data)),
137
- audio_data
138
- )
139
- sr = 16000
140
-
141
- # Process audio
142
- inputs = speech_processor(audio_data, sampling_rate=sr, return_tensors="pt")
143
- inputs = {k: v.to(device) for k, v in inputs.items()}
144
-
145
- # Generate transcription
146
- generated_ids = speech_recognizer.generate(**inputs)
147
- transcription = speech_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
148
-
149
- return transcription
150
- except Exception as e:
151
- return f"Speech recognition error: {str(e)}"
152
-
153
- # Speech synthesis function
154
- def synthesize_speech(text):
155
- if tts_processor is None or tts_model is None or tts_vocoder is None or speaker_embeddings is None:
156
- return None
157
-
158
- try:
159
- # Preprocess text
160
- inputs = tts_processor(text=text, return_tensors="pt").to(device)
161
-
162
- # Generate speech with speaker embeddings
163
- spectrogram = tts_model.generate_speech(inputs["input_ids"], speaker_embeddings)
164
-
165
- # Convert to waveform
166
- with torch.no_grad():
167
- speech = tts_vocoder(spectrogram)
168
-
169
- # Convert to numpy array and normalize
170
- speech = speech.cpu().numpy()
171
- speech = speech / np.max(np.abs(speech))
172
-
173
- return (16000, speech.squeeze())
174
- except Exception as e:
175
- print(f"Speech synthesis error: {str(e)}")
176
- return None
177
-
178
- # ... (keep all previous imports and class definitions)
179
-
180
- # Updated predict_speaker function to return consistent values
181
- def predict_speaker(audio, model, processor):
182
- if audio is None:
183
- return "Aucun audio détecté.", {}, "Aucun texte reconnu", "Inconnu" # Now returns 4 values
184
-
185
- try:
186
- audio_data, sr = sf.read(audio)
187
- input_tensor = processor.process_audio(audio_data, sr)
188
-
189
- device = next(model.parameters()).device
190
- input_tensor = input_tensor.to(device)
191
-
192
- with torch.no_grad():
193
- output = model(input_tensor)
194
- print(output) # Debug output
195
- probabilities = F.softmax(output, dim=1)
196
- confidence, predicted_class = torch.max(probabilities, 1)
197
-
198
- speakers = ["George", "Jackson", "Lucas", "Nicolas", "Theo", "Yweweler", "Narimene"]
199
- predicted_speaker = speakers[predicted_class.item()]
200
-
201
- result = f"Locuteur reconnu : {predicted_speaker} (confiance : {confidence.item()*100:.2f}%)"
202
-
203
- probs_dict = {speakers[i]: float(probs) for i, probs in enumerate(probabilities[0].cpu().numpy())}
204
-
205
- # Recognize speech
206
- recognized_text = recognize_speech(audio) if speech_recognizer else "Modèle de reconnaissance vocale non disponible"
207
-
208
- return result, probs_dict, recognized_text, predicted_speaker # Now returns 4 values
209
-
210
- except Exception as e:
211
- return f"Erreur : {str(e)}", {}, "Erreur de reconnaissance", "Inconnu"
212
-
213
- # Updated recognize function
214
- def recognize(audio, selected_model):
215
- model = load_model(model_filename=selected_model)
216
- if model is None:
217
- return "Erreur: Modèle non chargé", None, "Erreur", None
218
-
219
- res, probs, text, speaker = predict_speaker(audio, model, processor) # Now expects 4 values
220
-
221
- # Generate plot
222
- fig = None
223
- if probs:
224
- fig, ax = plt.subplots(figsize=(10, 6))
225
- ax.bar(probs.keys(), probs.values(), color='skyblue')
226
- ax.set_ylim([0, 1])
227
- ax.set_ylabel("Confiance")
228
- ax.set_xlabel("Locuteurs")
229
- ax.set_title("Probabilités de reconnaissance")
230
- plt.xticks(rotation=45)
231
- plt.tight_layout()
232
-
233
- # Generate speech synthesis if text was recognized
234
- synth_audio = None
235
- if synthesizer is not None and text and "erreur" not in text.lower():
236
- try:
237
- synth_text = f"Le locuteur {speaker} a dit : {text}" if speaker else f"Le locuteur a dit : {text}"
238
- synth_audio = synthesize_speech(synth_text)
239
- except Exception as e:
240
- print(f"Erreur de synthèse vocale: {e}")
241
-
242
- return res, fig, text, synth_audio if synth_audio else None
243
-
244
- # Updated interface creation
245
- def create_interface():
246
- processor = AudioProcessor()
247
-
248
- with gr.Blocks(title="Reconnaissance de Locuteur") as interface:
249
- gr.Markdown("# 🗣️ Reconnaissance de Locuteur")
250
- gr.Markdown("Enregistrez votre voix pendant 2 secondes pour identifier qui parle.")
251
-
252
- with gr.Row():
253
- with gr.Column():
254
- # Dropdown pour sélectionner le modèle
255
- model_selector = gr.Dropdown(
256
- choices=["model_1.pth", "model_2.pth", "model_3.pth"],
257
- value="model_3.pth",
258
- label="Choisissez le modèle"
259
- )
260
-
261
- # Créer des onglets pour Microphone et Upload Audio
262
- with gr.Tab("Microphone"):
263
- mic_input = gr.Audio(sources=["microphone"], type="filepath", label="🎙️ Enregistrer depuis le microphone")
264
-
265
- with gr.Tab("Upload Audio"):
266
- file_input = gr.Audio(sources=["upload"], type="filepath", label="📁 Télécharger un fichier audio")
267
-
268
- # Bouton pour démarrer la reconnaissance
269
- record_btn = gr.Button("Reconnaître")
270
-
271
- with gr.Column():
272
- # Résultat, graphique et texte reconnu
273
- result_text = gr.Textbox(label="Résultat")
274
- plot_output = gr.Plot(label="Confiance par locuteur")
275
- recognized_text = gr.Textbox(label="Texte reconnu")
276
- audio_output = gr.Audio(label="Synthèse vocale", visible=False)
277
-
278
- # Fonction de clique pour la reconnaissance
279
- def recognize(audio, selected_model):
280
- # Traitement audio et modèle à charger...
281
- pass # Remplace ici avec ton code de traitement
282
-
283
  # Lier le bouton "Reconnaître" à la fonction
284
  record_btn.click(
285
  fn=recognize,
286
  inputs=[mic_input, file_input, model_selector], # Remplacer Union par les deux inputs distincts
287
  outputs=[result_text, plot_output, recognized_text, audio_output]
288
- )
289
- return interface
290
-
291
- if __name__ == "__main__":
292
- app = create_interface()
293
- app.launch(share=True)
 
1
+ with gr.Row():
2
+ with gr.Column():
3
+ # Dropdown pour sélectionner le modèle
4
+ model_selector = gr.Dropdown(
5
+ choices=["model_1.pth", "model_2.pth", "model_3.pth"],
6
+ value="model_3.pth",
7
+ label="Choisissez le modèle"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  )
9
+
10
+ # Créer des onglets pour Microphone et Upload Audio
11
+ with gr.Tab("Microphone"):
12
+ mic_input = gr.Audio(sources=["microphone"], type="filepath", label="🎙️ Enregistrer depuis le microphone")
13
+
14
+ with gr.Tab("Upload Audio"):
15
+ file_input = gr.Audio(sources=["upload"], type="filepath", label="📁 Télécharger un fichier audio")
16
+
17
+ # Bouton pour démarrer la reconnaissance
18
+ record_btn = gr.Button("Reconnaître")
19
+
20
+ with gr.Column():
21
+ # Résultat, graphique et texte reconnu
22
+ result_text = gr.Textbox(label="Résultat")
23
+ plot_output = gr.Plot(label="Confiance par locuteur")
24
+ recognized_text = gr.Textbox(label="Texte reconnu")
25
+ audio_output = gr.Audio(label="Synthèse vocale", visible=False)
26
+
27
+ # Fonction de clique pour la reconnaissance
28
+ def recognize(audio, selected_model):
29
+ # Traitement audio et modèle à charger...
30
+ pass # Remplace ici avec ton code de traitement
31
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  # Lier le bouton "Reconnaître" à la fonction
33
  record_btn.click(
34
  fn=recognize,
35
  inputs=[mic_input, file_input, model_selector], # Remplacer Union par les deux inputs distincts
36
  outputs=[result_text, plot_output, recognized_text, audio_output]
37
+ )