File size: 3,273 Bytes
4e28017
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# Generated 2022-01-19 from:
# /scratch/elec/t405-puhe/p/porjazd1/Metadata_Classification/TCN/asr_topic_speechbrain/mgb_asr/hyperparams.yaml
# yamllint disable
# Seed needs to be set at top of yaml, before objects with parameters are made
seed: 1234
__set_seed: !apply:torch.manual_seed [1234]

skip_training: True

output_folder: output_folder_wavlm_base
label_encoder_file: !ref <output_folder>/label_encoder.txt

train_log: !ref <output_folder>/train_log.txt
train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
    save_file: !ref <output_folder>/train_log.txt
save_folder: !ref <output_folder>/save

wav2vec2_hub: microsoft/wavlm-base-plus-sv

wav2vec2_folder: !ref <save_folder>/wav2vec2_checkpoint

# Feature parameters
sample_rate: 22050
new_sample_rate: 16000
window_size: 25
n_mfcc: 23

# Training params
n_epochs: 28
stopping_factor: 10

dataloader_options:
    batch_size: 10
    shuffle: false

test_dataloader_options:
    batch_size: 1
    shuffle: false

lr: 0.0001
lr_wav2vec2: 0.00001

#freeze all wav2vec2
freeze_wav2vec2: False
#set to true to freeze the CONV part of the wav2vec2 model
# We see an improvement of 2% with freezing CNNs
freeze_wav2vec2_conv: True

label_encoder: !new:speechbrain.dataio.encoder.CategoricalEncoder

encoder_dims: 768
n_classes: 5

# Wav2vec2 encoder
wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
    source: !ref <wav2vec2_hub>
    output_norm: True
    freeze: !ref <freeze_wav2vec2>
    freeze_feature_extractor: !ref <freeze_wav2vec2_conv>
    save_path: !ref <wav2vec2_folder>
    output_all_hiddens: True


avg_pool: !new:speechbrain.nnet.pooling.StatisticsPooling
    return_std: False


label_lin: !new:speechbrain.nnet.linear.Linear
    input_size: !ref <encoder_dims>
    n_neurons: !ref <n_classes>
    bias: False


log_softmax: !new:speechbrain.nnet.activations.Softmax
    apply_log: True


opt_class: !name:torch.optim.Adam
    lr: !ref <lr>


wav2vec2_opt_class: !name:torch.optim.Adam
    lr: !ref <lr_wav2vec2>

epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
    limit: !ref <n_epochs>

# Functions that compute the statistics to track during the validation step.
accuracy_computer: !name:speechbrain.utils.Accuracy.AccuracyStats


compute_cost: !name:speechbrain.nnet.losses.nll_loss


error_stats: !name:speechbrain.utils.metric_stats.MetricStats
    metric: !name:speechbrain.nnet.losses.classification_error
        reduction: batch

modules:
    wav2vec2: !ref <wav2vec2>
    label_lin: !ref <label_lin>


model: !new:torch.nn.ModuleList
      - [!ref <label_lin>]


lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler
    initial_value: !ref <lr>
    improvement_threshold: 0.0025
    annealing_factor: 0.9
    patient: 0


lr_annealing_wav2vec2: !new:speechbrain.nnet.schedulers.NewBobScheduler
    initial_value: !ref <lr_wav2vec2>
    improvement_threshold: 0.0025
    annealing_factor: 0.9


checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
    checkpoints_dir: !ref <save_folder>
    recoverables:
        model: !ref <model>
        wav2vec2: !ref <wav2vec2>
        lr_annealing_output: !ref <lr_annealing>
        lr_annealing_wav2vec2: !ref <lr_annealing_wav2vec2>
        counter: !ref <epoch_counter>