File size: 3,293 Bytes
4e28017 e2e430e 4e28017 e2e430e 4e28017 e2e430e 4e28017 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
# Generated 2022-01-19 from:
# /scratch/elec/t405-puhe/p/porjazd1/Metadata_Classification/TCN/asr_topic_speechbrain/mgb_asr/hyperparams.yaml
# yamllint disable
# Seed needs to be set at top of yaml, before objects with parameters are made
seed: 1234
__set_seed: !apply:torch.manual_seed [1234]
skip_training: True
output_folder: output_folder_wavlm_base
label_encoder_file: !ref <output_folder>/label_encoder.txt
train_log: !ref <output_folder>/train_log.txt
train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
save_file: !ref <output_folder>/train_log.txt
save_folder: !ref <output_folder>/save
wav2vec2_hub: microsoft/wavlm-base-plus-sv
wav2vec2_folder: !ref <save_folder>/wav2vec2_checkpoint
# Feature parameters
sample_rate: 22050
new_sample_rate: 16000
window_size: 25
n_mfcc: 23
# Training params
n_epochs: 28
stopping_factor: 10
dataloader_options:
batch_size: 10
shuffle: false
test_dataloader_options:
batch_size: 1
shuffle: false
lr: 0.0001
lr_wav2vec2: 0.00001
#freeze all wav2vec2
freeze_wav2vec2: False
#set to true to freeze the CONV part of the wav2vec2 model
# We see an improvement of 2% with freezing CNNs
freeze_wav2vec2_conv: True
label_encoder: !new:speechbrain.dataio.encoder.CategoricalEncoder
encoder_dims: 768
n_classes: 5
# Wav2vec2 encoder
embedding_model: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
source: !ref <wav2vec2_hub>
output_norm: True
freeze: !ref <freeze_wav2vec2>
freeze_feature_extractor: !ref <freeze_wav2vec2_conv>
save_path: !ref <wav2vec2_folder>
output_all_hiddens: True
avg_pool: !new:speechbrain.nnet.pooling.StatisticsPooling
return_std: False
classifier: !new:speechbrain.nnet.linear.Linear
input_size: !ref <encoder_dims>
n_neurons: !ref <n_classes>
bias: False
log_softmax: !new:speechbrain.nnet.activations.Softmax
apply_log: True
opt_class: !name:torch.optim.Adam
lr: !ref <lr>
wav2vec2_opt_class: !name:torch.optim.Adam
lr: !ref <lr_wav2vec2>
epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
limit: !ref <n_epochs>
# Functions that compute the statistics to track during the validation step.
accuracy_computer: !name:speechbrain.utils.Accuracy.AccuracyStats
compute_cost: !name:speechbrain.nnet.losses.nll_loss
error_stats: !name:speechbrain.utils.metric_stats.MetricStats
metric: !name:speechbrain.nnet.losses.classification_error
reduction: batch
modules:
wav2vec2: !ref <wav2vec2>
label_lin: !ref <label_lin>
model: !new:torch.nn.ModuleList
- [!ref <label_lin>]
lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler
initial_value: !ref <lr>
improvement_threshold: 0.0025
annealing_factor: 0.9
patient: 0
lr_annealing_wav2vec2: !new:speechbrain.nnet.schedulers.NewBobScheduler
initial_value: !ref <lr_wav2vec2>
improvement_threshold: 0.0025
annealing_factor: 0.9
checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
checkpoints_dir: !ref <save_folder>
recoverables:
model: !ref <classifier>
wav2vec2: !ref <embedding_model>
lr_annealing_output: !ref <lr_annealing>
lr_annealing_wav2vec2: !ref <lr_annealing_wav2vec2>
counter: !ref <epoch_counter>
|