from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor from datasets import load_dataset import numpy as np import librosa import torch # Paths MODEL_DIR = "./wav2vec_trained_model" # Load the dataset dataset = load_dataset("lewtun/music_genres_small") # Retrieve the label names genre_mapping = {} for example in dataset["train"]: genre_id = example["genre_id"] genre = example["genre"] if genre_id not in genre_mapping: genre_mapping[genre_id] = genre if len(genre_mapping) == 9: break print(f"Loading model from {MODEL_DIR}...\n") model = Wav2Vec2ForSequenceClassification.from_pretrained("gastonduault/music-classifier") feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-large") # Function for preprocessing audio for prediction def preprocess_audio(audio_path, target_length=16000 * 180): # 30 seconds at 16kHz audio_array, sampling_rate = librosa.load(audio_path, sr=16000) if len(audio_array) > target_length: audio_array = audio_array[:target_length] else: padding = target_length - len(audio_array) audio_array = np.pad(audio_array, (0, padding), "constant") inputs = feature_extractor(audio_array, sampling_rate=16000, return_tensors="pt", padding=True) return inputs # Path to your audio file audio_path = "./Nirvana - Come As You Are.wav" # Preprocess audio inputs = preprocess_audio(audio_path) # Predict with torch.no_grad(): logits = model(**inputs).logits predicted_class = torch.argmax(logits, dim=-1).item() # Output the result print(f"song analized:{audio_path}") print(f"Predicted genre: {genre_mapping[predicted_class]}")