Spaces:
Sleeping
Sleeping
saeedbenadeeb
commited on
Commit
·
0874d87
1
Parent(s):
9a0a0d8
First commit
Browse files- .gradio/certificate.pem +31 -0
- app.py +168 -0
- datasets/TESS_Dataset.py +108 -0
- datasets/__init__.py +45 -0
- datasets/__pycache__/TESS_Dataset.cpython-311.pyc +0 -0
- datasets/__pycache__/__init__.cpython-311.pyc +0 -0
- datasets/__pycache__/audio_dataset.cpython-311.pyc +0 -0
- datasets/__pycache__/ctc_audio_dataclass.cpython-311.pyc +0 -0
- datasets/__pycache__/image_dataset.cpython-311.pyc +0 -0
- datasets/audio_dataset.py +120 -0
- datasets/ctc_audio_dataclass.py +126 -0
- datasets/image_dataset.py +62 -0
- emotion-detection +1 -0
- encoders/__init__.py +1 -0
- encoders/__pycache__/__init__.cpython-311.pyc +0 -0
- encoders/__pycache__/encoders.cpython-311.pyc +0 -0
- encoders/__pycache__/transformer.cpython-311.pyc +0 -0
- encoders/encoders.py +263 -0
- encoders/transformer.py +233 -0
- model.pth +3 -0
- models/CTCencoder.py +93 -0
- models/__init__.py +1 -0
- models/__pycache__/CTCencoder.cpython-311.pyc +0 -0
- models/__pycache__/__init__.cpython-311.pyc +0 -0
- requirements.txt +5 -0
- statics/style.css +9 -0
- upload_model.py +14 -0
- utils/__init__.py +2 -0
- utils/__pycache__/__init__.cpython-311.pyc +0 -0
- utils/__pycache__/helper_functions.cpython-311.pyc +0 -0
- utils/__pycache__/random_split.cpython-311.pyc +0 -0
- utils/helper_functions.py +70 -0
- utils/random_split.py +37 -0
.gradio/certificate.pem
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
-----BEGIN CERTIFICATE-----
|
2 |
+
MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
|
3 |
+
TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
|
4 |
+
cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
|
5 |
+
WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
|
6 |
+
ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
|
7 |
+
MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
|
8 |
+
h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
|
9 |
+
0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
|
10 |
+
A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
|
11 |
+
T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
|
12 |
+
B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
|
13 |
+
B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
|
14 |
+
KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
|
15 |
+
OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
|
16 |
+
jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
|
17 |
+
qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
|
18 |
+
rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
|
19 |
+
HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
|
20 |
+
hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
|
21 |
+
ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
|
22 |
+
3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
|
23 |
+
NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
|
24 |
+
ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
|
25 |
+
TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
|
26 |
+
jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
|
27 |
+
oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
|
28 |
+
4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
|
29 |
+
mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
|
30 |
+
emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
|
31 |
+
-----END CERTIFICATE-----
|
app.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import librosa
|
4 |
+
import numpy as np
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import os
|
7 |
+
|
8 |
+
from encoders.transformer import Wav2Vec2EmotionClassifier
|
9 |
+
|
10 |
+
# Define the emotions
|
11 |
+
emotions = ["happy", "sad", "angry", "neutral", "fear", "disgust", "surprise"]
|
12 |
+
label_mapping = {str(idx): emotion for idx, emotion in enumerate(emotions)}
|
13 |
+
|
14 |
+
# Load the trained model
|
15 |
+
model_path = "model.pth"
|
16 |
+
cfg = {
|
17 |
+
"model": {
|
18 |
+
"encoder": "Wav2Vec2Classifier",
|
19 |
+
"optimizer": {
|
20 |
+
"name": "Adam",
|
21 |
+
"lr": 0.0003,
|
22 |
+
"weight_decay": 3e-4
|
23 |
+
},
|
24 |
+
"l1_lambda": 0.0
|
25 |
+
}
|
26 |
+
}
|
27 |
+
model = Wav2Vec2EmotionClassifier(num_classes=len(emotions), optimizer_cfg=cfg["model"]["optimizer"])
|
28 |
+
model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
|
29 |
+
model.eval()
|
30 |
+
|
31 |
+
# Optional: we define a minimum number of samples to avoid Wav2Vec2 conv errors
|
32 |
+
MIN_SAMPLES = 10 # or 16000 if you want at least 1 second
|
33 |
+
|
34 |
+
# Preprocessing function
|
35 |
+
def preprocess_audio(file_path, sample_rate=16000):
|
36 |
+
"""
|
37 |
+
Safely loads the file at file_path and returns a (1, samples) torch tensor.
|
38 |
+
Returns None if the file is invalid or too short.
|
39 |
+
"""
|
40 |
+
if not file_path or (not os.path.exists(file_path)):
|
41 |
+
# file_path could be None or an empty string if user didn't record properly
|
42 |
+
return None
|
43 |
+
|
44 |
+
# Load with librosa (which merges to mono by default if multi-channel)
|
45 |
+
waveform, sr = librosa.load(file_path, sr=sample_rate)
|
46 |
+
|
47 |
+
# Check length
|
48 |
+
if len(waveform) < MIN_SAMPLES:
|
49 |
+
return None
|
50 |
+
|
51 |
+
# Convert to torch tensor, shape (1, samples)
|
52 |
+
waveform_tensor = torch.tensor(waveform, dtype=torch.float32).unsqueeze(0)
|
53 |
+
|
54 |
+
return waveform_tensor
|
55 |
+
|
56 |
+
# Prediction function
|
57 |
+
def predict_emotion(audio_file):
|
58 |
+
"""
|
59 |
+
audio_file is a file path from Gradio (type='filepath').
|
60 |
+
"""
|
61 |
+
# Preprocess
|
62 |
+
waveform = preprocess_audio(audio_file, sample_rate=16000)
|
63 |
+
|
64 |
+
# If invalid or too short, return an error-like message
|
65 |
+
if waveform is None:
|
66 |
+
return (
|
67 |
+
"Audio is too short or invalid. Please record/upload a longer clip.",
|
68 |
+
""
|
69 |
+
)
|
70 |
+
|
71 |
+
# Perform inference
|
72 |
+
with torch.no_grad():
|
73 |
+
logits = model(waveform)
|
74 |
+
probabilities = F.softmax(logits, dim=-1).cpu().numpy()[0]
|
75 |
+
|
76 |
+
# Get the predicted class
|
77 |
+
predicted_class = np.argmax(probabilities)
|
78 |
+
predicted_emotion = label_mapping[str(predicted_class)]
|
79 |
+
|
80 |
+
# Format probabilities for visualization
|
81 |
+
probabilities_output = [
|
82 |
+
f"""
|
83 |
+
<div style='display: flex; align-items: center; margin: 5px 0;'>
|
84 |
+
<div style='width: 20%; text-align: right; margin-right: 10px; font-weight: bold;'>{emotions[i]}</div>
|
85 |
+
<div style='flex-grow: 1; background-color: #374151; border-radius: 4px; overflow: hidden;'>
|
86 |
+
<div style='width: {probabilities[i]*100:.2f}%; background-color: #FFA500; height: 10px;'></div>
|
87 |
+
</div>
|
88 |
+
<div style='width: 10%; text-align: right; margin-left: 10px;'>{probabilities[i]*100:.2f}%</div>
|
89 |
+
</div>
|
90 |
+
"""
|
91 |
+
for i in range(len(emotions))
|
92 |
+
]
|
93 |
+
|
94 |
+
return predicted_emotion, "\n".join(probabilities_output)
|
95 |
+
|
96 |
+
# Create Gradio interface
|
97 |
+
def gradio_interface(audio):
|
98 |
+
detected_emotion, probabilities_html = predict_emotion(audio)
|
99 |
+
return detected_emotion, gr.HTML(probabilities_html)
|
100 |
+
|
101 |
+
# Define Gradio UI
|
102 |
+
with gr.Blocks(css="""
|
103 |
+
body {
|
104 |
+
background-color: #121212;
|
105 |
+
color: white;
|
106 |
+
font-family: Arial, sans-serif;
|
107 |
+
}
|
108 |
+
h1 {
|
109 |
+
color: #FFA500;
|
110 |
+
font-size: 48px;
|
111 |
+
text-align: center;
|
112 |
+
margin-bottom: 10px;
|
113 |
+
}
|
114 |
+
p {
|
115 |
+
text-align: center;
|
116 |
+
font-size: 18px;
|
117 |
+
}
|
118 |
+
.gradio-row {
|
119 |
+
justify-content: center;
|
120 |
+
align-items: center;
|
121 |
+
}
|
122 |
+
#submit_button {
|
123 |
+
background-color: #FFA500 !important;
|
124 |
+
color: black !important;
|
125 |
+
font-size: 18px;
|
126 |
+
padding: 10px 20px;
|
127 |
+
margin-top: 20px;
|
128 |
+
}
|
129 |
+
#detected_emotion {
|
130 |
+
font-size: 24px;
|
131 |
+
font-weight: bold;
|
132 |
+
text-align: center;
|
133 |
+
}
|
134 |
+
.probabilities-container {
|
135 |
+
margin-top: 20px;
|
136 |
+
padding: 10px;
|
137 |
+
background-color: #1F2937;
|
138 |
+
border-radius: 8px;
|
139 |
+
}
|
140 |
+
""") as demo:
|
141 |
+
gr.Markdown(
|
142 |
+
"""
|
143 |
+
<div>
|
144 |
+
<h1>Speech Emotion Recognition</h1>
|
145 |
+
<p>🎵 Upload or record an audio file (max 1 minute) to detect emotions.</p>
|
146 |
+
<p>Supported Emotions: 😊 Happy | 😭 Sad | 😡 Angry | 😐 Neutral | 😨 Fear | 🤢 Disgust | 😮 Surprise</p>
|
147 |
+
</div>
|
148 |
+
"""
|
149 |
+
)
|
150 |
+
|
151 |
+
with gr.Row():
|
152 |
+
with gr.Column(scale=1, elem_id="audio-block"):
|
153 |
+
# type="filepath" means we get a temporary file path from Gradio
|
154 |
+
audio_input = gr.Audio(label="🎤 Record or Upload Audio", type="filepath")
|
155 |
+
submit_button = gr.Button("Submit", elem_id="submit_button")
|
156 |
+
with gr.Column(scale=1):
|
157 |
+
detected_emotion_label = gr.Label(label="Detected Emotion", elem_id="detected_emotion")
|
158 |
+
probabilities_html = gr.HTML(label="Probabilities", elem_id="probabilities")
|
159 |
+
|
160 |
+
submit_button.click(
|
161 |
+
fn=gradio_interface,
|
162 |
+
inputs=audio_input,
|
163 |
+
outputs=[detected_emotion_label, probabilities_html]
|
164 |
+
)
|
165 |
+
|
166 |
+
# Launch the app
|
167 |
+
if __name__ == "__main__":
|
168 |
+
demo.launch(share=True)
|
datasets/TESS_Dataset.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
import librosa
|
6 |
+
from typing import List, Tuple
|
7 |
+
import shutil
|
8 |
+
import kagglehub
|
9 |
+
from transformers import Wav2Vec2Processor, Wav2Vec2Model
|
10 |
+
import subprocess
|
11 |
+
import zipfile
|
12 |
+
import os
|
13 |
+
# Constants (you may need to define these according to your requirements)
|
14 |
+
SAMPLE_RATE = 16000 # Define the sample rate for audio processing
|
15 |
+
DURATION = 3.0 # Duration of the audio in seconds
|
16 |
+
|
17 |
+
# Placeholder for waveform normalization
|
18 |
+
def normalize_waveform(audio: np.ndarray) -> torch.Tensor:
|
19 |
+
# Convert to tensor if necessary
|
20 |
+
if not isinstance(audio, torch.Tensor):
|
21 |
+
audio = torch.tensor(audio, dtype=torch.float32)
|
22 |
+
return (audio - torch.mean(audio)) / torch.std(audio)
|
23 |
+
|
24 |
+
class TESSRawWaveformDataset(Dataset):
|
25 |
+
def __init__(self, root_path: str, transform=None):
|
26 |
+
super().__init__()
|
27 |
+
self.root_path = root_path
|
28 |
+
self.audio_files = []
|
29 |
+
self.labels = []
|
30 |
+
self.emotions = ["happy", "sad", "angry", "neutral", "fear", "disgust", "surprise"]
|
31 |
+
emotion_mapping = {e.lower(): idx for idx, e in enumerate(self.emotions)}
|
32 |
+
self.download_dataset_if_not_exists()
|
33 |
+
# Load file paths and labels from nested directories
|
34 |
+
for root, dirs, files in os.walk(root_path):
|
35 |
+
for file_name in files:
|
36 |
+
if file_name.endswith(".wav"):
|
37 |
+
emotion_name = next(
|
38 |
+
(e for e in emotion_mapping if e in root.lower()), None
|
39 |
+
)
|
40 |
+
if emotion_name is not None:
|
41 |
+
self.audio_files.append(os.path.join(root, file_name))
|
42 |
+
self.labels.append(emotion_mapping[emotion_name])
|
43 |
+
|
44 |
+
self.labels = np.array(self.labels, dtype=np.int64)
|
45 |
+
self.transform = transform
|
46 |
+
|
47 |
+
def __len__(self):
|
48 |
+
return len(self.audio_files)
|
49 |
+
|
50 |
+
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
|
51 |
+
# Load raw waveform and label
|
52 |
+
audio_path = self.audio_files[idx]
|
53 |
+
label = self.labels[idx]
|
54 |
+
waveform = self.load_audio(audio_path)
|
55 |
+
|
56 |
+
if self.transform:
|
57 |
+
waveform = self.transform(waveform)
|
58 |
+
|
59 |
+
return waveform, label
|
60 |
+
|
61 |
+
@staticmethod
|
62 |
+
def load_audio(audio_path: str) -> torch.Tensor:
|
63 |
+
# Load audio and ensure it's at the correct sample rate
|
64 |
+
audio, sr = librosa.load(audio_path, sr=SAMPLE_RATE, duration=DURATION)
|
65 |
+
assert sr == SAMPLE_RATE, f"Sample rate mismatch: expected {SAMPLE_RATE}, got {sr}"
|
66 |
+
return normalize_waveform(audio)
|
67 |
+
|
68 |
+
def get_emotions(self) -> List[str]:
|
69 |
+
return self.emotions
|
70 |
+
|
71 |
+
def download_dataset_if_not_exists(self):
|
72 |
+
if not os.path.exists(self.root_path):
|
73 |
+
print(f"Dataset not found at {self.root_path}. Downloading...")
|
74 |
+
|
75 |
+
# Ensure the destination directory exists
|
76 |
+
os.makedirs(self.root_path, exist_ok=True)
|
77 |
+
|
78 |
+
# Download dataset using curl
|
79 |
+
dataset_zip_path = os.path.join(self.root_path, "toronto-emotional-speech-set-tess.zip")
|
80 |
+
curl_command = [
|
81 |
+
"curl",
|
82 |
+
"-L",
|
83 |
+
"-o",
|
84 |
+
dataset_zip_path,
|
85 |
+
"https://www.kaggle.com/api/v1/datasets/download/ejlok1/toronto-emotional-speech-set-tess",
|
86 |
+
]
|
87 |
+
|
88 |
+
try:
|
89 |
+
subprocess.run(curl_command, check=True)
|
90 |
+
print(f"Dataset downloaded to {dataset_zip_path}.")
|
91 |
+
|
92 |
+
# Extract the downloaded zip file
|
93 |
+
with zipfile.ZipFile(dataset_zip_path, "r") as zip_ref:
|
94 |
+
zip_ref.extractall(self.root_path)
|
95 |
+
print(f"Dataset extracted to {self.root_path}.")
|
96 |
+
|
97 |
+
# Remove the zip file to save space
|
98 |
+
os.remove(dataset_zip_path)
|
99 |
+
print(f"Removed zip file: {dataset_zip_path}")
|
100 |
+
|
101 |
+
except subprocess.CalledProcessError as e:
|
102 |
+
print(f"Error occurred during dataset download: {e}")
|
103 |
+
raise
|
104 |
+
|
105 |
+
|
106 |
+
# Example usage
|
107 |
+
# dataset = TESSRawWaveformDataset(root_path="./TESS", transform=None)
|
108 |
+
# print("Number of samples:", len(dataset))
|
datasets/__init__.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
from torch.utils.data import Dataset
|
3 |
+
|
4 |
+
from .image_dataset import CustomDataset
|
5 |
+
from .audio_dataset import EmodbDataset
|
6 |
+
from .ctc_audio_dataclass import CTCEmodbDataset
|
7 |
+
from .TESS_Dataset import TESSRawWaveformDataset
|
8 |
+
|
9 |
+
__dataset_mapper__ = {
|
10 |
+
"image": CustomDataset,
|
11 |
+
"emodb": EmodbDataset,
|
12 |
+
'CTCemodb': CTCEmodbDataset,
|
13 |
+
'TESSDataset': TESSRawWaveformDataset
|
14 |
+
}
|
15 |
+
|
16 |
+
def list_datasets() -> List[str]:
|
17 |
+
"""Returns a list of available dataset names.
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
List[str]: List of dataset names as strings.
|
21 |
+
|
22 |
+
Example:
|
23 |
+
>>> from datasets import list_datasets
|
24 |
+
>>> list_datasets()
|
25 |
+
['image', 'emodb']
|
26 |
+
"""
|
27 |
+
return sorted(__dataset_mapper__.keys())
|
28 |
+
|
29 |
+
def get_dataset_by_name(dataset: str, *args, **kwargs) -> Dataset:
|
30 |
+
"""Returns the Dataset class using the given name and arguments.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
dataset (str): The name of the dataset.
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
Dataset: The requested dataset instance.
|
37 |
+
|
38 |
+
Example:
|
39 |
+
>>> from datasets import get_dataset_by_name
|
40 |
+
>>> dataset = get_dataset_by_name("emodb", root_path="./data/emodb")
|
41 |
+
>>> type(dataset)
|
42 |
+
<class 'datasets.audio_dataset.EmodbDataset'>
|
43 |
+
"""
|
44 |
+
assert dataset in __dataset_mapper__, f"Dataset '{dataset}' not found in the mapper."
|
45 |
+
return __dataset_mapper__[dataset](*args, **kwargs)
|
datasets/__pycache__/TESS_Dataset.cpython-311.pyc
ADDED
Binary file (6.95 kB). View file
|
|
datasets/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (1.96 kB). View file
|
|
datasets/__pycache__/audio_dataset.cpython-311.pyc
ADDED
Binary file (7.66 kB). View file
|
|
datasets/__pycache__/ctc_audio_dataclass.cpython-311.pyc
ADDED
Binary file (8.27 kB). View file
|
|
datasets/__pycache__/image_dataset.cpython-311.pyc
ADDED
Binary file (4.99 kB). View file
|
|
datasets/audio_dataset.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import zipfile
|
3 |
+
import requests
|
4 |
+
from tqdm import tqdm
|
5 |
+
from typing import List, Tuple
|
6 |
+
import numpy as np
|
7 |
+
from torch.utils.data import Dataset
|
8 |
+
import librosa
|
9 |
+
import torch
|
10 |
+
|
11 |
+
SAMPLE_RATE = 22050
|
12 |
+
DURATION = 1.4 # second
|
13 |
+
|
14 |
+
class EmodbDataset(Dataset):
|
15 |
+
__url__ = "http://www.emodb.bilderbar.info/download/download.zip"
|
16 |
+
__labels__ = ("angry", "happy", "neutral", "sad")
|
17 |
+
__suffixes__ = {
|
18 |
+
"angry": ["Wa", "Wb", "Wc", "Wd"],
|
19 |
+
"happy": ["Fa", "Fb", "Fc", "Fd"],
|
20 |
+
"neutral": ["Na", "Nb", "Nc", "Nd"],
|
21 |
+
"sad": ["Ta", "Tb", "Tc", "Td"]
|
22 |
+
}
|
23 |
+
|
24 |
+
def __init__(self, root_path: str = './data/emodb', transform=None):
|
25 |
+
super().__init__()
|
26 |
+
self.root_path = root_path
|
27 |
+
self.audio_root_path = os.path.join(root_path, "wav")
|
28 |
+
|
29 |
+
# Ensure the dataset is downloaded
|
30 |
+
self._ensure_dataset()
|
31 |
+
|
32 |
+
ids = []
|
33 |
+
targets = []
|
34 |
+
for audio_file in os.listdir(self.audio_root_path):
|
35 |
+
f_name, ext = os.path.splitext(audio_file)
|
36 |
+
if ext != ".wav":
|
37 |
+
continue
|
38 |
+
|
39 |
+
suffix = f_name[-2:]
|
40 |
+
for label, suffixes in self.__suffixes__.items():
|
41 |
+
if suffix in suffixes:
|
42 |
+
ids.append(os.path.join(self.audio_root_path, audio_file))
|
43 |
+
targets.append(self.label2id(label))
|
44 |
+
break
|
45 |
+
|
46 |
+
self.ids = ids
|
47 |
+
self.targets = np.array(targets, dtype=np.int64)
|
48 |
+
self.transform = transform
|
49 |
+
|
50 |
+
def _ensure_dataset(self):
|
51 |
+
"""
|
52 |
+
Ensures the dataset is downloaded and extracted.
|
53 |
+
"""
|
54 |
+
if not os.path.isdir(self.audio_root_path):
|
55 |
+
print(f"Dataset not found at {self.audio_root_path}. Downloading...")
|
56 |
+
self._download_and_extract()
|
57 |
+
|
58 |
+
def _download_and_extract(self):
|
59 |
+
"""
|
60 |
+
Downloads and extracts the dataset zip file.
|
61 |
+
"""
|
62 |
+
# Ensure the root path exists
|
63 |
+
os.makedirs(self.root_path, exist_ok=True)
|
64 |
+
|
65 |
+
# Download the dataset
|
66 |
+
zip_path = os.path.join(self.root_path, "emodb.zip")
|
67 |
+
with requests.get(self.__url__, stream=True) as r:
|
68 |
+
r.raise_for_status()
|
69 |
+
total_size = int(r.headers.get("content-length", 0))
|
70 |
+
with open(zip_path, "wb") as f, tqdm(
|
71 |
+
desc="Downloading EMO-DB dataset",
|
72 |
+
total=total_size,
|
73 |
+
unit="B",
|
74 |
+
unit_scale=True,
|
75 |
+
unit_divisor=1024,
|
76 |
+
) as bar:
|
77 |
+
for chunk in r.iter_content(chunk_size=8192):
|
78 |
+
f.write(chunk)
|
79 |
+
bar.update(len(chunk))
|
80 |
+
|
81 |
+
# Extract the dataset
|
82 |
+
print("Extracting dataset...")
|
83 |
+
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
84 |
+
zip_ref.extractall(self.root_path)
|
85 |
+
|
86 |
+
# Clean up the zip file
|
87 |
+
os.remove(zip_path)
|
88 |
+
|
89 |
+
def __len__(self):
|
90 |
+
return len(self.ids)
|
91 |
+
|
92 |
+
def __getitem__(self, idx: int) -> Tuple:
|
93 |
+
target = self.targets[idx]
|
94 |
+
audio = self.load_audio(self.ids[idx]) # Should return a numpy array
|
95 |
+
|
96 |
+
if self.transform:
|
97 |
+
audio = self.transform(audio) # Apply transform
|
98 |
+
|
99 |
+
return audio, target
|
100 |
+
|
101 |
+
@staticmethod
|
102 |
+
def id2label(idx: int) -> str:
|
103 |
+
return EmodbDataset.__labels__[idx]
|
104 |
+
|
105 |
+
@staticmethod
|
106 |
+
def label2id(label: str) -> int:
|
107 |
+
if label not in EmodbDataset.__labels__:
|
108 |
+
raise ValueError(f"Unknown label: {label}")
|
109 |
+
return EmodbDataset.__labels__.index(label)
|
110 |
+
|
111 |
+
@staticmethod
|
112 |
+
def load_audio(audio_file_path: str) -> np.ndarray:
|
113 |
+
audio, sr = librosa.load(audio_file_path, sr=SAMPLE_RATE, duration=DURATION)
|
114 |
+
assert SAMPLE_RATE == sr, "broken audio file"
|
115 |
+
# Convert numpy array to PyTorch tensor
|
116 |
+
return torch.tensor(audio, dtype=torch.float32)
|
117 |
+
|
118 |
+
@staticmethod
|
119 |
+
def get_labels() -> List[str]:
|
120 |
+
return list(EmodbDataset.__labels__)
|
datasets/ctc_audio_dataclass.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import zipfile
|
3 |
+
import requests
|
4 |
+
from tqdm import tqdm
|
5 |
+
from typing import List, Tuple
|
6 |
+
import numpy as np
|
7 |
+
from torch.utils.data import Dataset
|
8 |
+
import librosa
|
9 |
+
import torch
|
10 |
+
from torch.nn.utils.rnn import pad_sequence
|
11 |
+
|
12 |
+
SAMPLE_RATE = 22050
|
13 |
+
DURATION = 1.4 # seconds
|
14 |
+
|
15 |
+
class CTCEmodbDataset(Dataset):
|
16 |
+
__url__ = "http://www.emodb.bilderbar.info/download/download.zip"
|
17 |
+
__labels__ = ("angry", "happy", "neutral", "sad")
|
18 |
+
__suffixes__ = {
|
19 |
+
"angry": ["Wa", "Wb", "Wc", "Wd"],
|
20 |
+
"happy": ["Fa", "Fb", "Fc", "Fd"],
|
21 |
+
"neutral": ["Na", "Nb", "Nc", "Nd"],
|
22 |
+
"sad": ["Ta", "Tb", "Tc", "Td"]
|
23 |
+
}
|
24 |
+
|
25 |
+
def __init__(self, root_path: str = './data/emodb', transform=None):
|
26 |
+
super().__init__()
|
27 |
+
self.root_path = root_path
|
28 |
+
self.audio_root_path = os.path.join(root_path, "wav")
|
29 |
+
|
30 |
+
# Ensure the dataset is downloaded
|
31 |
+
self._ensure_dataset()
|
32 |
+
|
33 |
+
ids = []
|
34 |
+
targets = []
|
35 |
+
for audio_file in os.listdir(self.audio_root_path):
|
36 |
+
f_name, ext = os.path.splitext(audio_file)
|
37 |
+
if ext != ".wav":
|
38 |
+
continue
|
39 |
+
|
40 |
+
suffix = f_name[-2:]
|
41 |
+
for label, suffixes in self.__suffixes__.items():
|
42 |
+
if suffix in suffixes:
|
43 |
+
ids.append(os.path.join(self.audio_root_path, audio_file))
|
44 |
+
targets.append(self.label2id(label)) # Store as integers
|
45 |
+
break
|
46 |
+
|
47 |
+
self.ids = ids
|
48 |
+
self.targets = targets # Target sequences as a list of lists
|
49 |
+
self.transform = transform
|
50 |
+
|
51 |
+
def _ensure_dataset(self):
|
52 |
+
"""
|
53 |
+
Ensures the dataset is downloaded and extracted.
|
54 |
+
"""
|
55 |
+
if not os.path.isdir(self.audio_root_path):
|
56 |
+
print(f"Dataset not found at {self.audio_root_path}. Downloading...")
|
57 |
+
self._download_and_extract()
|
58 |
+
|
59 |
+
def _download_and_extract(self):
|
60 |
+
"""
|
61 |
+
Downloads and extracts the dataset zip file.
|
62 |
+
"""
|
63 |
+
os.makedirs(self.root_path, exist_ok=True)
|
64 |
+
zip_path = os.path.join(self.root_path, "emodb.zip")
|
65 |
+
with requests.get(self.__url__, stream=True) as r:
|
66 |
+
r.raise_for_status()
|
67 |
+
total_size = int(r.headers.get("content-length", 0))
|
68 |
+
with open(zip_path, "wb") as f, tqdm(
|
69 |
+
desc="Downloading EMO-DB dataset",
|
70 |
+
total=total_size,
|
71 |
+
unit="B",
|
72 |
+
unit_scale=True,
|
73 |
+
unit_divisor=1024,
|
74 |
+
) as bar:
|
75 |
+
for chunk in r.iter_content(chunk_size=8192):
|
76 |
+
f.write(chunk)
|
77 |
+
bar.update(len(chunk))
|
78 |
+
|
79 |
+
print("Extracting dataset...")
|
80 |
+
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
81 |
+
zip_ref.extractall(self.root_path)
|
82 |
+
|
83 |
+
os.remove(zip_path)
|
84 |
+
|
85 |
+
def __len__(self):
|
86 |
+
return len(self.ids)
|
87 |
+
|
88 |
+
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, int, int]:
|
89 |
+
"""
|
90 |
+
Returns:
|
91 |
+
x (torch.Tensor): Input sequence (audio features or waveform)
|
92 |
+
y (torch.Tensor): Target sequence (labels or tokenized transcription)
|
93 |
+
input_length (int): Length of input sequence
|
94 |
+
target_length (int): Length of target sequence
|
95 |
+
"""
|
96 |
+
target = torch.tensor([self.targets[idx]], dtype=torch.long)
|
97 |
+
audio = self.load_audio(self.ids[idx]) # Should return a numpy array
|
98 |
+
|
99 |
+
if self.transform:
|
100 |
+
audio = self.transform(audio)
|
101 |
+
|
102 |
+
# Input length (for CTC)
|
103 |
+
input_length = audio.shape[-1] # Last dimension is the time dimension
|
104 |
+
target_length = len(target) # Length of target sequence
|
105 |
+
|
106 |
+
return audio, target, input_length, target_length
|
107 |
+
|
108 |
+
@staticmethod
|
109 |
+
def id2label(idx: int) -> str:
|
110 |
+
return CTCEmodbDataset.__labels__[idx]
|
111 |
+
|
112 |
+
@staticmethod
|
113 |
+
def label2id(label: str) -> int:
|
114 |
+
if label not in CTCEmodbDataset.__labels__:
|
115 |
+
raise ValueError(f"Unknown label: {label}")
|
116 |
+
return CTCEmodbDataset.__labels__.index(label)
|
117 |
+
|
118 |
+
@staticmethod
|
119 |
+
def load_audio(audio_file_path: str) -> torch.Tensor:
|
120 |
+
audio, sr = librosa.load(audio_file_path, sr=SAMPLE_RATE, duration=DURATION)
|
121 |
+
assert SAMPLE_RATE == sr, "broken audio file"
|
122 |
+
return torch.tensor(audio, dtype=torch.float32)
|
123 |
+
|
124 |
+
@staticmethod
|
125 |
+
def get_labels() -> List[str]:
|
126 |
+
return list(CTCEmodbDataset.__labels__)
|
datasets/image_dataset.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from torchvision.datasets import VisionDataset
|
3 |
+
from PIL import Image
|
4 |
+
from sklearn.model_selection import train_test_split
|
5 |
+
|
6 |
+
|
7 |
+
class CustomDataset(VisionDataset):
|
8 |
+
def __init__(self, root_path, subset="train", transform=None, target_transform=None, split_ratios=(0.7, 0.15, 0.15), seed=42):
|
9 |
+
super(CustomDataset, self).__init__(root_path, transform=transform, target_transform=target_transform)
|
10 |
+
self.root = root_path
|
11 |
+
self.subset = subset # Can be "train", "val", or "test"
|
12 |
+
self.split_ratios = split_ratios
|
13 |
+
self.seed = seed
|
14 |
+
|
15 |
+
self.classes, self.class_idx = self._find_classes()
|
16 |
+
self.samples = self._make_dataset()
|
17 |
+
|
18 |
+
def _find_classes(self):
|
19 |
+
classes = [d.name for d in os.scandir(self.root) if d.is_dir()]
|
20 |
+
classes.sort()
|
21 |
+
class_idx = {cls_name: i for i, cls_name in enumerate(classes)}
|
22 |
+
return classes, class_idx
|
23 |
+
|
24 |
+
def _make_dataset(self):
|
25 |
+
samples = []
|
26 |
+
for target_class in sorted(self.class_idx.keys()):
|
27 |
+
class_index = self.class_idx[target_class]
|
28 |
+
target_dir = os.path.join(self.root, target_class)
|
29 |
+
for root, _, fnames in sorted(os.walk(target_dir)):
|
30 |
+
for fname in sorted(fnames):
|
31 |
+
path = os.path.join(root, fname)
|
32 |
+
samples.append((path, class_index))
|
33 |
+
|
34 |
+
# Split into train, val, and test sets
|
35 |
+
train_samples, test_samples = train_test_split(
|
36 |
+
samples, test_size=1 - self.split_ratios[0], random_state=self.seed, stratify=[s[1] for s in samples]
|
37 |
+
)
|
38 |
+
val_samples, test_samples = train_test_split(
|
39 |
+
test_samples, test_size=self.split_ratios[2] / (self.split_ratios[1] + self.split_ratios[2]),
|
40 |
+
random_state=self.seed, stratify=[s[1] for s in test_samples]
|
41 |
+
)
|
42 |
+
|
43 |
+
if self.subset == "train":
|
44 |
+
return train_samples
|
45 |
+
elif self.subset == "val":
|
46 |
+
return val_samples
|
47 |
+
elif self.subset == "test":
|
48 |
+
return test_samples
|
49 |
+
else:
|
50 |
+
raise ValueError(f"Unknown subset: {self.subset}")
|
51 |
+
|
52 |
+
def __len__(self):
|
53 |
+
return len(self.samples)
|
54 |
+
|
55 |
+
def __getitem__(self, index):
|
56 |
+
path, target = self.samples[index]
|
57 |
+
img = Image.open(path).convert("RGB")
|
58 |
+
|
59 |
+
if self.transform is not None:
|
60 |
+
img = self.transform(img)
|
61 |
+
|
62 |
+
return img, target
|
emotion-detection
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Subproject commit 4f5c928446aeb2dabd215f85d3f9647ac92e7e67
|
encoders/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from . import encoders, transformer
|
encoders/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (235 Bytes). View file
|
|
encoders/__pycache__/encoders.cpython-311.pyc
ADDED
Binary file (16.2 kB). View file
|
|
encoders/__pycache__/transformer.cpython-311.pyc
ADDED
Binary file (12.3 kB). View file
|
|
encoders/encoders.py
ADDED
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.optim as optim
|
2 |
+
import pytorch_lightning as pl
|
3 |
+
import timm
|
4 |
+
from torchmetrics import Accuracy, Precision, Recall, F1Score
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
|
9 |
+
class timm_backbones(pl.LightningModule):
|
10 |
+
"""
|
11 |
+
PyTorch Lightning model for image classification using a ResNet-18 architecture.
|
12 |
+
|
13 |
+
This model uses a pre-trained ResNet-18 model and fine-tunes it for a specific number of classes.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
num_classes (int, optional): The number of classes in the dataset. Defaults to 2.
|
17 |
+
optimizer_cfg (DictConfig, optional): A Hydra configuration object for the optimizer.
|
18 |
+
|
19 |
+
Methods:
|
20 |
+
forward(x): Computes the forward pass of the model.
|
21 |
+
configure_optimizers(): Configures the optimizer for the model.
|
22 |
+
training_step(batch, batch_idx): Performs a training step on the model.
|
23 |
+
validation_step(batch, batch_idx): Performs a validation step on the model.
|
24 |
+
on_validation_epoch_end(): Called at the end of each validation epoch.
|
25 |
+
test_step(batch, batch_idx): Performs a test step on the model.
|
26 |
+
|
27 |
+
Example:
|
28 |
+
model = ResNet18(num_classes=2, optimizer_cfg=cfg.model.optimizer)
|
29 |
+
trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)
|
30 |
+
trainer.test(model, dataloaders=test_dataloader)
|
31 |
+
"""
|
32 |
+
def __init__(self, encoder='resnet18', num_classes=2, optimizer_cfg=None, l1_lambda=0.0):
|
33 |
+
super().__init__()
|
34 |
+
|
35 |
+
self.encoder = encoder
|
36 |
+
self.model = timm.create_model(encoder, pretrained=True)
|
37 |
+
if self.model.default_cfg["input_size"][1] == 3: # If model expects 3 channels
|
38 |
+
self.model.conv1 = torch.nn.Conv2d(
|
39 |
+
in_channels=1, # Change to single channel
|
40 |
+
out_channels=self.model.conv1.out_channels,
|
41 |
+
kernel_size=self.model.conv1.kernel_size,
|
42 |
+
stride=self.model.conv1.stride,
|
43 |
+
padding=self.model.conv1.padding,
|
44 |
+
bias=False
|
45 |
+
)
|
46 |
+
|
47 |
+
self.accuracy = Accuracy(task="multiclass", num_classes=num_classes)
|
48 |
+
self.precision = Precision(task="multiclass", num_classes=num_classes)
|
49 |
+
self.recall = Recall(task="multiclass", num_classes=num_classes)
|
50 |
+
self.f1 = F1Score(task="multiclass", num_classes=num_classes)
|
51 |
+
|
52 |
+
self.l1_lambda = l1_lambda
|
53 |
+
if hasattr(self.model, 'fc'): # For models with 'fc' as the classification layer
|
54 |
+
in_features = self.model.fc.in_features
|
55 |
+
self.model.fc = torch.nn.Linear(in_features, num_classes)
|
56 |
+
elif hasattr(self.model, 'classifier'): # For models with 'classifier'
|
57 |
+
in_features = self.model.classifier.in_features
|
58 |
+
self.model.classifier = torch.nn.Linear(in_features, num_classes)
|
59 |
+
elif hasattr(self.model, 'head'): # For models with 'head'
|
60 |
+
in_features = self.model.head.in_features
|
61 |
+
self.model.head = torch.nn.Linear(in_features, num_classes)
|
62 |
+
else:
|
63 |
+
raise ValueError(f"Unsupported model architecture for encoder: {encoder}")
|
64 |
+
|
65 |
+
if optimizer_cfg is not None:
|
66 |
+
optimizer_name = optimizer_cfg.name
|
67 |
+
optimizer_lr = optimizer_cfg.lr
|
68 |
+
optimizer_weight_decay = optimizer_cfg.weight_decay
|
69 |
+
|
70 |
+
if optimizer_name == 'Adam':
|
71 |
+
self.optimizer = optim.Adam(self.parameters(), lr=optimizer_lr, weight_decay=optimizer_weight_decay)
|
72 |
+
elif optimizer_name == 'SGD':
|
73 |
+
self.optimizer = optim.SGD(self.parameters(), lr=optimizer_lr, weight_decay=optimizer_weight_decay)
|
74 |
+
else:
|
75 |
+
raise ValueError(f"Unsupported optimizer: {optimizer_name}")
|
76 |
+
else:
|
77 |
+
self.optimizer = None
|
78 |
+
|
79 |
+
def forward(self, x):
|
80 |
+
return self.model(x)
|
81 |
+
|
82 |
+
def configure_optimizers(self):
|
83 |
+
optimizer = self.optimizer
|
84 |
+
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.2, patience=20, min_lr=5e-5)
|
85 |
+
return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"}
|
86 |
+
|
87 |
+
def training_step(self, batch, batch_idx):
|
88 |
+
x, y = batch
|
89 |
+
y = y.long()
|
90 |
+
|
91 |
+
# Compute predictions and loss
|
92 |
+
logits = self(x)
|
93 |
+
loss = torch.nn.functional.cross_entropy(logits, y)
|
94 |
+
|
95 |
+
# Add L1 regularization
|
96 |
+
l1_norm = sum(param.abs().sum() for param in self.parameters())
|
97 |
+
loss += self.l1_lambda * l1_norm
|
98 |
+
|
99 |
+
self.log('train_loss', loss, prog_bar=True, on_epoch=True, on_step=False, logger=True)
|
100 |
+
|
101 |
+
return loss
|
102 |
+
|
103 |
+
def validation_step(self, batch, batch_idx):
|
104 |
+
x, y = batch
|
105 |
+
y = y.long()
|
106 |
+
|
107 |
+
logits = self(x)
|
108 |
+
loss = torch.nn.functional.cross_entropy(logits, y)
|
109 |
+
|
110 |
+
preds = torch.argmax(logits, dim=1)
|
111 |
+
accuracy = self.accuracy(y, preds)
|
112 |
+
precision = self.precision(y, preds)
|
113 |
+
recall = self.recall(y, preds)
|
114 |
+
f1 = self.f1(y, preds)
|
115 |
+
|
116 |
+
self.log('val_loss', loss, prog_bar=True, on_epoch=True, on_step=True)
|
117 |
+
self.log('val_acc', accuracy, prog_bar=True, on_epoch=True, on_step=True)
|
118 |
+
self.log('val_precision', precision, prog_bar=True, on_epoch=True, on_step=True)
|
119 |
+
self.log('val_recall', recall, prog_bar=True, on_epoch=True, on_step=True)
|
120 |
+
self.log('val_f1', f1, prog_bar=True, on_epoch=True, on_step=True)
|
121 |
+
|
122 |
+
return loss
|
123 |
+
|
124 |
+
def on_validation_epoch_end(self):
|
125 |
+
avg_loss = self.trainer.logged_metrics['val_loss_epoch']
|
126 |
+
accuracy = self.trainer.logged_metrics['val_acc_epoch']
|
127 |
+
|
128 |
+
self.log('val_loss', avg_loss, prog_bar=True, on_epoch=True)
|
129 |
+
self.log('val_acc', accuracy, prog_bar=True, on_epoch=True)
|
130 |
+
|
131 |
+
return {'Average Loss:': avg_loss, 'Accuracy:': accuracy}
|
132 |
+
|
133 |
+
def test_step(self, batch, batch_idx):
|
134 |
+
x, y = batch
|
135 |
+
y = y.long()
|
136 |
+
logits = self(x)
|
137 |
+
loss = torch.nn.functional.cross_entropy(logits, y)
|
138 |
+
|
139 |
+
preds = torch.argmax(logits, dim=1)
|
140 |
+
accuracy = self.accuracy(y, preds)
|
141 |
+
precision = self.precision(y, preds)
|
142 |
+
recall = self.recall(y, preds)
|
143 |
+
f1 = self.f1(y, preds)
|
144 |
+
|
145 |
+
# Log test metrics
|
146 |
+
self.log('test_loss', loss, prog_bar=True, logger=True)
|
147 |
+
self.log('test_acc', accuracy, prog_bar=True, logger=True)
|
148 |
+
self.log('test_precision', precision, prog_bar=True, logger=True)
|
149 |
+
self.log('test_recall', recall, prog_bar=True, logger=True)
|
150 |
+
self.log('test_f1', f1, prog_bar=True, logger=True)
|
151 |
+
|
152 |
+
return {'test_loss': loss, 'test_accuracy': accuracy, 'test_precision': precision, 'test_recall': recall, 'test_f1': f1}
|
153 |
+
|
154 |
+
|
155 |
+
|
156 |
+
class CTCEncoderPL(pl.LightningModule):
|
157 |
+
def __init__(self, ctc_encoder, num_classes, optimizer_cfg):
|
158 |
+
super(CTCEncoderPL, self).__init__()
|
159 |
+
self.ctc_encoder = ctc_encoder
|
160 |
+
self.ctc_loss = torch.nn.CTCLoss(blank=0, zero_infinity=True)
|
161 |
+
self.optimizer_cfg = optimizer_cfg
|
162 |
+
self.accuracy = Accuracy(task="multiclass", num_classes=num_classes)
|
163 |
+
self.precision = Precision(task="multiclass", num_classes=num_classes)
|
164 |
+
self.recall = Recall(task="multiclass", num_classes=num_classes)
|
165 |
+
self.f1 = F1Score(task="multiclass", num_classes=num_classes)
|
166 |
+
|
167 |
+
|
168 |
+
if optimizer_cfg is not None:
|
169 |
+
optimizer_name = optimizer_cfg.name
|
170 |
+
optimizer_lr = optimizer_cfg.lr
|
171 |
+
optimizer_weight_decay = optimizer_cfg.weight_decay
|
172 |
+
|
173 |
+
if optimizer_name == 'Adam':
|
174 |
+
self.optimizer = optim.Adam(self.parameters(), lr=optimizer_lr, weight_decay=optimizer_weight_decay)
|
175 |
+
elif optimizer_name == 'SGD':
|
176 |
+
self.optimizer = optim.SGD(self.parameters(), lr=optimizer_lr, weight_decay=optimizer_weight_decay)
|
177 |
+
else:
|
178 |
+
raise ValueError(f"Unsupported optimizer: {optimizer_name}")
|
179 |
+
else:
|
180 |
+
self.optimizer = None
|
181 |
+
def forward(self, x):
|
182 |
+
return self.ctc_encoder(x)
|
183 |
+
|
184 |
+
def training_step(self, batch, batch_idx):
|
185 |
+
x, y, input_lengths, target_lengths = batch
|
186 |
+
|
187 |
+
logits, input_lengths = self.ctc_encoder(x, input_lengths)
|
188 |
+
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
|
189 |
+
loss = self.ctc_loss(log_probs, y, input_lengths, target_lengths)
|
190 |
+
assert input_lengths.size(0) == x.size(0), f"input_lengths size ({input_lengths.size(0)}) must match batch size ({x.size(0)})"
|
191 |
+
preds = torch.argmax(log_probs, dim=-1)
|
192 |
+
self.log("train_loss", loss, on_epoch=True)
|
193 |
+
return loss
|
194 |
+
|
195 |
+
def validation_step(self, batch, batch_idx):
|
196 |
+
x, y, input_lengths, target_lengths = batch
|
197 |
+
|
198 |
+
# Compute logits and adjust input lengths
|
199 |
+
logits, input_lengths = self.ctc_encoder(x, input_lengths)
|
200 |
+
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
|
201 |
+
|
202 |
+
# Validate input_lengths size
|
203 |
+
assert input_lengths.size(0) == logits.size(0), "Mismatch between input_lengths and batch size"
|
204 |
+
|
205 |
+
# Compute CTC loss
|
206 |
+
loss = self.ctc_loss(log_probs, y, input_lengths, target_lengths)
|
207 |
+
|
208 |
+
# Compute metrics
|
209 |
+
preds = torch.argmax(log_probs, dim=-1)
|
210 |
+
accuracy = self.accuracy(y, preds)
|
211 |
+
precision = self.precision(y, preds)
|
212 |
+
recall = self.recall(y, preds)
|
213 |
+
f1 = self.f1(y, preds)
|
214 |
+
|
215 |
+
# Log metrics
|
216 |
+
self.log('val_loss', loss, prog_bar=True, on_epoch=True, on_step=True)
|
217 |
+
self.log('val_acc', accuracy, prog_bar=True, on_epoch=True, on_step=True)
|
218 |
+
self.log('val_precision', precision, prog_bar=True, on_epoch=True, on_step=True)
|
219 |
+
self.log('val_recall', recall, prog_bar=True, on_epoch=True, on_step=True)
|
220 |
+
self.log('val_f1', f1, prog_bar=True, on_epoch=True, on_step=True)
|
221 |
+
|
222 |
+
return loss
|
223 |
+
|
224 |
+
def on_validation_epoch_end(self):
|
225 |
+
avg_loss = self.trainer.logged_metrics['val_loss_epoch']
|
226 |
+
accuracy = self.trainer.logged_metrics['val_acc_epoch']
|
227 |
+
|
228 |
+
self.log('val_loss', avg_loss, prog_bar=True, on_epoch=True)
|
229 |
+
self.log('val_acc', accuracy, prog_bar=True, on_epoch=True)
|
230 |
+
|
231 |
+
return {'Average Loss:': avg_loss, 'Accuracy:': accuracy}
|
232 |
+
|
233 |
+
def test_step(self, batch, batch_idx):
|
234 |
+
x, y, input_lengths, target_lengths = batch
|
235 |
+
logits = self(x)
|
236 |
+
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
|
237 |
+
loss = self.ctc_loss(log_probs, y, input_lengths, target_lengths)
|
238 |
+
|
239 |
+
preds = torch.argmax(log_probs, dim=-1)
|
240 |
+
accuracy = self.accuracy(y, preds)
|
241 |
+
precision = self.precision(y, preds)
|
242 |
+
recall = self.recall(y, preds)
|
243 |
+
f1 = self.f1(y, preds)
|
244 |
+
|
245 |
+
self.log('test_loss', loss, prog_bar=True, logger=True)
|
246 |
+
self.log('test_acc', accuracy, prog_bar=True, logger=True)
|
247 |
+
self.log('test_precision', precision, prog_bar=True, logger=True)
|
248 |
+
self.log('test_recall', recall, prog_bar=True, logger=True)
|
249 |
+
self.log('test_f1', f1, prog_bar=True, logger=True)
|
250 |
+
|
251 |
+
return {'test_loss': loss, 'test_accuracy': accuracy, 'test_precision': precision, 'test_recall': recall, 'test_f1': f1}
|
252 |
+
|
253 |
+
def configure_optimizers(self):
|
254 |
+
optimizer = self.optimizer
|
255 |
+
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.2, patience=20, min_lr=5e-5)
|
256 |
+
return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"}
|
257 |
+
|
258 |
+
def greedy_decode(self, log_probs):
|
259 |
+
"""
|
260 |
+
Perform greedy decoding to get predictions from log probabilities.
|
261 |
+
"""
|
262 |
+
preds = torch.argmax(log_probs, dim=-1)
|
263 |
+
return preds
|
encoders/transformer.py
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytorch_lightning as pl
|
2 |
+
import torch
|
3 |
+
from torchmetrics import Accuracy, Precision, Recall, F1Score
|
4 |
+
from transformers import Wav2Vec2Model, Wav2Vec2ForSequenceClassification
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
|
8 |
+
class Wav2Vec2Classifier(pl.LightningModule):
|
9 |
+
def __init__(self, num_classes, optimizer_cfg = "Adam", l1_lambda=0.0):
|
10 |
+
super(Wav2Vec2Classifier, self).__init__()
|
11 |
+
self.save_hyperparameters()
|
12 |
+
|
13 |
+
# Wav2Vec2 backbone
|
14 |
+
# self.wav2vec2 = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
|
15 |
+
self.wav2vec2 = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-large-xlsr-53")
|
16 |
+
|
17 |
+
# trying without the need to fine tune it
|
18 |
+
for param in self.wav2vec2.parameters():
|
19 |
+
param.requires_grad = False
|
20 |
+
# Classification head
|
21 |
+
self.classifier = torch.nn.Linear(self.wav2vec2.config.hidden_size, num_classes)
|
22 |
+
|
23 |
+
# Metrics
|
24 |
+
self.accuracy = Accuracy(task="multiclass", num_classes=num_classes)
|
25 |
+
self.precision = Precision(task="multiclass", num_classes=num_classes)
|
26 |
+
self.recall = Recall(task="multiclass", num_classes=num_classes)
|
27 |
+
self.f1 = F1Score(task="multiclass", num_classes=num_classes)
|
28 |
+
|
29 |
+
self.l1_lambda = l1_lambda
|
30 |
+
if optimizer_cfg is not None:
|
31 |
+
optimizer_name = optimizer_cfg.name
|
32 |
+
optimizer_lr = optimizer_cfg.lr
|
33 |
+
optimizer_weight_decay = optimizer_cfg.weight_decay
|
34 |
+
|
35 |
+
if optimizer_name == 'Adam':
|
36 |
+
self.optimizer = torch.optim.Adam(self.parameters(), lr=optimizer_lr, weight_decay=optimizer_weight_decay)
|
37 |
+
elif optimizer_name == 'SGD':
|
38 |
+
self.optimizer = torch.optim.SGD(self.parameters(), lr=optimizer_lr, weight_decay=optimizer_weight_decay)
|
39 |
+
else:
|
40 |
+
raise ValueError(f"Unsupported optimizer: {optimizer_name}")
|
41 |
+
else:
|
42 |
+
self.optimizer = None
|
43 |
+
|
44 |
+
def forward(self, x, attention_mask=None):
|
45 |
+
# Debug input shape
|
46 |
+
|
47 |
+
# Ensure input shape is [batch_size, sequence_length]
|
48 |
+
if x.dim() > 2:
|
49 |
+
x = x.squeeze(-1) # Remove unnecessary dimensions if present
|
50 |
+
|
51 |
+
# Pass through Wav2Vec2 backbone
|
52 |
+
output = self.wav2vec2(x, attention_mask=attention_mask)
|
53 |
+
x = output.last_hidden_state
|
54 |
+
|
55 |
+
# Classification head
|
56 |
+
x = torch.mean(x, dim=1) # Pooling
|
57 |
+
logits = self.classifier(x)
|
58 |
+
return logits
|
59 |
+
|
60 |
+
|
61 |
+
def training_step(self, batch, batch_idx):
|
62 |
+
x, attention_mask, y = batch
|
63 |
+
|
64 |
+
# Forward pass
|
65 |
+
logits = self(x, attention_mask=attention_mask)
|
66 |
+
|
67 |
+
# Compute loss
|
68 |
+
loss = F.cross_entropy(logits, y)
|
69 |
+
|
70 |
+
# Add L1 regularization if specified
|
71 |
+
l1_norm = sum(param.abs().sum() for param in self.parameters())
|
72 |
+
loss += self.l1_lambda * l1_norm
|
73 |
+
|
74 |
+
# Log metrics
|
75 |
+
self.log("train_loss", loss, prog_bar=True, logger=True)
|
76 |
+
return loss
|
77 |
+
|
78 |
+
def validation_step(self, batch, batch_idx):
|
79 |
+
x, attention_mask, y = batch # Unpack batch
|
80 |
+
|
81 |
+
# Forward pass
|
82 |
+
logits = self(x, attention_mask=attention_mask)
|
83 |
+
|
84 |
+
|
85 |
+
# Compute loss and metrics
|
86 |
+
loss = F.cross_entropy(logits, y)
|
87 |
+
preds = torch.argmax(logits, dim=1)
|
88 |
+
accuracy = self.accuracy(preds, y)
|
89 |
+
precision = self.precision(preds, y)
|
90 |
+
recall = self.recall(preds, y)
|
91 |
+
f1 = self.f1(preds, y)
|
92 |
+
|
93 |
+
# Log metrics
|
94 |
+
self.log("val_loss", loss, prog_bar=True, logger=True)
|
95 |
+
self.log("val_acc", accuracy, prog_bar=True, logger=True)
|
96 |
+
self.log("val_precision", precision, prog_bar=True, logger=True)
|
97 |
+
self.log("val_recall", recall, prog_bar=True, logger=True)
|
98 |
+
self.log("val_f1", f1, prog_bar=True, logger=True)
|
99 |
+
return loss
|
100 |
+
|
101 |
+
def test_step(self, batch, batch_idx):
|
102 |
+
x, attention_mask, y = batch # Unpack batch
|
103 |
+
|
104 |
+
# Forward pass
|
105 |
+
logits = self(x, attention_mask=attention_mask)
|
106 |
+
|
107 |
+
|
108 |
+
# Compute loss and metrics
|
109 |
+
loss = F.cross_entropy(logits, y)
|
110 |
+
preds = torch.argmax(logits, dim=1)
|
111 |
+
accuracy = self.accuracy(preds, y)
|
112 |
+
precision = self.precision(preds, y)
|
113 |
+
recall = self.recall(preds, y)
|
114 |
+
f1 = self.f1(preds, y)
|
115 |
+
|
116 |
+
# Log metrics
|
117 |
+
self.log("test_loss", loss, prog_bar=True, logger=True)
|
118 |
+
self.log("test_acc", accuracy, prog_bar=True, logger=True)
|
119 |
+
self.log("test_precision", precision, prog_bar=True, logger=True)
|
120 |
+
self.log("test_recall", recall, prog_bar=True, logger=True)
|
121 |
+
self.log("test_f1", f1, prog_bar=True, logger=True)
|
122 |
+
|
123 |
+
return {"test_loss": loss, "test_accuracy": accuracy}
|
124 |
+
|
125 |
+
def configure_optimizers(self):
|
126 |
+
optimizer = self.optimizer
|
127 |
+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.2, patience=20, min_lr=5e-5)
|
128 |
+
return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"}
|
129 |
+
|
130 |
+
|
131 |
+
class Wav2Vec2EmotionClassifier(pl.LightningModule):
|
132 |
+
def __init__(self, num_classes, learning_rate=1e-4, freeze_base=False, optimizer_cfg="AdamW"):
|
133 |
+
super(Wav2Vec2EmotionClassifier, self).__init__()
|
134 |
+
self.save_hyperparameters()
|
135 |
+
|
136 |
+
# Load a pre-trained Wav2Vec2 model optimized for emotion recognition
|
137 |
+
self.model = Wav2Vec2ForSequenceClassification.from_pretrained(
|
138 |
+
"audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim",
|
139 |
+
num_labels=num_classes,
|
140 |
+
)
|
141 |
+
# Optionally freeze the Wav2Vec2 base layers
|
142 |
+
if freeze_base:
|
143 |
+
for param in self.model.wav2vec2.parameters():
|
144 |
+
param.requires_grad = False
|
145 |
+
|
146 |
+
# Metrics
|
147 |
+
self.accuracy = Accuracy(task="multiclass", num_classes=num_classes)
|
148 |
+
self.precision = Precision(task="multiclass", num_classes=num_classes)
|
149 |
+
self.recall = Recall(task="multiclass", num_classes=num_classes)
|
150 |
+
self.f1 = F1Score(task="multiclass", num_classes=num_classes)
|
151 |
+
|
152 |
+
self.learning_rate = learning_rate
|
153 |
+
if optimizer_cfg is not None:
|
154 |
+
optimizer_name = optimizer_cfg['name']
|
155 |
+
optimizer_lr = optimizer_cfg['lr']
|
156 |
+
optimizer_weight_decay = optimizer_cfg['weight_decay']
|
157 |
+
|
158 |
+
if optimizer_name == 'Adam':
|
159 |
+
self.optimizer = torch.optim.Adam(self.parameters(), lr=optimizer_lr, weight_decay=optimizer_weight_decay)
|
160 |
+
elif optimizer_name == 'SGD':
|
161 |
+
self.optimizer = torch.optim.SGD(self.parameters(), lr=optimizer_lr, weight_decay=optimizer_weight_decay)
|
162 |
+
elif optimizer_name == 'AdamW':
|
163 |
+
self.optimizer = torch.optim.AdamW(self.parameters(), lr=optimizer_lr, weight_decay=optimizer_weight_decay)
|
164 |
+
else:
|
165 |
+
raise ValueError(f"Unsupported optimizer: {optimizer_name}")
|
166 |
+
else:
|
167 |
+
self.optimizer = None
|
168 |
+
|
169 |
+
def forward(self, x, attention_mask=None):
|
170 |
+
return self.model(x, attention_mask=attention_mask).logits
|
171 |
+
|
172 |
+
def training_step(self, batch, batch_idx):
|
173 |
+
x, attention_mask, y = batch
|
174 |
+
|
175 |
+
# Forward pass
|
176 |
+
logits = self(x, attention_mask=attention_mask)
|
177 |
+
|
178 |
+
# Compute loss
|
179 |
+
loss = F.cross_entropy(logits, y)
|
180 |
+
|
181 |
+
# Log training loss
|
182 |
+
self.log("train_loss", loss, prog_bar=True, logger=True)
|
183 |
+
return loss
|
184 |
+
|
185 |
+
def validation_step(self, batch, batch_idx):
|
186 |
+
x, attention_mask, y = batch
|
187 |
+
|
188 |
+
# Forward pass
|
189 |
+
logits = self(x, attention_mask=attention_mask)
|
190 |
+
|
191 |
+
# Compute loss and metrics
|
192 |
+
loss = F.cross_entropy(logits, y)
|
193 |
+
preds = torch.argmax(logits, dim=1)
|
194 |
+
|
195 |
+
accuracy = self.accuracy(preds, y)
|
196 |
+
precision = self.precision(preds, y)
|
197 |
+
recall = self.recall(preds, y)
|
198 |
+
f1 = self.f1(preds, y)
|
199 |
+
|
200 |
+
# Log metrics
|
201 |
+
self.log("val_loss", loss, prog_bar=True, logger=True)
|
202 |
+
self.log("val_acc", accuracy, prog_bar=True, logger=True)
|
203 |
+
self.log("val_precision", precision, prog_bar=True, logger=True)
|
204 |
+
self.log("val_recall", recall, prog_bar=True, logger=True)
|
205 |
+
self.log("val_f1", f1, prog_bar=True, logger=True)
|
206 |
+
return loss
|
207 |
+
|
208 |
+
def test_step(self, batch, batch_idx):
|
209 |
+
x, attention_mask, y = batch
|
210 |
+
|
211 |
+
# Forward pass
|
212 |
+
logits = self(x, attention_mask=attention_mask)
|
213 |
+
|
214 |
+
# Compute loss and metrics
|
215 |
+
loss = F.cross_entropy(logits, y)
|
216 |
+
preds = torch.argmax(logits, dim=1)
|
217 |
+
accuracy = self.accuracy(preds, y)
|
218 |
+
precision = self.precision(preds, y)
|
219 |
+
recall = self.recall(preds, y)
|
220 |
+
f1 = self.f1(preds, y)
|
221 |
+
|
222 |
+
# Log metrics
|
223 |
+
self.log("test_loss", loss, prog_bar=True, logger=True)
|
224 |
+
self.log("test_acc", accuracy, prog_bar=True, logger=True)
|
225 |
+
self.log("test_precision", precision, prog_bar=True, logger=True)
|
226 |
+
self.log("test_recall", recall, prog_bar=True, logger=True)
|
227 |
+
self.log("test_f1", f1, prog_bar=True, logger=True)
|
228 |
+
|
229 |
+
return {"test_loss": loss, "test_accuracy": accuracy}
|
230 |
+
def configure_optimizers(self):
|
231 |
+
optimizer = self.optimizer
|
232 |
+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.2, patience=20, min_lr=5e-5)
|
233 |
+
return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"}
|
model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0116227115053abebb4951ef1bd0bd25750797f2bfe98d74df152dc2289295d6
|
3 |
+
size 658272386
|
models/CTCencoder.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
class CTCEncoder(nn.Module):
|
5 |
+
def __init__(self, num_classes, cnn_output_dim=256, rnn_hidden_dim=256, rnn_layers=3):
|
6 |
+
"""
|
7 |
+
CTC Encoder with a CNN feature extractor and LSTM for sequence modeling.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
num_classes (int): Number of output classes for the model.
|
11 |
+
cnn_output_dim (int): Number of output channels from the CNN.
|
12 |
+
rnn_hidden_dim (int): Hidden size of the LSTM.
|
13 |
+
rnn_layers (int): Number of layers in the LSTM.
|
14 |
+
"""
|
15 |
+
super(CTCEncoder, self).__init__()
|
16 |
+
|
17 |
+
# CNN Feature Extractor
|
18 |
+
self.feature_extractor = nn.Sequential(
|
19 |
+
nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
|
20 |
+
nn.ReLU(),
|
21 |
+
nn.MaxPool2d(kernel_size=2, stride=2), # Down-sample by 2
|
22 |
+
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
|
23 |
+
nn.ReLU(),
|
24 |
+
nn.MaxPool2d(kernel_size=2, stride=2), # Down-sample by another 2
|
25 |
+
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
|
26 |
+
nn.ReLU(),
|
27 |
+
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
|
28 |
+
nn.ReLU(),
|
29 |
+
nn.AdaptiveAvgPool2d((1, None)) # Ensure output height is 1
|
30 |
+
)
|
31 |
+
|
32 |
+
# Bidirectional LSTM
|
33 |
+
self.rnn_hidden_dim = rnn_hidden_dim
|
34 |
+
self.rnn_layers = rnn_layers
|
35 |
+
self.cnn_output_dim = cnn_output_dim
|
36 |
+
|
37 |
+
self.rnn = nn.LSTM(
|
38 |
+
input_size=cnn_output_dim, # Output channels from CNN
|
39 |
+
hidden_size=rnn_hidden_dim,
|
40 |
+
num_layers=rnn_layers,
|
41 |
+
batch_first=True,
|
42 |
+
bidirectional=True
|
43 |
+
)
|
44 |
+
|
45 |
+
# Fully connected layer
|
46 |
+
self.fc = nn.Linear(rnn_hidden_dim * 2, num_classes)
|
47 |
+
|
48 |
+
def compute_input_lengths(self, input_lengths):
|
49 |
+
"""
|
50 |
+
Adjusts input lengths based on the CNN's down-sampling operations.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
input_lengths (torch.Tensor): Original input lengths.
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
torch.Tensor: Adjusted input lengths.
|
57 |
+
"""
|
58 |
+
# Account for down-sampling by MaxPool layers (factor of 2 for each MaxPool)
|
59 |
+
input_lengths = input_lengths // 2 # First MaxPool
|
60 |
+
input_lengths = input_lengths // 2 # Second MaxPool
|
61 |
+
input_lengths = input_lengths // 2 # Third pooling layer or additional down-sampling
|
62 |
+
return input_lengths
|
63 |
+
|
64 |
+
def forward(self, x, input_lengths):
|
65 |
+
"""
|
66 |
+
Forward pass through the encoder.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
x (torch.Tensor): Input tensor of shape [B, 1, H, W].
|
70 |
+
input_lengths (torch.Tensor): Lengths of the sequences in the batch.
|
71 |
+
|
72 |
+
Returns:
|
73 |
+
torch.Tensor: Logits of shape [B, T, num_classes].
|
74 |
+
torch.Tensor: Adjusted input lengths.
|
75 |
+
"""
|
76 |
+
# Feature extraction
|
77 |
+
x = self.feature_extractor(x) # [Batch_Size, Channels, Height, Width]
|
78 |
+
print(f"Shape after CNN: {x.shape}") # Debug the shape
|
79 |
+
|
80 |
+
# Reshape for LSTM
|
81 |
+
x = x.squeeze(2).permute(0, 2, 1) # [Batch_Size, Sequence_Length, Features]
|
82 |
+
assert x.size(-1) == 256, f"Expected last dimension to be 256, but got {x.size(-1)}"
|
83 |
+
|
84 |
+
# Adjust input lengths
|
85 |
+
input_lengths = self.compute_input_lengths(input_lengths)
|
86 |
+
assert input_lengths.size(0) == x.size(0), f"input_lengths size ({input_lengths.size(0)}) must match batch size ({x.size(0)})"
|
87 |
+
|
88 |
+
# Pass through LSTM
|
89 |
+
x, _ = self.rnn(x) # [Batch_Size, Sequence_Length, 2 * Hidden_Dim]
|
90 |
+
|
91 |
+
# Fully connected output
|
92 |
+
x = self.fc(x) # [Batch_Size, Sequence_Length, Num_Classes]
|
93 |
+
return x, input_lengths
|
models/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from . import CTCencoder
|
models/__pycache__/CTCencoder.cpython-311.pyc
ADDED
Binary file (4.67 kB). View file
|
|
models/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (207 Bytes). View file
|
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
librosa
|
3 |
+
torch
|
4 |
+
transformers
|
5 |
+
numpy
|
statics/style.css
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#audio_input {
|
2 |
+
border: 2px solid #4CAF50;
|
3 |
+
border-radius: 10px;
|
4 |
+
}
|
5 |
+
#submit_button {
|
6 |
+
background-color: #4CAF50;
|
7 |
+
color: white;
|
8 |
+
border-radius: 5px;
|
9 |
+
}
|
upload_model.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from huggingface_hub import HfApi, HfFolder, Repository
|
2 |
+
|
3 |
+
api = HfApi()
|
4 |
+
repo_url = api.create_repo(repo_id="saeedbenadeeb/emotion-detection", exist_ok=True)
|
5 |
+
|
6 |
+
repo = Repository(local_dir="emotion-detection", clone_from=repo_url)
|
7 |
+
repo.git_pull()
|
8 |
+
|
9 |
+
# Copy model files to the repo directory
|
10 |
+
import shutil
|
11 |
+
shutil.copy("model.pth", "emotion-detection")
|
12 |
+
|
13 |
+
# Add files and push
|
14 |
+
repo.push_to_hub(commit_message="Initial model upload")
|
utils/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from . import random_split
|
2 |
+
from . import helper_functions
|
utils/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (263 Bytes). View file
|
|
utils/__pycache__/helper_functions.cpython-311.pyc
ADDED
Binary file (3.61 kB). View file
|
|
utils/__pycache__/random_split.cpython-311.pyc
ADDED
Binary file (2.39 kB). View file
|
|
utils/helper_functions.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import os
|
4 |
+
import shutil
|
5 |
+
def normalize_ratios(ratios):
|
6 |
+
total = sum(ratios)
|
7 |
+
return [r / total for r in ratios]
|
8 |
+
|
9 |
+
from torch.nn.utils.rnn import pad_sequence
|
10 |
+
|
11 |
+
def collate_fn_transformer(batch):
|
12 |
+
"""
|
13 |
+
Custom collate function to handle variable-length raw waveform inputs.
|
14 |
+
Args:
|
15 |
+
batch: List of tuples (tensor, label), where tensor has shape [sequence_length].
|
16 |
+
Returns:
|
17 |
+
padded_waveforms: Padded tensor of shape [batch_size, max_seq_len].
|
18 |
+
attention_mask: Attention mask for padded sequences.
|
19 |
+
labels: Tensor of shape [batch_size].
|
20 |
+
"""
|
21 |
+
# Separate waveforms and labels
|
22 |
+
waveforms, labels = zip(*batch)
|
23 |
+
|
24 |
+
# Ensure waveforms are 1D tensors
|
25 |
+
waveforms = [torch.tensor(waveform).squeeze() for waveform in waveforms]
|
26 |
+
|
27 |
+
# Pad sequences to the same length
|
28 |
+
padded_waveforms = pad_sequence(waveforms, batch_first=True) # [batch_size, max_seq_len]
|
29 |
+
|
30 |
+
|
31 |
+
# Create attention mask
|
32 |
+
attention_mask = (padded_waveforms != 0).long() # Mask for non-padded values
|
33 |
+
# In the training loop or DataLoader debug
|
34 |
+
|
35 |
+
|
36 |
+
# Convert labels to a tensor
|
37 |
+
labels = torch.tensor(labels, dtype=torch.long)
|
38 |
+
|
39 |
+
return padded_waveforms, attention_mask, labels
|
40 |
+
|
41 |
+
def collate_fn(batch):
|
42 |
+
inputs, targets, input_lengths, target_lengths = zip(*batch)
|
43 |
+
inputs = torch.stack(inputs) # Convert list of tensors to a batch tensor
|
44 |
+
targets = torch.cat(targets) # Flatten target sequences
|
45 |
+
input_lengths = torch.tensor(input_lengths, dtype=torch.long)
|
46 |
+
target_lengths = torch.tensor(target_lengths, dtype=torch.long)
|
47 |
+
return inputs, targets, input_lengths, target_lengths
|
48 |
+
|
49 |
+
|
50 |
+
|
51 |
+
def save_test_data(test_dataset, dataset, save_dir):
|
52 |
+
|
53 |
+
if os.path.exists(save_dir):
|
54 |
+
shutil.rmtree(save_dir) # Delete the existing directory and its contents
|
55 |
+
print(f"Existing test data directory '{save_dir}' removed.")
|
56 |
+
|
57 |
+
os.makedirs(save_dir, exist_ok=True)
|
58 |
+
|
59 |
+
for idx in test_dataset.indices:
|
60 |
+
audio_file_path = dataset.audio_files[idx] # Assuming dataset has `audio_files` attribute
|
61 |
+
label = dataset.labels[idx] # Assuming dataset has `labels` attribute
|
62 |
+
|
63 |
+
# Create a directory for the label if it doesn't exist
|
64 |
+
label_dir = os.path.join(save_dir, str(label))
|
65 |
+
os.makedirs(label_dir, exist_ok=True)
|
66 |
+
|
67 |
+
# Copy the audio file to the label directory
|
68 |
+
shutil.copy(audio_file_path, label_dir)
|
69 |
+
|
70 |
+
print(f"Test data saved in {save_dir}")
|
utils/random_split.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
import torch
|
3 |
+
from torch.utils.data import Subset
|
4 |
+
from sklearn.model_selection import train_test_split
|
5 |
+
from utils.helper_functions import normalize_ratios
|
6 |
+
|
7 |
+
def stratified_random_split(ds: torch.utils.data.Dataset, parts: List[float], targets: List[int]) -> List[torch.utils.data.Dataset]:
|
8 |
+
"""
|
9 |
+
Perform a stratified random split on the dataset.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
ds: PyTorch dataset to split.
|
13 |
+
parts: List of proportions that sum to 1.
|
14 |
+
targets: List of labels corresponding to dataset samples.
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
List of PyTorch datasets corresponding to the splits.
|
18 |
+
"""
|
19 |
+
total_length = len(ds)
|
20 |
+
|
21 |
+
# Normalize ratios
|
22 |
+
parts = normalize_ratios(parts)
|
23 |
+
|
24 |
+
lengths = list(map(lambda p: int(p * total_length), parts))
|
25 |
+
left_over = total_length - sum(lengths)
|
26 |
+
lengths[0] += left_over # Adjust first split to account for leftover
|
27 |
+
|
28 |
+
indices = list(range(total_length))
|
29 |
+
train_indices, temp_indices, _, temp_targets = train_test_split(
|
30 |
+
indices, targets, test_size=(1 - parts[0]), stratify=targets, random_state=42
|
31 |
+
)
|
32 |
+
val_size = parts[1] / (parts[1] + parts[2])
|
33 |
+
val_indices, test_indices, _, _ = train_test_split(
|
34 |
+
temp_indices, temp_targets, test_size=(1 - val_size), stratify=temp_targets, random_state=42
|
35 |
+
)
|
36 |
+
|
37 |
+
return [Subset(ds, train_indices), Subset(ds, val_indices), Subset(ds, test_indices)]
|