Update hyperparams.yalm
Browse files- hyperparams.yalm +18 -17
hyperparams.yalm
CHANGED
@@ -7,7 +7,7 @@ __set_seed: !apply:torch.manual_seed [1234]
|
|
7 |
|
8 |
skip_training: True
|
9 |
|
10 |
-
output_folder:
|
11 |
label_encoder_file: !ref <output_folder>/label_encoder.txt
|
12 |
|
13 |
train_log: !ref <output_folder>/train_log.txt
|
@@ -17,6 +17,8 @@ save_folder: !ref <output_folder>/save
|
|
17 |
|
18 |
wav2vec2_hub: microsoft/wavlm-base-plus-sv
|
19 |
|
|
|
|
|
20 |
wav2vec2_folder: !ref <save_folder>/wav2vec2_checkpoint
|
21 |
|
22 |
# Feature parameters
|
@@ -52,7 +54,7 @@ encoder_dims: 768
|
|
52 |
n_classes: 5
|
53 |
|
54 |
# Wav2vec2 encoder
|
55 |
-
|
56 |
source: !ref <wav2vec2_hub>
|
57 |
output_norm: True
|
58 |
freeze: !ref <freeze_wav2vec2>
|
@@ -60,25 +62,20 @@ embedding_model: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWa
|
|
60 |
save_path: !ref <wav2vec2_folder>
|
61 |
output_all_hiddens: True
|
62 |
|
63 |
-
|
64 |
avg_pool: !new:speechbrain.nnet.pooling.StatisticsPooling
|
65 |
return_std: False
|
66 |
|
67 |
-
|
68 |
-
classifier: !new:speechbrain.nnet.linear.Linear
|
69 |
input_size: !ref <encoder_dims>
|
70 |
n_neurons: !ref <n_classes>
|
71 |
bias: False
|
72 |
|
73 |
-
|
74 |
log_softmax: !new:speechbrain.nnet.activations.Softmax
|
75 |
apply_log: True
|
76 |
|
77 |
-
|
78 |
opt_class: !name:torch.optim.Adam
|
79 |
lr: !ref <lr>
|
80 |
|
81 |
-
|
82 |
wav2vec2_opt_class: !name:torch.optim.Adam
|
83 |
lr: !ref <lr_wav2vec2>
|
84 |
|
@@ -88,41 +85,45 @@ epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
|
|
88 |
# Functions that compute the statistics to track during the validation step.
|
89 |
accuracy_computer: !name:speechbrain.utils.Accuracy.AccuracyStats
|
90 |
|
91 |
-
|
92 |
compute_cost: !name:speechbrain.nnet.losses.nll_loss
|
93 |
|
94 |
-
|
95 |
error_stats: !name:speechbrain.utils.metric_stats.MetricStats
|
96 |
metric: !name:speechbrain.nnet.losses.classification_error
|
97 |
reduction: batch
|
98 |
-
|
99 |
modules:
|
100 |
wav2vec2: !ref <wav2vec2>
|
101 |
label_lin: !ref <label_lin>
|
102 |
|
103 |
-
|
104 |
model: !new:torch.nn.ModuleList
|
105 |
- [!ref <label_lin>]
|
106 |
|
107 |
-
|
108 |
lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler
|
109 |
initial_value: !ref <lr>
|
110 |
improvement_threshold: 0.0025
|
111 |
annealing_factor: 0.9
|
112 |
patient: 0
|
113 |
|
114 |
-
|
115 |
lr_annealing_wav2vec2: !new:speechbrain.nnet.schedulers.NewBobScheduler
|
116 |
initial_value: !ref <lr_wav2vec2>
|
117 |
improvement_threshold: 0.0025
|
118 |
annealing_factor: 0.9
|
119 |
|
120 |
-
|
121 |
checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
|
122 |
checkpoints_dir: !ref <save_folder>
|
123 |
recoverables:
|
124 |
-
model: !ref <
|
125 |
-
wav2vec2: !ref <
|
126 |
lr_annealing_output: !ref <lr_annealing>
|
127 |
lr_annealing_wav2vec2: !ref <lr_annealing_wav2vec2>
|
128 |
counter: !ref <epoch_counter>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
skip_training: True
|
9 |
|
10 |
+
output_folder: output_folder_wavlm_base_full_data
|
11 |
label_encoder_file: !ref <output_folder>/label_encoder.txt
|
12 |
|
13 |
train_log: !ref <output_folder>/train_log.txt
|
|
|
17 |
|
18 |
wav2vec2_hub: microsoft/wavlm-base-plus-sv
|
19 |
|
20 |
+
pretrained_path: Porjaz/wavlm-base-emo-fi
|
21 |
+
|
22 |
wav2vec2_folder: !ref <save_folder>/wav2vec2_checkpoint
|
23 |
|
24 |
# Feature parameters
|
|
|
54 |
n_classes: 5
|
55 |
|
56 |
# Wav2vec2 encoder
|
57 |
+
wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
|
58 |
source: !ref <wav2vec2_hub>
|
59 |
output_norm: True
|
60 |
freeze: !ref <freeze_wav2vec2>
|
|
|
62 |
save_path: !ref <wav2vec2_folder>
|
63 |
output_all_hiddens: True
|
64 |
|
|
|
65 |
avg_pool: !new:speechbrain.nnet.pooling.StatisticsPooling
|
66 |
return_std: False
|
67 |
|
68 |
+
label_lin: !new:speechbrain.nnet.linear.Linear
|
|
|
69 |
input_size: !ref <encoder_dims>
|
70 |
n_neurons: !ref <n_classes>
|
71 |
bias: False
|
72 |
|
|
|
73 |
log_softmax: !new:speechbrain.nnet.activations.Softmax
|
74 |
apply_log: True
|
75 |
|
|
|
76 |
opt_class: !name:torch.optim.Adam
|
77 |
lr: !ref <lr>
|
78 |
|
|
|
79 |
wav2vec2_opt_class: !name:torch.optim.Adam
|
80 |
lr: !ref <lr_wav2vec2>
|
81 |
|
|
|
85 |
# Functions that compute the statistics to track during the validation step.
|
86 |
accuracy_computer: !name:speechbrain.utils.Accuracy.AccuracyStats
|
87 |
|
|
|
88 |
compute_cost: !name:speechbrain.nnet.losses.nll_loss
|
89 |
|
|
|
90 |
error_stats: !name:speechbrain.utils.metric_stats.MetricStats
|
91 |
metric: !name:speechbrain.nnet.losses.classification_error
|
92 |
reduction: batch
|
|
|
93 |
modules:
|
94 |
wav2vec2: !ref <wav2vec2>
|
95 |
label_lin: !ref <label_lin>
|
96 |
|
|
|
97 |
model: !new:torch.nn.ModuleList
|
98 |
- [!ref <label_lin>]
|
99 |
|
|
|
100 |
lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler
|
101 |
initial_value: !ref <lr>
|
102 |
improvement_threshold: 0.0025
|
103 |
annealing_factor: 0.9
|
104 |
patient: 0
|
105 |
|
|
|
106 |
lr_annealing_wav2vec2: !new:speechbrain.nnet.schedulers.NewBobScheduler
|
107 |
initial_value: !ref <lr_wav2vec2>
|
108 |
improvement_threshold: 0.0025
|
109 |
annealing_factor: 0.9
|
110 |
|
|
|
111 |
checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
|
112 |
checkpoints_dir: !ref <save_folder>
|
113 |
recoverables:
|
114 |
+
model: !ref <model>
|
115 |
+
wav2vec2: !ref <wav2vec2>
|
116 |
lr_annealing_output: !ref <lr_annealing>
|
117 |
lr_annealing_wav2vec2: !ref <lr_annealing_wav2vec2>
|
118 |
counter: !ref <epoch_counter>
|
119 |
+
|
120 |
+
|
121 |
+
pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer
|
122 |
+
loadables:
|
123 |
+
wav2vec2: !ref <wav2vec2>
|
124 |
+
model: !ref <model>
|
125 |
+
label_encoder: !ref <label_encoder>
|
126 |
+
paths:
|
127 |
+
wav2vec2: !ref <pretrained_path>/wav2vec2.ckpt
|
128 |
+
model: !ref <pretrained_path>/model.ckpt
|
129 |
+
label_encoder: !ref <pretrained_path>/label_encoder.txt
|