Spaces:
Running
Running
import gradio as gr | |
import librosa | |
from transformers import AutoModelForAudioClassification, AutoFeatureExtractor | |
import torch | |
import numpy as np | |
import os | |
# Load the pre-trained model and feature extractor for genre prediction | |
model_name = "sanchit-gandhi/distilhubert-finetuned-gtzan" | |
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) | |
model = AutoModelForAudioClassification.from_pretrained(model_name) | |
# List of genres the model can predict | |
genres = ["blues", "classical", "country", "disco", "hiphop", "jazz", "metal", "pop", "reggae", "rock"] | |
# Function to process the uploaded audio file | |
def process_audio(audio_file, user_label): | |
try: | |
# Extract filename from the uploaded file path | |
filename = os.path.basename(audio_file) | |
# Load the audio file with its native sample rate | |
audio, sr = librosa.load(audio_file, sr=None) | |
# Extract duration | |
duration = librosa.get_duration(y=audio, sr=sr) | |
# Extract tempo | |
tempo = librosa.beat.tempo(y=audio, sr=sr)[0] | |
# Preprocess audio for the model (resample to 16kHz if needed) | |
target_sr = 16000 | |
if sr != target_sr: | |
audio = librosa.resample(audio, orig_sr=sr, target_sr=target_sr) | |
sr_model = target_sr | |
else: | |
sr_model = sr | |
inputs = feature_extractor(audio, sampling_rate=sr_model, return_tensors="pt") | |
# Predict genre using the model | |
with torch.no_grad(): | |
logits = model(**inputs).logits | |
probabilities = torch.nn.functional.softmax(logits, dim=-1).squeeze().numpy() | |
predicted_genre = genres[np.argmax(probabilities)] | |
# Use the user-provided label as the description | |
description = f"{user_label}, {predicted_genre}, {tempo:.2f} BPM, {sr} Hz" | |
# Create metadata dictionary | |
metadata = { | |
"filename": filename, | |
"duration": np.round(duration, 3), | |
"description": description, | |
"genre": predicted_genre, | |
"tempo": np.round(tempo, 2), | |
"sample_rate": sr | |
} | |
return metadata | |
except Exception as e: | |
return {"error": str(e)} | |
# Gradio interface | |
with gr.Blocks(theme="Surn/beeuty") as app: | |
gr.Markdown("# Audio Classifier for MusicGen Fine Tuning") | |
gr.Markdown("Upload a audio file (preferred `.wav`), provide a label, and get metadata for MusicGen training.") | |
with gr.Row(): | |
audio_input = gr.Audio(label="Upload Audio File", type="filepath") | |
label_input = gr.Textbox(label="Enter Label", placeholder="e.g., A calm melody") | |
submit_button = gr.Button("Classify") | |
output_json = gr.JSON(label="Metadata Output") | |
submit_button.click(process_audio, inputs=[audio_input, label_input], outputs=output_json) | |
# Launch the app | |
app.launch() |