import gradio as gr import librosa from transformers import AutoModelForAudioClassification, AutoFeatureExtractor import torch import numpy as np import os # Load the pre-trained model and feature extractor for genre prediction model_name = "sanchit-gandhi/distilhubert-finetuned-gtzan" feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) model = AutoModelForAudioClassification.from_pretrained(model_name) # List of genres the model can predict genres = ["blues", "classical", "country", "disco", "hiphop", "jazz", "metal", "pop", "reggae", "rock"] # Function to process the uploaded audio file def process_audio(audio_file, user_label): try: # Extract filename from the uploaded file path filename = os.path.basename(audio_file) # Load the audio file with its native sample rate audio, sr = librosa.load(audio_file, sr=None) # Extract duration duration = librosa.get_duration(y=audio, sr=sr) # Extract tempo tempo = librosa.beat.tempo(y=audio, sr=sr)[0] # Preprocess audio for the model (resample to 16kHz if needed) target_sr = 16000 if sr != target_sr: audio = librosa.resample(audio, orig_sr=sr, target_sr=target_sr) sr_model = target_sr else: sr_model = sr inputs = feature_extractor(audio, sampling_rate=sr_model, return_tensors="pt") # Predict genre using the model with torch.no_grad(): logits = model(**inputs).logits probabilities = torch.nn.functional.softmax(logits, dim=-1).squeeze().numpy() predicted_genre = genres[np.argmax(probabilities)] # Use the user-provided label as the description description = f"{user_label}, {predicted_genre}, {tempo:.2f} BPM, {sr} Hz" # Create metadata dictionary metadata = { "filename": filename, "duration": np.round(duration, 3), "description": description, "genre": predicted_genre, "tempo": np.round(tempo, 2), "sample_rate": sr } return metadata except Exception as e: return {"error": str(e)} # Gradio interface with gr.Blocks(theme="Surn/beeuty") as app: gr.Markdown("# Audio Classifier for MusicGen Fine Tuning") gr.Markdown("Upload a audio file (preferred `.wav`), provide a label, and get metadata for MusicGen training.") with gr.Row(): audio_input = gr.Audio(label="Upload Audio File", type="filepath") label_input = gr.Textbox(label="Enter Label", placeholder="e.g., A calm melody") submit_button = gr.Button("Classify") output_json = gr.JSON(label="Metadata Output") submit_button.click(process_audio, inputs=[audio_input, label_input], outputs=output_json) # Launch the app app.launch()