|
from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor |
|
from datasets import load_dataset |
|
import numpy as np |
|
import librosa |
|
import torch |
|
|
|
|
|
MODEL_DIR = "./wav2vec_trained_model" |
|
|
|
|
|
dataset = load_dataset("lewtun/music_genres_small") |
|
|
|
|
|
genre_mapping = {} |
|
for example in dataset["train"]: |
|
genre_id = example["genre_id"] |
|
genre = example["genre"] |
|
if genre_id not in genre_mapping: |
|
genre_mapping[genre_id] = genre |
|
if len(genre_mapping) == 9: |
|
break |
|
|
|
print(f"Loading model from {MODEL_DIR}...\n") |
|
model = Wav2Vec2ForSequenceClassification.from_pretrained("gastoooon/music-classifier") |
|
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-large") |
|
|
|
|
|
def preprocess_audio(audio_path, target_length=16000 * 180): |
|
audio_array, sampling_rate = librosa.load(audio_path, sr=16000) |
|
|
|
if len(audio_array) > target_length: |
|
audio_array = audio_array[:target_length] |
|
else: |
|
padding = target_length - len(audio_array) |
|
audio_array = np.pad(audio_array, (0, padding), "constant") |
|
|
|
inputs = feature_extractor(audio_array, sampling_rate=16000, return_tensors="pt", padding=True) |
|
return inputs |
|
|
|
|
|
|
|
audio_path = "./Nirvana - Come As You Are.wav" |
|
|
|
|
|
|
|
inputs = preprocess_audio(audio_path) |
|
|
|
|
|
with torch.no_grad(): |
|
logits = model(**inputs).logits |
|
predicted_class = torch.argmax(logits, dim=-1).item() |
|
|
|
|
|
print(f"song analized:{audio_path}") |
|
print(f"Predicted genre: {genre_mapping[predicted_class]}") |
|
|
|
|