imvladikon commited on
Commit
f6079f4
1 Parent(s): 8b7806c
README.md ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - he
4
+ tags:
5
+ - automatic-speech-recognition
6
+ - robust-speech-event
7
+ - he
8
+ - generated_from_trainer
9
+ model-index:
10
+ - name: wav2vec2-xls-r-300m-hebrew
11
+ results: []
12
+ ---
13
+
14
+ <!-- This model card has been generated automatically according to the information the Trainer had access to. You
15
+ should probably proofread and complete it, then remove this comment. -->
16
+
17
+ # wav2vec2-xls-r-300m-hebrew
18
+
19
+ This model is a fine-tuned version of [facebook/wav2vec2-xls-r-300m](https://huggingface.co/facebook/wav2vec2-xls-r-300m) on the private dataset with stats:
20
+
21
+ | split |size | n_samples | duration(hrs)| |
22
+ |---|---|---|---|---|
23
+ |train|4.19gb| 20306 | 28 | |
24
+ |dev |1.05gb| 5076 | 7 | |
25
+
26
+
27
+ It achieves the following results on the evaluation set:
28
+ - Loss: 0.5438
29
+ - Wer: 0.1773
30
+
31
+ ## Model description
32
+
33
+ More information needed
34
+
35
+ ## Intended uses & limitations
36
+
37
+ More information needed
38
+
39
+ ## Training and evaluation data
40
+
41
+ More information needed
42
+
43
+ ## Training procedure
44
+
45
+ ### Training hyperparameters
46
+
47
+ The following hyperparameters were used during training:
48
+ - learning_rate: 0.0003
49
+ - train_batch_size: 8
50
+ - eval_batch_size: 8
51
+ - seed: 42
52
+ - distributed_type: multi-GPU
53
+ - num_devices: 2
54
+ - gradient_accumulation_steps: 4
55
+ - total_train_batch_size: 64
56
+ - total_eval_batch_size: 16
57
+ - optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
58
+ - lr_scheduler_type: linear
59
+ - lr_scheduler_warmup_steps: 1000
60
+ - num_epochs: 100.0
61
+ - mixed_precision_training: Native AMP
62
+
63
+ ### Training results
64
+
65
+ | Training Loss | Epoch | Step | Validation Loss | Wer |
66
+ |:-------------:|:-----:|:-----:|:---------------:|:------:|
67
+ | No log | 3.15 | 1000 | 0.5203 | 0.4333 |
68
+ | 1.4284 | 6.31 | 2000 | 0.4816 | 0.3951 |
69
+ | 1.4284 | 9.46 | 3000 | 0.4315 | 0.3546 |
70
+ | 1.283 | 12.62 | 4000 | 0.4278 | 0.3404 |
71
+ | 1.283 | 15.77 | 5000 | 0.4090 | 0.3054 |
72
+ | 1.1777 | 18.93 | 6000 | 0.3893 | 0.3006 |
73
+ | 1.1777 | 22.08 | 7000 | 0.3968 | 0.2857 |
74
+ | 1.0994 | 25.24 | 8000 | 0.3892 | 0.2751 |
75
+ | 1.0994 | 28.39 | 9000 | 0.4061 | 0.2690 |
76
+ | 1.0323 | 31.54 | 10000 | 0.4114 | 0.2507 |
77
+ | 1.0323 | 34.7 | 11000 | 0.4021 | 0.2508 |
78
+ | 0.9623 | 37.85 | 12000 | 0.4032 | 0.2378 |
79
+ | 0.9623 | 41.01 | 13000 | 0.4148 | 0.2374 |
80
+ | 0.9077 | 44.16 | 14000 | 0.4350 | 0.2323 |
81
+ | 0.9077 | 47.32 | 15000 | 0.4515 | 0.2246 |
82
+ | 0.8573 | 50.47 | 16000 | 0.4474 | 0.2180 |
83
+ | 0.8573 | 53.63 | 17000 | 0.4649 | 0.2171 |
84
+ | 0.8083 | 56.78 | 18000 | 0.4455 | 0.2102 |
85
+ | 0.8083 | 59.94 | 19000 | 0.4587 | 0.2092 |
86
+ | 0.769 | 63.09 | 20000 | 0.4794 | 0.2012 |
87
+ | 0.769 | 66.25 | 21000 | 0.4845 | 0.2007 |
88
+ | 0.7308 | 69.4 | 22000 | 0.4937 | 0.2008 |
89
+ | 0.7308 | 72.55 | 23000 | 0.4920 | 0.1895 |
90
+ | 0.6927 | 75.71 | 24000 | 0.5179 | 0.1911 |
91
+ | 0.6927 | 78.86 | 25000 | 0.5202 | 0.1877 |
92
+ | 0.6622 | 82.02 | 26000 | 0.5266 | 0.1840 |
93
+ | 0.6622 | 85.17 | 27000 | 0.5351 | 0.1854 |
94
+ | 0.6315 | 88.33 | 28000 | 0.5373 | 0.1811 |
95
+ | 0.6315 | 91.48 | 29000 | 0.5331 | 0.1792 |
96
+ | 0.6075 | 94.64 | 30000 | 0.5390 | 0.1779 |
97
+ | 0.6075 | 97.79 | 31000 | 0.5459 | 0.1773 |
98
+
99
+
100
+ ### Framework versions
101
+
102
+ - Transformers 4.17.0.dev0
103
+ - Pytorch 1.10.2+cu102
104
+ - Datasets 1.18.2.dev0
105
+ - Tokenizers 0.11.0
added_tokens.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"<s>": 30, "</s>": 31}
all_results.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "epoch": 100.0,
3
+ "eval_loss": 0.5438345074653625,
4
+ "eval_runtime": 140.268,
5
+ "eval_samples": 5076,
6
+ "eval_samples_per_second": 36.188,
7
+ "eval_steps_per_second": 2.267,
8
+ "eval_wer": 0.177349387392344,
9
+ "train_loss": 0.8928292760036721,
10
+ "train_runtime": 80759.6589,
11
+ "train_samples": 20306,
12
+ "train_samples_per_second": 25.144,
13
+ "train_steps_per_second": 0.393
14
+ }
config.json ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "imvladikon/wav2vec2-xls-r-300m-hebrew",
3
+ "activation_dropout": 0.1,
4
+ "adapter_kernel_size": 3,
5
+ "adapter_stride": 2,
6
+ "add_adapter": false,
7
+ "apply_spec_augment": true,
8
+ "architectures": [
9
+ "Wav2Vec2ForCTC"
10
+ ],
11
+ "attention_dropout": 0.0,
12
+ "bos_token_id": 1,
13
+ "classifier_proj_size": 256,
14
+ "codevector_dim": 768,
15
+ "contrastive_logits_temperature": 0.1,
16
+ "conv_bias": true,
17
+ "conv_dim": [
18
+ 512,
19
+ 512,
20
+ 512,
21
+ 512,
22
+ 512,
23
+ 512,
24
+ 512
25
+ ],
26
+ "conv_kernel": [
27
+ 10,
28
+ 3,
29
+ 3,
30
+ 3,
31
+ 3,
32
+ 2,
33
+ 2
34
+ ],
35
+ "conv_stride": [
36
+ 5,
37
+ 2,
38
+ 2,
39
+ 2,
40
+ 2,
41
+ 2,
42
+ 2
43
+ ],
44
+ "ctc_loss_reduction": "mean",
45
+ "ctc_zero_infinity": false,
46
+ "diversity_loss_weight": 0.1,
47
+ "do_stable_layer_norm": true,
48
+ "eos_token_id": 2,
49
+ "feat_extract_activation": "gelu",
50
+ "feat_extract_dropout": 0.0,
51
+ "feat_extract_norm": "layer",
52
+ "feat_proj_dropout": 0.0,
53
+ "feat_quantizer_dropout": 0.0,
54
+ "final_dropout": 0.0,
55
+ "hidden_act": "gelu",
56
+ "hidden_dropout": 0.0,
57
+ "hidden_size": 1024,
58
+ "initializer_range": 0.02,
59
+ "intermediate_size": 4096,
60
+ "layer_norm_eps": 1e-05,
61
+ "layerdrop": 0.0,
62
+ "mask_feature_length": 64,
63
+ "mask_feature_min_masks": 0,
64
+ "mask_feature_prob": 0.25,
65
+ "mask_time_length": 10,
66
+ "mask_time_min_masks": 2,
67
+ "mask_time_prob": 0.75,
68
+ "model_type": "wav2vec2",
69
+ "num_adapter_layers": 3,
70
+ "num_attention_heads": 16,
71
+ "num_codevector_groups": 2,
72
+ "num_codevectors_per_group": 320,
73
+ "num_conv_pos_embedding_groups": 16,
74
+ "num_conv_pos_embeddings": 128,
75
+ "num_feat_extract_layers": 7,
76
+ "num_hidden_layers": 24,
77
+ "num_negatives": 100,
78
+ "output_hidden_size": 1024,
79
+ "pad_token_id": 29,
80
+ "proj_codevector_dim": 768,
81
+ "tdnn_dilation": [
82
+ 1,
83
+ 2,
84
+ 3,
85
+ 1,
86
+ 1
87
+ ],
88
+ "tdnn_dim": [
89
+ 512,
90
+ 512,
91
+ 512,
92
+ 512,
93
+ 1500
94
+ ],
95
+ "tdnn_kernel": [
96
+ 5,
97
+ 3,
98
+ 3,
99
+ 1,
100
+ 1
101
+ ],
102
+ "torch_dtype": "float32",
103
+ "transformers_version": "4.17.0.dev0",
104
+ "use_weighted_layer_sum": false,
105
+ "vocab_size": 32,
106
+ "xvector_output_dim": 512
107
+ }
preprocessor_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_normalize": true,
3
+ "feature_extractor_type": "Wav2Vec2FeatureExtractor",
4
+ "feature_size": 1,
5
+ "padding_side": "right",
6
+ "padding_value": 0.0,
7
+ "return_attention_mask": true,
8
+ "sampling_rate": 16000
9
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:31ddb630e651b5670137ac169c70f3befd37058b7e77c57183510edbf5c313f9
3
+ size 1262054897
run_train.py ADDED
@@ -0,0 +1,981 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # !/usr/bin/env python
2
+ # coding=utf-8
3
+ import functools
4
+ import json
5
+ import logging
6
+ import os
7
+ import re
8
+ import sys
9
+ import warnings
10
+ from dataclasses import dataclass, field
11
+ from typing import Any, Callable, Dict, List, Optional, Union
12
+
13
+ import datasets
14
+ import numpy as np
15
+ import torch
16
+ import torchaudio
17
+ from datasets import DatasetDict, ReadInstruction, load_dataset, load_metric, concatenate_datasets
18
+
19
+ try:
20
+ import bitsandbytes as bnb
21
+
22
+ BNB_AVAILABLE = True
23
+ except:
24
+ BNB_AVAILABLE = False
25
+ try:
26
+ import wandb
27
+
28
+ WANDB_AVAILABLE = True
29
+ except:
30
+ WANDB_AVAILABLE = False
31
+ import transformers
32
+ from transformers import (
33
+ AutoConfig,
34
+ AutoFeatureExtractor,
35
+ AutoModelForCTC,
36
+ AutoTokenizer,
37
+ HfArgumentParser,
38
+ Trainer,
39
+ TrainerCallback, TrainingArguments,
40
+ Wav2Vec2Processor,
41
+ set_seed,
42
+ )
43
+
44
+ try:
45
+ from torch_audiomentations import (
46
+ Compose,
47
+ AddGaussianNoise,
48
+ AddGaussianSNR,
49
+ ClippingDistortion,
50
+ FrequencyMask,
51
+ Gain,
52
+ LoudnessNormalization,
53
+ Normalize,
54
+ PitchShift,
55
+ PolarityInversion,
56
+ Shift,
57
+ TimeMask,
58
+ TimeStretch,
59
+ )
60
+
61
+ AUDIOMENTATIONS_AVAILABLE = True
62
+ except:
63
+ AUDIOMENTATIONS_AVAILABLE = False
64
+ try:
65
+ from transformers import AutoProcessor
66
+ except:
67
+ pass
68
+ from transformers.trainer_pt_utils import get_parameter_names
69
+ from transformers.trainer_utils import get_last_checkpoint, is_main_process
70
+ from transformers.utils import check_min_version
71
+ from transformers.utils.versions import require_version
72
+
73
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
74
+ check_min_version("4.16.0")
75
+
76
+ require_version(
77
+ "datasets>=1.13.3",
78
+ "To fix: pip install -r examples/pytorch/text-classification/requirements.txt",
79
+ )
80
+
81
+ logger = logging.getLogger(__name__)
82
+
83
+
84
+ def list_field(default=None, metadata=None):
85
+ return field(default_factory=lambda: default, metadata=metadata)
86
+
87
+
88
+ @dataclass
89
+ class ModelArguments:
90
+ """
91
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
92
+ """
93
+
94
+ model_name_or_path: str = field(
95
+ metadata={
96
+ "help": "Path to pretrained model or model identifier from huggingface.co/models"
97
+ }
98
+ )
99
+ tokenizer_name_or_path: Optional[str] = field(
100
+ default=None,
101
+ metadata={
102
+ "help": "Path to pretrained tokenizer or tokenizer identifier from huggingface.co/models"
103
+ },
104
+ )
105
+ cache_dir: Optional[str] = field(
106
+ default=None,
107
+ metadata={
108
+ "help": "Where do you want to store the pretrained models downloaded from huggingface.co"
109
+ },
110
+ )
111
+ freeze_feature_encoder: bool = field(
112
+ default=True,
113
+ metadata={"help": "Whether to freeze the feature encoder layers of the model."},
114
+ )
115
+ attention_dropout: float = field(
116
+ default=0.0,
117
+ metadata={"help": "The dropout ratio for the attention probabilities."},
118
+ )
119
+ activation_dropout: float = field(
120
+ default=0.0,
121
+ metadata={
122
+ "help": "The dropout ratio for activations inside the fully connected layer."
123
+ },
124
+ )
125
+ feat_proj_dropout: float = field(
126
+ default=0.0, metadata={"help": "The dropout ratio for the projected features."}
127
+ )
128
+ hidden_dropout: float = field(
129
+ default=0.0,
130
+ metadata={
131
+ "help": "The dropout probability for all fully connected layers in the embeddings, encoder, and pooler."
132
+ },
133
+ )
134
+ final_dropout: float = field(
135
+ default=0.0,
136
+ metadata={"help": "The dropout probability for the final projection layer."},
137
+ )
138
+ mask_time_prob: float = field(
139
+ default=0.05,
140
+ metadata={
141
+ "help": "Probability of each feature vector along the time axis to be chosen as the start of the vector"
142
+ "span to be masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature"
143
+ "vectors will be masked along the time axis."
144
+ },
145
+ )
146
+ mask_time_length: int = field(
147
+ default=10,
148
+ metadata={"help": "Length of vector span to mask along the time axis."},
149
+ )
150
+ mask_feature_prob: float = field(
151
+ default=0.0,
152
+ metadata={
153
+ "help": "Probability of each feature vector along the feature axis to be chosen as the start of the vector"
154
+ "span to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature bins will be masked along the time axis."
155
+ },
156
+ )
157
+ mask_feature_length: int = field(
158
+ default=10,
159
+ metadata={"help": "Length of vector span to mask along the feature axis."},
160
+ )
161
+ layerdrop: float = field(default=0.0, metadata={"help": "The LayerDrop probability."})
162
+ ctc_loss_reduction: Optional[str] = field(
163
+ default="mean",
164
+ metadata={
165
+ "help": "The way the ctc loss should be reduced. Should be one of 'mean' or 'sum'."
166
+ },
167
+ )
168
+
169
+
170
+ @dataclass
171
+ class DataTrainingArguments:
172
+ """
173
+ Arguments pertaining to what data we are going to input our model for training and eval.
174
+
175
+ Using `HfArgumentParser` we can turn this class
176
+ into argparse arguments to be able to specify them on
177
+ the command line.
178
+ """
179
+
180
+ dataset_path: str = field(
181
+ default=None,
182
+ metadata={
183
+ "help": "The configuration name of the dataset to use (via the datasets library)."
184
+ }
185
+ )
186
+ dataset_name: str = field(
187
+ default=None,
188
+ metadata={
189
+ "help": "The configuration name of the dataset to use (via the datasets library)."
190
+ },
191
+ )
192
+ dataset_config_name: str = field(
193
+ default=None,
194
+ metadata={
195
+ "help": "The configuration name of the dataset to use (via the datasets library)."
196
+ },
197
+ )
198
+ train_split_name: str = field(
199
+ default="train",
200
+ metadata={
201
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
202
+ },
203
+ )
204
+ eval_split_name: str = field(
205
+ default="validation",
206
+ metadata={
207
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
208
+ },
209
+ )
210
+ audio_column_name: str = field(
211
+ default="audio",
212
+ metadata={
213
+ "help": "The name of the dataset column containing the audio data. Defaults to 'audio'"
214
+ },
215
+ )
216
+ text_column_name: str = field(
217
+ default="text",
218
+ metadata={
219
+ "help": "The name of the dataset column containing the text data. Defaults to 'text'"
220
+ },
221
+ )
222
+ wav_filesize_column_name: str = field(
223
+ default=None,
224
+ metadata={
225
+ "help": "The name of the dataset column containing the wav filesize. Defaults is None"
226
+ },
227
+ )
228
+ overwrite_cache: bool = field(
229
+ default=False,
230
+ metadata={"help": "Overwrite the cached preprocessed datasets or not."},
231
+ )
232
+ preprocessing_num_workers: Optional[int] = field(
233
+ default=None,
234
+ metadata={"help": "The number of processes to use for the preprocessing."},
235
+ )
236
+ max_train_samples: Optional[int] = field(
237
+ default=None,
238
+ metadata={
239
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
240
+ "value if set."
241
+ },
242
+ )
243
+ max_eval_samples: Optional[int] = field(
244
+ default=None,
245
+ metadata={
246
+ "help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
247
+ "value if set."
248
+ },
249
+ )
250
+ chars_to_ignore: Optional[List[str]] = list_field(
251
+ default=None,
252
+ metadata={"help": "A list of characters to remove from the transcripts."},
253
+ )
254
+ eval_metrics: List[str] = list_field(
255
+ default=["wer"],
256
+ metadata={
257
+ "help": "A list of metrics the model should be evaluated on. E.g. `'wer cer'`"
258
+ },
259
+ )
260
+ max_duration_in_seconds: float = field(
261
+ default=20.0,
262
+ metadata={
263
+ "help": "Filter audio files that are longer than `max_duration_in_seconds` seconds to 'max_duration_in_seconds`"
264
+ },
265
+ )
266
+ min_duration_in_seconds: float = field(
267
+ default=0.0,
268
+ metadata={
269
+ "help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"
270
+ },
271
+ )
272
+ preprocessing_only: bool = field(
273
+ default=False,
274
+ metadata={
275
+ "help": "Whether to only do data preprocessing and skip training. "
276
+ "This is especially useful when data preprocessing errors out in distributed training due to timeout. "
277
+ "In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` "
278
+ "so that the cached datasets can consequently be loaded in distributed training"
279
+ },
280
+ )
281
+ print_samples: bool = field(
282
+ default=False,
283
+ metadata={
284
+ "help": "Print row with validation inference results to stdout after each epoch"
285
+ },
286
+ )
287
+ use_augmentations: bool = field(
288
+ default=False,
289
+ metadata={
290
+ "help": "Use data augmentation during training"
291
+ },
292
+ )
293
+ use_auth_token: str = field(
294
+ default="",
295
+ metadata={
296
+ "help": "If :obj:`True`, will use the token generated when running"
297
+ ":obj:`transformers-cli login` as HTTP bearer authorization for remote files."
298
+ },
299
+ )
300
+ unk_token: str = field(
301
+ default="[UNK]",
302
+ metadata={"help": "The unk token for the tokenizer"},
303
+ )
304
+ pad_token: str = field(
305
+ default="[PAD]",
306
+ metadata={"help": "The padding token for the tokenizer"},
307
+ )
308
+ word_delimiter_token: str = field(
309
+ default="|",
310
+ metadata={"help": "The word delimiter token for the tokenizer"},
311
+ )
312
+ phoneme_language: Optional[str] = field(
313
+ default=None,
314
+ metadata={
315
+ "help": "The target language that should be used be"
316
+ " passed to the tokenizer for tokenization. Note that"
317
+ " this is only relevant if the model classifies the"
318
+ " input audio to a sequence of phoneme sequences."
319
+ },
320
+ )
321
+
322
+
323
+ class Augmentator:
324
+
325
+ def __init__(
326
+ self,
327
+ apply_gaussian_noise_with_p=0.1,
328
+ apply_gain_with_p=0.1,
329
+ apply_pitch_shift_with_p=0.1,
330
+ apply_time_stretch_with_p=0.1,
331
+ augment_proba=0.1,
332
+ sample_rate=16_000
333
+ ):
334
+ self.augmentator_fn = None
335
+ self.sample_rate = sample_rate
336
+ self.augment_proba = augment_proba
337
+ all_p = (
338
+ apply_gaussian_noise_with_p
339
+ + apply_gain_with_p
340
+ + apply_pitch_shift_with_p
341
+ + apply_time_stretch_with_p
342
+ )
343
+ if AUDIOMENTATIONS_AVAILABLE and all_p > 0:
344
+ self.augmentator_fn = Compose([
345
+ TimeStretch(min_rate=0.8, max_rate=1.2, leave_length_unchanged=False,
346
+ p=apply_time_stretch_with_p),
347
+ PitchShift(min_semitones=-1, max_semitones=1,
348
+ p=apply_pitch_shift_with_p),
349
+ Gain(min_gain_in_db=-1, max_gain_in_db=1, p=apply_gain_with_p),
350
+ AddGaussianNoise(min_amplitude=0.0001, max_amplitude=0.001,
351
+ p=apply_gaussian_noise_with_p),
352
+ ])
353
+
354
+ def __call__(self, input_values: List[float], *args, **kwargs):
355
+ if AUDIOMENTATIONS_AVAILABLE and self.augmentator_fn is not None:
356
+ return self.augmentator_fn(samples=np.array(input_values),
357
+ sample_rate=self.sample_rate).tolist()
358
+ else:
359
+ return input_values
360
+
361
+
362
+ @dataclass
363
+ class DataCollatorCTCWithPadding:
364
+ """
365
+ Data collator that will dynamically pad the inputs received.
366
+ Args:
367
+ processor (:class:`~transformers.AutoProcessor`)
368
+ The processor used for proccessing the data.
369
+ padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
370
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
371
+ among:
372
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
373
+ sequence if provided).
374
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
375
+ maximum acceptable input length for the model if that argument is not provided.
376
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
377
+ different lengths).
378
+ max_length (:obj:`int`, `optional`):
379
+ Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
380
+ max_length_labels (:obj:`int`, `optional`):
381
+ Maximum length of the ``labels`` returned list and optionally padding length (see above).
382
+ pad_to_multiple_of (:obj:`int`, `optional`):
383
+ If set will pad the sequence to a multiple of the provided value.
384
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
385
+ 7.5 (Volta).
386
+ """
387
+
388
+ processor: 'AutoProcessor'
389
+ padding: Union[bool, str] = "longest"
390
+ pad_to_multiple_of: Optional[int] = None
391
+ pad_to_multiple_of_labels: Optional[int] = None
392
+ augmentator_fn: Optional[Callable] = None
393
+ use_augmentations: bool = False
394
+
395
+ def __call__(
396
+ self, features: List[Dict[str, Union[List[int], torch.Tensor]]]
397
+ ) -> Dict[str, torch.Tensor]:
398
+ # split inputs and labels since they have to be of different lenghts and need
399
+ # different padding methods
400
+ input_features = [
401
+ {
402
+ "input_values": self.augmentator_fn(feature["input_values"])
403
+ if self.use_augmentations
404
+ else feature["input_values"]}
405
+ for feature in features
406
+ ]
407
+ label_features = [{"input_ids": feature["labels"]} for feature in features]
408
+
409
+ batch = self.processor.pad(
410
+ input_features,
411
+ padding=self.padding,
412
+ pad_to_multiple_of=self.pad_to_multiple_of,
413
+ return_tensors="pt",
414
+ )
415
+
416
+ with self.processor.as_target_processor():
417
+ labels_batch = self.processor.pad(
418
+ label_features,
419
+ padding=self.padding,
420
+ pad_to_multiple_of=self.pad_to_multiple_of_labels,
421
+ return_tensors="pt",
422
+ )
423
+
424
+ # replace padding with -100 to ignore loss correctly
425
+ labels = labels_batch["input_ids"].masked_fill(
426
+ labels_batch.attention_mask.ne(1), -100
427
+ )
428
+
429
+ batch["labels"] = labels
430
+
431
+ return batch
432
+
433
+
434
+ def create_vocabulary_from_data(
435
+ datasets: DatasetDict,
436
+ text_column_name: str,
437
+ train_split_name: str,
438
+ word_delimiter_token: Optional[str] = None,
439
+ unk_token: Optional[str] = None,
440
+ pad_token: Optional[str] = None,
441
+ ):
442
+ # Given training and test labels create vocabulary
443
+ def extract_all_chars(batch):
444
+ all_text = " ".join(batch[text_column_name])
445
+ vocab = list(set(all_text))
446
+ return {"vocab": [vocab], "all_text": [all_text]}
447
+
448
+ print("extract chars")
449
+ vocabs = datasets.map(
450
+ extract_all_chars,
451
+ batched=True,
452
+ batch_size=-1,
453
+ keep_in_memory=True,
454
+ remove_columns=datasets[train_split_name].column_names,
455
+ )
456
+
457
+ # take union of all unique characters in each dataset
458
+ print("make vocab_set")
459
+ vocab_set = functools.reduce(
460
+ lambda vocab_1, vocab_2: set(vocab_1["vocab"][0]) | set(vocab_2["vocab"][0]),
461
+ vocabs.values(),
462
+ )
463
+
464
+ vocab_dict = {v: k for k, v in enumerate(sorted(list(vocab_set)))}
465
+
466
+ # replace white space with delimiter token
467
+ if word_delimiter_token is not None:
468
+ vocab_dict[word_delimiter_token] = vocab_dict[" "]
469
+ del vocab_dict[" "]
470
+
471
+ # add unk and pad token
472
+ if unk_token is not None:
473
+ vocab_dict[unk_token] = len(vocab_dict)
474
+
475
+ if pad_token is not None:
476
+ vocab_dict[pad_token] = len(vocab_dict)
477
+
478
+ return vocab_dict
479
+
480
+
481
+ def speech_file_to_array_fn(batch, audio_column_name, dataset_path=""):
482
+ if dataset_path:
483
+ dataset_path = os.path.join(dataset_path, batch[audio_column_name])
484
+ else:
485
+ dataset_path = batch[audio_column_name] if isinstance(batch[audio_column_name],
486
+ str) else \
487
+ batch[audio_column_name]["path"]
488
+ speech_array, sampling_rate = torchaudio.load(dataset_path)
489
+ batch[audio_column_name] = {
490
+ "array": speech_array[0].numpy(),
491
+ "sampling_rate": sampling_rate,
492
+ }
493
+ return batch
494
+
495
+
496
+ class PrintSamplesPredictionCallback(TrainerCallback):
497
+
498
+ def __init__(self, processor, eval_dataset):
499
+ super(PrintSamplesPredictionCallback, self).__init__()
500
+ self.processor = processor
501
+ self.eval_dataset = eval_dataset
502
+ self.metric_fn = load_metric("wer")
503
+
504
+ def on_log(
505
+ self,
506
+ args: Any,
507
+ state: Any,
508
+ control: Any,
509
+ model: Any,
510
+ logs: Optional[Any] = None,
511
+ **kwargs
512
+ ):
513
+ """
514
+ :param args:
515
+ :param state:
516
+ :param control:
517
+ :param model:
518
+ :param logs:
519
+ :param kwargs: 'tokenizer', 'optimizer', 'lr_scheduler', 'train_dataloader', 'eval_dataloader'
520
+ :return:
521
+ """
522
+ if state.is_local_process_zero:
523
+ columns = ["id", "prediction", "reference", "audio", "wer"]
524
+ data = []
525
+ for idx, row in enumerate(self.eval_dataset):
526
+ input_dict = self.processor(row["input_values"],
527
+ return_tensors="pt", padding=True)
528
+ logits = model(input_dict.input_values.to(model.device)).logits
529
+ pred_ids = torch.argmax(logits, dim=-1)[0]
530
+ prediction = self.processor.decode(pred_ids)
531
+ print(f"Prediction: {prediction}")
532
+ reference = row['references'].lower()
533
+ print(f"\nReference: {reference}")
534
+
535
+ if WANDB_AVAILABLE:
536
+
537
+ audio, sample_rate = tuple(row["audio"].values())
538
+ audio = wandb.Audio(np.squeeze(audio),
539
+ sample_rate=sample_rate)
540
+ wer = self.metric_fn.compute(
541
+ predictions=[prediction],
542
+ references=[reference],
543
+ )
544
+
545
+ data.append([idx, prediction, reference, audio, wer])
546
+ if WANDB_AVAILABLE:
547
+ table = wandb.Table(data=data, columns=columns)
548
+ wandb.run.log({"audio_predictions": table})
549
+
550
+
551
+ def main():
552
+ # See all possible arguments in src/transformers/training_args.py
553
+ # or by passing the --help flag to this script.
554
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
555
+
556
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
557
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
558
+ # If we pass only one argument to the script and it's the path to a json file,
559
+ # let's parse it to get our arguments.
560
+ model_args, data_args, training_args = parser.parse_json_file(
561
+ json_file=os.path.abspath(sys.argv[1])
562
+ )
563
+ else:
564
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
565
+
566
+ # Detecting last checkpoint.
567
+ last_checkpoint = None
568
+ if (
569
+ os.path.isdir(training_args.output_dir)
570
+ and training_args.do_train
571
+ and not training_args.overwrite_output_dir
572
+ ):
573
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
574
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
575
+ raise ValueError(
576
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
577
+ "Use --overwrite_output_dir to overcome."
578
+ )
579
+ elif last_checkpoint is not None:
580
+ logger.info(
581
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
582
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
583
+ )
584
+
585
+ # Setup logging
586
+ logging.basicConfig(
587
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
588
+ datefmt="%m/%d/%Y %H:%M:%S",
589
+ handlers=[logging.StreamHandler(sys.stdout)],
590
+ )
591
+ logger.setLevel(
592
+ logging.INFO if is_main_process(training_args.local_rank) else logging.WARN
593
+ )
594
+
595
+ # Log on each process the small summary:
596
+ logger.warning(
597
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
598
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
599
+ )
600
+ # Set the verbosity to info of the Transformers logger (on main process only):
601
+ if is_main_process(training_args.local_rank):
602
+ transformers.utils.logging.set_verbosity_info()
603
+ logger.info("Training/evaluation parameters %s", training_args)
604
+
605
+ # Set seed before initializing model.
606
+ set_seed(training_args.seed)
607
+
608
+ train_split_name = data_args.train_split_name
609
+ eval_split_name = data_args.eval_split_name
610
+
611
+ # 1. First, let's load the dataset
612
+ raw_datasets = DatasetDict({
613
+ train_split_name: None,
614
+ eval_split_name: None,
615
+ })
616
+
617
+ if data_args.dataset_path:
618
+ raw_datasets = load_dataset(
619
+ "csv",
620
+ data_files={
621
+ train_split_name: os.path.join(data_args.dataset_path, "train-all.csv"),
622
+ eval_split_name: os.path.join(data_args.dataset_path, "eval-all.csv"),
623
+ },
624
+ )
625
+
626
+ if training_args.do_train:
627
+ if raw_datasets[train_split_name] is None:
628
+ raw_datasets[train_split_name] = load_dataset(
629
+ data_args.dataset_name,
630
+ data_args.dataset_config_name,
631
+ split=data_args.train_split_name,
632
+ use_auth_token=data_args.use_auth_token,
633
+ )
634
+
635
+ if data_args.audio_column_name not in raw_datasets[train_split_name].column_names:
636
+ raise ValueError(
637
+ f"--audio_column_name '{data_args.audio_column_name}' not found in dataset. "
638
+ "Make sure to set `--audio_column_name` to the correct audio column - one of "
639
+ f"{', '.join(raw_datasets['train'].column_names)}."
640
+ )
641
+
642
+ if data_args.text_column_name not in raw_datasets[train_split_name].column_names:
643
+ raise ValueError(
644
+ f"--text_column_name {data_args.text_column_name} not found in dataset. "
645
+ "Make sure to set `--text_column_name` to the correct text column - one of "
646
+ f"{', '.join(raw_datasets['train'].column_names)}."
647
+ )
648
+
649
+ if data_args.max_train_samples is not None:
650
+ raw_datasets[train_split_name] = raw_datasets[train_split_name].select(
651
+ range(data_args.max_train_samples)
652
+ )
653
+
654
+ if data_args.wav_filesize_column_name is not None:
655
+ raw_datasets[train_split_name] = raw_datasets[train_split_name].sort(
656
+ data_args.wav_filesize_column_name, reverse=True)
657
+
658
+ if training_args.do_eval:
659
+ if raw_datasets[eval_split_name] is None:
660
+ raw_datasets[eval_split_name] = load_dataset(
661
+ data_args.dataset_name,
662
+ data_args.dataset_config_name,
663
+ split=data_args.eval_split_name,
664
+ use_auth_token=data_args.use_auth_token,
665
+ )
666
+
667
+ if data_args.max_eval_samples is not None:
668
+ raw_datasets[eval_split_name] = raw_datasets[eval_split_name].select(
669
+ range(data_args.max_eval_samples)
670
+ )
671
+ if data_args.wav_filesize_column_name is not None:
672
+ raw_datasets[eval_split_name] = raw_datasets[eval_split_name].sort(
673
+ data_args.wav_filesize_column_name, reverse=True)
674
+
675
+ # save special tokens for tokenizer
676
+ word_delimiter_token = data_args.word_delimiter_token
677
+ unk_token = data_args.unk_token
678
+ pad_token = data_args.pad_token
679
+
680
+ # 3. Next, let's load the config as we might need it to create
681
+ # the tokenizer
682
+ # load config
683
+ config = AutoConfig.from_pretrained(
684
+ model_args.model_name_or_path,
685
+ cache_dir=model_args.cache_dir,
686
+ use_auth_token=data_args.use_auth_token,
687
+ )
688
+
689
+ # 4. Next, if no tokenizer file is defined,
690
+ # we create the vocabulary of the model by extracting all unique characters from
691
+ # the training and evaluation datasets
692
+ # We need to make sure that only first rank saves vocabulary
693
+ # make sure all processes wait until vocab is created
694
+ tokenizer_name_or_path = model_args.tokenizer_name_or_path
695
+ tokenizer_kwargs = {}
696
+
697
+ # 5. Now we can instantiate the feature extractor, tokenizer and model
698
+ # Note for distributed training, the .from_pretrained methods guarantee that only
699
+ # one local process can concurrently download model & vocab.
700
+ with open(os.path.join(tokenizer_name_or_path, "vocab.json"), "r") as fin:
701
+ print("loading tokenizer")
702
+ print(fin.read())
703
+
704
+ # load feature_extractor and tokenizer
705
+ tokenizer = AutoTokenizer.from_pretrained(
706
+ tokenizer_name_or_path,
707
+ use_auth_token=data_args.use_auth_token,
708
+ **tokenizer_kwargs,
709
+ )
710
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
711
+ model_args.model_name_or_path,
712
+ cache_dir=model_args.cache_dir,
713
+ use_auth_token=data_args.use_auth_token,
714
+ )
715
+
716
+ # adapt config
717
+ config.update(
718
+ {
719
+ "feat_proj_dropout": model_args.feat_proj_dropout,
720
+ "attention_dropout": model_args.attention_dropout,
721
+ "hidden_dropout": model_args.hidden_dropout,
722
+ "final_dropout": model_args.final_dropout,
723
+ "mask_time_prob": model_args.mask_time_prob,
724
+ "mask_time_length": model_args.mask_time_length,
725
+ "mask_feature_prob": model_args.mask_feature_prob,
726
+ "mask_feature_length": model_args.mask_feature_length,
727
+ "gradient_checkpointing": training_args.gradient_checkpointing,
728
+ "layerdrop": model_args.layerdrop,
729
+ "ctc_loss_reduction": model_args.ctc_loss_reduction,
730
+ "pad_token_id": tokenizer.pad_token_id,
731
+ "vocab_size": len(tokenizer),
732
+ "activation_dropout": model_args.activation_dropout,
733
+ }
734
+ )
735
+
736
+ # create model
737
+ model = AutoModelForCTC.from_pretrained(
738
+ model_args.model_name_or_path,
739
+ cache_dir=model_args.cache_dir,
740
+ config=config,
741
+ use_auth_token=data_args.use_auth_token,
742
+ )
743
+
744
+ # freeze encoder
745
+ if model_args.freeze_feature_encoder:
746
+ model.freeze_feature_encoder()
747
+
748
+ # 6. Now we preprocess the datasets including loading the audio, resampling and normalization
749
+ # Thankfully, `datasets` takes care of automatically loading and resampling the audio,
750
+ # so that we just need to set the correct target sampling rate and normalize the input
751
+ # via the `feature_extractor`
752
+
753
+ # make sure that dataset decodes audio with correct sampling rate
754
+
755
+ # derive max & min input length for sample rate & max duration
756
+ audio_column_name = data_args.audio_column_name
757
+ num_workers = data_args.preprocessing_num_workers
758
+
759
+ # `phoneme_language` is only relevant if the model is fine-tuned on phoneme classification
760
+ phoneme_language = data_args.phoneme_language
761
+
762
+ raw_datasets[train_split_name] = raw_datasets[train_split_name].map(
763
+ speech_file_to_array_fn,
764
+ num_proc=num_workers,
765
+ fn_kwargs={"dataset_path": data_args.dataset_path,
766
+ "audio_column_name": audio_column_name},
767
+ )
768
+ raw_datasets[eval_split_name] = raw_datasets[eval_split_name].map(
769
+ speech_file_to_array_fn,
770
+ num_proc=num_workers,
771
+ fn_kwargs={"dataset_path": data_args.dataset_path,
772
+ "audio_column_name": audio_column_name},
773
+ )
774
+
775
+ # Preprocessing the datasets.
776
+ # We need to read the audio files as arrays and tokenize the targets.
777
+ def prepare_dataset(batch):
778
+ # load audio
779
+ sample = batch[audio_column_name]
780
+
781
+ inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
782
+ batch["input_values"] = inputs.input_values[0]
783
+ batch["input_length"] = len(batch["input_values"])
784
+
785
+ # encode targets
786
+ additional_kwargs = {}
787
+ if phoneme_language is not None:
788
+ additional_kwargs["phonemizer_lang"] = phoneme_language
789
+
790
+ batch["labels"] = tokenizer(batch[data_args.text_column_name],
791
+ **additional_kwargs).input_ids
792
+ return batch
793
+
794
+ print(f"Vectorizing")
795
+
796
+ with training_args.main_process_first(desc="dataset map preprocessing"):
797
+ vectorized_datasets = raw_datasets.map(
798
+ prepare_dataset,
799
+ remove_columns=next(iter(raw_datasets.values())).column_names,
800
+ num_proc=num_workers,
801
+ desc="preprocess datasets",
802
+ )
803
+
804
+ # 7. Next, we can prepare the training.
805
+ # Let's use word error rate (WER) as our evaluation metric,
806
+ # instantiate a data collator and the trainer
807
+
808
+ # Define evaluation metrics during training, *i.e.* word error rate, character error rate
809
+ eval_metrics = {metric: load_metric(metric) for metric in data_args.eval_metrics}
810
+
811
+ # for large datasets it is advised to run the preprocessing on a
812
+ # single machine first with ``args.preprocessing_only`` since there will mostly likely
813
+ # be a timeout when running the script in distributed mode.
814
+ # In a second step ``args.preprocessing_only`` can then be set to `False` to load the
815
+ # cached dataset
816
+ if data_args.preprocessing_only:
817
+ logger.info(
818
+ f"Data preprocessing finished. Files cached at {vectorized_datasets.cache_files}"
819
+ )
820
+ return
821
+
822
+ def compute_metrics(pred):
823
+ pred_logits = pred.predictions
824
+ pred_ids = np.argmax(pred_logits, axis=-1)
825
+
826
+ pred.label_ids[pred.label_ids == -100] = tokenizer.pad_token_id
827
+
828
+ pred_str = tokenizer.batch_decode(pred_ids)
829
+ # we do not want to group tokens when computing the metrics
830
+ label_str = tokenizer.batch_decode(pred.label_ids, group_tokens=False)
831
+
832
+ metrics = {
833
+ k: v.compute(predictions=pred_str, references=label_str)
834
+ for k, v in eval_metrics.items()
835
+ }
836
+
837
+ return metrics
838
+
839
+ # Now save everything to be able to create a single processor later
840
+ if is_main_process(training_args.local_rank):
841
+ # save feature extractor, tokenizer and config
842
+ feature_extractor.save_pretrained(training_args.output_dir)
843
+ tokenizer.save_pretrained(training_args.output_dir)
844
+ config.save_pretrained(training_args.output_dir)
845
+
846
+ try:
847
+ processor = AutoProcessor.from_pretrained(training_args.output_dir)
848
+ except (OSError, KeyError):
849
+ warnings.warn(
850
+ "Loading a processor from a feature extractor config that does not"
851
+ " include a `processor_class` attribute is deprecated and will be removed in v5. Please add the following "
852
+ " attribute to your `preprocessor_config.json` file to suppress this warning: "
853
+ " `'processor_class': 'Wav2Vec2Processor'`",
854
+ FutureWarning,
855
+ )
856
+ processor = Wav2Vec2Processor.from_pretrained(training_args.output_dir)
857
+
858
+ # Instantiate custom data collator
859
+ data_collator = DataCollatorCTCWithPadding(
860
+ processor=processor,
861
+ augmentator_fn=Augmentator(),
862
+ use_augmentations=data_args.use_augmentations
863
+ )
864
+
865
+ decay_parameters = get_parameter_names(model, [torch.nn.LayerNorm])
866
+ decay_parameters = [name for name in decay_parameters if "bias" not in name]
867
+ optimizer_grouped_parameters = [
868
+ {
869
+ "params": [p for n, p in model.named_parameters() if n in decay_parameters],
870
+ "weight_decay": training_args.weight_decay,
871
+ },
872
+ {
873
+ "params": [
874
+ p for n, p in model.named_parameters() if n not in decay_parameters
875
+ ],
876
+ "weight_decay": 0.0,
877
+ },
878
+ ]
879
+ trainer_kwargs = {}
880
+ if BNB_AVAILABLE:
881
+ optimizer = bnb.optim.Adam8bit(
882
+ params=optimizer_grouped_parameters,
883
+ betas=(training_args.adam_beta1, training_args.adam_beta2),
884
+ eps=training_args.adam_epsilon,
885
+ )
886
+ trainer_kwargs["optimizers"] = (optimizer, None)
887
+
888
+ samples_to_log = [
889
+ {
890
+ **vectorized_datasets[eval_split_name][i],
891
+ "references": raw_datasets[eval_split_name][i][data_args.text_column_name],
892
+ "audio": raw_datasets[eval_split_name][i][data_args.audio_column_name],
893
+ } for i in range(5)
894
+ ]
895
+
896
+ trainer = Trainer(
897
+ model=model,
898
+ data_collator=data_collator,
899
+ args=training_args,
900
+ compute_metrics=compute_metrics,
901
+ train_dataset=vectorized_datasets[
902
+ train_split_name] if training_args.do_train else None,
903
+ eval_dataset=vectorized_datasets[
904
+ eval_split_name] if training_args.do_eval else None,
905
+ tokenizer=feature_extractor,
906
+ **trainer_kwargs,
907
+ callbacks=[PrintSamplesPredictionCallback(
908
+ processor=processor,
909
+ eval_dataset=samples_to_log)] if data_args.print_samples and training_args.do_eval else None,
910
+ )
911
+
912
+ # 8. Finally, we can start training
913
+
914
+ # Training
915
+ if training_args.do_train:
916
+
917
+ # use last checkpoint if exist
918
+ if last_checkpoint is not None:
919
+ checkpoint = last_checkpoint
920
+ elif os.path.isdir(model_args.model_name_or_path):
921
+ checkpoint = model_args.model_name_or_path
922
+ else:
923
+ checkpoint = None
924
+
925
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
926
+ trainer.save_model()
927
+
928
+ metrics = train_result.metrics
929
+ max_train_samples = (
930
+ data_args.max_train_samples
931
+ if data_args.max_train_samples is not None
932
+ else len(vectorized_datasets[train_split_name])
933
+ )
934
+ metrics["train_samples"] = min(
935
+ max_train_samples, len(vectorized_datasets[train_split_name])
936
+ )
937
+
938
+ trainer.log_metrics(train_split_name, metrics)
939
+ trainer.save_metrics(train_split_name, metrics)
940
+ trainer.save_state()
941
+
942
+ # Evaluation
943
+ results = {}
944
+ if training_args.do_eval:
945
+ logger.info("*** Evaluate ***")
946
+ metrics = trainer.evaluate()
947
+ max_eval_samples = (
948
+ data_args.max_eval_samples
949
+ if data_args.max_eval_samples is not None
950
+ else len(vectorized_datasets[eval_split_name])
951
+ )
952
+ metrics["eval_samples"] = min(max_eval_samples,
953
+ len(vectorized_datasets[eval_split_name]))
954
+
955
+ trainer.log_metrics(eval_split_name, metrics)
956
+ trainer.save_metrics(eval_split_name, metrics)
957
+
958
+ # Write model card and (optionally) push to hub
959
+ config_name = (
960
+ data_args.dataset_config_name
961
+ if data_args.dataset_config_name is not None
962
+ else "na"
963
+ )
964
+ kwargs = {
965
+ "language": "he",
966
+ "finetuned_from": model_args.model_name_or_path,
967
+ "tasks": "speech-recognition",
968
+ "tags": ["automatic-speech-recognition", "robust-speech-event", "he"],
969
+ "dataset_args": f"Config: {config_name}, Training split: {data_args.train_split_name}, Eval split: {data_args.eval_split_name}",
970
+ }
971
+
972
+ if training_args.push_to_hub:
973
+ trainer.push_to_hub(**kwargs)
974
+ else:
975
+ trainer.create_model_card(**kwargs)
976
+
977
+ return results
978
+
979
+
980
+ if __name__ == "__main__":
981
+ main()
run_train.sh ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export CUDA_VISIBLE_DEVICES="0,1"
2
+
3
+ python -m torch.distributed.launch --nproc_per_node=2 run_train.py \
4
+ --dataset_name="imvladikon/hebrew_speech_???" \
5
+ --use_auth_token="???" \
6
+ --audio_column_name="audio" \
7
+ --text_column_name="sentence" \
8
+ --model_name_or_path="imvladikon/wav2vec2-xls-r-300m-hebrew" \
9
+ --tokenizer_name_or_path="./wav2vec2-xls-r-300m-hebrew" \
10
+ --output_dir="./wav2vec2-xls-r-300m-hebrew" \
11
+ --overwrite_output_dir \
12
+ --evaluation_strategy="steps" \
13
+ --length_column_name="input_length" \
14
+ --gradient_checkpointing \
15
+ --fp16 \
16
+ --group_by_length \
17
+ --num_train_epochs="100" \
18
+ --per_device_train_batch_size="8" \
19
+ --per_device_eval_batch_size="8" \
20
+ --gradient_accumulation_steps="4" \
21
+ --learning_rate="3e-4" \
22
+ --warmup_steps="1000" \
23
+ --save_steps="1000" \
24
+ --eval_steps="1000" \
25
+ --preprocessing_num_workers="$(nproc)" \
26
+ --logging_steps="2000" \
27
+ --layerdrop="0.0" \
28
+ --activation_dropout="0.1" \
29
+ --save_total_limit="3" \
30
+ --freeze_feature_encoder \
31
+ --feat_proj_dropout="0.0" \
32
+ --mask_time_prob="0.75" \
33
+ --mask_time_length="10" \
34
+ --mask_feature_prob="0.25" \
35
+ --mask_feature_length="64" \
36
+ --do_train --do_eval \
37
+ --print_samples \
38
+ --use_augmentations \
39
+ --push_to_hub
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"bos_token": "<s>", "eos_token": "</s>", "unk_token": "[UNK]", "pad_token": "[PAD]", "additional_special_tokens": [{"content": "<s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, {"content": "</s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, {"content": "<s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, {"content": "</s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, {"content": "<s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, {"content": "</s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, {"content": "<s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, {"content": "</s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, {"content": "<s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, {"content": "</s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}]}
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"unk_token": "[UNK]", "bos_token": "<s>", "eos_token": "</s>", "pad_token": "[PAD]", "do_lower_case": false, "word_delimiter_token": "|", "special_tokens_map_file": null, "tokenizer_file": null, "name_or_path": "./wav2vec2-xls-r-300m-hebrew", "tokenizer_class": "Wav2Vec2CTCTokenizer"}
train_results.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "epoch": 100.0,
3
+ "train_loss": 0.8928292760036721,
4
+ "train_runtime": 80759.6589,
5
+ "train_samples": 20306,
6
+ "train_samples_per_second": 25.144,
7
+ "train_steps_per_second": 0.393
8
+ }
trainer_state.json ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_metric": null,
3
+ "best_model_checkpoint": null,
4
+ "epoch": 99.99842519685039,
5
+ "global_step": 31700,
6
+ "is_hyper_param_search": false,
7
+ "is_local_process_zero": true,
8
+ "is_world_process_zero": true,
9
+ "log_history": [
10
+ {
11
+ "epoch": 3.15,
12
+ "eval_loss": 0.5203462243080139,
13
+ "eval_runtime": 141.5768,
14
+ "eval_samples_per_second": 35.853,
15
+ "eval_steps_per_second": 2.246,
16
+ "eval_wer": 0.4333326279704594,
17
+ "step": 1000
18
+ },
19
+ {
20
+ "epoch": 6.31,
21
+ "learning_rate": 0.0009674592833876221,
22
+ "loss": 1.4284,
23
+ "step": 2000
24
+ },
25
+ {
26
+ "epoch": 6.31,
27
+ "eval_loss": 0.48156219720840454,
28
+ "eval_runtime": 142.5636,
29
+ "eval_samples_per_second": 35.605,
30
+ "eval_steps_per_second": 2.231,
31
+ "eval_wer": 0.3950949065746873,
32
+ "step": 2000
33
+ },
34
+ {
35
+ "epoch": 9.46,
36
+ "eval_loss": 0.4314565062522888,
37
+ "eval_runtime": 139.8322,
38
+ "eval_samples_per_second": 36.301,
39
+ "eval_steps_per_second": 2.274,
40
+ "eval_wer": 0.3545506485811626,
41
+ "step": 3000
42
+ },
43
+ {
44
+ "epoch": 12.62,
45
+ "learning_rate": 0.0009023452768729642,
46
+ "loss": 1.283,
47
+ "step": 4000
48
+ },
49
+ {
50
+ "epoch": 12.62,
51
+ "eval_loss": 0.42783403396606445,
52
+ "eval_runtime": 141.0912,
53
+ "eval_samples_per_second": 35.977,
54
+ "eval_steps_per_second": 2.254,
55
+ "eval_wer": 0.34039401570137756,
56
+ "step": 4000
57
+ },
58
+ {
59
+ "epoch": 15.77,
60
+ "eval_loss": 0.40902915596961975,
61
+ "eval_runtime": 140.9561,
62
+ "eval_samples_per_second": 36.011,
63
+ "eval_steps_per_second": 2.256,
64
+ "eval_wer": 0.3053939098969465,
65
+ "step": 5000
66
+ },
67
+ {
68
+ "epoch": 18.93,
69
+ "learning_rate": 0.0008372312703583062,
70
+ "loss": 1.1777,
71
+ "step": 6000
72
+ },
73
+ {
74
+ "epoch": 18.93,
75
+ "eval_loss": 0.3892641067504883,
76
+ "eval_runtime": 139.9163,
77
+ "eval_samples_per_second": 36.279,
78
+ "eval_steps_per_second": 2.273,
79
+ "eval_wer": 0.30056922783926193,
80
+ "step": 6000
81
+ },
82
+ {
83
+ "epoch": 22.08,
84
+ "eval_loss": 0.3967570960521698,
85
+ "eval_runtime": 139.2725,
86
+ "eval_samples_per_second": 36.447,
87
+ "eval_steps_per_second": 2.283,
88
+ "eval_wer": 0.28565080305563195,
89
+ "step": 7000
90
+ },
91
+ {
92
+ "epoch": 25.24,
93
+ "learning_rate": 0.0007720846905537459,
94
+ "loss": 1.0994,
95
+ "step": 8000
96
+ },
97
+ {
98
+ "epoch": 25.24,
99
+ "eval_loss": 0.3892391324043274,
100
+ "eval_runtime": 138.7844,
101
+ "eval_samples_per_second": 36.575,
102
+ "eval_steps_per_second": 2.291,
103
+ "eval_wer": 0.27509152083289246,
104
+ "step": 8000
105
+ },
106
+ {
107
+ "epoch": 28.39,
108
+ "eval_loss": 0.4061281681060791,
109
+ "eval_runtime": 139.2312,
110
+ "eval_samples_per_second": 36.457,
111
+ "eval_steps_per_second": 2.284,
112
+ "eval_wer": 0.2689760247159151,
113
+ "step": 9000
114
+ },
115
+ {
116
+ "epoch": 31.54,
117
+ "learning_rate": 0.000706970684039088,
118
+ "loss": 1.0323,
119
+ "step": 10000
120
+ },
121
+ {
122
+ "epoch": 31.54,
123
+ "eval_loss": 0.41136494278907776,
124
+ "eval_runtime": 139.4432,
125
+ "eval_samples_per_second": 36.402,
126
+ "eval_steps_per_second": 2.28,
127
+ "eval_wer": 0.25065069725120087,
128
+ "step": 10000
129
+ },
130
+ {
131
+ "epoch": 34.7,
132
+ "eval_loss": 0.40214526653289795,
133
+ "eval_runtime": 139.4093,
134
+ "eval_samples_per_second": 36.411,
135
+ "eval_steps_per_second": 2.281,
136
+ "eval_wer": 0.2508411452271621,
137
+ "step": 11000
138
+ },
139
+ {
140
+ "epoch": 37.85,
141
+ "learning_rate": 0.00064185667752443,
142
+ "loss": 0.9623,
143
+ "step": 12000
144
+ },
145
+ {
146
+ "epoch": 37.85,
147
+ "eval_loss": 0.40321338176727295,
148
+ "eval_runtime": 139.9917,
149
+ "eval_samples_per_second": 36.259,
150
+ "eval_steps_per_second": 2.272,
151
+ "eval_wer": 0.2378060393169266,
152
+ "step": 12000
153
+ },
154
+ {
155
+ "epoch": 41.01,
156
+ "eval_loss": 0.4147748053073883,
157
+ "eval_runtime": 139.3612,
158
+ "eval_samples_per_second": 36.423,
159
+ "eval_steps_per_second": 2.282,
160
+ "eval_wer": 0.23744630425122204,
161
+ "step": 13000
162
+ },
163
+ {
164
+ "epoch": 44.16,
165
+ "learning_rate": 0.0005767100977198697,
166
+ "loss": 0.9077,
167
+ "step": 14000
168
+ },
169
+ {
170
+ "epoch": 44.16,
171
+ "eval_loss": 0.4350396394729614,
172
+ "eval_runtime": 138.8108,
173
+ "eval_samples_per_second": 36.568,
174
+ "eval_steps_per_second": 2.291,
175
+ "eval_wer": 0.23230420890026873,
176
+ "step": 14000
177
+ },
178
+ {
179
+ "epoch": 47.32,
180
+ "eval_loss": 0.4514589309692383,
181
+ "eval_runtime": 138.719,
182
+ "eval_samples_per_second": 36.592,
183
+ "eval_steps_per_second": 2.292,
184
+ "eval_wer": 0.22464396808938358,
185
+ "step": 15000
186
+ },
187
+ {
188
+ "epoch": 50.47,
189
+ "learning_rate": 0.0005115960912052118,
190
+ "loss": 0.8573,
191
+ "step": 16000
192
+ },
193
+ {
194
+ "epoch": 50.47,
195
+ "eval_loss": 0.4473990499973297,
196
+ "eval_runtime": 140.3605,
197
+ "eval_samples_per_second": 36.164,
198
+ "eval_steps_per_second": 2.266,
199
+ "eval_wer": 0.21797828893074042,
200
+ "step": 16000
201
+ },
202
+ {
203
+ "epoch": 53.63,
204
+ "eval_loss": 0.4649062752723694,
205
+ "eval_runtime": 137.3039,
206
+ "eval_samples_per_second": 36.969,
207
+ "eval_steps_per_second": 2.316,
208
+ "eval_wer": 0.21713185348202382,
209
+ "step": 17000
210
+ },
211
+ {
212
+ "epoch": 56.78,
213
+ "learning_rate": 0.00044651465798045605,
214
+ "loss": 0.8083,
215
+ "step": 18000
216
+ },
217
+ {
218
+ "epoch": 56.78,
219
+ "eval_loss": 0.44551119208335876,
220
+ "eval_runtime": 139.4699,
221
+ "eval_samples_per_second": 36.395,
222
+ "eval_steps_per_second": 2.28,
223
+ "eval_wer": 0.2102334045749836,
224
+ "step": 18000
225
+ },
226
+ {
227
+ "epoch": 59.94,
228
+ "eval_loss": 0.4586869478225708,
229
+ "eval_runtime": 139.1403,
230
+ "eval_samples_per_second": 36.481,
231
+ "eval_steps_per_second": 2.285,
232
+ "eval_wer": 0.20917536026408787,
233
+ "step": 19000
234
+ },
235
+ {
236
+ "epoch": 63.09,
237
+ "learning_rate": 0.00038140065146579803,
238
+ "loss": 0.769,
239
+ "step": 20000
240
+ },
241
+ {
242
+ "epoch": 63.09,
243
+ "eval_loss": 0.4793929159641266,
244
+ "eval_runtime": 139.562,
245
+ "eval_samples_per_second": 36.371,
246
+ "eval_steps_per_second": 2.279,
247
+ "eval_wer": 0.20117654527371606,
248
+ "step": 20000
249
+ },
250
+ {
251
+ "epoch": 66.25,
252
+ "eval_loss": 0.4844733476638794,
253
+ "eval_runtime": 138.9678,
254
+ "eval_samples_per_second": 36.526,
255
+ "eval_steps_per_second": 2.288,
256
+ "eval_wer": 0.20073216666313987,
257
+ "step": 21000
258
+ },
259
+ {
260
+ "epoch": 69.4,
261
+ "learning_rate": 0.00031628664495114006,
262
+ "loss": 0.7308,
263
+ "step": 22000
264
+ },
265
+ {
266
+ "epoch": 69.4,
267
+ "eval_loss": 0.49372631311416626,
268
+ "eval_runtime": 139.7927,
269
+ "eval_samples_per_second": 36.311,
270
+ "eval_steps_per_second": 2.275,
271
+ "eval_wer": 0.20075332754935776,
272
+ "step": 22000
273
+ },
274
+ {
275
+ "epoch": 72.55,
276
+ "eval_loss": 0.4920376241207123,
277
+ "eval_runtime": 138.7644,
278
+ "eval_samples_per_second": 36.58,
279
+ "eval_steps_per_second": 2.292,
280
+ "eval_wer": 0.1894745751952092,
281
+ "step": 23000
282
+ },
283
+ {
284
+ "epoch": 75.71,
285
+ "learning_rate": 0.0002511400651465798,
286
+ "loss": 0.6927,
287
+ "step": 24000
288
+ },
289
+ {
290
+ "epoch": 75.71,
291
+ "eval_loss": 0.5178954005241394,
292
+ "eval_runtime": 139.799,
293
+ "eval_samples_per_second": 36.309,
294
+ "eval_steps_per_second": 2.275,
295
+ "eval_wer": 0.19114628520642443,
296
+ "step": 24000
297
+ },
298
+ {
299
+ "epoch": 78.86,
300
+ "eval_loss": 0.520152747631073,
301
+ "eval_runtime": 140.3812,
302
+ "eval_samples_per_second": 36.159,
303
+ "eval_steps_per_second": 2.265,
304
+ "eval_wer": 0.18767589986668642,
305
+ "step": 25000
306
+ },
307
+ {
308
+ "epoch": 82.02,
309
+ "learning_rate": 0.00018602605863192182,
310
+ "loss": 0.6622,
311
+ "step": 26000
312
+ },
313
+ {
314
+ "epoch": 82.02,
315
+ "eval_loss": 0.5265706181526184,
316
+ "eval_runtime": 138.1289,
317
+ "eval_samples_per_second": 36.748,
318
+ "eval_steps_per_second": 2.302,
319
+ "eval_wer": 0.18401506655098715,
320
+ "step": 26000
321
+ },
322
+ {
323
+ "epoch": 85.17,
324
+ "eval_loss": 0.5350863933563232,
325
+ "eval_runtime": 140.1605,
326
+ "eval_samples_per_second": 36.216,
327
+ "eval_steps_per_second": 2.269,
328
+ "eval_wer": 0.18541168504136954,
329
+ "step": 27000
330
+ },
331
+ {
332
+ "epoch": 88.33,
333
+ "learning_rate": 0.00012091205211726384,
334
+ "loss": 0.6315,
335
+ "step": 28000
336
+ },
337
+ {
338
+ "epoch": 88.33,
339
+ "eval_loss": 0.5373002290725708,
340
+ "eval_runtime": 138.125,
341
+ "eval_samples_per_second": 36.749,
342
+ "eval_steps_per_second": 2.302,
343
+ "eval_wer": 0.18113718602535075,
344
+ "step": 28000
345
+ },
346
+ {
347
+ "epoch": 91.48,
348
+ "eval_loss": 0.5330832600593567,
349
+ "eval_runtime": 139.0156,
350
+ "eval_samples_per_second": 36.514,
351
+ "eval_steps_per_second": 2.288,
352
+ "eval_wer": 0.17923270626573842,
353
+ "step": 29000
354
+ },
355
+ {
356
+ "epoch": 94.64,
357
+ "learning_rate": 5.576547231270358e-05,
358
+ "loss": 0.6075,
359
+ "step": 30000
360
+ },
361
+ {
362
+ "epoch": 94.64,
363
+ "eval_loss": 0.538992166519165,
364
+ "eval_runtime": 138.185,
365
+ "eval_samples_per_second": 36.733,
366
+ "eval_steps_per_second": 2.301,
367
+ "eval_wer": 0.17787840954779185,
368
+ "step": 30000
369
+ },
370
+ {
371
+ "epoch": 97.79,
372
+ "eval_loss": 0.5459240078926086,
373
+ "eval_runtime": 137.8608,
374
+ "eval_samples_per_second": 36.82,
375
+ "eval_steps_per_second": 2.307,
376
+ "eval_wer": 0.17730706561990817,
377
+ "step": 31000
378
+ },
379
+ {
380
+ "epoch": 100.0,
381
+ "step": 31700,
382
+ "total_flos": 3.173184730349909e+20,
383
+ "train_loss": 0.8928292760036721,
384
+ "train_runtime": 80759.6589,
385
+ "train_samples_per_second": 25.144,
386
+ "train_steps_per_second": 0.393
387
+ }
388
+ ],
389
+ "max_steps": 31700,
390
+ "num_train_epochs": 100,
391
+ "total_flos": 3.173184730349909e+20,
392
+ "trial_name": null,
393
+ "trial_params": null
394
+ }
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e69a04248e0ba54277463cd3b7be574e9546358e8382b38381a54e5df9b996a
3
+ size 3055
validation_results.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "epoch": 100.0,
3
+ "eval_loss": 0.5438345074653625,
4
+ "eval_runtime": 140.268,
5
+ "eval_samples": 5076,
6
+ "eval_samples_per_second": 36.188,
7
+ "eval_steps_per_second": 2.267,
8
+ "eval_wer": 0.177349387392344
9
+ }
vocab.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"讗": 1, "讘": 2, "讙": 3, "讚": 4, "讛": 5, "讜": 6, "讝": 7, "讞": 8, "讟": 9, "讬": 10, "讱": 11, "讻": 12, "诇": 13, "诐": 14, "诪": 15, "谉": 16, "谞": 17, "住": 18, "注": 19, "祝": 20, "驻": 21, "抓": 22, "爪": 23, "拽": 24, "专": 25, "砖": 26, "转": 27, "|": 0, "[UNK]": 28, "[PAD]": 29}