File size: 2,924 Bytes
8df571c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import gradio as gr
import librosa
from transformers import AutoModelForAudioClassification, AutoFeatureExtractor
import torch
import numpy as np
import os

# Load the pre-trained model and feature extractor for genre prediction
model_name = "sanchit-gandhi/distilhubert-finetuned-gtzan"
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
model = AutoModelForAudioClassification.from_pretrained(model_name)

# List of genres the model can predict
genres = ["blues", "classical", "country", "disco", "hiphop", "jazz", "metal", "pop", "reggae", "rock"]

# Function to process the uploaded audio file
def process_audio(audio_file, user_label):
    try:
        # Extract filename from the uploaded file path
        filename = os.path.basename(audio_file)
        
        # Load the audio file with its native sample rate
        audio, sr = librosa.load(audio_file, sr=None)
        
        # Extract duration
        duration = librosa.get_duration(y=audio, sr=sr)
        
        # Extract tempo
        tempo = librosa.beat.tempo(y=audio, sr=sr)[0]
        
        # Preprocess audio for the model (resample to 16kHz if needed)
        target_sr = 16000
        if sr != target_sr:
            audio = librosa.resample(audio, orig_sr=sr, target_sr=target_sr)
            sr_model = target_sr
        else:
            sr_model = sr
        inputs = feature_extractor(audio, sampling_rate=sr_model, return_tensors="pt")
        
        # Predict genre using the model
        with torch.no_grad():
            logits = model(**inputs).logits
        probabilities = torch.nn.functional.softmax(logits, dim=-1).squeeze().numpy()
        predicted_genre = genres[np.argmax(probabilities)]
        
        # Use the user-provided label as the description
        description = f"{user_label}, {predicted_genre}, {tempo:.2f} BPM, {sr} Hz"
        
        # Create metadata dictionary
        metadata = {
            "filename": filename,
            "duration": np.round(duration, 3),
            "description": description,
            "genre": predicted_genre,
            "tempo":  np.round(tempo, 2),
            "sample_rate": sr
        }
        
        return metadata
    except Exception as e:
        return {"error": str(e)}

# Gradio interface
with gr.Blocks(theme="Surn/beeuty") as app:
    gr.Markdown("# Audio Classifier for MusicGen Fine Tuning")
    gr.Markdown("Upload a audio file (preferred `.wav`), provide a label, and get metadata for MusicGen training.")
    
    with gr.Row():
        audio_input = gr.Audio(label="Upload Audio File", type="filepath")
        label_input = gr.Textbox(label="Enter Label", placeholder="e.g., A calm melody")
    
    submit_button = gr.Button("Classify")
    
    output_json = gr.JSON(label="Metadata Output")
    
    submit_button.click(process_audio, inputs=[audio_input, label_input], outputs=output_json)

# Launch the app
app.launch()