File size: 1,707 Bytes
d77d9c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor
from datasets import load_dataset
import numpy as np
import librosa
import torch

# Paths
MODEL_DIR = "./wav2vec_trained_model"

# Load the dataset
dataset = load_dataset("lewtun/music_genres_small")

# Retrieve the label names
genre_mapping = {}
for example in dataset["train"]:
    genre_id = example["genre_id"]
    genre = example["genre"]
    if genre_id not in genre_mapping:
        genre_mapping[genre_id] = genre
        if len(genre_mapping) == 9:
            break

print(f"Loading model from {MODEL_DIR}...\n")
model = Wav2Vec2ForSequenceClassification.from_pretrained("gastoooon/music-classifier")
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-large")

# Function for preprocessing audio for prediction
def preprocess_audio(audio_path, target_length=16000 * 180):  # 30 seconds at 16kHz
    audio_array, sampling_rate = librosa.load(audio_path, sr=16000)

    if len(audio_array) > target_length:
        audio_array = audio_array[:target_length]
    else:
        padding = target_length - len(audio_array)
        audio_array = np.pad(audio_array, (0, padding), "constant")

    inputs = feature_extractor(audio_array, sampling_rate=16000, return_tensors="pt", padding=True)
    return inputs


# Path to your audio file
audio_path = "./Nirvana - Come As You Are.wav"


# Preprocess audio
inputs = preprocess_audio(audio_path)

# Predict
with torch.no_grad():
    logits = model(**inputs).logits
    predicted_class = torch.argmax(logits, dim=-1).item()

# Output the result
print(f"song analized:{audio_path}")
print(f"Predicted genre: {genre_mapping[predicted_class]}")