sanchit-gandhi commited on
Commit
31d7cf2
·
1 Parent(s): 52dde8a
config.json ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "SpeechEncoderDecoderModel"
4
+ ],
5
+ "decoder": {
6
+ "_name_or_path": "facebook/bart-large-cnn",
7
+ "_num_labels": 3,
8
+ "activation_dropout": 0.0,
9
+ "activation_function": "gelu",
10
+ "add_cross_attention": true,
11
+ "add_final_layer_norm": false,
12
+ "architectures": [
13
+ "BartForConditionalGeneration"
14
+ ],
15
+ "attention_dropout": 0.0,
16
+ "bad_words_ids": null,
17
+ "bos_token_id": 0,
18
+ "chunk_size_feed_forward": 0,
19
+ "classif_dropout": 0.0,
20
+ "classifier_dropout": 0.0,
21
+ "cross_attention_hidden_size": null,
22
+ "d_model": 1024,
23
+ "decoder_attention_heads": 16,
24
+ "decoder_ffn_dim": 4096,
25
+ "decoder_layerdrop": 0.0,
26
+ "decoder_layers": 12,
27
+ "decoder_start_token_id": 2,
28
+ "diversity_penalty": 0.0,
29
+ "do_sample": false,
30
+ "dropout": 0.1,
31
+ "early_stopping": true,
32
+ "encoder_attention_heads": 16,
33
+ "encoder_ffn_dim": 4096,
34
+ "encoder_layerdrop": 0.0,
35
+ "encoder_layers": 12,
36
+ "encoder_no_repeat_ngram_size": 0,
37
+ "eos_token_id": 2,
38
+ "finetuning_task": null,
39
+ "force_bos_token_to_be_generated": true,
40
+ "forced_bos_token_id": 0,
41
+ "forced_eos_token_id": 2,
42
+ "gradient_checkpointing": false,
43
+ "id2label": {
44
+ "0": "LABEL_0",
45
+ "1": "LABEL_1",
46
+ "2": "LABEL_2"
47
+ },
48
+ "init_std": 0.02,
49
+ "is_decoder": true,
50
+ "is_encoder_decoder": false,
51
+ "label2id": {
52
+ "LABEL_0": 0,
53
+ "LABEL_1": 1,
54
+ "LABEL_2": 2
55
+ },
56
+ "length_penalty": 2.0,
57
+ "max_length": 142,
58
+ "max_position_embeddings": 1024,
59
+ "min_length": 56,
60
+ "model_type": "bart",
61
+ "no_repeat_ngram_size": 3,
62
+ "normalize_before": false,
63
+ "num_beam_groups": 1,
64
+ "num_beams": 4,
65
+ "num_hidden_layers": 12,
66
+ "num_return_sequences": 1,
67
+ "output_attentions": false,
68
+ "output_hidden_states": false,
69
+ "output_past": true,
70
+ "output_scores": false,
71
+ "pad_token_id": 1,
72
+ "prefix": " ",
73
+ "problem_type": null,
74
+ "pruned_heads": {},
75
+ "remove_invalid_values": false,
76
+ "repetition_penalty": 1.0,
77
+ "return_dict": true,
78
+ "return_dict_in_generate": false,
79
+ "scale_embedding": false,
80
+ "sep_token_id": null,
81
+ "task_specific_params": {
82
+ "summarization": {
83
+ "early_stopping": true,
84
+ "length_penalty": 2.0,
85
+ "max_length": 142,
86
+ "min_length": 56,
87
+ "no_repeat_ngram_size": 3,
88
+ "num_beams": 4
89
+ }
90
+ },
91
+ "temperature": 1.0,
92
+ "tie_encoder_decoder": false,
93
+ "tie_word_embeddings": true,
94
+ "tokenizer_class": null,
95
+ "top_k": 50,
96
+ "top_p": 1.0,
97
+ "torch_dtype": null,
98
+ "torchscript": false,
99
+ "transformers_version": "4.18.0.dev0",
100
+ "typical_p": 1.0,
101
+ "use_bfloat16": false,
102
+ "use_cache": true,
103
+ "vocab_size": 50264
104
+ },
105
+ "decoder_start_token_id": 0,
106
+ "encoder": {
107
+ "_name_or_path": "facebook/wav2vec2-large-lv60",
108
+ "activation_dropout": 0.1,
109
+ "adapter_kernel_size": 3,
110
+ "adapter_stride": 2,
111
+ "add_adapter": true,
112
+ "add_cross_attention": false,
113
+ "apply_spec_augment": true,
114
+ "architectures": [
115
+ "Wav2Vec2ForPreTraining"
116
+ ],
117
+ "attention_dropout": 0.1,
118
+ "bad_words_ids": null,
119
+ "bos_token_id": 1,
120
+ "chunk_size_feed_forward": 0,
121
+ "classifier_proj_size": 256,
122
+ "codevector_dim": 768,
123
+ "contrastive_logits_temperature": 0.1,
124
+ "conv_bias": true,
125
+ "conv_dim": [
126
+ 512,
127
+ 512,
128
+ 512,
129
+ 512,
130
+ 512,
131
+ 512,
132
+ 512
133
+ ],
134
+ "conv_kernel": [
135
+ 10,
136
+ 3,
137
+ 3,
138
+ 3,
139
+ 3,
140
+ 2,
141
+ 2
142
+ ],
143
+ "conv_stride": [
144
+ 5,
145
+ 2,
146
+ 2,
147
+ 2,
148
+ 2,
149
+ 2,
150
+ 2
151
+ ],
152
+ "cross_attention_hidden_size": null,
153
+ "ctc_loss_reduction": "sum",
154
+ "ctc_zero_infinity": false,
155
+ "decoder_start_token_id": null,
156
+ "diversity_loss_weight": 0.1,
157
+ "diversity_penalty": 0.0,
158
+ "do_sample": false,
159
+ "do_stable_layer_norm": true,
160
+ "early_stopping": false,
161
+ "encoder_no_repeat_ngram_size": 0,
162
+ "eos_token_id": 2,
163
+ "feat_extract_activation": "gelu",
164
+ "feat_extract_dropout": 0.0,
165
+ "feat_extract_norm": "layer",
166
+ "feat_proj_dropout": 0.0,
167
+ "feat_quantizer_dropout": 0.0,
168
+ "final_dropout": 0.0,
169
+ "finetuning_task": null,
170
+ "forced_bos_token_id": null,
171
+ "forced_eos_token_id": null,
172
+ "gradient_checkpointing": false,
173
+ "hidden_act": "gelu",
174
+ "hidden_dropout": 0.1,
175
+ "hidden_dropout_prob": 0.1,
176
+ "hidden_size": 1024,
177
+ "id2label": {
178
+ "0": "LABEL_0",
179
+ "1": "LABEL_1"
180
+ },
181
+ "initializer_range": 0.02,
182
+ "intermediate_size": 4096,
183
+ "is_decoder": false,
184
+ "is_encoder_decoder": false,
185
+ "label2id": {
186
+ "LABEL_0": 0,
187
+ "LABEL_1": 1
188
+ },
189
+ "layer_norm_eps": 1e-05,
190
+ "layerdrop": 0.0,
191
+ "length_penalty": 1.0,
192
+ "mask_feature_length": 10,
193
+ "mask_feature_min_masks": 0,
194
+ "mask_feature_prob": 0.0,
195
+ "mask_time_length": 10,
196
+ "mask_time_min_masks": 2,
197
+ "mask_time_prob": 0.1,
198
+ "max_length": 20,
199
+ "min_length": 0,
200
+ "model_type": "wav2vec2",
201
+ "no_repeat_ngram_size": 0,
202
+ "num_adapter_layers": 3,
203
+ "num_attention_heads": 16,
204
+ "num_beam_groups": 1,
205
+ "num_beams": 1,
206
+ "num_codevector_groups": 2,
207
+ "num_codevectors_per_group": 320,
208
+ "num_conv_pos_embedding_groups": 16,
209
+ "num_conv_pos_embeddings": 128,
210
+ "num_feat_extract_layers": 7,
211
+ "num_hidden_layers": 24,
212
+ "num_negatives": 100,
213
+ "num_return_sequences": 1,
214
+ "output_attentions": false,
215
+ "output_hidden_size": 1024,
216
+ "output_hidden_states": false,
217
+ "output_scores": false,
218
+ "pad_token_id": 0,
219
+ "prefix": null,
220
+ "problem_type": null,
221
+ "proj_codevector_dim": 768,
222
+ "pruned_heads": {},
223
+ "remove_invalid_values": false,
224
+ "repetition_penalty": 1.0,
225
+ "return_dict": true,
226
+ "return_dict_in_generate": false,
227
+ "sep_token_id": null,
228
+ "task_specific_params": null,
229
+ "tdnn_dilation": [
230
+ 1,
231
+ 2,
232
+ 3,
233
+ 1,
234
+ 1
235
+ ],
236
+ "tdnn_dim": [
237
+ 512,
238
+ 512,
239
+ 512,
240
+ 512,
241
+ 1500
242
+ ],
243
+ "tdnn_kernel": [
244
+ 5,
245
+ 3,
246
+ 3,
247
+ 1,
248
+ 1
249
+ ],
250
+ "temperature": 1.0,
251
+ "tie_encoder_decoder": false,
252
+ "tie_word_embeddings": true,
253
+ "tokenizer_class": null,
254
+ "top_k": 50,
255
+ "top_p": 1.0,
256
+ "torch_dtype": null,
257
+ "torchscript": false,
258
+ "transformers_version": "4.18.0.dev0",
259
+ "typical_p": 1.0,
260
+ "use_bfloat16": false,
261
+ "use_weighted_layer_sum": false,
262
+ "vocab_size": 32,
263
+ "xvector_output_dim": 512
264
+ },
265
+ "eos_token_id": 2,
266
+ "is_encoder_decoder": true,
267
+ "max_length": 40,
268
+ "model_type": "speech-encoder-decoder",
269
+ "pad_token_id": 1,
270
+ "processor_class": "Wav2Vec2Processor",
271
+ "transformers_version": null,
272
+ "use_cache": false
273
+ }
create_model.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jax.numpy as jnp
2
+ from transformers import AutoFeatureExtractor, AutoTokenizer, FlaxSpeechEncoderDecoderModel
3
+
4
+
5
+ encoder_id = "facebook/wav2vec2-large-lv60"
6
+ decoder_id = "facebook/bart-large-cnn"
7
+
8
+ model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(encoder_id, decoder_id, encoder_add_adapter=True)
9
+
10
+ model.config.encoder.feat_proj_dropout = 0.0
11
+ model.config.encoder.final_dropout = 0.0
12
+ model.config.encoder.mask_time_prob = 0.1
13
+ model.config.decoder_start_token_id = model.config.decoder.bos_token_id
14
+ model.config.pad_token_id = model.config.decoder.pad_token_id
15
+ model.config.eos_token_id = model.config.decoder.eos_token_id
16
+ model.config.max_length = 40
17
+ model.config.num_beams = 1
18
+ model.config.encoder.layerdrop = 0.0
19
+ model.config.use_cache = False
20
+ model.config.processor_class = "Wav2Vec2Processor"
21
+
22
+ # check if generation works
23
+ out = model.generate(jnp.ones((1, 2000)))
24
+
25
+ model.save_pretrained("./")
26
+
27
+ feature_extractor = AutoFeatureExtractor.from_pretrained(encoder_id)
28
+ feature_extractor.save_pretrained("./")
29
+ tokenizer = AutoTokenizer.from_pretrained(decoder_id)
30
+ tokenizer.save_pretrained("./")
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
preprocessor_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_normalize": true,
3
+ "feature_extractor_type": "Wav2Vec2FeatureExtractor",
4
+ "feature_size": 1,
5
+ "padding_side": "right",
6
+ "padding_value": 0.0,
7
+ "return_attention_mask": true,
8
+ "sampling_rate": 16000
9
+ }
run_flax_speech_recognition_seq2seq.py ADDED
@@ -0,0 +1,897 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2022 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the Flax library models for sequence to sequence speech recognition.
18
+ """
19
+ # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
20
+
21
+ import logging
22
+ import os
23
+ import sys
24
+ import time
25
+ from dataclasses import field
26
+ from functools import partial
27
+ from pathlib import Path
28
+ from typing import Any, Callable, Dict, List, Optional, Union
29
+
30
+ import datasets
31
+ import numpy as np
32
+ from datasets import DatasetDict, load_dataset, load_metric
33
+ from tqdm import tqdm
34
+
35
+ import flax
36
+ import jax
37
+ import jax.numpy as jnp
38
+ import optax
39
+ import transformers
40
+ from flax import jax_utils, traverse_util
41
+ from flax.jax_utils import unreplicate
42
+ from flax.training import train_state
43
+ from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
44
+ from huggingface_hub import Repository
45
+ from transformers import (
46
+ AutoConfig,
47
+ AutoFeatureExtractor,
48
+ AutoProcessor,
49
+ AutoTokenizer,
50
+ FlaxAutoModelForSpeechSeq2Seq,
51
+ HfArgumentParser,
52
+ Seq2SeqTrainingArguments,
53
+ is_tensorboard_available,
54
+ )
55
+ from transformers.file_utils import get_full_repo_name
56
+ from transformers.trainer_utils import get_last_checkpoint, is_main_process
57
+ from transformers.utils import check_min_version
58
+ from transformers.utils.versions import require_version
59
+
60
+
61
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
62
+ check_min_version("4.17.0.dev0")
63
+
64
+ require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
65
+
66
+ logger = logging.getLogger(__name__)
67
+
68
+
69
+ @flax.struct.dataclass
70
+ class ModelArguments:
71
+ """
72
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
73
+ """
74
+
75
+ model_name_or_path: str = field(
76
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
77
+ )
78
+ config_name: Optional[str] = field(
79
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
80
+ )
81
+ tokenizer_name: Optional[str] = field(
82
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
83
+ )
84
+ feature_extractor_name: Optional[str] = field(
85
+ default=None, metadata={"help": "feature extractor name or path if not the same as model_name"}
86
+ )
87
+ cache_dir: Optional[str] = field(
88
+ default=None,
89
+ metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
90
+ )
91
+ use_fast_tokenizer: bool = field(
92
+ default=True,
93
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
94
+ )
95
+ model_revision: str = field(
96
+ default="main",
97
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
98
+ )
99
+ use_auth_token: bool = field(
100
+ default=False,
101
+ metadata={
102
+ "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
103
+ "with private models)."
104
+ },
105
+ )
106
+ freeze_feature_encoder: bool = field(
107
+ default=True, metadata={"help": "Whether to freeze the feature encoder layers of the model."}
108
+ )
109
+
110
+
111
+ @flax.struct.dataclass
112
+ class DataTrainingArguments:
113
+ """
114
+ Arguments pertaining to what data we are going to input our model for training and eval.
115
+ """
116
+
117
+ dataset_name: str = field(
118
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
119
+ )
120
+ dataset_config_name: Optional[str] = field(
121
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
122
+ )
123
+ text_column: Optional[str] = field(
124
+ default=None,
125
+ metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
126
+ )
127
+ overwrite_cache: bool = field(
128
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
129
+ )
130
+ preprocessing_num_workers: Optional[int] = field(
131
+ default=None,
132
+ metadata={"help": "The number of processes to use for the preprocessing."},
133
+ )
134
+ max_train_samples: Optional[int] = field(
135
+ default=None,
136
+ metadata={
137
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
138
+ "value if set."
139
+ },
140
+ )
141
+ max_eval_samples: Optional[int] = field(
142
+ default=None,
143
+ metadata={
144
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
145
+ "value if set."
146
+ },
147
+ )
148
+ audio_column_name: str = field(
149
+ default="audio",
150
+ metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
151
+ )
152
+ text_column_name: str = field(
153
+ default="text",
154
+ metadata={"help": "The name of the dataset column containing the text data. Defaults to 'text'"},
155
+ )
156
+ max_duration_in_seconds: float = field(
157
+ default=20.0,
158
+ metadata={
159
+ "help": "Truncate audio files that are longer than `max_duration_in_seconds` seconds to 'max_duration_in_seconds`"
160
+ },
161
+ )
162
+ min_duration_in_seconds: float = field(
163
+ default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"}
164
+ )
165
+ max_target_length: Optional[int] = field(
166
+ default=128,
167
+ metadata={
168
+ "help": "The maximum total sequence length for target text after tokenization. Sequences longer "
169
+ "than this will be truncated, sequences shorter will be padded."
170
+ },
171
+ )
172
+ min_target_length: Optional[int] = field(
173
+ default=0,
174
+ metadata={
175
+ "help": "The minimum total sequence length for target text after tokenization. Sequences shorter "
176
+ "than this will be filtered."
177
+ },
178
+ )
179
+ pad_input_to_multiple_of: Optional[int] = field(
180
+ default=None,
181
+ metadata={
182
+ "help": "If set will pad the input sequence to a multiple of the provided value. This is important to avoid triggering recompilations on TPU"
183
+ },
184
+ )
185
+ pad_target_to_multiple_of: Optional[int] = field(
186
+ default=None,
187
+ metadata={
188
+ "help": "If set will pad the target sequence to a multiple of the provided value. This is important to avoid triggering recompilations on TPU"
189
+ },
190
+ )
191
+ preprocessing_only: bool = field(
192
+ default=False,
193
+ metadata={
194
+ "help": "Whether to only do data preprocessing and skip training. "
195
+ "This is especially useful when data preprocessing errors out in distributed training due to timeout. "
196
+ "In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` "
197
+ "so that the cached datasets can consequently be loaded in distributed training"
198
+ },
199
+ )
200
+ train_split_name: str = field(
201
+ default="train",
202
+ metadata={
203
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
204
+ },
205
+ )
206
+ eval_split_name: str = field(
207
+ default="test",
208
+ metadata={
209
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
210
+ },
211
+ )
212
+ do_lower_case: bool = field(
213
+ default=True,
214
+ metadata={"help": "Whether the target text should be lower cased."},
215
+ )
216
+
217
+
218
+ class TrainState(train_state.TrainState):
219
+ dropout_rng: jnp.ndarray
220
+
221
+ def replicate(self):
222
+ return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
223
+
224
+
225
+ def shift_tokens_right(label_ids: np.array, decoder_start_token_id: int) -> np.ndarray:
226
+ """
227
+ Shift label ids one token to the right.
228
+ """
229
+ shifted_label_ids = np.zeros_like(label_ids)
230
+ shifted_label_ids[:, 1:] = label_ids[:, :-1]
231
+ shifted_label_ids[:, 0] = decoder_start_token_id
232
+
233
+ return shifted_label_ids
234
+
235
+
236
+ @flax.struct.dataclass
237
+ class FlaxDataCollatorSpeechSeq2SeqWithPadding:
238
+ """
239
+ Data collator that will dynamically pad the inputs received.
240
+ Args:
241
+ processor ([`Wav2Vec2Processor`])
242
+ The processor used for proccessing the data.
243
+ decoder_start_token_id (`int`)
244
+ The begin-of-sentence of the decoder.
245
+ input_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
246
+ Select a strategy to pad the returned input sequences (according to the model's padding side and padding index)
247
+ among:
248
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
249
+ sequence if provided).
250
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
251
+ maximum acceptable input length for the model if that argument is not provided.
252
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
253
+ different lengths).
254
+ target_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
255
+ Select a strategy to pad the returned target sequences (according to the model's padding side and padding index).
256
+ See above for details.
257
+ max_input_length (:obj:`float`, `optional`):
258
+ Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
259
+ max_target_length (:obj:`int`, `optional`):
260
+ Maximum length of the ``labels`` of the returned list and optionally padding length (see above).
261
+ pad_input_to_multiple_of (:obj:`int`, `optional`):
262
+ If set will pad the input sequence to a multiple of the provided value.
263
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
264
+ 7.5 (Volta).
265
+ pad_target_to_multiple_of (:obj:`int`, `optional`):
266
+ If set will pad the target sequence to a multiple of the provided value.
267
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
268
+ 7.5 (Volta).
269
+ """
270
+
271
+ processor: Any
272
+ decoder_start_token_id: int
273
+ input_padding: Union[bool, str] = "max_length"
274
+ target_padding: Union[bool, str] = "max_length"
275
+ max_input_length: Optional[float] = None
276
+ max_target_length: Optional[int] = None
277
+ pad_input_to_multiple_of: Optional[int] = None
278
+ pad_target_to_multiple_of: Optional[int] = None
279
+
280
+ def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> Dict[str, np.ndarray]:
281
+ # split inputs and labels since they have to be of different lengths and need
282
+ # different padding methods
283
+ input_features = [{"input_values": feature["input_values"]} for feature in features]
284
+ label_features = [{"input_ids": feature["labels"]} for feature in features]
285
+
286
+ # reformat list to dict and set to pytorch format
287
+ batch = self.processor.feature_extractor.pad(
288
+ input_features,
289
+ max_length=self.max_input_length,
290
+ padding=self.input_padding,
291
+ pad_to_multiple_of=self.pad_input_to_multiple_of,
292
+ return_tensors="np",
293
+ )
294
+
295
+ labels_batch = self.processor.tokenizer.pad(
296
+ label_features,
297
+ max_length=self.max_target_length,
298
+ padding=self.target_padding,
299
+ pad_to_multiple_of=self.pad_target_to_multiple_of,
300
+ return_tensors="np",
301
+ )
302
+
303
+ # if bos token is appended in previous tokenization step,
304
+ # cut bos token here as it's append later anyways
305
+ labels = labels_batch["input_ids"]
306
+ if (labels[:, 0] == self.decoder_start_token_id).all().item():
307
+ labels = labels[:, 1:]
308
+ labels_batch.attention_mask = labels_batch.attention_mask[:, 1:]
309
+
310
+ decoder_input_ids = shift_tokens_right(labels, self.decoder_start_token_id)
311
+
312
+ # replace padding with -100 to ignore loss correctly
313
+ labels = np.ma.array(labels, mask=np.not_equal(labels_batch.attention_mask, 1))
314
+ labels = labels.filled(fill_value=-100)
315
+
316
+ batch["inputs"] = batch.pop("input_values")
317
+ batch["labels"] = labels
318
+ batch["decoder_input_ids"] = decoder_input_ids
319
+ # decoder_attention_mask known to give issues with nan's
320
+ # remove decoder_attention_mask as an arg for the time being - handled by the causal mask in XXXForCausalLM
321
+ # batch["decoder_attention_mask"] = labels_batch.attention_mask
322
+
323
+ return batch
324
+
325
+
326
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
327
+ summary_writer.scalar("train_time", train_time, step)
328
+
329
+ train_metrics = get_metrics(train_metrics)
330
+ for key, vals in train_metrics.items():
331
+ tag = f"train_{key}"
332
+ for i, val in enumerate(vals):
333
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
334
+
335
+
336
+ def write_eval_metric(summary_writer, eval_metrics, step):
337
+ for metric_name, value in eval_metrics.items():
338
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
339
+
340
+
341
+ def create_learning_rate_fn(
342
+ train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
343
+ ) -> Callable[[int], jnp.array]:
344
+ """Returns a linear warmup, linear_decay learning rate function."""
345
+ steps_per_epoch = train_ds_size // train_batch_size
346
+ num_train_steps = steps_per_epoch * num_train_epochs
347
+ warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
348
+ decay_fn = optax.linear_schedule(
349
+ init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
350
+ )
351
+ schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
352
+ return schedule_fn
353
+
354
+
355
+ def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
356
+ num_samples = len(samples_idx)
357
+ samples_to_remove = num_samples % batch_size
358
+
359
+ if samples_to_remove != 0:
360
+ samples_idx = samples_idx[:-samples_to_remove]
361
+ sections_split = num_samples // batch_size
362
+ batch_idx = np.split(samples_idx, sections_split)
363
+ return batch_idx
364
+
365
+
366
+ def main():
367
+ # 1. Parse input arguments
368
+ # See all possible arguments in src/transformers/training_args.py
369
+ # or by passing the --help flag to this script.
370
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
371
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
372
+
373
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
374
+ # If we pass only one argument to the script and it's the path to a json file,
375
+ # let's parse it to get our arguments.
376
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
377
+ else:
378
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
379
+
380
+ # 2. Setup logging
381
+ logging.basicConfig(
382
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
383
+ datefmt="%m/%d/%Y %H:%M:%S",
384
+ handlers=[logging.StreamHandler(sys.stdout)],
385
+ )
386
+ # We only want one process per machine to log things on the screen.
387
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
388
+ if jax.process_index() == 0:
389
+ datasets.utils.logging.set_verbosity_warning()
390
+ transformers.utils.logging.set_verbosity_info()
391
+ else:
392
+ datasets.utils.logging.set_verbosity_error()
393
+ transformers.utils.logging.set_verbosity_error()
394
+
395
+ # Log on each process the small summary:
396
+ logger.warning(
397
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
398
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
399
+ )
400
+
401
+ # Set the verbosity to info of the Transformers logger (on main process only):
402
+ if is_main_process(training_args.local_rank):
403
+ transformers.utils.logging.set_verbosity_info()
404
+ logger.info("Training/evaluation parameters %s", training_args)
405
+
406
+ logger.info(f"JAX devices: {jax.device_count()}")
407
+
408
+ # 3. Detecting last checkpoint and eventually continue from last checkpoint
409
+ last_checkpoint = None
410
+ if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
411
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
412
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
413
+ raise ValueError(
414
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
415
+ "Use --overwrite_output_dir to overcome."
416
+ )
417
+ elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
418
+ logger.info(
419
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
420
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
421
+ )
422
+
423
+ # 4. Load dataset
424
+ raw_datasets = DatasetDict()
425
+
426
+ if training_args.do_train:
427
+ raw_datasets["train"] = load_dataset(
428
+ data_args.dataset_name, data_args.dataset_config_name, split=data_args.train_split_name
429
+ )
430
+
431
+ if training_args.do_eval:
432
+ raw_datasets["eval"] = load_dataset(
433
+ data_args.dataset_name, data_args.dataset_config_name, split=data_args.eval_split_name
434
+ )
435
+
436
+ if data_args.audio_column_name not in next(iter(raw_datasets.values())).column_names:
437
+ raise ValueError(
438
+ f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. "
439
+ "Make sure to set `--audio_column_name` to the correct audio column - one of "
440
+ f"{', '.join(next(iter(raw_datasets.values())).column_names)}."
441
+ )
442
+
443
+ if data_args.text_column_name not in next(iter(raw_datasets.values())).column_names:
444
+ raise ValueError(
445
+ f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. "
446
+ "Make sure to set `--text_column_name` to the correct text column - one of "
447
+ f"{', '.join(next(iter(raw_datasets.values())).column_names)}."
448
+ )
449
+
450
+ # 5. Load pretrained model, tokenizer, and feature extractor
451
+ #
452
+ # Distributed training:
453
+ # The .from_pretrained methods guarantee that only one local process can concurrently
454
+ config = AutoConfig.from_pretrained(
455
+ model_args.config_name if model_args.config_name else model_args.model_name_or_path,
456
+ cache_dir=model_args.cache_dir,
457
+ revision=model_args.model_revision,
458
+ use_auth_token=True if model_args.use_auth_token else None,
459
+ )
460
+
461
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
462
+ model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path,
463
+ cache_dir=model_args.cache_dir,
464
+ revision=model_args.model_revision,
465
+ use_auth_token=True if model_args.use_auth_token else None,
466
+ )
467
+ tokenizer = AutoTokenizer.from_pretrained(
468
+ model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
469
+ cache_dir=model_args.cache_dir,
470
+ use_fast=model_args.use_fast_tokenizer,
471
+ revision=model_args.model_revision,
472
+ use_auth_token=True if model_args.use_auth_token else None,
473
+ )
474
+ model = FlaxAutoModelForSpeechSeq2Seq.from_pretrained(
475
+ model_args.model_name_or_path,
476
+ config=config,
477
+ cache_dir=model_args.cache_dir,
478
+ revision=model_args.model_revision,
479
+ use_auth_token=True if model_args.use_auth_token else None,
480
+ )
481
+
482
+ if model.config.decoder_start_token_id is None:
483
+ raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
484
+
485
+ # 6. Resample speech dataset if necessary
486
+ dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate
487
+ if dataset_sampling_rate != feature_extractor.sampling_rate:
488
+ raw_datasets = raw_datasets.cast_column(
489
+ data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
490
+ )
491
+
492
+ # 7. Preprocessing the datasets.
493
+ # We need to read the audio files as arrays and tokenize the targets.
494
+ max_input_length = int(data_args.max_duration_in_seconds * feature_extractor.sampling_rate)
495
+ min_input_length = int(data_args.min_duration_in_seconds * feature_extractor.sampling_rate)
496
+ max_target_length = data_args.max_target_length
497
+ min_target_length = data_args.min_target_length
498
+ pad_input_to_multiple_of = data_args.pad_input_to_multiple_of
499
+ pad_target_to_multiple_of = data_args.pad_target_to_multiple_of
500
+ audio_column_name = data_args.audio_column_name
501
+ num_workers = data_args.preprocessing_num_workers
502
+ text_column_name = data_args.text_column_name
503
+ model_input_name = feature_extractor.model_input_names[0]
504
+ do_lower_case = data_args.do_lower_case
505
+
506
+ if data_args.max_train_samples is not None:
507
+ raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
508
+
509
+ if data_args.max_eval_samples is not None:
510
+ raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))
511
+
512
+ def prepare_dataset(batch):
513
+ # process audio
514
+ sample = batch[audio_column_name]
515
+ inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
516
+ # process audio length
517
+ batch[model_input_name] = inputs.input_values[0]
518
+ batch["input_length"] = len(batch["input_values"])
519
+
520
+ # process targets
521
+ input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name]
522
+ batch["labels"] = tokenizer(input_str).input_ids
523
+ batch["labels_length"] = len(batch["labels"])
524
+ return batch
525
+
526
+ with training_args.main_process_first(desc="dataset map pre-processing"):
527
+ vectorized_datasets = raw_datasets.map(
528
+ prepare_dataset,
529
+ remove_columns=next(iter(raw_datasets.values())).column_names,
530
+ num_proc=data_args.preprocessing_num_workers,
531
+ desc="preprocess train dataset",
532
+ )
533
+
534
+ # filter data with inputs shorter than min_input_length or longer than
535
+ # max_input_length
536
+ def is_audio_in_length_range(length):
537
+ return length > min_input_length and length < max_input_length
538
+
539
+ vectorized_datasets = vectorized_datasets.filter(
540
+ is_audio_in_length_range,
541
+ num_proc=num_workers,
542
+ input_columns=["input_length"],
543
+ )
544
+
545
+ # filter data with targets shorter than min_target_length or longer than
546
+ # max_target_length
547
+ def is_labels_in_length_range(length):
548
+ return length > min_target_length and length < max_target_length
549
+
550
+ vectorized_datasets = vectorized_datasets.filter(
551
+ is_labels_in_length_range,
552
+ num_proc=num_workers,
553
+ input_columns=["labels_length"],
554
+ )
555
+
556
+ # for large datasets it is advised to run the preprocessing on a
557
+ # single machine first with `args.preprocessing_only` since there will mostly likely
558
+ # be a timeout when running the script in distributed mode.
559
+ # In a second step `args.preprocessing_only` can then be set to `False` to load the
560
+ # cached dataset
561
+ if data_args.preprocessing_only:
562
+ cache = {k: v.cache_files for k, v in vectorized_datasets.items()}
563
+ logger.info(f"Data preprocessing finished. Files cached at {cache}.")
564
+ return
565
+
566
+ # 8. Load Metric
567
+ metric = load_metric("wer")
568
+
569
+ def compute_metrics(pred_ids: List[List[int]], label_ids: List[List[int]]):
570
+ padded_ids = np.where(np.asarray(label_ids) == -100, tokenizer.pad_token_id, np.asarray(label_ids))
571
+
572
+ pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
573
+ # we do not want to group tokens when computing the metrics
574
+ label_str = tokenizer.batch_decode(padded_ids, skip_special_tokens=True)
575
+
576
+ wer = metric.compute(predictions=pred_str, references=label_str)
577
+
578
+ return {"wer": wer}
579
+
580
+ # 9. Create a single speech processor
581
+ if is_main_process(training_args.local_rank):
582
+ # save feature extractor, tokenizer and config
583
+ feature_extractor.save_pretrained(training_args.output_dir)
584
+ tokenizer.save_pretrained(training_args.output_dir)
585
+ config.save_pretrained(training_args.output_dir)
586
+
587
+ processor = AutoProcessor.from_pretrained(training_args.output_dir)
588
+
589
+ data_collator = FlaxDataCollatorSpeechSeq2SeqWithPadding(
590
+ processor=processor,
591
+ decoder_start_token_id=model.config.decoder_start_token_id,
592
+ input_padding="max_length",
593
+ target_padding="max_length",
594
+ max_input_length=max_input_length,
595
+ max_target_length=max_target_length,
596
+ pad_input_to_multiple_of=pad_input_to_multiple_of,
597
+ pad_target_to_multiple_of=pad_target_to_multiple_of,
598
+ )
599
+
600
+ # Enable tensorboard only on the master node
601
+ has_tensorboard = is_tensorboard_available()
602
+ if has_tensorboard and jax.process_index() == 0:
603
+ try:
604
+ from flax.metrics.tensorboard import SummaryWriter
605
+
606
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
607
+ except ImportError as ie:
608
+ has_tensorboard = False
609
+ logger.warning(
610
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
611
+ )
612
+ else:
613
+ logger.warning(
614
+ "Unable to display metrics through TensorBoard because the package is not installed: "
615
+ "Please run `pip install tensorboard` to enable."
616
+ )
617
+
618
+ # 10. Handle the repository creation
619
+ if training_args.push_to_hub:
620
+ if training_args.hub_model_id is None:
621
+ repo_name = get_full_repo_name(
622
+ Path(training_args.output_dir).absolute().name, token=training_args.hub_token
623
+ )
624
+ else:
625
+ repo_name = training_args.hub_model_id
626
+ repo = Repository(training_args.output_dir, clone_from=repo_name)
627
+
628
+ # 11. Initialize our training
629
+ rng = jax.random.PRNGKey(training_args.seed)
630
+ rng, dropout_rng = jax.random.split(rng)
631
+
632
+ # Store some constant
633
+ num_epochs = int(training_args.num_train_epochs)
634
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
635
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
636
+ steps_per_epoch = len(vectorized_datasets["train"]) // train_batch_size
637
+ total_train_steps = steps_per_epoch * num_epochs
638
+ gradient_accumulation_steps = int(training_args.gradient_accumulation_steps)
639
+
640
+ # Create learning rate schedule
641
+ linear_decay_lr_schedule_fn = create_learning_rate_fn(
642
+ len(vectorized_datasets["train"]),
643
+ train_batch_size,
644
+ training_args.num_train_epochs,
645
+ training_args.warmup_steps,
646
+ training_args.learning_rate,
647
+ )
648
+
649
+ # We use Optax's "masking" functionality to not apply weight decay
650
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
651
+ # mask boolean with the same structure as the parameters.
652
+ # The mask is True for parameters that should be decayed.
653
+ # Note that this mask is specifically adapted for FlaxBart.
654
+ # For FlaxT5, one should correct the layer norm parameter naming
655
+ # accordingly - see `run_t5_mlm_flax.py` e.g.
656
+ # TODO: check param dictionary of encoder and decoder match the layer_norm_params list
657
+ def decay_mask_fn(params):
658
+ flat_params = traverse_util.flatten_dict(params)
659
+ layer_norm_params = [
660
+ (name, "scale") for name in ["self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"]
661
+ ]
662
+ flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params}
663
+ return traverse_util.unflatten_dict(flat_mask)
664
+
665
+ # create adam optimizer
666
+ adamw = optax.adamw(
667
+ learning_rate=linear_decay_lr_schedule_fn,
668
+ b1=training_args.adam_beta1,
669
+ b2=training_args.adam_beta2,
670
+ eps=training_args.adam_epsilon,
671
+ weight_decay=training_args.weight_decay,
672
+ mask=decay_mask_fn,
673
+ )
674
+
675
+ # augment adam optimizer to facilitate gradient accumulation (ignore for now)
676
+ # optim = optax.chain(adamw, optax.apply_every(gradient_accumulation_steps))
677
+
678
+ # Setup train state
679
+ state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
680
+
681
+ # label smoothed cross entropy
682
+ def loss_fn(logits, labels, label_smoothing_factor=0.0):
683
+ """
684
+ The label smoothing implementation is adapted from Flax's official example:
685
+ https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
686
+ """
687
+ vocab_size = logits.shape[-1]
688
+ confidence = 1.0 - label_smoothing_factor
689
+ low_confidence = (1.0 - confidence) / (vocab_size - 1)
690
+ normalizing_constant = -(
691
+ confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
692
+ )
693
+ soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
694
+
695
+ loss = optax.softmax_cross_entropy(logits, soft_labels)
696
+ loss = loss - normalizing_constant
697
+
698
+ # ignore padded tokens from loss, i.e. where labels are not set to -100
699
+ padding = labels > 0
700
+ loss = loss * padding
701
+ loss = loss.sum() / padding.sum()
702
+ return loss
703
+
704
+ # Define gradient update step fn
705
+ def train_step(state, batch, label_smoothing_factor=0.0):
706
+ dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
707
+
708
+ def compute_loss(params):
709
+ labels = batch.pop("labels")
710
+ outputs = state.apply_fn(
711
+ **batch,
712
+ params=params,
713
+ dropout_rng=dropout_rng,
714
+ freeze_feature_encoder=model_args.freeze_feature_encoder,
715
+ return_dict=True,
716
+ output_attentions=True,
717
+ output_hidden_states=True,
718
+ train=True,
719
+ )
720
+ encoder_hidden_states = jnp.asarray(outputs.encoder_hidden_states)
721
+ encoder_outputs = outputs.encoder_last_hidden_state
722
+ decoder_hidden_states = jnp.asarray(outputs.decoder_hidden_states)
723
+ logits = outputs.logits
724
+
725
+ # check for nan in inputs by taking l2-norm over inputs
726
+ # a single nan in the inputs will return a nan when normed
727
+ logs = {"inputs": jnp.linalg.norm(batch["inputs"])}
728
+
729
+ # check for nan in encoder_hidden_states, encoder_outputs
730
+ logs["encoder_hidden_states"] = jnp.linalg.norm(
731
+ encoder_hidden_states.reshape(-1, encoder_hidden_states.shape[0]), axis=0
732
+ )
733
+ logs["encoder_outputs"] = jnp.linalg.norm(encoder_outputs)
734
+
735
+ # check for nan in decoder_hidden_states, decoder_outputs (logits)
736
+ logs["decoder_hidden_states"] = jnp.linalg.norm(
737
+ decoder_hidden_states.reshape(-1, decoder_hidden_states.shape[0]), axis=0
738
+ )
739
+ logs["logits"] = jnp.linalg.norm(logits)
740
+
741
+ loss = loss_fn(logits, labels, label_smoothing_factor)
742
+ # normalize loss over gradient accumulation steps (ignore for now)
743
+ # loss = loss / gradient_accumulation_steps
744
+ return loss, logs
745
+
746
+ grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
747
+ (loss, logs), grad = grad_fn(state.params)
748
+ # TODO: compute loss correctly over pmapped axis
749
+ grad = jax.lax.pmean(grad, "batch")
750
+
751
+ # compute gradient norm for monitoring
752
+ # (re-introduce when no nan's on forward pass, currently meaningless)
753
+ # grad_norm = jnp.linalg.norm(jax.tree_util.tree_leaves(jax.tree_map(jnp.linalg.norm, grad)))
754
+
755
+ new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
756
+
757
+ # don't log learning-rate and grad-norm until forward pass returns real-valued numbers
758
+ metrics = {"loss": loss}
759
+ metrics.update(logs)
760
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
761
+
762
+ return new_state, metrics
763
+
764
+ # Define eval fn
765
+ def eval_step(params, batch, label_smoothing_factor=0.0):
766
+ labels = batch.pop("labels")
767
+ logits = model(**batch, params=params, train=False)[0]
768
+ loss = loss_fn(logits, labels, label_smoothing_factor)
769
+
770
+ # summarize metrics
771
+ metrics = {"loss": loss}
772
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
773
+ return metrics
774
+
775
+ # Define generation function
776
+ gen_kwargs = {"max_length": training_args.generation_max_length, "num_beams": training_args.generation_num_beams}
777
+
778
+ def generate_step(params, batch):
779
+ model.params = params
780
+ output_ids = model.generate(batch["inputs"], **gen_kwargs)
781
+ return output_ids.sequences
782
+
783
+ # Create parallel version of the train and eval step
784
+ p_train_step = jax.pmap(
785
+ partial(train_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch", donate_argnums=(0,)
786
+ )
787
+ p_eval_step = jax.pmap(partial(eval_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch")
788
+ p_generate_step = jax.pmap(generate_step, "batch")
789
+
790
+ # Replicate the train state on each device
791
+ state = state.replicate()
792
+
793
+ logger.info("***** Running training *****")
794
+ logger.info(f" Num examples = {len(vectorized_datasets['train'])}")
795
+ logger.info(f" Num Epochs = {num_epochs}")
796
+ logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
797
+ logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
798
+ logger.info(f" Total optimization steps = {total_train_steps}")
799
+
800
+ train_time = 0
801
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
802
+ for epoch in epochs:
803
+ # ======================== Training ================================
804
+ train_start = time.time()
805
+
806
+ # Create sampling rng
807
+ rng, input_rng = jax.random.split(rng)
808
+ train_metrics = []
809
+
810
+ # Generate an epoch by shuffling sampling indices from the train dataset
811
+ num_train_samples = len(vectorized_datasets["train"])
812
+ train_samples_idx = np.random.permutation(np.arange(num_train_samples))
813
+ train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
814
+
815
+ # Gather the indexes for creating the batch and do a training step
816
+ for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
817
+ samples = [vectorized_datasets["train"][int(idx)] for idx in batch_idx]
818
+ batch = data_collator(samples)
819
+ batch = shard(batch.data)
820
+ state, train_metric = p_train_step(state, batch)
821
+ train_metrics.append(train_metric)
822
+
823
+ cur_step = epoch * (num_train_samples // train_batch_size) + step
824
+
825
+ if cur_step % training_args.logging_steps == 0 and cur_step > 0:
826
+ # Save metrics
827
+ train_metric = jax_utils.unreplicate(train_metric)
828
+ train_time += time.time() - train_start
829
+ # if has_tensorboard and jax.process_index() == 0:
830
+ # write_train_metric(summary_writer, train_metrics, train_time, cur_step)
831
+
832
+ # Log everything
833
+ metric_desc = " ".join([f"{key}: {value} |" for key, value in train_metric.items()])
834
+ epochs.write(f"Step... ({cur_step}) | {metric_desc}")
835
+
836
+ train_metrics = []
837
+
838
+ # epochs.write(
839
+ # f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
840
+ # )
841
+
842
+ continue
843
+ # ======================== Evaluating ==============================
844
+ eval_metrics = []
845
+ eval_preds = []
846
+ eval_labels = []
847
+
848
+ num_eval_samples = len(vectorized_datasets["eval"])
849
+ eval_samples_idx = jnp.arange(num_eval_samples)
850
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
851
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
852
+ samples = [vectorized_datasets["eval"][int(idx)] for idx in batch_idx]
853
+ batch = data_collator(samples)
854
+ batch = shard(batch.data)
855
+ labels = batch["labels"]
856
+
857
+ metrics = p_eval_step(state.params, batch)
858
+ eval_metrics.append(metrics)
859
+
860
+ # generation
861
+ if training_args.predict_with_generate:
862
+ generated_ids = p_generate_step(state.params, batch)
863
+ eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
864
+ eval_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1])))
865
+
866
+ # normalize eval metrics
867
+ eval_metrics = get_metrics(eval_metrics)
868
+ eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
869
+
870
+ # compute WER metric
871
+ wer_desc = ""
872
+ if training_args.predict_with_generate:
873
+ wer_metric = compute_metrics(eval_preds, eval_labels)
874
+ eval_metrics.update(wer_metric)
875
+ wer_desc = " ".join([f"Eval {key}: {value} |" for key, value in wer_metric.items()])
876
+
877
+ # Print metrics and update progress bar
878
+ desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {wer_desc})"
879
+ epochs.write(desc)
880
+ epochs.desc = desc
881
+
882
+ # Save metrics
883
+ if has_tensorboard and jax.process_index() == 0:
884
+ cur_step = epoch * (len(vectorized_datasets["train"]) // train_batch_size)
885
+ write_eval_metric(summary_writer, eval_metrics, cur_step)
886
+
887
+ # save checkpoint after each epoch and push checkpoint to the hub
888
+ if jax.process_index() == 0:
889
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
890
+ model.save_pretrained(training_args.output_dir, params=params)
891
+ tokenizer.save_pretrained(training_args.output_dir)
892
+ if training_args.push_to_hub:
893
+ repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)
894
+
895
+
896
+ if __name__ == "__main__":
897
+ main()
run_librispeech.sh ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ JAX_DEFAULT_MATMUL_PRECISION=float32 python run_flax_speech_recognition_seq2seq.py \
3
+ --dataset_name="librispeech_asr" \
4
+ --model_name_or_path="./" \
5
+ --dataset_config_name="clean" \
6
+ --train_split_name="train.100[:5%]" \
7
+ --eval_split_name="validation[:5%]" \
8
+ --output_dir="./" \
9
+ --preprocessing_num_workers="16" \
10
+ --length_column_name="input_length" \
11
+ --overwrite_output_dir \
12
+ --num_train_epochs="1" \
13
+ --per_device_train_batch_size="2" \
14
+ --per_device_eval_batch_size="2" \
15
+ --logging_steps="1" \
16
+ --max_duration_in_seconds="10" \
17
+ --max_target_length="32" \
18
+ --generation_max_length="40" \
19
+ --generation_num_beams="1" \
20
+ --learning_rate="3e-4" \
21
+ --warmup_steps="500" \
22
+ --text_column_name="text" \
23
+ --save_total_limit="1" \
24
+ --freeze_feature_encoder \
25
+ --predict_with_generate \
26
+ --do_lower_case \
27
+ --do_eval \
28
+ --do_train
29
+
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"bos_token": "<s>", "eos_token": "</s>", "unk_token": "<unk>", "sep_token": "</s>", "pad_token": "<pad>", "cls_token": "<s>", "mask_token": {"content": "<mask>", "single_word": false, "lstrip": true, "rstrip": false, "normalized": false}}
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"errors": "replace", "bos_token": "<s>", "eos_token": "</s>", "sep_token": "</s>", "cls_token": "<s>", "unk_token": "<unk>", "pad_token": "<pad>", "mask_token": "<mask>", "add_prefix_space": false, "trim_offsets": true, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "facebook/bart-large-cnn", "tokenizer_class": "BartTokenizer"}
vocab.json ADDED
The diff for this file is too large to render. See raw diff