duault commited on
Commit
53e20fe
·
verified ·
1 Parent(s): 0608af7

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +13 -12
README.md CHANGED
@@ -20,21 +20,22 @@ This model classifies music genres based on audio signals. It was fine-tuned on
20
  - **Validation Accuracy**: 75%
21
  - **F1 Score**: 74%
22
  - **Validation Loss**: 0.77
23
-
24
- ## Usage
25
  ```python
26
  from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor
27
  import torch
28
 
29
- # Load the model and feature extractor
30
- model = Wav2Vec2ForSequenceClassification.from_pretrained("username/repo-name")
31
- feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("username/repo-name")
32
 
33
- # Prepare input
34
- audio = ... # Your audio array
35
- inputs = feature_extractor(audio, sampling_rate=16000, return_tensors="pt")
36
 
37
- # Make predictions
38
- logits = model(**inputs).logits
39
- predicted_class = torch.argmax(logits, dim=-1).item()
40
- print(predicted_class)
 
 
20
  - **Validation Accuracy**: 75%
21
  - **F1 Score**: 74%
22
  - **Validation Loss**: 0.77
23
+
24
+ ## Example Usage
25
  ```python
26
  from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor
27
  import torch
28
 
29
+ # Load model and feature extractor
30
+ model = Wav2Vec2ForSequenceClassification.from_pretrained("username/repo_name")
31
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("username/repo_name")
32
 
33
+ # Process audio file
34
+ audio_path = "path/to/audio.wav"
35
+ audio_input = feature_extractor(audio_array, sampling_rate=16000, return_tensors="pt", padding=True)
36
 
37
+ # Predict
38
+ with torch.no_grad():
39
+ logits = model(audio_input["input_values"])
40
+ predicted_class = torch.argmax(logits.logits, dim=-1)
41
+ print(predicted_class)