File size: 3,117 Bytes
fdc1efd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
from pathlib import Path
import numpy as np
import torch
from torchvision import transforms
from src.modeling import ASTPretrained, FeatureExtractor, PreprocessPipeline, StudentAST
MODELS_FOLDER = Path(__file__).parent / "models"
CLASSES = ["tru", "sax", "vio", "gac", "org", "cla", "flu", "voi", "gel", "cel", "pia"]
def load_model(model_type: str):
"""
Loads a pre-trained AST model of the specified type.
:param model_type: The type of model to load
:type model_type: str
:return: The loaded pre-trained AST model.
:rtype: ASTPretrained
"""
if model_type == "accuracy":
model = ASTPretrained(n_classes=11, download_weights=False)
model.load_state_dict(torch.load(f"{MODELS_FOLDER}/acc_model_ast.pth", map_location=torch.device("cpu")))
else:
model = StudentAST(n_classes=11, hidden_size=192, num_heads=3)
model.load_state_dict(torch.load(f"{MODELS_FOLDER}/speed_model_ast.pth", map_location=torch.device("cpu")))
model.eval()
return model
def load_labels():
"""
Loads a dictionary of class labels for the AST model.
:return: A dictionary where the keys are the class indices and the values are the class labels.
:rtype: Dict[int, str]
"""
labels = {i: CLASSES[i] for i in range(len(CLASSES))}
return labels
def load_thresholds(model_type: str):
"""
Loads the prediction thresholds for the AST model.
:return: The prediction thresholds for each class.
:rtype: np.ndarray
"""
if model_type == "accuracy":
thresholds = np.load(f"{MODELS_FOLDER}/acc_model_thresh.npy", allow_pickle=True)
else:
thresholds = np.load(f"{MODELS_FOLDER}/speed_model_thresh.npy", allow_pickle=True)
return thresholds
class ModelServiceAST:
def __init__(self, model_type: str):
"""
Initializes a ModelServiceAST instance with the specified model type.
:param model_type: The type of model to load
:type model_type: str
"""
self.model = load_model(model_type)
self.labels = load_labels()
self.thresholds = load_thresholds(model_type)
self.transform = transforms.Compose([PreprocessPipeline(target_sr=16000), FeatureExtractor(sr=16000)])
def get_prediction(self, audio):
"""
Gets the binary predictions for the given audio file.
:param audio_file: The file object for the input audio to make predictions for.
:type audio_file: file object
:return: A dictionary where the keys are the class labels and the values are binary predictions (0 or 1).
:rtype: Dict[str, int]
"""
processed = self.transform(audio)
with torch.no_grad():
# Don't forget to transpose the output to seq_len x num_features!!!
output = torch.sigmoid(self.model(processed.mT))
output = output.squeeze().numpy().astype(float)
binary_predictions = {}
for i, label in enumerate(CLASSES):
binary_predictions[label] = int(output[i] >= self.thresholds[i])
return binary_predictions
|