lucas-ventura commited on
Commit
35adc06
·
verified ·
1 Parent(s): c08a35a

Rename tools/extract/asr.py to tools/extract/asr_whisperx.py

Browse files
tools/extract/{asr.py → asr_whisperx.py} RENAMED
@@ -1,4 +1,3 @@
1
- import inspect
2
  from pathlib import Path
3
 
4
  import torch
@@ -22,22 +21,12 @@ class ASRProcessor:
22
 
23
  def __init__(self, model_name="large-v2", compute_type="float16"):
24
  self.model_name = model_name
25
- # Check if whisperx.load_model accepts compute_type parameter
26
-
27
- if "compute_type" in inspect.signature(whisperx.load_model).parameters:
28
- self.model = whisperx.load_model(
29
- model_name, device, compute_type=compute_type
30
- )
31
- else:
32
- self.model = whisperx.load_model(model_name, device)
33
 
34
  def get_asr(self, audio_file, return_duration=True):
35
  assert Path(audio_file).exists(), f"File {audio_file} does not exist"
36
  audio = whisperx.load_audio(audio_file)
37
- if "batch_size" in inspect.signature(self.model.transcribe).parameters:
38
- result = self.model.transcribe(audio, batch_size=1)
39
- else:
40
- result = self.model.transcribe(audio)
41
  language = result["language"]
42
  duration = audio.shape[0] / SAMPLE_RATE
43
 
 
 
1
  from pathlib import Path
2
 
3
  import torch
 
21
 
22
  def __init__(self, model_name="large-v2", compute_type="float16"):
23
  self.model_name = model_name
24
+ self.model = whisperx.load_model(model_name, device, compute_type=compute_type)
 
 
 
 
 
 
 
25
 
26
  def get_asr(self, audio_file, return_duration=True):
27
  assert Path(audio_file).exists(), f"File {audio_file} does not exist"
28
  audio = whisperx.load_audio(audio_file)
29
+ result = self.model.transcribe(audio, batch_size=1)
 
 
 
30
  language = result["language"]
31
  duration = audio.shape[0] / SAMPLE_RATE
32