saeedbenadeeb commited on
Commit
0874d87
·
1 Parent(s): 9a0a0d8

First commit

Browse files
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
app.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import librosa
4
+ import numpy as np
5
+ import torch.nn.functional as F
6
+ import os
7
+
8
+ from encoders.transformer import Wav2Vec2EmotionClassifier
9
+
10
+ # Define the emotions
11
+ emotions = ["happy", "sad", "angry", "neutral", "fear", "disgust", "surprise"]
12
+ label_mapping = {str(idx): emotion for idx, emotion in enumerate(emotions)}
13
+
14
+ # Load the trained model
15
+ model_path = "model.pth"
16
+ cfg = {
17
+ "model": {
18
+ "encoder": "Wav2Vec2Classifier",
19
+ "optimizer": {
20
+ "name": "Adam",
21
+ "lr": 0.0003,
22
+ "weight_decay": 3e-4
23
+ },
24
+ "l1_lambda": 0.0
25
+ }
26
+ }
27
+ model = Wav2Vec2EmotionClassifier(num_classes=len(emotions), optimizer_cfg=cfg["model"]["optimizer"])
28
+ model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
29
+ model.eval()
30
+
31
+ # Optional: we define a minimum number of samples to avoid Wav2Vec2 conv errors
32
+ MIN_SAMPLES = 10 # or 16000 if you want at least 1 second
33
+
34
+ # Preprocessing function
35
+ def preprocess_audio(file_path, sample_rate=16000):
36
+ """
37
+ Safely loads the file at file_path and returns a (1, samples) torch tensor.
38
+ Returns None if the file is invalid or too short.
39
+ """
40
+ if not file_path or (not os.path.exists(file_path)):
41
+ # file_path could be None or an empty string if user didn't record properly
42
+ return None
43
+
44
+ # Load with librosa (which merges to mono by default if multi-channel)
45
+ waveform, sr = librosa.load(file_path, sr=sample_rate)
46
+
47
+ # Check length
48
+ if len(waveform) < MIN_SAMPLES:
49
+ return None
50
+
51
+ # Convert to torch tensor, shape (1, samples)
52
+ waveform_tensor = torch.tensor(waveform, dtype=torch.float32).unsqueeze(0)
53
+
54
+ return waveform_tensor
55
+
56
+ # Prediction function
57
+ def predict_emotion(audio_file):
58
+ """
59
+ audio_file is a file path from Gradio (type='filepath').
60
+ """
61
+ # Preprocess
62
+ waveform = preprocess_audio(audio_file, sample_rate=16000)
63
+
64
+ # If invalid or too short, return an error-like message
65
+ if waveform is None:
66
+ return (
67
+ "Audio is too short or invalid. Please record/upload a longer clip.",
68
+ ""
69
+ )
70
+
71
+ # Perform inference
72
+ with torch.no_grad():
73
+ logits = model(waveform)
74
+ probabilities = F.softmax(logits, dim=-1).cpu().numpy()[0]
75
+
76
+ # Get the predicted class
77
+ predicted_class = np.argmax(probabilities)
78
+ predicted_emotion = label_mapping[str(predicted_class)]
79
+
80
+ # Format probabilities for visualization
81
+ probabilities_output = [
82
+ f"""
83
+ <div style='display: flex; align-items: center; margin: 5px 0;'>
84
+ <div style='width: 20%; text-align: right; margin-right: 10px; font-weight: bold;'>{emotions[i]}</div>
85
+ <div style='flex-grow: 1; background-color: #374151; border-radius: 4px; overflow: hidden;'>
86
+ <div style='width: {probabilities[i]*100:.2f}%; background-color: #FFA500; height: 10px;'></div>
87
+ </div>
88
+ <div style='width: 10%; text-align: right; margin-left: 10px;'>{probabilities[i]*100:.2f}%</div>
89
+ </div>
90
+ """
91
+ for i in range(len(emotions))
92
+ ]
93
+
94
+ return predicted_emotion, "\n".join(probabilities_output)
95
+
96
+ # Create Gradio interface
97
+ def gradio_interface(audio):
98
+ detected_emotion, probabilities_html = predict_emotion(audio)
99
+ return detected_emotion, gr.HTML(probabilities_html)
100
+
101
+ # Define Gradio UI
102
+ with gr.Blocks(css="""
103
+ body {
104
+ background-color: #121212;
105
+ color: white;
106
+ font-family: Arial, sans-serif;
107
+ }
108
+ h1 {
109
+ color: #FFA500;
110
+ font-size: 48px;
111
+ text-align: center;
112
+ margin-bottom: 10px;
113
+ }
114
+ p {
115
+ text-align: center;
116
+ font-size: 18px;
117
+ }
118
+ .gradio-row {
119
+ justify-content: center;
120
+ align-items: center;
121
+ }
122
+ #submit_button {
123
+ background-color: #FFA500 !important;
124
+ color: black !important;
125
+ font-size: 18px;
126
+ padding: 10px 20px;
127
+ margin-top: 20px;
128
+ }
129
+ #detected_emotion {
130
+ font-size: 24px;
131
+ font-weight: bold;
132
+ text-align: center;
133
+ }
134
+ .probabilities-container {
135
+ margin-top: 20px;
136
+ padding: 10px;
137
+ background-color: #1F2937;
138
+ border-radius: 8px;
139
+ }
140
+ """) as demo:
141
+ gr.Markdown(
142
+ """
143
+ <div>
144
+ <h1>Speech Emotion Recognition</h1>
145
+ <p>🎵 Upload or record an audio file (max 1 minute) to detect emotions.</p>
146
+ <p>Supported Emotions: 😊 Happy | 😭 Sad | 😡 Angry | 😐 Neutral | 😨 Fear | 🤢 Disgust | 😮 Surprise</p>
147
+ </div>
148
+ """
149
+ )
150
+
151
+ with gr.Row():
152
+ with gr.Column(scale=1, elem_id="audio-block"):
153
+ # type="filepath" means we get a temporary file path from Gradio
154
+ audio_input = gr.Audio(label="🎤 Record or Upload Audio", type="filepath")
155
+ submit_button = gr.Button("Submit", elem_id="submit_button")
156
+ with gr.Column(scale=1):
157
+ detected_emotion_label = gr.Label(label="Detected Emotion", elem_id="detected_emotion")
158
+ probabilities_html = gr.HTML(label="Probabilities", elem_id="probabilities")
159
+
160
+ submit_button.click(
161
+ fn=gradio_interface,
162
+ inputs=audio_input,
163
+ outputs=[detected_emotion_label, probabilities_html]
164
+ )
165
+
166
+ # Launch the app
167
+ if __name__ == "__main__":
168
+ demo.launch(share=True)
datasets/TESS_Dataset.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ from torch.utils.data import Dataset
5
+ import librosa
6
+ from typing import List, Tuple
7
+ import shutil
8
+ import kagglehub
9
+ from transformers import Wav2Vec2Processor, Wav2Vec2Model
10
+ import subprocess
11
+ import zipfile
12
+ import os
13
+ # Constants (you may need to define these according to your requirements)
14
+ SAMPLE_RATE = 16000 # Define the sample rate for audio processing
15
+ DURATION = 3.0 # Duration of the audio in seconds
16
+
17
+ # Placeholder for waveform normalization
18
+ def normalize_waveform(audio: np.ndarray) -> torch.Tensor:
19
+ # Convert to tensor if necessary
20
+ if not isinstance(audio, torch.Tensor):
21
+ audio = torch.tensor(audio, dtype=torch.float32)
22
+ return (audio - torch.mean(audio)) / torch.std(audio)
23
+
24
+ class TESSRawWaveformDataset(Dataset):
25
+ def __init__(self, root_path: str, transform=None):
26
+ super().__init__()
27
+ self.root_path = root_path
28
+ self.audio_files = []
29
+ self.labels = []
30
+ self.emotions = ["happy", "sad", "angry", "neutral", "fear", "disgust", "surprise"]
31
+ emotion_mapping = {e.lower(): idx for idx, e in enumerate(self.emotions)}
32
+ self.download_dataset_if_not_exists()
33
+ # Load file paths and labels from nested directories
34
+ for root, dirs, files in os.walk(root_path):
35
+ for file_name in files:
36
+ if file_name.endswith(".wav"):
37
+ emotion_name = next(
38
+ (e for e in emotion_mapping if e in root.lower()), None
39
+ )
40
+ if emotion_name is not None:
41
+ self.audio_files.append(os.path.join(root, file_name))
42
+ self.labels.append(emotion_mapping[emotion_name])
43
+
44
+ self.labels = np.array(self.labels, dtype=np.int64)
45
+ self.transform = transform
46
+
47
+ def __len__(self):
48
+ return len(self.audio_files)
49
+
50
+ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
51
+ # Load raw waveform and label
52
+ audio_path = self.audio_files[idx]
53
+ label = self.labels[idx]
54
+ waveform = self.load_audio(audio_path)
55
+
56
+ if self.transform:
57
+ waveform = self.transform(waveform)
58
+
59
+ return waveform, label
60
+
61
+ @staticmethod
62
+ def load_audio(audio_path: str) -> torch.Tensor:
63
+ # Load audio and ensure it's at the correct sample rate
64
+ audio, sr = librosa.load(audio_path, sr=SAMPLE_RATE, duration=DURATION)
65
+ assert sr == SAMPLE_RATE, f"Sample rate mismatch: expected {SAMPLE_RATE}, got {sr}"
66
+ return normalize_waveform(audio)
67
+
68
+ def get_emotions(self) -> List[str]:
69
+ return self.emotions
70
+
71
+ def download_dataset_if_not_exists(self):
72
+ if not os.path.exists(self.root_path):
73
+ print(f"Dataset not found at {self.root_path}. Downloading...")
74
+
75
+ # Ensure the destination directory exists
76
+ os.makedirs(self.root_path, exist_ok=True)
77
+
78
+ # Download dataset using curl
79
+ dataset_zip_path = os.path.join(self.root_path, "toronto-emotional-speech-set-tess.zip")
80
+ curl_command = [
81
+ "curl",
82
+ "-L",
83
+ "-o",
84
+ dataset_zip_path,
85
+ "https://www.kaggle.com/api/v1/datasets/download/ejlok1/toronto-emotional-speech-set-tess",
86
+ ]
87
+
88
+ try:
89
+ subprocess.run(curl_command, check=True)
90
+ print(f"Dataset downloaded to {dataset_zip_path}.")
91
+
92
+ # Extract the downloaded zip file
93
+ with zipfile.ZipFile(dataset_zip_path, "r") as zip_ref:
94
+ zip_ref.extractall(self.root_path)
95
+ print(f"Dataset extracted to {self.root_path}.")
96
+
97
+ # Remove the zip file to save space
98
+ os.remove(dataset_zip_path)
99
+ print(f"Removed zip file: {dataset_zip_path}")
100
+
101
+ except subprocess.CalledProcessError as e:
102
+ print(f"Error occurred during dataset download: {e}")
103
+ raise
104
+
105
+
106
+ # Example usage
107
+ # dataset = TESSRawWaveformDataset(root_path="./TESS", transform=None)
108
+ # print("Number of samples:", len(dataset))
datasets/__init__.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from torch.utils.data import Dataset
3
+
4
+ from .image_dataset import CustomDataset
5
+ from .audio_dataset import EmodbDataset
6
+ from .ctc_audio_dataclass import CTCEmodbDataset
7
+ from .TESS_Dataset import TESSRawWaveformDataset
8
+
9
+ __dataset_mapper__ = {
10
+ "image": CustomDataset,
11
+ "emodb": EmodbDataset,
12
+ 'CTCemodb': CTCEmodbDataset,
13
+ 'TESSDataset': TESSRawWaveformDataset
14
+ }
15
+
16
+ def list_datasets() -> List[str]:
17
+ """Returns a list of available dataset names.
18
+
19
+ Returns:
20
+ List[str]: List of dataset names as strings.
21
+
22
+ Example:
23
+ >>> from datasets import list_datasets
24
+ >>> list_datasets()
25
+ ['image', 'emodb']
26
+ """
27
+ return sorted(__dataset_mapper__.keys())
28
+
29
+ def get_dataset_by_name(dataset: str, *args, **kwargs) -> Dataset:
30
+ """Returns the Dataset class using the given name and arguments.
31
+
32
+ Args:
33
+ dataset (str): The name of the dataset.
34
+
35
+ Returns:
36
+ Dataset: The requested dataset instance.
37
+
38
+ Example:
39
+ >>> from datasets import get_dataset_by_name
40
+ >>> dataset = get_dataset_by_name("emodb", root_path="./data/emodb")
41
+ >>> type(dataset)
42
+ <class 'datasets.audio_dataset.EmodbDataset'>
43
+ """
44
+ assert dataset in __dataset_mapper__, f"Dataset '{dataset}' not found in the mapper."
45
+ return __dataset_mapper__[dataset](*args, **kwargs)
datasets/__pycache__/TESS_Dataset.cpython-311.pyc ADDED
Binary file (6.95 kB). View file
 
datasets/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.96 kB). View file
 
datasets/__pycache__/audio_dataset.cpython-311.pyc ADDED
Binary file (7.66 kB). View file
 
datasets/__pycache__/ctc_audio_dataclass.cpython-311.pyc ADDED
Binary file (8.27 kB). View file
 
datasets/__pycache__/image_dataset.cpython-311.pyc ADDED
Binary file (4.99 kB). View file
 
datasets/audio_dataset.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import zipfile
3
+ import requests
4
+ from tqdm import tqdm
5
+ from typing import List, Tuple
6
+ import numpy as np
7
+ from torch.utils.data import Dataset
8
+ import librosa
9
+ import torch
10
+
11
+ SAMPLE_RATE = 22050
12
+ DURATION = 1.4 # second
13
+
14
+ class EmodbDataset(Dataset):
15
+ __url__ = "http://www.emodb.bilderbar.info/download/download.zip"
16
+ __labels__ = ("angry", "happy", "neutral", "sad")
17
+ __suffixes__ = {
18
+ "angry": ["Wa", "Wb", "Wc", "Wd"],
19
+ "happy": ["Fa", "Fb", "Fc", "Fd"],
20
+ "neutral": ["Na", "Nb", "Nc", "Nd"],
21
+ "sad": ["Ta", "Tb", "Tc", "Td"]
22
+ }
23
+
24
+ def __init__(self, root_path: str = './data/emodb', transform=None):
25
+ super().__init__()
26
+ self.root_path = root_path
27
+ self.audio_root_path = os.path.join(root_path, "wav")
28
+
29
+ # Ensure the dataset is downloaded
30
+ self._ensure_dataset()
31
+
32
+ ids = []
33
+ targets = []
34
+ for audio_file in os.listdir(self.audio_root_path):
35
+ f_name, ext = os.path.splitext(audio_file)
36
+ if ext != ".wav":
37
+ continue
38
+
39
+ suffix = f_name[-2:]
40
+ for label, suffixes in self.__suffixes__.items():
41
+ if suffix in suffixes:
42
+ ids.append(os.path.join(self.audio_root_path, audio_file))
43
+ targets.append(self.label2id(label))
44
+ break
45
+
46
+ self.ids = ids
47
+ self.targets = np.array(targets, dtype=np.int64)
48
+ self.transform = transform
49
+
50
+ def _ensure_dataset(self):
51
+ """
52
+ Ensures the dataset is downloaded and extracted.
53
+ """
54
+ if not os.path.isdir(self.audio_root_path):
55
+ print(f"Dataset not found at {self.audio_root_path}. Downloading...")
56
+ self._download_and_extract()
57
+
58
+ def _download_and_extract(self):
59
+ """
60
+ Downloads and extracts the dataset zip file.
61
+ """
62
+ # Ensure the root path exists
63
+ os.makedirs(self.root_path, exist_ok=True)
64
+
65
+ # Download the dataset
66
+ zip_path = os.path.join(self.root_path, "emodb.zip")
67
+ with requests.get(self.__url__, stream=True) as r:
68
+ r.raise_for_status()
69
+ total_size = int(r.headers.get("content-length", 0))
70
+ with open(zip_path, "wb") as f, tqdm(
71
+ desc="Downloading EMO-DB dataset",
72
+ total=total_size,
73
+ unit="B",
74
+ unit_scale=True,
75
+ unit_divisor=1024,
76
+ ) as bar:
77
+ for chunk in r.iter_content(chunk_size=8192):
78
+ f.write(chunk)
79
+ bar.update(len(chunk))
80
+
81
+ # Extract the dataset
82
+ print("Extracting dataset...")
83
+ with zipfile.ZipFile(zip_path, "r") as zip_ref:
84
+ zip_ref.extractall(self.root_path)
85
+
86
+ # Clean up the zip file
87
+ os.remove(zip_path)
88
+
89
+ def __len__(self):
90
+ return len(self.ids)
91
+
92
+ def __getitem__(self, idx: int) -> Tuple:
93
+ target = self.targets[idx]
94
+ audio = self.load_audio(self.ids[idx]) # Should return a numpy array
95
+
96
+ if self.transform:
97
+ audio = self.transform(audio) # Apply transform
98
+
99
+ return audio, target
100
+
101
+ @staticmethod
102
+ def id2label(idx: int) -> str:
103
+ return EmodbDataset.__labels__[idx]
104
+
105
+ @staticmethod
106
+ def label2id(label: str) -> int:
107
+ if label not in EmodbDataset.__labels__:
108
+ raise ValueError(f"Unknown label: {label}")
109
+ return EmodbDataset.__labels__.index(label)
110
+
111
+ @staticmethod
112
+ def load_audio(audio_file_path: str) -> np.ndarray:
113
+ audio, sr = librosa.load(audio_file_path, sr=SAMPLE_RATE, duration=DURATION)
114
+ assert SAMPLE_RATE == sr, "broken audio file"
115
+ # Convert numpy array to PyTorch tensor
116
+ return torch.tensor(audio, dtype=torch.float32)
117
+
118
+ @staticmethod
119
+ def get_labels() -> List[str]:
120
+ return list(EmodbDataset.__labels__)
datasets/ctc_audio_dataclass.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import zipfile
3
+ import requests
4
+ from tqdm import tqdm
5
+ from typing import List, Tuple
6
+ import numpy as np
7
+ from torch.utils.data import Dataset
8
+ import librosa
9
+ import torch
10
+ from torch.nn.utils.rnn import pad_sequence
11
+
12
+ SAMPLE_RATE = 22050
13
+ DURATION = 1.4 # seconds
14
+
15
+ class CTCEmodbDataset(Dataset):
16
+ __url__ = "http://www.emodb.bilderbar.info/download/download.zip"
17
+ __labels__ = ("angry", "happy", "neutral", "sad")
18
+ __suffixes__ = {
19
+ "angry": ["Wa", "Wb", "Wc", "Wd"],
20
+ "happy": ["Fa", "Fb", "Fc", "Fd"],
21
+ "neutral": ["Na", "Nb", "Nc", "Nd"],
22
+ "sad": ["Ta", "Tb", "Tc", "Td"]
23
+ }
24
+
25
+ def __init__(self, root_path: str = './data/emodb', transform=None):
26
+ super().__init__()
27
+ self.root_path = root_path
28
+ self.audio_root_path = os.path.join(root_path, "wav")
29
+
30
+ # Ensure the dataset is downloaded
31
+ self._ensure_dataset()
32
+
33
+ ids = []
34
+ targets = []
35
+ for audio_file in os.listdir(self.audio_root_path):
36
+ f_name, ext = os.path.splitext(audio_file)
37
+ if ext != ".wav":
38
+ continue
39
+
40
+ suffix = f_name[-2:]
41
+ for label, suffixes in self.__suffixes__.items():
42
+ if suffix in suffixes:
43
+ ids.append(os.path.join(self.audio_root_path, audio_file))
44
+ targets.append(self.label2id(label)) # Store as integers
45
+ break
46
+
47
+ self.ids = ids
48
+ self.targets = targets # Target sequences as a list of lists
49
+ self.transform = transform
50
+
51
+ def _ensure_dataset(self):
52
+ """
53
+ Ensures the dataset is downloaded and extracted.
54
+ """
55
+ if not os.path.isdir(self.audio_root_path):
56
+ print(f"Dataset not found at {self.audio_root_path}. Downloading...")
57
+ self._download_and_extract()
58
+
59
+ def _download_and_extract(self):
60
+ """
61
+ Downloads and extracts the dataset zip file.
62
+ """
63
+ os.makedirs(self.root_path, exist_ok=True)
64
+ zip_path = os.path.join(self.root_path, "emodb.zip")
65
+ with requests.get(self.__url__, stream=True) as r:
66
+ r.raise_for_status()
67
+ total_size = int(r.headers.get("content-length", 0))
68
+ with open(zip_path, "wb") as f, tqdm(
69
+ desc="Downloading EMO-DB dataset",
70
+ total=total_size,
71
+ unit="B",
72
+ unit_scale=True,
73
+ unit_divisor=1024,
74
+ ) as bar:
75
+ for chunk in r.iter_content(chunk_size=8192):
76
+ f.write(chunk)
77
+ bar.update(len(chunk))
78
+
79
+ print("Extracting dataset...")
80
+ with zipfile.ZipFile(zip_path, "r") as zip_ref:
81
+ zip_ref.extractall(self.root_path)
82
+
83
+ os.remove(zip_path)
84
+
85
+ def __len__(self):
86
+ return len(self.ids)
87
+
88
+ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, int, int]:
89
+ """
90
+ Returns:
91
+ x (torch.Tensor): Input sequence (audio features or waveform)
92
+ y (torch.Tensor): Target sequence (labels or tokenized transcription)
93
+ input_length (int): Length of input sequence
94
+ target_length (int): Length of target sequence
95
+ """
96
+ target = torch.tensor([self.targets[idx]], dtype=torch.long)
97
+ audio = self.load_audio(self.ids[idx]) # Should return a numpy array
98
+
99
+ if self.transform:
100
+ audio = self.transform(audio)
101
+
102
+ # Input length (for CTC)
103
+ input_length = audio.shape[-1] # Last dimension is the time dimension
104
+ target_length = len(target) # Length of target sequence
105
+
106
+ return audio, target, input_length, target_length
107
+
108
+ @staticmethod
109
+ def id2label(idx: int) -> str:
110
+ return CTCEmodbDataset.__labels__[idx]
111
+
112
+ @staticmethod
113
+ def label2id(label: str) -> int:
114
+ if label not in CTCEmodbDataset.__labels__:
115
+ raise ValueError(f"Unknown label: {label}")
116
+ return CTCEmodbDataset.__labels__.index(label)
117
+
118
+ @staticmethod
119
+ def load_audio(audio_file_path: str) -> torch.Tensor:
120
+ audio, sr = librosa.load(audio_file_path, sr=SAMPLE_RATE, duration=DURATION)
121
+ assert SAMPLE_RATE == sr, "broken audio file"
122
+ return torch.tensor(audio, dtype=torch.float32)
123
+
124
+ @staticmethod
125
+ def get_labels() -> List[str]:
126
+ return list(CTCEmodbDataset.__labels__)
datasets/image_dataset.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from torchvision.datasets import VisionDataset
3
+ from PIL import Image
4
+ from sklearn.model_selection import train_test_split
5
+
6
+
7
+ class CustomDataset(VisionDataset):
8
+ def __init__(self, root_path, subset="train", transform=None, target_transform=None, split_ratios=(0.7, 0.15, 0.15), seed=42):
9
+ super(CustomDataset, self).__init__(root_path, transform=transform, target_transform=target_transform)
10
+ self.root = root_path
11
+ self.subset = subset # Can be "train", "val", or "test"
12
+ self.split_ratios = split_ratios
13
+ self.seed = seed
14
+
15
+ self.classes, self.class_idx = self._find_classes()
16
+ self.samples = self._make_dataset()
17
+
18
+ def _find_classes(self):
19
+ classes = [d.name for d in os.scandir(self.root) if d.is_dir()]
20
+ classes.sort()
21
+ class_idx = {cls_name: i for i, cls_name in enumerate(classes)}
22
+ return classes, class_idx
23
+
24
+ def _make_dataset(self):
25
+ samples = []
26
+ for target_class in sorted(self.class_idx.keys()):
27
+ class_index = self.class_idx[target_class]
28
+ target_dir = os.path.join(self.root, target_class)
29
+ for root, _, fnames in sorted(os.walk(target_dir)):
30
+ for fname in sorted(fnames):
31
+ path = os.path.join(root, fname)
32
+ samples.append((path, class_index))
33
+
34
+ # Split into train, val, and test sets
35
+ train_samples, test_samples = train_test_split(
36
+ samples, test_size=1 - self.split_ratios[0], random_state=self.seed, stratify=[s[1] for s in samples]
37
+ )
38
+ val_samples, test_samples = train_test_split(
39
+ test_samples, test_size=self.split_ratios[2] / (self.split_ratios[1] + self.split_ratios[2]),
40
+ random_state=self.seed, stratify=[s[1] for s in test_samples]
41
+ )
42
+
43
+ if self.subset == "train":
44
+ return train_samples
45
+ elif self.subset == "val":
46
+ return val_samples
47
+ elif self.subset == "test":
48
+ return test_samples
49
+ else:
50
+ raise ValueError(f"Unknown subset: {self.subset}")
51
+
52
+ def __len__(self):
53
+ return len(self.samples)
54
+
55
+ def __getitem__(self, index):
56
+ path, target = self.samples[index]
57
+ img = Image.open(path).convert("RGB")
58
+
59
+ if self.transform is not None:
60
+ img = self.transform(img)
61
+
62
+ return img, target
emotion-detection ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 4f5c928446aeb2dabd215f85d3f9647ac92e7e67
encoders/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import encoders, transformer
encoders/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (235 Bytes). View file
 
encoders/__pycache__/encoders.cpython-311.pyc ADDED
Binary file (16.2 kB). View file
 
encoders/__pycache__/transformer.cpython-311.pyc ADDED
Binary file (12.3 kB). View file
 
encoders/encoders.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.optim as optim
2
+ import pytorch_lightning as pl
3
+ import timm
4
+ from torchmetrics import Accuracy, Precision, Recall, F1Score
5
+ import torch
6
+
7
+
8
+
9
+ class timm_backbones(pl.LightningModule):
10
+ """
11
+ PyTorch Lightning model for image classification using a ResNet-18 architecture.
12
+
13
+ This model uses a pre-trained ResNet-18 model and fine-tunes it for a specific number of classes.
14
+
15
+ Args:
16
+ num_classes (int, optional): The number of classes in the dataset. Defaults to 2.
17
+ optimizer_cfg (DictConfig, optional): A Hydra configuration object for the optimizer.
18
+
19
+ Methods:
20
+ forward(x): Computes the forward pass of the model.
21
+ configure_optimizers(): Configures the optimizer for the model.
22
+ training_step(batch, batch_idx): Performs a training step on the model.
23
+ validation_step(batch, batch_idx): Performs a validation step on the model.
24
+ on_validation_epoch_end(): Called at the end of each validation epoch.
25
+ test_step(batch, batch_idx): Performs a test step on the model.
26
+
27
+ Example:
28
+ model = ResNet18(num_classes=2, optimizer_cfg=cfg.model.optimizer)
29
+ trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)
30
+ trainer.test(model, dataloaders=test_dataloader)
31
+ """
32
+ def __init__(self, encoder='resnet18', num_classes=2, optimizer_cfg=None, l1_lambda=0.0):
33
+ super().__init__()
34
+
35
+ self.encoder = encoder
36
+ self.model = timm.create_model(encoder, pretrained=True)
37
+ if self.model.default_cfg["input_size"][1] == 3: # If model expects 3 channels
38
+ self.model.conv1 = torch.nn.Conv2d(
39
+ in_channels=1, # Change to single channel
40
+ out_channels=self.model.conv1.out_channels,
41
+ kernel_size=self.model.conv1.kernel_size,
42
+ stride=self.model.conv1.stride,
43
+ padding=self.model.conv1.padding,
44
+ bias=False
45
+ )
46
+
47
+ self.accuracy = Accuracy(task="multiclass", num_classes=num_classes)
48
+ self.precision = Precision(task="multiclass", num_classes=num_classes)
49
+ self.recall = Recall(task="multiclass", num_classes=num_classes)
50
+ self.f1 = F1Score(task="multiclass", num_classes=num_classes)
51
+
52
+ self.l1_lambda = l1_lambda
53
+ if hasattr(self.model, 'fc'): # For models with 'fc' as the classification layer
54
+ in_features = self.model.fc.in_features
55
+ self.model.fc = torch.nn.Linear(in_features, num_classes)
56
+ elif hasattr(self.model, 'classifier'): # For models with 'classifier'
57
+ in_features = self.model.classifier.in_features
58
+ self.model.classifier = torch.nn.Linear(in_features, num_classes)
59
+ elif hasattr(self.model, 'head'): # For models with 'head'
60
+ in_features = self.model.head.in_features
61
+ self.model.head = torch.nn.Linear(in_features, num_classes)
62
+ else:
63
+ raise ValueError(f"Unsupported model architecture for encoder: {encoder}")
64
+
65
+ if optimizer_cfg is not None:
66
+ optimizer_name = optimizer_cfg.name
67
+ optimizer_lr = optimizer_cfg.lr
68
+ optimizer_weight_decay = optimizer_cfg.weight_decay
69
+
70
+ if optimizer_name == 'Adam':
71
+ self.optimizer = optim.Adam(self.parameters(), lr=optimizer_lr, weight_decay=optimizer_weight_decay)
72
+ elif optimizer_name == 'SGD':
73
+ self.optimizer = optim.SGD(self.parameters(), lr=optimizer_lr, weight_decay=optimizer_weight_decay)
74
+ else:
75
+ raise ValueError(f"Unsupported optimizer: {optimizer_name}")
76
+ else:
77
+ self.optimizer = None
78
+
79
+ def forward(self, x):
80
+ return self.model(x)
81
+
82
+ def configure_optimizers(self):
83
+ optimizer = self.optimizer
84
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.2, patience=20, min_lr=5e-5)
85
+ return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"}
86
+
87
+ def training_step(self, batch, batch_idx):
88
+ x, y = batch
89
+ y = y.long()
90
+
91
+ # Compute predictions and loss
92
+ logits = self(x)
93
+ loss = torch.nn.functional.cross_entropy(logits, y)
94
+
95
+ # Add L1 regularization
96
+ l1_norm = sum(param.abs().sum() for param in self.parameters())
97
+ loss += self.l1_lambda * l1_norm
98
+
99
+ self.log('train_loss', loss, prog_bar=True, on_epoch=True, on_step=False, logger=True)
100
+
101
+ return loss
102
+
103
+ def validation_step(self, batch, batch_idx):
104
+ x, y = batch
105
+ y = y.long()
106
+
107
+ logits = self(x)
108
+ loss = torch.nn.functional.cross_entropy(logits, y)
109
+
110
+ preds = torch.argmax(logits, dim=1)
111
+ accuracy = self.accuracy(y, preds)
112
+ precision = self.precision(y, preds)
113
+ recall = self.recall(y, preds)
114
+ f1 = self.f1(y, preds)
115
+
116
+ self.log('val_loss', loss, prog_bar=True, on_epoch=True, on_step=True)
117
+ self.log('val_acc', accuracy, prog_bar=True, on_epoch=True, on_step=True)
118
+ self.log('val_precision', precision, prog_bar=True, on_epoch=True, on_step=True)
119
+ self.log('val_recall', recall, prog_bar=True, on_epoch=True, on_step=True)
120
+ self.log('val_f1', f1, prog_bar=True, on_epoch=True, on_step=True)
121
+
122
+ return loss
123
+
124
+ def on_validation_epoch_end(self):
125
+ avg_loss = self.trainer.logged_metrics['val_loss_epoch']
126
+ accuracy = self.trainer.logged_metrics['val_acc_epoch']
127
+
128
+ self.log('val_loss', avg_loss, prog_bar=True, on_epoch=True)
129
+ self.log('val_acc', accuracy, prog_bar=True, on_epoch=True)
130
+
131
+ return {'Average Loss:': avg_loss, 'Accuracy:': accuracy}
132
+
133
+ def test_step(self, batch, batch_idx):
134
+ x, y = batch
135
+ y = y.long()
136
+ logits = self(x)
137
+ loss = torch.nn.functional.cross_entropy(logits, y)
138
+
139
+ preds = torch.argmax(logits, dim=1)
140
+ accuracy = self.accuracy(y, preds)
141
+ precision = self.precision(y, preds)
142
+ recall = self.recall(y, preds)
143
+ f1 = self.f1(y, preds)
144
+
145
+ # Log test metrics
146
+ self.log('test_loss', loss, prog_bar=True, logger=True)
147
+ self.log('test_acc', accuracy, prog_bar=True, logger=True)
148
+ self.log('test_precision', precision, prog_bar=True, logger=True)
149
+ self.log('test_recall', recall, prog_bar=True, logger=True)
150
+ self.log('test_f1', f1, prog_bar=True, logger=True)
151
+
152
+ return {'test_loss': loss, 'test_accuracy': accuracy, 'test_precision': precision, 'test_recall': recall, 'test_f1': f1}
153
+
154
+
155
+
156
+ class CTCEncoderPL(pl.LightningModule):
157
+ def __init__(self, ctc_encoder, num_classes, optimizer_cfg):
158
+ super(CTCEncoderPL, self).__init__()
159
+ self.ctc_encoder = ctc_encoder
160
+ self.ctc_loss = torch.nn.CTCLoss(blank=0, zero_infinity=True)
161
+ self.optimizer_cfg = optimizer_cfg
162
+ self.accuracy = Accuracy(task="multiclass", num_classes=num_classes)
163
+ self.precision = Precision(task="multiclass", num_classes=num_classes)
164
+ self.recall = Recall(task="multiclass", num_classes=num_classes)
165
+ self.f1 = F1Score(task="multiclass", num_classes=num_classes)
166
+
167
+
168
+ if optimizer_cfg is not None:
169
+ optimizer_name = optimizer_cfg.name
170
+ optimizer_lr = optimizer_cfg.lr
171
+ optimizer_weight_decay = optimizer_cfg.weight_decay
172
+
173
+ if optimizer_name == 'Adam':
174
+ self.optimizer = optim.Adam(self.parameters(), lr=optimizer_lr, weight_decay=optimizer_weight_decay)
175
+ elif optimizer_name == 'SGD':
176
+ self.optimizer = optim.SGD(self.parameters(), lr=optimizer_lr, weight_decay=optimizer_weight_decay)
177
+ else:
178
+ raise ValueError(f"Unsupported optimizer: {optimizer_name}")
179
+ else:
180
+ self.optimizer = None
181
+ def forward(self, x):
182
+ return self.ctc_encoder(x)
183
+
184
+ def training_step(self, batch, batch_idx):
185
+ x, y, input_lengths, target_lengths = batch
186
+
187
+ logits, input_lengths = self.ctc_encoder(x, input_lengths)
188
+ log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
189
+ loss = self.ctc_loss(log_probs, y, input_lengths, target_lengths)
190
+ assert input_lengths.size(0) == x.size(0), f"input_lengths size ({input_lengths.size(0)}) must match batch size ({x.size(0)})"
191
+ preds = torch.argmax(log_probs, dim=-1)
192
+ self.log("train_loss", loss, on_epoch=True)
193
+ return loss
194
+
195
+ def validation_step(self, batch, batch_idx):
196
+ x, y, input_lengths, target_lengths = batch
197
+
198
+ # Compute logits and adjust input lengths
199
+ logits, input_lengths = self.ctc_encoder(x, input_lengths)
200
+ log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
201
+
202
+ # Validate input_lengths size
203
+ assert input_lengths.size(0) == logits.size(0), "Mismatch between input_lengths and batch size"
204
+
205
+ # Compute CTC loss
206
+ loss = self.ctc_loss(log_probs, y, input_lengths, target_lengths)
207
+
208
+ # Compute metrics
209
+ preds = torch.argmax(log_probs, dim=-1)
210
+ accuracy = self.accuracy(y, preds)
211
+ precision = self.precision(y, preds)
212
+ recall = self.recall(y, preds)
213
+ f1 = self.f1(y, preds)
214
+
215
+ # Log metrics
216
+ self.log('val_loss', loss, prog_bar=True, on_epoch=True, on_step=True)
217
+ self.log('val_acc', accuracy, prog_bar=True, on_epoch=True, on_step=True)
218
+ self.log('val_precision', precision, prog_bar=True, on_epoch=True, on_step=True)
219
+ self.log('val_recall', recall, prog_bar=True, on_epoch=True, on_step=True)
220
+ self.log('val_f1', f1, prog_bar=True, on_epoch=True, on_step=True)
221
+
222
+ return loss
223
+
224
+ def on_validation_epoch_end(self):
225
+ avg_loss = self.trainer.logged_metrics['val_loss_epoch']
226
+ accuracy = self.trainer.logged_metrics['val_acc_epoch']
227
+
228
+ self.log('val_loss', avg_loss, prog_bar=True, on_epoch=True)
229
+ self.log('val_acc', accuracy, prog_bar=True, on_epoch=True)
230
+
231
+ return {'Average Loss:': avg_loss, 'Accuracy:': accuracy}
232
+
233
+ def test_step(self, batch, batch_idx):
234
+ x, y, input_lengths, target_lengths = batch
235
+ logits = self(x)
236
+ log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
237
+ loss = self.ctc_loss(log_probs, y, input_lengths, target_lengths)
238
+
239
+ preds = torch.argmax(log_probs, dim=-1)
240
+ accuracy = self.accuracy(y, preds)
241
+ precision = self.precision(y, preds)
242
+ recall = self.recall(y, preds)
243
+ f1 = self.f1(y, preds)
244
+
245
+ self.log('test_loss', loss, prog_bar=True, logger=True)
246
+ self.log('test_acc', accuracy, prog_bar=True, logger=True)
247
+ self.log('test_precision', precision, prog_bar=True, logger=True)
248
+ self.log('test_recall', recall, prog_bar=True, logger=True)
249
+ self.log('test_f1', f1, prog_bar=True, logger=True)
250
+
251
+ return {'test_loss': loss, 'test_accuracy': accuracy, 'test_precision': precision, 'test_recall': recall, 'test_f1': f1}
252
+
253
+ def configure_optimizers(self):
254
+ optimizer = self.optimizer
255
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.2, patience=20, min_lr=5e-5)
256
+ return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"}
257
+
258
+ def greedy_decode(self, log_probs):
259
+ """
260
+ Perform greedy decoding to get predictions from log probabilities.
261
+ """
262
+ preds = torch.argmax(log_probs, dim=-1)
263
+ return preds
encoders/transformer.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ import torch
3
+ from torchmetrics import Accuracy, Precision, Recall, F1Score
4
+ from transformers import Wav2Vec2Model, Wav2Vec2ForSequenceClassification
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class Wav2Vec2Classifier(pl.LightningModule):
9
+ def __init__(self, num_classes, optimizer_cfg = "Adam", l1_lambda=0.0):
10
+ super(Wav2Vec2Classifier, self).__init__()
11
+ self.save_hyperparameters()
12
+
13
+ # Wav2Vec2 backbone
14
+ # self.wav2vec2 = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
15
+ self.wav2vec2 = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-large-xlsr-53")
16
+
17
+ # trying without the need to fine tune it
18
+ for param in self.wav2vec2.parameters():
19
+ param.requires_grad = False
20
+ # Classification head
21
+ self.classifier = torch.nn.Linear(self.wav2vec2.config.hidden_size, num_classes)
22
+
23
+ # Metrics
24
+ self.accuracy = Accuracy(task="multiclass", num_classes=num_classes)
25
+ self.precision = Precision(task="multiclass", num_classes=num_classes)
26
+ self.recall = Recall(task="multiclass", num_classes=num_classes)
27
+ self.f1 = F1Score(task="multiclass", num_classes=num_classes)
28
+
29
+ self.l1_lambda = l1_lambda
30
+ if optimizer_cfg is not None:
31
+ optimizer_name = optimizer_cfg.name
32
+ optimizer_lr = optimizer_cfg.lr
33
+ optimizer_weight_decay = optimizer_cfg.weight_decay
34
+
35
+ if optimizer_name == 'Adam':
36
+ self.optimizer = torch.optim.Adam(self.parameters(), lr=optimizer_lr, weight_decay=optimizer_weight_decay)
37
+ elif optimizer_name == 'SGD':
38
+ self.optimizer = torch.optim.SGD(self.parameters(), lr=optimizer_lr, weight_decay=optimizer_weight_decay)
39
+ else:
40
+ raise ValueError(f"Unsupported optimizer: {optimizer_name}")
41
+ else:
42
+ self.optimizer = None
43
+
44
+ def forward(self, x, attention_mask=None):
45
+ # Debug input shape
46
+
47
+ # Ensure input shape is [batch_size, sequence_length]
48
+ if x.dim() > 2:
49
+ x = x.squeeze(-1) # Remove unnecessary dimensions if present
50
+
51
+ # Pass through Wav2Vec2 backbone
52
+ output = self.wav2vec2(x, attention_mask=attention_mask)
53
+ x = output.last_hidden_state
54
+
55
+ # Classification head
56
+ x = torch.mean(x, dim=1) # Pooling
57
+ logits = self.classifier(x)
58
+ return logits
59
+
60
+
61
+ def training_step(self, batch, batch_idx):
62
+ x, attention_mask, y = batch
63
+
64
+ # Forward pass
65
+ logits = self(x, attention_mask=attention_mask)
66
+
67
+ # Compute loss
68
+ loss = F.cross_entropy(logits, y)
69
+
70
+ # Add L1 regularization if specified
71
+ l1_norm = sum(param.abs().sum() for param in self.parameters())
72
+ loss += self.l1_lambda * l1_norm
73
+
74
+ # Log metrics
75
+ self.log("train_loss", loss, prog_bar=True, logger=True)
76
+ return loss
77
+
78
+ def validation_step(self, batch, batch_idx):
79
+ x, attention_mask, y = batch # Unpack batch
80
+
81
+ # Forward pass
82
+ logits = self(x, attention_mask=attention_mask)
83
+
84
+
85
+ # Compute loss and metrics
86
+ loss = F.cross_entropy(logits, y)
87
+ preds = torch.argmax(logits, dim=1)
88
+ accuracy = self.accuracy(preds, y)
89
+ precision = self.precision(preds, y)
90
+ recall = self.recall(preds, y)
91
+ f1 = self.f1(preds, y)
92
+
93
+ # Log metrics
94
+ self.log("val_loss", loss, prog_bar=True, logger=True)
95
+ self.log("val_acc", accuracy, prog_bar=True, logger=True)
96
+ self.log("val_precision", precision, prog_bar=True, logger=True)
97
+ self.log("val_recall", recall, prog_bar=True, logger=True)
98
+ self.log("val_f1", f1, prog_bar=True, logger=True)
99
+ return loss
100
+
101
+ def test_step(self, batch, batch_idx):
102
+ x, attention_mask, y = batch # Unpack batch
103
+
104
+ # Forward pass
105
+ logits = self(x, attention_mask=attention_mask)
106
+
107
+
108
+ # Compute loss and metrics
109
+ loss = F.cross_entropy(logits, y)
110
+ preds = torch.argmax(logits, dim=1)
111
+ accuracy = self.accuracy(preds, y)
112
+ precision = self.precision(preds, y)
113
+ recall = self.recall(preds, y)
114
+ f1 = self.f1(preds, y)
115
+
116
+ # Log metrics
117
+ self.log("test_loss", loss, prog_bar=True, logger=True)
118
+ self.log("test_acc", accuracy, prog_bar=True, logger=True)
119
+ self.log("test_precision", precision, prog_bar=True, logger=True)
120
+ self.log("test_recall", recall, prog_bar=True, logger=True)
121
+ self.log("test_f1", f1, prog_bar=True, logger=True)
122
+
123
+ return {"test_loss": loss, "test_accuracy": accuracy}
124
+
125
+ def configure_optimizers(self):
126
+ optimizer = self.optimizer
127
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.2, patience=20, min_lr=5e-5)
128
+ return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"}
129
+
130
+
131
+ class Wav2Vec2EmotionClassifier(pl.LightningModule):
132
+ def __init__(self, num_classes, learning_rate=1e-4, freeze_base=False, optimizer_cfg="AdamW"):
133
+ super(Wav2Vec2EmotionClassifier, self).__init__()
134
+ self.save_hyperparameters()
135
+
136
+ # Load a pre-trained Wav2Vec2 model optimized for emotion recognition
137
+ self.model = Wav2Vec2ForSequenceClassification.from_pretrained(
138
+ "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim",
139
+ num_labels=num_classes,
140
+ )
141
+ # Optionally freeze the Wav2Vec2 base layers
142
+ if freeze_base:
143
+ for param in self.model.wav2vec2.parameters():
144
+ param.requires_grad = False
145
+
146
+ # Metrics
147
+ self.accuracy = Accuracy(task="multiclass", num_classes=num_classes)
148
+ self.precision = Precision(task="multiclass", num_classes=num_classes)
149
+ self.recall = Recall(task="multiclass", num_classes=num_classes)
150
+ self.f1 = F1Score(task="multiclass", num_classes=num_classes)
151
+
152
+ self.learning_rate = learning_rate
153
+ if optimizer_cfg is not None:
154
+ optimizer_name = optimizer_cfg['name']
155
+ optimizer_lr = optimizer_cfg['lr']
156
+ optimizer_weight_decay = optimizer_cfg['weight_decay']
157
+
158
+ if optimizer_name == 'Adam':
159
+ self.optimizer = torch.optim.Adam(self.parameters(), lr=optimizer_lr, weight_decay=optimizer_weight_decay)
160
+ elif optimizer_name == 'SGD':
161
+ self.optimizer = torch.optim.SGD(self.parameters(), lr=optimizer_lr, weight_decay=optimizer_weight_decay)
162
+ elif optimizer_name == 'AdamW':
163
+ self.optimizer = torch.optim.AdamW(self.parameters(), lr=optimizer_lr, weight_decay=optimizer_weight_decay)
164
+ else:
165
+ raise ValueError(f"Unsupported optimizer: {optimizer_name}")
166
+ else:
167
+ self.optimizer = None
168
+
169
+ def forward(self, x, attention_mask=None):
170
+ return self.model(x, attention_mask=attention_mask).logits
171
+
172
+ def training_step(self, batch, batch_idx):
173
+ x, attention_mask, y = batch
174
+
175
+ # Forward pass
176
+ logits = self(x, attention_mask=attention_mask)
177
+
178
+ # Compute loss
179
+ loss = F.cross_entropy(logits, y)
180
+
181
+ # Log training loss
182
+ self.log("train_loss", loss, prog_bar=True, logger=True)
183
+ return loss
184
+
185
+ def validation_step(self, batch, batch_idx):
186
+ x, attention_mask, y = batch
187
+
188
+ # Forward pass
189
+ logits = self(x, attention_mask=attention_mask)
190
+
191
+ # Compute loss and metrics
192
+ loss = F.cross_entropy(logits, y)
193
+ preds = torch.argmax(logits, dim=1)
194
+
195
+ accuracy = self.accuracy(preds, y)
196
+ precision = self.precision(preds, y)
197
+ recall = self.recall(preds, y)
198
+ f1 = self.f1(preds, y)
199
+
200
+ # Log metrics
201
+ self.log("val_loss", loss, prog_bar=True, logger=True)
202
+ self.log("val_acc", accuracy, prog_bar=True, logger=True)
203
+ self.log("val_precision", precision, prog_bar=True, logger=True)
204
+ self.log("val_recall", recall, prog_bar=True, logger=True)
205
+ self.log("val_f1", f1, prog_bar=True, logger=True)
206
+ return loss
207
+
208
+ def test_step(self, batch, batch_idx):
209
+ x, attention_mask, y = batch
210
+
211
+ # Forward pass
212
+ logits = self(x, attention_mask=attention_mask)
213
+
214
+ # Compute loss and metrics
215
+ loss = F.cross_entropy(logits, y)
216
+ preds = torch.argmax(logits, dim=1)
217
+ accuracy = self.accuracy(preds, y)
218
+ precision = self.precision(preds, y)
219
+ recall = self.recall(preds, y)
220
+ f1 = self.f1(preds, y)
221
+
222
+ # Log metrics
223
+ self.log("test_loss", loss, prog_bar=True, logger=True)
224
+ self.log("test_acc", accuracy, prog_bar=True, logger=True)
225
+ self.log("test_precision", precision, prog_bar=True, logger=True)
226
+ self.log("test_recall", recall, prog_bar=True, logger=True)
227
+ self.log("test_f1", f1, prog_bar=True, logger=True)
228
+
229
+ return {"test_loss": loss, "test_accuracy": accuracy}
230
+ def configure_optimizers(self):
231
+ optimizer = self.optimizer
232
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.2, patience=20, min_lr=5e-5)
233
+ return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"}
model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0116227115053abebb4951ef1bd0bd25750797f2bfe98d74df152dc2289295d6
3
+ size 658272386
models/CTCencoder.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class CTCEncoder(nn.Module):
5
+ def __init__(self, num_classes, cnn_output_dim=256, rnn_hidden_dim=256, rnn_layers=3):
6
+ """
7
+ CTC Encoder with a CNN feature extractor and LSTM for sequence modeling.
8
+
9
+ Args:
10
+ num_classes (int): Number of output classes for the model.
11
+ cnn_output_dim (int): Number of output channels from the CNN.
12
+ rnn_hidden_dim (int): Hidden size of the LSTM.
13
+ rnn_layers (int): Number of layers in the LSTM.
14
+ """
15
+ super(CTCEncoder, self).__init__()
16
+
17
+ # CNN Feature Extractor
18
+ self.feature_extractor = nn.Sequential(
19
+ nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
20
+ nn.ReLU(),
21
+ nn.MaxPool2d(kernel_size=2, stride=2), # Down-sample by 2
22
+ nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
23
+ nn.ReLU(),
24
+ nn.MaxPool2d(kernel_size=2, stride=2), # Down-sample by another 2
25
+ nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
26
+ nn.ReLU(),
27
+ nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
28
+ nn.ReLU(),
29
+ nn.AdaptiveAvgPool2d((1, None)) # Ensure output height is 1
30
+ )
31
+
32
+ # Bidirectional LSTM
33
+ self.rnn_hidden_dim = rnn_hidden_dim
34
+ self.rnn_layers = rnn_layers
35
+ self.cnn_output_dim = cnn_output_dim
36
+
37
+ self.rnn = nn.LSTM(
38
+ input_size=cnn_output_dim, # Output channels from CNN
39
+ hidden_size=rnn_hidden_dim,
40
+ num_layers=rnn_layers,
41
+ batch_first=True,
42
+ bidirectional=True
43
+ )
44
+
45
+ # Fully connected layer
46
+ self.fc = nn.Linear(rnn_hidden_dim * 2, num_classes)
47
+
48
+ def compute_input_lengths(self, input_lengths):
49
+ """
50
+ Adjusts input lengths based on the CNN's down-sampling operations.
51
+
52
+ Args:
53
+ input_lengths (torch.Tensor): Original input lengths.
54
+
55
+ Returns:
56
+ torch.Tensor: Adjusted input lengths.
57
+ """
58
+ # Account for down-sampling by MaxPool layers (factor of 2 for each MaxPool)
59
+ input_lengths = input_lengths // 2 # First MaxPool
60
+ input_lengths = input_lengths // 2 # Second MaxPool
61
+ input_lengths = input_lengths // 2 # Third pooling layer or additional down-sampling
62
+ return input_lengths
63
+
64
+ def forward(self, x, input_lengths):
65
+ """
66
+ Forward pass through the encoder.
67
+
68
+ Args:
69
+ x (torch.Tensor): Input tensor of shape [B, 1, H, W].
70
+ input_lengths (torch.Tensor): Lengths of the sequences in the batch.
71
+
72
+ Returns:
73
+ torch.Tensor: Logits of shape [B, T, num_classes].
74
+ torch.Tensor: Adjusted input lengths.
75
+ """
76
+ # Feature extraction
77
+ x = self.feature_extractor(x) # [Batch_Size, Channels, Height, Width]
78
+ print(f"Shape after CNN: {x.shape}") # Debug the shape
79
+
80
+ # Reshape for LSTM
81
+ x = x.squeeze(2).permute(0, 2, 1) # [Batch_Size, Sequence_Length, Features]
82
+ assert x.size(-1) == 256, f"Expected last dimension to be 256, but got {x.size(-1)}"
83
+
84
+ # Adjust input lengths
85
+ input_lengths = self.compute_input_lengths(input_lengths)
86
+ assert input_lengths.size(0) == x.size(0), f"input_lengths size ({input_lengths.size(0)}) must match batch size ({x.size(0)})"
87
+
88
+ # Pass through LSTM
89
+ x, _ = self.rnn(x) # [Batch_Size, Sequence_Length, 2 * Hidden_Dim]
90
+
91
+ # Fully connected output
92
+ x = self.fc(x) # [Batch_Size, Sequence_Length, Num_Classes]
93
+ return x, input_lengths
models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import CTCencoder
models/__pycache__/CTCencoder.cpython-311.pyc ADDED
Binary file (4.67 kB). View file
 
models/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (207 Bytes). View file
 
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio
2
+ librosa
3
+ torch
4
+ transformers
5
+ numpy
statics/style.css ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ #audio_input {
2
+ border: 2px solid #4CAF50;
3
+ border-radius: 10px;
4
+ }
5
+ #submit_button {
6
+ background-color: #4CAF50;
7
+ color: white;
8
+ border-radius: 5px;
9
+ }
upload_model.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import HfApi, HfFolder, Repository
2
+
3
+ api = HfApi()
4
+ repo_url = api.create_repo(repo_id="saeedbenadeeb/emotion-detection", exist_ok=True)
5
+
6
+ repo = Repository(local_dir="emotion-detection", clone_from=repo_url)
7
+ repo.git_pull()
8
+
9
+ # Copy model files to the repo directory
10
+ import shutil
11
+ shutil.copy("model.pth", "emotion-detection")
12
+
13
+ # Add files and push
14
+ repo.push_to_hub(commit_message="Initial model upload")
utils/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from . import random_split
2
+ from . import helper_functions
utils/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (263 Bytes). View file
 
utils/__pycache__/helper_functions.cpython-311.pyc ADDED
Binary file (3.61 kB). View file
 
utils/__pycache__/random_split.cpython-311.pyc ADDED
Binary file (2.39 kB). View file
 
utils/helper_functions.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import os
4
+ import shutil
5
+ def normalize_ratios(ratios):
6
+ total = sum(ratios)
7
+ return [r / total for r in ratios]
8
+
9
+ from torch.nn.utils.rnn import pad_sequence
10
+
11
+ def collate_fn_transformer(batch):
12
+ """
13
+ Custom collate function to handle variable-length raw waveform inputs.
14
+ Args:
15
+ batch: List of tuples (tensor, label), where tensor has shape [sequence_length].
16
+ Returns:
17
+ padded_waveforms: Padded tensor of shape [batch_size, max_seq_len].
18
+ attention_mask: Attention mask for padded sequences.
19
+ labels: Tensor of shape [batch_size].
20
+ """
21
+ # Separate waveforms and labels
22
+ waveforms, labels = zip(*batch)
23
+
24
+ # Ensure waveforms are 1D tensors
25
+ waveforms = [torch.tensor(waveform).squeeze() for waveform in waveforms]
26
+
27
+ # Pad sequences to the same length
28
+ padded_waveforms = pad_sequence(waveforms, batch_first=True) # [batch_size, max_seq_len]
29
+
30
+
31
+ # Create attention mask
32
+ attention_mask = (padded_waveforms != 0).long() # Mask for non-padded values
33
+ # In the training loop or DataLoader debug
34
+
35
+
36
+ # Convert labels to a tensor
37
+ labels = torch.tensor(labels, dtype=torch.long)
38
+
39
+ return padded_waveforms, attention_mask, labels
40
+
41
+ def collate_fn(batch):
42
+ inputs, targets, input_lengths, target_lengths = zip(*batch)
43
+ inputs = torch.stack(inputs) # Convert list of tensors to a batch tensor
44
+ targets = torch.cat(targets) # Flatten target sequences
45
+ input_lengths = torch.tensor(input_lengths, dtype=torch.long)
46
+ target_lengths = torch.tensor(target_lengths, dtype=torch.long)
47
+ return inputs, targets, input_lengths, target_lengths
48
+
49
+
50
+
51
+ def save_test_data(test_dataset, dataset, save_dir):
52
+
53
+ if os.path.exists(save_dir):
54
+ shutil.rmtree(save_dir) # Delete the existing directory and its contents
55
+ print(f"Existing test data directory '{save_dir}' removed.")
56
+
57
+ os.makedirs(save_dir, exist_ok=True)
58
+
59
+ for idx in test_dataset.indices:
60
+ audio_file_path = dataset.audio_files[idx] # Assuming dataset has `audio_files` attribute
61
+ label = dataset.labels[idx] # Assuming dataset has `labels` attribute
62
+
63
+ # Create a directory for the label if it doesn't exist
64
+ label_dir = os.path.join(save_dir, str(label))
65
+ os.makedirs(label_dir, exist_ok=True)
66
+
67
+ # Copy the audio file to the label directory
68
+ shutil.copy(audio_file_path, label_dir)
69
+
70
+ print(f"Test data saved in {save_dir}")
utils/random_split.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import torch
3
+ from torch.utils.data import Subset
4
+ from sklearn.model_selection import train_test_split
5
+ from utils.helper_functions import normalize_ratios
6
+
7
+ def stratified_random_split(ds: torch.utils.data.Dataset, parts: List[float], targets: List[int]) -> List[torch.utils.data.Dataset]:
8
+ """
9
+ Perform a stratified random split on the dataset.
10
+
11
+ Args:
12
+ ds: PyTorch dataset to split.
13
+ parts: List of proportions that sum to 1.
14
+ targets: List of labels corresponding to dataset samples.
15
+
16
+ Returns:
17
+ List of PyTorch datasets corresponding to the splits.
18
+ """
19
+ total_length = len(ds)
20
+
21
+ # Normalize ratios
22
+ parts = normalize_ratios(parts)
23
+
24
+ lengths = list(map(lambda p: int(p * total_length), parts))
25
+ left_over = total_length - sum(lengths)
26
+ lengths[0] += left_over # Adjust first split to account for leftover
27
+
28
+ indices = list(range(total_length))
29
+ train_indices, temp_indices, _, temp_targets = train_test_split(
30
+ indices, targets, test_size=(1 - parts[0]), stratify=targets, random_state=42
31
+ )
32
+ val_size = parts[1] / (parts[1] + parts[2])
33
+ val_indices, test_indices, _, _ = train_test_split(
34
+ temp_indices, temp_targets, test_size=(1 - val_size), stratify=temp_targets, random_state=42
35
+ )
36
+
37
+ return [Subset(ds, train_indices), Subset(ds, val_indices), Subset(ds, test_indices)]