LAP-DEV commited on
Commit
68f96cf
·
verified ·
1 Parent(s): be967ac

Delete modules/whisper/whisper_base_old.py

Browse files
Files changed (1) hide show
  1. modules/whisper/whisper_base_old.py +0 -648
modules/whisper/whisper_base_old.py DELETED
@@ -1,648 +0,0 @@
1
- import os
2
- import torch
3
- import whisper
4
- import gradio as gr
5
- import torchaudio
6
- from abc import ABC, abstractmethod
7
- from typing import BinaryIO, Union, Tuple, List
8
- import numpy as np
9
- from datetime import datetime
10
- from faster_whisper.vad import VadOptions
11
- from dataclasses import astuple
12
-
13
- from modules.uvr.music_separator import MusicSeparator
14
- from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH,
15
- UVR_MODELS_DIR)
16
- from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, get_csv, write_file, safe_filename
17
- from modules.utils.youtube_manager import get_ytdata, get_ytaudio
18
- from modules.utils.files_manager import get_media_files, format_gradio_files, load_yaml, save_yaml
19
- from modules.whisper.whisper_parameter import *
20
- from modules.diarize.diarizer import Diarizer
21
- from modules.vad.silero_vad import SileroVAD
22
- from modules.translation.nllb_inference import NLLBInference
23
- from modules.translation.nllb_inference import NLLB_AVAILABLE_LANGS
24
-
25
- class WhisperBase(ABC):
26
- def __init__(self,
27
- model_dir: str = WHISPER_MODELS_DIR,
28
- diarization_model_dir: str = DIARIZATION_MODELS_DIR,
29
- uvr_model_dir: str = UVR_MODELS_DIR,
30
- output_dir: str = OUTPUT_DIR,
31
- ):
32
- self.model_dir = model_dir
33
- self.output_dir = output_dir
34
- os.makedirs(self.output_dir, exist_ok=True)
35
- os.makedirs(self.model_dir, exist_ok=True)
36
- self.diarizer = Diarizer(
37
- model_dir=diarization_model_dir
38
- )
39
- self.vad = SileroVAD()
40
- self.music_separator = MusicSeparator(
41
- model_dir=uvr_model_dir,
42
- output_dir=os.path.join(output_dir, "UVR")
43
- )
44
-
45
- self.model = None
46
- self.current_model_size = None
47
- self.available_models = whisper.available_models()
48
- self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
49
- #self.translatable_models = ["large", "large-v1", "large-v2", "large-v3"]
50
- self.translatable_models = whisper.available_models()
51
- self.device = self.get_device()
52
- self.available_compute_types = ["float16", "float32"]
53
- self.current_compute_type = "float16" if self.device == "cuda" else "float32"
54
-
55
- @abstractmethod
56
- def transcribe(self,
57
- audio: Union[str, BinaryIO, np.ndarray],
58
- progress: gr.Progress = gr.Progress(),
59
- *whisper_params,
60
- ):
61
- """Inference whisper model to transcribe"""
62
- pass
63
-
64
- @abstractmethod
65
- def update_model(self,
66
- model_size: str,
67
- compute_type: str,
68
- progress: gr.Progress = gr.Progress()
69
- ):
70
- """Initialize whisper model"""
71
- pass
72
-
73
- def run(self,
74
- audio: Union[str, BinaryIO, np.ndarray],
75
- progress: gr.Progress = gr.Progress(),
76
- add_timestamp: bool = True,
77
- *whisper_params,
78
- ) -> Tuple[List[dict], float]:
79
- """
80
- Run transcription with conditional pre-processing and post-processing.
81
- The VAD will be performed to remove noise from the audio input in pre-processing, if enabled.
82
- The diarization will be performed in post-processing, if enabled.
83
-
84
- Parameters
85
- ----------
86
- audio: Union[str, BinaryIO, np.ndarray]
87
- Audio input. This can be file path or binary type.
88
- progress: gr.Progress
89
- Indicator to show progress directly in gradio.
90
- add_timestamp: bool
91
- Whether to add a timestamp at the end of the filename.
92
- *whisper_params: tuple
93
- Parameters related with whisper. This will be dealt with "WhisperParameters" data class
94
-
95
- Returns
96
- ----------
97
- segments_result: List[dict]
98
- list of dicts that includes start, end timestamps and transcribed text
99
- elapsed_time: float
100
- elapsed time for running
101
- """
102
- params = WhisperParameters.as_value(*whisper_params)
103
-
104
- self.cache_parameters(
105
- whisper_params=params,
106
- add_timestamp=add_timestamp
107
- )
108
-
109
- if params.lang is None:
110
- pass
111
- elif params.lang == "Automatic Detection":
112
- params.lang = None
113
- else:
114
- language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
115
- params.lang = language_code_dict[params.lang]
116
-
117
- if params.is_bgm_separate:
118
- music, audio, _ = self.music_separator.separate(
119
- audio=audio,
120
- model_name=params.uvr_model_size,
121
- device=params.uvr_device,
122
- segment_size=params.uvr_segment_size,
123
- save_file=params.uvr_save_file,
124
- progress=progress
125
- )
126
-
127
- if audio.ndim >= 2:
128
- audio = audio.mean(axis=1)
129
- if self.music_separator.audio_info is None:
130
- origin_sample_rate = 16000
131
- else:
132
- origin_sample_rate = self.music_separator.audio_info.sample_rate
133
- audio = self.resample_audio(audio=audio, original_sample_rate=origin_sample_rate)
134
-
135
- if params.uvr_enable_offload:
136
- self.music_separator.offload()
137
-
138
- if params.vad_filter:
139
- # Explicit value set for float('inf') from gr.Number()
140
- if params.max_speech_duration_s is None or params.max_speech_duration_s >= 9999:
141
- params.max_speech_duration_s = float('inf')
142
-
143
- vad_options = VadOptions(
144
- threshold=params.threshold,
145
- min_speech_duration_ms=params.min_speech_duration_ms,
146
- max_speech_duration_s=params.max_speech_duration_s,
147
- min_silence_duration_ms=params.min_silence_duration_ms,
148
- speech_pad_ms=params.speech_pad_ms
149
- )
150
-
151
- audio, speech_chunks = self.vad.run(
152
- audio=audio,
153
- vad_parameters=vad_options,
154
- progress=progress
155
- )
156
-
157
- result, elapsed_time = self.transcribe(
158
- audio,
159
- progress,
160
- *astuple(params)
161
- )
162
-
163
- if params.vad_filter:
164
- result = self.vad.restore_speech_timestamps(
165
- segments=result,
166
- speech_chunks=speech_chunks,
167
- )
168
-
169
- if params.is_diarize:
170
- result, elapsed_time_diarization = self.diarizer.run(
171
- audio=audio,
172
- use_auth_token=params.hf_token,
173
- transcribed_result=result,
174
- )
175
- elapsed_time += elapsed_time_diarization
176
- return result, elapsed_time
177
-
178
- def transcribe_file(self,
179
- files: Optional[List] = None,
180
- input_folder_path: Optional[str] = None,
181
- file_format: str = "SRT",
182
- add_timestamp: bool = True,
183
- translate_output: bool = False,
184
- translate_model: str = "",
185
- target_lang: str = "",
186
- progress=gr.Progress(),
187
- *whisper_params,
188
- ) -> list:
189
- """
190
- Write subtitle file from Files
191
-
192
- Parameters
193
- ----------
194
- files: list
195
- List of files to transcribe from gr.Files()
196
- input_folder_path: str
197
- Input folder path to transcribe from gr.Textbox(). If this is provided, `files` will be ignored and
198
- this will be used instead.
199
- file_format: str
200
- Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
201
- add_timestamp: bool
202
- Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the subtitle filename.
203
- translate_output: bool
204
- Translate output
205
- translate_model: str
206
- Translation model to use
207
- target_lang: str
208
- Target language to use
209
- progress: gr.Progress
210
- Indicator to show progress directly in gradio.
211
- *whisper_params: tuple
212
- Parameters related with whisper. This will be dealt with "WhisperParameters" data class
213
-
214
- Returns
215
- ----------
216
- result_str:
217
- Result of transcription to return to gr.Textbox()
218
- result_file_path:
219
- Output file path to return to gr.Files()
220
- """
221
- try:
222
- if input_folder_path:
223
- files = get_media_files(input_folder_path)
224
- if isinstance(files, str):
225
- files = [files]
226
- if files and isinstance(files[0], gr.utils.NamedString):
227
- files = [file.name for file in files]
228
-
229
- ## Initialization variables & start time
230
- files_info = {}
231
- files_to_download = {}
232
- time_start = datetime.now()
233
-
234
- ## Load parameters related with whisper
235
- params = WhisperParameters.as_value(*whisper_params)
236
-
237
- ## Load model to detect language
238
- model = whisper.load_model("base")
239
-
240
- for file in files:
241
-
242
- ## Detect language
243
- mel = whisper.log_mel_spectrogram(whisper.pad_or_trim(whisper.load_audio(file))).to(model.device)
244
- _, probs = model.detect_language(mel)
245
- file_language = ""
246
- file_lang_probs = ""
247
- for key,value in whisper.tokenizer.LANGUAGES.items():
248
- if key == str(max(probs, key=probs.get)):
249
- file_language = value.capitalize()
250
- for key_prob,value_prob in probs.items():
251
- if key == key_prob:
252
- file_lang_probs = str((round(value_prob*100)))
253
- break
254
- break
255
-
256
- transcribed_segments, time_for_task = self.run(
257
- file,
258
- progress,
259
- add_timestamp,
260
- *whisper_params,
261
- )
262
-
263
- # Define source language
264
- source_lang = file_language
265
-
266
- # Translate to English using Whisper built-in functionality
267
- transcription_note = ""
268
- if params.is_translate:
269
- if source_lang != "English":
270
- transcription_note = "To English"
271
- source_lang = "English"
272
- else:
273
- transcription_note = "Already in English"
274
-
275
- # Translate the transcribed segments
276
- translation_note = ""
277
- if translate_output:
278
- if source_lang != target_lang:
279
- self.nllb_inf = NLLBInference()
280
- if source_lang in NLLB_AVAILABLE_LANGS.keys():
281
- transcribed_segments = self.nllb_inf.translate_text(
282
- input_list_dict=transcribed_segments,
283
- model_size=translate_model,
284
- src_lang=source_lang,
285
- tgt_lang=target_lang,
286
- speaker_diarization=params.is_diarize
287
- )
288
- translation_note = "To " + target_lang
289
- else:
290
- translation_note = source_lang + " not supported"
291
- else:
292
- translation_note = "Already in " + target_lang
293
-
294
- ## Get preview as txt
295
- file_name, file_ext = os.path.splitext(os.path.basename(file))
296
- subtitle = get_txt(transcribed_segments)
297
- files_info[file_name] = {"subtitle": subtitle, "time_for_task": time_for_task, "lang": file_language, "lang_prob": file_lang_probs, "input_source_file": (file_name+file_ext), "translation": translation_note, "transcription": transcription_note}
298
-
299
- ## Add output file as txt
300
- file_name, file_ext = os.path.splitext(os.path.basename(file))
301
- subtitle, file_path = self.generate_and_write_file(
302
- file_name=file_name,
303
- transcribed_segments=transcribed_segments,
304
- add_timestamp=add_timestamp,
305
- file_format="txt",
306
- output_dir=self.output_dir
307
- )
308
- files_to_download[file_name+"_txt"] = {"path": file_path}
309
-
310
- ## Add output file as srt
311
- file_name, file_ext = os.path.splitext(os.path.basename(file))
312
- subtitle, file_path = self.generate_and_write_file(
313
- file_name=file_name,
314
- transcribed_segments=transcribed_segments,
315
- add_timestamp=add_timestamp,
316
- file_format="srt",
317
- output_dir=self.output_dir
318
- )
319
- files_to_download[file_name+"_srt"] = {"path": file_path}
320
-
321
- ## Add output file as csv
322
- file_name, file_ext = os.path.splitext(os.path.basename(file))
323
- subtitle, file_path = self.generate_and_write_file(
324
- file_name=file_name,
325
- transcribed_segments=transcribed_segments,
326
- add_timestamp=add_timestamp,
327
- file_format="csv",
328
- output_dir=self.output_dir
329
- )
330
- files_to_download[file_name+"_csv"] = {"path": file_path}
331
-
332
- total_result = ''
333
- total_info = ''
334
- total_time = 0
335
- for file_name, info in files_info.items():
336
- total_result += f'{info["subtitle"]}'
337
- total_time += info["time_for_task"]
338
- total_info += f'Input file:\t\t{info["input_source_file"]}\nLanguage:\t{info["lang"]} (probability {info["lang_prob"]}%)\n'
339
-
340
- if params.is_translate:
341
- total_info += f'Translation:\t{info["transcription"]}\n\t⤷ Handled by OpenAI Whisper\n'
342
-
343
- if translate_output:
344
- total_info += f'Translation:\t{info["translation"]}\n\t⤷ Handled by Facebook NLLB\n'
345
-
346
- time_end = datetime.now()
347
- total_info += f"\nTotal processing time: {self.format_time((time_end-time_start).total_seconds())}"
348
-
349
- result_str = total_result.rstrip("\n")
350
- result_file_path = [info['path'] for info in files_to_download.values()]
351
-
352
- return [result_str,result_file_path,total_info]
353
-
354
- except Exception as e:
355
- print(f"Error transcribing file: {e}")
356
- finally:
357
- self.release_cuda_memory()
358
-
359
- def transcribe_mic(self,
360
- mic_audio: str,
361
- file_format: str = "SRT",
362
- add_timestamp: bool = True,
363
- progress=gr.Progress(),
364
- *whisper_params,
365
- ) -> list:
366
- """
367
- Write subtitle file from microphone
368
-
369
- Parameters
370
- ----------
371
- mic_audio: str
372
- Audio file path from gr.Microphone()
373
- file_format: str
374
- Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
375
- add_timestamp: bool
376
- Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
377
- progress: gr.Progress
378
- Indicator to show progress directly in gradio.
379
- *whisper_params: tuple
380
- Parameters related with whisper. This will be dealt with "WhisperParameters" data class
381
-
382
- Returns
383
- ----------
384
- result_str:
385
- Result of transcription to return to gr.Textbox()
386
- result_file_path:
387
- Output file path to return to gr.Files()
388
- """
389
- try:
390
- progress(0, desc="Loading Audio...")
391
- transcribed_segments, time_for_task = self.run(
392
- mic_audio,
393
- progress,
394
- add_timestamp,
395
- *whisper_params,
396
- )
397
- progress(1, desc="Completed!")
398
-
399
- subtitle, result_file_path = self.generate_and_write_file(
400
- file_name="Mic",
401
- transcribed_segments=transcribed_segments,
402
- add_timestamp=add_timestamp,
403
- file_format=file_format,
404
- output_dir=self.output_dir
405
- )
406
-
407
- result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
408
- return [result_str, result_file_path]
409
- except Exception as e:
410
- print(f"Error transcribing file: {e}")
411
- finally:
412
- self.release_cuda_memory()
413
-
414
- def transcribe_youtube(self,
415
- youtube_link: str,
416
- file_format: str = "SRT",
417
- add_timestamp: bool = True,
418
- progress=gr.Progress(),
419
- *whisper_params,
420
- ) -> list:
421
- """
422
- Write subtitle file from Youtube
423
-
424
- Parameters
425
- ----------
426
- youtube_link: str
427
- URL of the Youtube video to transcribe from gr.Textbox()
428
- file_format: str
429
- Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
430
- add_timestamp: bool
431
- Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
432
- progress: gr.Progress
433
- Indicator to show progress directly in gradio.
434
- *whisper_params: tuple
435
- Parameters related with whisper. This will be dealt with "WhisperParameters" data class
436
-
437
- Returns
438
- ----------
439
- result_str:
440
- Result of transcription to return to gr.Textbox()
441
- result_file_path:
442
- Output file path to return to gr.Files()
443
- """
444
- try:
445
- progress(0, desc="Loading Audio from Youtube...")
446
- yt = get_ytdata(youtube_link)
447
- audio = get_ytaudio(yt)
448
-
449
- transcribed_segments, time_for_task = self.run(
450
- audio,
451
- progress,
452
- add_timestamp,
453
- *whisper_params,
454
- )
455
-
456
- progress(1, desc="Completed!")
457
-
458
- file_name = safe_filename(yt.title)
459
- subtitle, result_file_path = self.generate_and_write_file(
460
- file_name=file_name,
461
- transcribed_segments=transcribed_segments,
462
- add_timestamp=add_timestamp,
463
- file_format=file_format,
464
- output_dir=self.output_dir
465
- )
466
- result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
467
-
468
- if os.path.exists(audio):
469
- os.remove(audio)
470
-
471
- return [result_str, result_file_path]
472
-
473
- except Exception as e:
474
- print(f"Error transcribing file: {e}")
475
- finally:
476
- self.release_cuda_memory()
477
-
478
- @staticmethod
479
- def generate_and_write_file(file_name: str,
480
- transcribed_segments: list,
481
- add_timestamp: bool,
482
- file_format: str,
483
- output_dir: str
484
- ) -> str:
485
- """
486
- Writes subtitle file
487
-
488
- Parameters
489
- ----------
490
- file_name: str
491
- Output file name
492
- transcribed_segments: list
493
- Text segments transcribed from audio
494
- add_timestamp: bool
495
- Determines whether to add a timestamp to the end of the filename.
496
- file_format: str
497
- File format to write. Supported formats: [SRT, WebVTT, txt, csv]
498
- output_dir: str
499
- Directory path of the output
500
-
501
- Returns
502
- ----------
503
- content: str
504
- Result of the transcription
505
- output_path: str
506
- output file path
507
- """
508
- if add_timestamp:
509
- #timestamp = datetime.now().strftime("%m%d%H%M%S")
510
- timestamp = datetime.now().strftime("%Y%m%d %H%M%S")
511
- output_path = os.path.join(output_dir, f"{file_name} - {timestamp}")
512
- else:
513
- output_path = os.path.join(output_dir, f"{file_name}")
514
-
515
- file_format = file_format.strip().lower()
516
- if file_format == "srt":
517
- content = get_srt(transcribed_segments)
518
- output_path += '.srt'
519
-
520
- elif file_format == "webvtt":
521
- content = get_vtt(transcribed_segments)
522
- output_path += '.vtt'
523
-
524
- elif file_format == "txt":
525
- content = get_txt(transcribed_segments)
526
- output_path += '.txt'
527
-
528
- elif file_format == "csv":
529
- content = get_csv(transcribed_segments)
530
- output_path += '.csv'
531
-
532
- write_file(content, output_path)
533
- return content, output_path
534
-
535
- @staticmethod
536
- def format_time(elapsed_time: float) -> str:
537
- """
538
- Get {hours} {minutes} {seconds} time format string
539
-
540
- Parameters
541
- ----------
542
- elapsed_time: str
543
- Elapsed time for transcription
544
-
545
- Returns
546
- ----------
547
- Time format string
548
- """
549
- hours, rem = divmod(elapsed_time, 3600)
550
- minutes, seconds = divmod(rem, 60)
551
-
552
- time_str = ""
553
-
554
- hours = round(hours)
555
- if hours:
556
- if hours == 1:
557
- time_str += f"{hours} hour "
558
- else:
559
- time_str += f"{hours} hours "
560
-
561
- minutes = round(minutes)
562
- if minutes:
563
- if minutes == 1:
564
- time_str += f"{minutes} minute "
565
- else:
566
- time_str += f"{minutes} minutes "
567
-
568
- seconds = round(seconds)
569
- if seconds == 1:
570
- time_str += f"{seconds} second"
571
- else:
572
- time_str += f"{seconds} seconds"
573
-
574
- return time_str.strip()
575
-
576
- @staticmethod
577
- def get_device():
578
- if torch.cuda.is_available():
579
- return "cuda"
580
- elif torch.backends.mps.is_available():
581
- if not WhisperBase.is_sparse_api_supported():
582
- # Device `SparseMPS` is not supported for now. See : https://github.com/pytorch/pytorch/issues/87886
583
- return "cpu"
584
- return "mps"
585
- else:
586
- return "cpu"
587
-
588
- @staticmethod
589
- def is_sparse_api_supported():
590
- if not torch.backends.mps.is_available():
591
- return False
592
-
593
- try:
594
- device = torch.device("mps")
595
- sparse_tensor = torch.sparse_coo_tensor(
596
- indices=torch.tensor([[0, 1], [2, 3]]),
597
- values=torch.tensor([1, 2]),
598
- size=(4, 4),
599
- device=device
600
- )
601
- return True
602
- except RuntimeError:
603
- return False
604
-
605
- @staticmethod
606
- def release_cuda_memory():
607
- """Release memory"""
608
- if torch.cuda.is_available():
609
- torch.cuda.empty_cache()
610
- torch.cuda.reset_max_memory_allocated()
611
-
612
- @staticmethod
613
- def remove_input_files(file_paths: List[str]):
614
- """Remove gradio cached files"""
615
- if not file_paths:
616
- return
617
-
618
- for file_path in file_paths:
619
- if file_path and os.path.exists(file_path):
620
- os.remove(file_path)
621
-
622
- @staticmethod
623
- def cache_parameters(
624
- whisper_params: WhisperValues,
625
- add_timestamp: bool
626
- ):
627
- """cache parameters to the yaml file"""
628
- cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
629
- cached_whisper_param = whisper_params.to_yaml()
630
- cached_yaml = {**cached_params, **cached_whisper_param}
631
- cached_yaml["whisper"]["add_timestamp"] = add_timestamp
632
-
633
- save_yaml(cached_yaml, DEFAULT_PARAMETERS_CONFIG_PATH)
634
-
635
- @staticmethod
636
- def resample_audio(audio: Union[str, np.ndarray],
637
- new_sample_rate: int = 16000,
638
- original_sample_rate: Optional[int] = None,) -> np.ndarray:
639
- """Resamples audio to 16k sample rate, standard on Whisper model"""
640
- if isinstance(audio, str):
641
- audio, original_sample_rate = torchaudio.load(audio)
642
- else:
643
- if original_sample_rate is None:
644
- raise ValueError("original_sample_rate must be provided when audio is numpy array.")
645
- audio = torch.from_numpy(audio)
646
- resampler = torchaudio.transforms.Resample(orig_freq=original_sample_rate, new_freq=new_sample_rate)
647
- resampled_audio = resampler(audio).numpy()
648
- return resampled_audio