Surn's picture
initial commit
8df571c
import gradio as gr
import librosa
from transformers import AutoModelForAudioClassification, AutoFeatureExtractor
import torch
import numpy as np
import os
# Load the pre-trained model and feature extractor for genre prediction
model_name = "sanchit-gandhi/distilhubert-finetuned-gtzan"
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
model = AutoModelForAudioClassification.from_pretrained(model_name)
# List of genres the model can predict
genres = ["blues", "classical", "country", "disco", "hiphop", "jazz", "metal", "pop", "reggae", "rock"]
# Function to process the uploaded audio file
def process_audio(audio_file, user_label):
try:
# Extract filename from the uploaded file path
filename = os.path.basename(audio_file)
# Load the audio file with its native sample rate
audio, sr = librosa.load(audio_file, sr=None)
# Extract duration
duration = librosa.get_duration(y=audio, sr=sr)
# Extract tempo
tempo = librosa.beat.tempo(y=audio, sr=sr)[0]
# Preprocess audio for the model (resample to 16kHz if needed)
target_sr = 16000
if sr != target_sr:
audio = librosa.resample(audio, orig_sr=sr, target_sr=target_sr)
sr_model = target_sr
else:
sr_model = sr
inputs = feature_extractor(audio, sampling_rate=sr_model, return_tensors="pt")
# Predict genre using the model
with torch.no_grad():
logits = model(**inputs).logits
probabilities = torch.nn.functional.softmax(logits, dim=-1).squeeze().numpy()
predicted_genre = genres[np.argmax(probabilities)]
# Use the user-provided label as the description
description = f"{user_label}, {predicted_genre}, {tempo:.2f} BPM, {sr} Hz"
# Create metadata dictionary
metadata = {
"filename": filename,
"duration": np.round(duration, 3),
"description": description,
"genre": predicted_genre,
"tempo": np.round(tempo, 2),
"sample_rate": sr
}
return metadata
except Exception as e:
return {"error": str(e)}
# Gradio interface
with gr.Blocks(theme="Surn/beeuty") as app:
gr.Markdown("# Audio Classifier for MusicGen Fine Tuning")
gr.Markdown("Upload a audio file (preferred `.wav`), provide a label, and get metadata for MusicGen training.")
with gr.Row():
audio_input = gr.Audio(label="Upload Audio File", type="filepath")
label_input = gr.Textbox(label="Enter Label", placeholder="e.g., A calm melody")
submit_button = gr.Button("Classify")
output_json = gr.JSON(label="Metadata Output")
submit_button.click(process_audio, inputs=[audio_input, label_input], outputs=output_json)
# Launch the app
app.launch()