salti commited on
Commit
ab26e54
·
1 Parent(s): be04ff1

add training and tokenization scripts

Browse files
run-t5v1_1-small.sh DELETED
@@ -1 +0,0 @@
1
- ../run-t5v1_1-small.sh
 
 
run-t5v1_1-small.sh ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export model_dir=arabic-t5-small
2
+ export train_batch_size=48
3
+ export eval_batch_size=96
4
+
5
+ python ./run_t5_mlm_flax.py \
6
+ --model_type t5 \
7
+ --config_name ${model_dir} \
8
+ --tokenizer_name ${model_dir} \
9
+ --use_fast_tokenizer True \
10
+ --dtype float32 \
11
+ --max_seq_length 512 \
12
+ --preprocessing_num_workers 96 \
13
+ --output_dir ${model_dir} \
14
+ --overwrite_output_dir True \
15
+ --do_train \
16
+ --per_device_train_batch_size ${train_batch_size} \
17
+ --per_device_eval_batch_size ${eval_batch_size} \
18
+ --learning_rate 1e-2 \
19
+ --num_train_epochs 1 \
20
+ --logging_steps 100 \
21
+ --eval_steps 1000 \
22
+ --save_steps 1000 \
23
+ --seed 12 \
24
+ --adafactor True \
25
+ --push_to_hub \
26
+ --cache_dir ./training_cache \
run_t5_mlm_flax.py ADDED
@@ -0,0 +1,684 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Pretraining the library models for T5-like span-masked language modeling on a text file or a dataset.
18
+
19
+ Here is the full list of checkpoints on the hub that can be pretrained by this script:
20
+ https://huggingface.co/models?filter=t5
21
+ """
22
+ # You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
23
+ import logging
24
+ import os
25
+ import sys
26
+ import time
27
+ from dataclasses import dataclass, field
28
+ from pathlib import Path
29
+ from typing import Dict, List, Optional
30
+
31
+ import numpy as np
32
+ from datasets import load_dataset, concatenate_datasets, load_from_disk
33
+ from tqdm import tqdm
34
+
35
+ import flax
36
+ import jax
37
+ import jax.numpy as jnp
38
+ import optax
39
+ from flax import jax_utils, traverse_util
40
+ from flax.training import train_state
41
+ from flax.training.checkpoints import save_checkpoint
42
+ from flax.training.common_utils import get_metrics, onehot, shard
43
+ from transformers import (
44
+ CONFIG_MAPPING,
45
+ FLAX_MODEL_FOR_MASKED_LM_MAPPING,
46
+ BatchEncoding,
47
+ FlaxT5ForConditionalGeneration,
48
+ HfArgumentParser,
49
+ PreTrainedTokenizerBase,
50
+ T5Config,
51
+ T5TokenizerFast,
52
+ TrainingArguments,
53
+ is_tensorboard_available,
54
+ set_seed,
55
+ )
56
+ from transformers.models.t5.modeling_flax_t5 import shift_tokens_right
57
+
58
+
59
+ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
60
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
61
+
62
+
63
+ @dataclass
64
+ class ModelArguments:
65
+ """
66
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
67
+ """
68
+
69
+ model_name_or_path: Optional[str] = field(
70
+ default=None,
71
+ metadata={
72
+ "help": "The model checkpoint for weights initialization."
73
+ "Don't set if you want to train a model from scratch."
74
+ },
75
+ )
76
+ model_type: Optional[str] = field(
77
+ default=None,
78
+ metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
79
+ )
80
+ config_name: Optional[str] = field(
81
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
82
+ )
83
+ tokenizer_name: Optional[str] = field(
84
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
85
+ )
86
+ cache_dir: Optional[str] = field(
87
+ default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
88
+ )
89
+ use_fast_tokenizer: bool = field(
90
+ default=True,
91
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
92
+ )
93
+ dtype: Optional[str] = field(
94
+ default="float32",
95
+ metadata={
96
+ "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
97
+ },
98
+ )
99
+
100
+
101
+ @dataclass
102
+ class DataTrainingArguments:
103
+ """
104
+ Arguments pertaining to what data we are going to input our model for training and eval.
105
+ """
106
+
107
+ overwrite_cache: bool = field(
108
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
109
+ )
110
+ max_seq_length: Optional[int] = field(
111
+ default=None,
112
+ metadata={
113
+ "help": "The maximum total input sequence length after tokenization and masking. Sequences longer than this will be truncated. Default to the max input length of the model."
114
+ },
115
+ )
116
+ preprocessing_num_workers: Optional[int] = field(
117
+ default=None,
118
+ metadata={"help": "The number of processes to use for the preprocessing."},
119
+ )
120
+ mlm_probability: float = field(
121
+ default=0.15, metadata={"help": "Ratio of tokens to mask for span masked language modeling loss"}
122
+ )
123
+ mean_noise_span_length: float = field(
124
+ default=3.0,
125
+ metadata={"help": "Mean span length of masked tokens"},
126
+ )
127
+
128
+
129
+ def compute_input_and_target_lengths(inputs_length, noise_density, mean_noise_span_length):
130
+ """This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2466>`__ .
131
+
132
+ Training parameters to avoid padding with random_spans_noise_mask.
133
+ When training a model with random_spans_noise_mask, we would like to set the other
134
+ training hyperparmeters in a way that avoids padding.
135
+ This function helps us compute these hyperparameters.
136
+ We assume that each noise span in the input is replaced by extra_tokens_per_span_inputs sentinel tokens,
137
+ and each non-noise span in the targets is replaced by extra_tokens_per_span_targets sentinel tokens.
138
+ This function tells us the required number of tokens in the raw example (for split_tokens())
139
+ as well as the length of the encoded targets. Note that this function assumes
140
+ the inputs and targets will have EOS appended and includes that in the reported length.
141
+
142
+ Args:
143
+ inputs_length: an integer - desired length of the tokenized inputs sequence
144
+ noise_density: a float
145
+ mean_noise_span_length: a float
146
+ Returns:
147
+ tokens_length: length of original text in tokens
148
+ targets_length: an integer - length in tokens of encoded targets sequence
149
+ """
150
+
151
+ def _tokens_length_to_inputs_length_targets_length(tokens_length):
152
+ num_noise_tokens = int(round(tokens_length * noise_density))
153
+ num_nonnoise_tokens = tokens_length - num_noise_tokens
154
+ num_noise_spans = int(round(num_noise_tokens / mean_noise_span_length))
155
+ # inputs contain all nonnoise tokens, sentinels for all noise spans
156
+ # and one EOS token.
157
+ _input_length = num_nonnoise_tokens + num_noise_spans + 1
158
+ _output_length = num_noise_tokens + num_noise_spans + 1
159
+ return _input_length, _output_length
160
+
161
+ tokens_length = inputs_length
162
+
163
+ while _tokens_length_to_inputs_length_targets_length(tokens_length + 1)[0] <= inputs_length:
164
+ tokens_length += 1
165
+
166
+ inputs_length, targets_length = _tokens_length_to_inputs_length_targets_length(tokens_length)
167
+
168
+ # minor hack to get the targets length to be equal to inputs length
169
+ # which is more likely to have been set to a nice round number.
170
+ if noise_density == 0.5 and targets_length > inputs_length:
171
+ tokens_length -= 1
172
+ targets_length -= 1
173
+ return tokens_length, targets_length
174
+
175
+
176
+ @flax.struct.dataclass
177
+ class FlaxDataCollatorForT5MLM:
178
+ """
179
+ Data collator used for T5 span-masked language modeling.
180
+ It is made sure that after masking the inputs are of length `data_args.max_seq_length` and targets are also of fixed length.
181
+ For more information on how T5 span-masked language modeling works, one can take a look
182
+ at the `official paper <https://arxiv.org/pdf/1910.10683.pdf>`__
183
+ or the `official code for preprocessing <https://github.com/google-research/text-to-text-transfer-transformer/blob/master/t5/data/preprocessors.py>`__ .
184
+
185
+ Args:
186
+ tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
187
+ The tokenizer used for encoding the data.
188
+ noise_density (:obj:`float`):
189
+ The probability with which to (randomly) mask tokens in the input.
190
+ mean_noise_span_length (:obj:`float`):
191
+ The average span length of the masked tokens.
192
+ input_length (:obj:`int`):
193
+ The expected input length after masking.
194
+ target_length (:obj:`int`):
195
+ The expected target length after masking.
196
+ pad_token_id: (:obj:`int`):
197
+ The pad token id of the model
198
+ decoder_start_token_id: (:obj:`int):
199
+ The decoder start token id of the model
200
+ """
201
+
202
+ tokenizer: PreTrainedTokenizerBase
203
+ noise_density: float
204
+ mean_noise_span_length: float
205
+ input_length: int
206
+ target_length: int
207
+ pad_token_id: int
208
+ decoder_start_token_id: int
209
+
210
+ def __call__(self, examples: List[Dict[str, np.ndarray]]) -> Dict[str, np.ndarray]:
211
+
212
+ # convert list to dict and tensorize input
213
+ batch = BatchEncoding(
214
+ {k: np.array([examples[i][k] for i in range(len(examples))]) for k, v in examples[0].items()}
215
+ )
216
+
217
+ input_ids = batch["input_ids"]
218
+ batch_size, expandend_input_length = input_ids.shape
219
+
220
+ mask_indices = np.asarray([self.random_spans_noise_mask(expandend_input_length) for i in range(batch_size)])
221
+ labels_mask = ~mask_indices
222
+
223
+ input_ids_sentinel = self.create_sentinel_ids(mask_indices.astype(np.int8))
224
+ labels_sentinel = self.create_sentinel_ids(labels_mask.astype(np.int8))
225
+
226
+ batch["input_ids"] = self.filter_input_ids(input_ids, input_ids_sentinel)
227
+ batch["labels"] = self.filter_input_ids(input_ids, labels_sentinel)
228
+
229
+ if batch["input_ids"].shape[-1] != self.input_length:
230
+ raise ValueError(
231
+ f"`input_ids` are incorrectly preprocessed. `input_ids` length is {batch['input_ids'].shape[-1]}, but should be {self.target_length}."
232
+ )
233
+
234
+ if batch["labels"].shape[-1] != self.target_length:
235
+ raise ValueError(
236
+ f"`labels` are incorrectly preprocessed. `labels` length is {batch['labels'].shape[-1]}, but should be {self.target_length}."
237
+ )
238
+
239
+ # to check that tokens are correctly proprocessed, one can run `self.tokenizer.batch_decode(input_ids)` and `self.tokenizer.batch_decode(labels)` here...
240
+ batch["decoder_input_ids"] = shift_tokens_right(
241
+ batch["labels"], self.pad_token_id, self.decoder_start_token_id
242
+ )
243
+
244
+ return batch
245
+
246
+ def create_sentinel_ids(self, mask_indices):
247
+ """
248
+ Sentinel ids creation given the indices that should be masked.
249
+ The start indices of each mask are replaced by the sentinel ids in increasing
250
+ order. Consecutive mask indices to be deleted are replaced with `-1`.
251
+ """
252
+ start_indices = mask_indices - np.roll(mask_indices, 1, axis=-1) * mask_indices
253
+ start_indices[:, 0] = mask_indices[:, 0]
254
+
255
+ sentinel_ids = np.where(start_indices != 0, np.cumsum(start_indices, axis=-1), start_indices)
256
+ sentinel_ids = np.where(sentinel_ids != 0, (sentinel_ids + self.tokenizer.vocab_size - 1), 0)
257
+ sentinel_ids -= mask_indices - start_indices
258
+
259
+ return sentinel_ids
260
+
261
+ def filter_input_ids(self, input_ids, sentinel_ids):
262
+ """
263
+ Puts sentinel mask on `input_ids` and fuse consecutive mask tokens into a single mask token by deleting.
264
+ This will reduce the sequence length from `expanded_inputs_length` to `input_length`.
265
+ """
266
+ batch_size = input_ids.shape[0]
267
+
268
+ input_ids_full = np.where(sentinel_ids != 0, sentinel_ids, input_ids)
269
+ input_ids = input_ids_full[input_ids_full > 0].reshape((batch_size, -1))
270
+ input_ids = np.concatenate(
271
+ [input_ids, np.full((batch_size, 1), self.tokenizer.eos_token_id, dtype=np.int32)], axis=-1
272
+ )
273
+ return input_ids
274
+
275
+ def random_spans_noise_mask(self, length):
276
+
277
+ """This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2682>`__ .
278
+
279
+ Noise mask consisting of random spans of noise tokens.
280
+ The number of noise tokens and the number of noise spans and non-noise spans
281
+ are determined deterministically as follows:
282
+ num_noise_tokens = round(length * noise_density)
283
+ num_nonnoise_spans = num_noise_spans = round(num_noise_tokens / mean_noise_span_length)
284
+ Spans alternate between non-noise and noise, beginning with non-noise.
285
+ Subject to the above restrictions, all masks are equally likely.
286
+
287
+ Args:
288
+ length: an int32 scalar (length of the incoming token sequence)
289
+ noise_density: a float - approximate density of output mask
290
+ mean_noise_span_length: a number
291
+
292
+ Returns:
293
+ a boolean tensor with shape [length]
294
+ """
295
+
296
+ orig_length = length
297
+
298
+ num_noise_tokens = int(np.round(length * self.noise_density))
299
+ # avoid degeneracy by ensuring positive numbers of noise and nonnoise tokens.
300
+ num_noise_tokens = min(max(num_noise_tokens, 1), length - 1)
301
+ num_noise_spans = int(np.round(num_noise_tokens / self.mean_noise_span_length))
302
+
303
+ # avoid degeneracy by ensuring positive number of noise spans
304
+ num_noise_spans = max(num_noise_spans, 1)
305
+ num_nonnoise_tokens = length - num_noise_tokens
306
+
307
+ # pick the lengths of the noise spans and the non-noise spans
308
+ def _random_segmentation(num_items, num_segments):
309
+ """Partition a sequence of items randomly into non-empty segments.
310
+ Args:
311
+ num_items: an integer scalar > 0
312
+ num_segments: an integer scalar in [1, num_items]
313
+ Returns:
314
+ a Tensor with shape [num_segments] containing positive integers that add
315
+ up to num_items
316
+ """
317
+ mask_indices = np.arange(num_items - 1) < (num_segments - 1)
318
+ np.random.shuffle(mask_indices)
319
+ first_in_segment = np.pad(mask_indices, [[1, 0]])
320
+ segment_id = np.cumsum(first_in_segment)
321
+ segment_length = np.asarray(jax.ops.segment_sum(np.ones_like(segment_id), segment_id))
322
+ return segment_length
323
+
324
+ noise_span_lengths = _random_segmentation(num_noise_tokens, num_noise_spans)
325
+ nonnoise_span_lengths = _random_segmentation(num_nonnoise_tokens, num_noise_spans)
326
+
327
+ interleaved_span_lengths = np.reshape(
328
+ np.stack([nonnoise_span_lengths, noise_span_lengths], axis=1), [num_noise_spans * 2]
329
+ )
330
+ span_starts = np.cumsum(interleaved_span_lengths)[:-1]
331
+ span_start_indicator = np.zeros((length,), dtype=np.int8)
332
+ span_start_indicator[span_starts] = True
333
+ span_num = np.cumsum(span_start_indicator)
334
+ is_noise = np.equal(span_num % 2, 1)
335
+
336
+ return is_noise[:orig_length]
337
+
338
+
339
+ def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
340
+ num_samples = len(samples_idx)
341
+ samples_to_remove = num_samples % batch_size
342
+
343
+ if samples_to_remove != 0:
344
+ samples_idx = samples_idx[:-samples_to_remove]
345
+ sections_split = num_samples // batch_size
346
+ batch_idx = np.split(samples_idx, sections_split)
347
+ return batch_idx
348
+
349
+
350
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
351
+ summary_writer.scalar("train_time", train_time, step)
352
+
353
+ train_metrics = get_metrics(train_metrics)
354
+ for key, vals in train_metrics.items():
355
+ tag = f"train_{key}"
356
+ for i, val in enumerate(vals):
357
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
358
+
359
+
360
+ def write_eval_metric(summary_writer, eval_metrics, step):
361
+ for metric_name, value in eval_metrics.items():
362
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
363
+
364
+
365
+ if __name__ == "__main__":
366
+ # See all possible arguments in src/transformers/training_args.py
367
+ # or by passing the --help flag to this script.
368
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
369
+
370
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
371
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
372
+ # If we pass only one argument to the script and it's the path to a json file,
373
+ # let's parse it to get our arguments.
374
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
375
+ else:
376
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
377
+
378
+ if (
379
+ os.path.exists(training_args.output_dir)
380
+ and os.listdir(training_args.output_dir)
381
+ and training_args.do_train
382
+ and not training_args.overwrite_output_dir
383
+ ):
384
+ raise ValueError(
385
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
386
+ "Use --overwrite_output_dir to overcome."
387
+ )
388
+
389
+ # Setup logging
390
+ logging.basicConfig(
391
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
392
+ level="NOTSET",
393
+ datefmt="[%X]",
394
+ )
395
+
396
+ # Log on each process the small summary:
397
+ logger = logging.getLogger(__name__)
398
+
399
+ # Set the verbosity to info of the Transformers logger (on main process only):
400
+ logger.info(f"Training/evaluation parameters {training_args}")
401
+
402
+ # Set seed before initializing model.
403
+ set_seed(training_args.seed)
404
+
405
+ # Load pretrained model and tokenizer
406
+
407
+ if model_args.tokenizer_name:
408
+ tokenizer = T5TokenizerFast.from_pretrained(
409
+ model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
410
+ )
411
+ elif model_args.model_name_or_path:
412
+ tokenizer = T5TokenizerFast.from_pretrained(
413
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
414
+ )
415
+ else:
416
+ raise ValueError(
417
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
418
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
419
+ )
420
+
421
+ if model_args.config_name:
422
+ config = T5Config.from_pretrained(
423
+ model_args.config_name, cache_dir=model_args.cache_dir, vocab_size=len(tokenizer)
424
+ )
425
+ elif model_args.model_name_or_path:
426
+ config = T5Config.from_pretrained(
427
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, vocab_size=len(tokenizer)
428
+ )
429
+ else:
430
+ config = CONFIG_MAPPING[model_args.model_type]()
431
+ logger.warning("You are instantiating a new config instance from scratch.")
432
+
433
+
434
+ max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
435
+
436
+ # T5-like span masked language modeling will fuse consecutively masked tokens to a single sentinel token.
437
+ # To ensure that the input length is `max_seq_length`, we need to increase the maximum length
438
+ # according to `mlm_probability` and `mean_noise_span_length`. We can also define the label length accordingly.
439
+ expanded_inputs_length, targets_length = compute_input_and_target_lengths(
440
+ inputs_length=max_seq_length,
441
+ noise_density=data_args.mlm_probability,
442
+ mean_noise_span_length=data_args.mean_noise_span_length,
443
+ )
444
+
445
+ # load the tokenized and grouped dataset
446
+ tokenized_datasets = load_from_disk("./training_cache")
447
+
448
+ # Enable tensorboard only on the master node
449
+ has_tensorboard = is_tensorboard_available()
450
+ if has_tensorboard and jax.process_index() == 0:
451
+ try:
452
+ from flax.metrics.tensorboard import SummaryWriter
453
+
454
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
455
+ except ImportError as ie:
456
+ has_tensorboard = False
457
+ logger.warning(
458
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
459
+ )
460
+ else:
461
+ logger.warning(
462
+ "Unable to display metrics through TensorBoard because the package is not installed: "
463
+ "Please run pip install tensorboard to enable."
464
+ )
465
+
466
+ # Initialize our training
467
+ rng = jax.random.PRNGKey(training_args.seed)
468
+ dropout_rngs = jax.random.split(rng, jax.local_device_count())
469
+
470
+ logger.info(
471
+ f"JAX devices:\n{jax.devices()}\nNum devices: {jax.device_count()}\nBackend: {jax.lib.xla_bridge.get_backend().platform}"
472
+ )
473
+
474
+ logger.info(
475
+ "\n==================================================Initializing the model==================================================\n"
476
+ )
477
+
478
+ if model_args.model_name_or_path:
479
+ model = FlaxT5ForConditionalGeneration.from_pretrained(
480
+ model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
481
+ )
482
+ else:
483
+ model = FlaxT5ForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
484
+
485
+ logger.info(
486
+ "\n==================================================Done!==================================================\n"
487
+ )
488
+
489
+ # Data collator
490
+ # This one will take care of randomly masking the tokens.
491
+ data_collator = FlaxDataCollatorForT5MLM(
492
+ tokenizer=tokenizer,
493
+ noise_density=data_args.mlm_probability,
494
+ mean_noise_span_length=data_args.mean_noise_span_length,
495
+ input_length=max_seq_length,
496
+ target_length=targets_length,
497
+ pad_token_id=model.config.pad_token_id,
498
+ decoder_start_token_id=model.config.decoder_start_token_id,
499
+ )
500
+
501
+ # Store some constant
502
+ num_epochs = int(training_args.num_train_epochs)
503
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
504
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
505
+
506
+ num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
507
+
508
+ # Create learning rate schedule
509
+ warmup_steps = num_train_steps * 5 // 100
510
+ warmup_fn = optax.linear_schedule(
511
+ init_value=0.0, end_value=training_args.learning_rate, transition_steps=warmup_steps
512
+ )
513
+ decay_fn = optax.linear_schedule(
514
+ init_value=training_args.learning_rate,
515
+ end_value=0,
516
+ transition_steps=num_train_steps - warmup_steps,
517
+ )
518
+ linear_decay_lr_schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[warmup_steps])
519
+
520
+ # We use Optax's "masking" functionality to not apply weight decay
521
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
522
+ # mask boolean with the same structure as the parameters.
523
+ # The mask is True for parameters that should be decayed.
524
+ def decay_mask_fn(params):
525
+ flat_params = traverse_util.flatten_dict(params)
526
+ flat_mask = {
527
+ path: (path[-1] != "bias" and path[-2:] not in [("layer_norm", "scale"), ("final_layer_norm", "scale")])
528
+ for path in flat_params
529
+ }
530
+ return traverse_util.unflatten_dict(flat_mask)
531
+
532
+ # create adam optimizer
533
+ if training_args.adafactor:
534
+ # We use the default parameters here to initialize adafactor,
535
+ # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
536
+ optimizer = optax.adafactor(
537
+ learning_rate=linear_decay_lr_schedule_fn,
538
+ )
539
+ else:
540
+ optimizer = optax.adamw(
541
+ learning_rate=linear_decay_lr_schedule_fn,
542
+ b1=training_args.adam_beta1,
543
+ b2=training_args.adam_beta2,
544
+ weight_decay=training_args.weight_decay,
545
+ mask=decay_mask_fn,
546
+ )
547
+
548
+ # Setup train state
549
+ state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)
550
+
551
+ # Define gradient update step fn
552
+ def train_step(state, batch, dropout_rng):
553
+ dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
554
+
555
+ def loss_fn(params):
556
+ labels = batch.pop("labels")
557
+
558
+ logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
559
+
560
+ # compute loss
561
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])).mean()
562
+
563
+ return loss
564
+
565
+ grad_fn = jax.value_and_grad(loss_fn)
566
+ loss, grad = grad_fn(state.params)
567
+ grad = jax.lax.pmean(grad, "batch")
568
+ new_state = state.apply_gradients(grads=grad)
569
+
570
+ metrics = jax.lax.pmean(
571
+ {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch"
572
+ )
573
+
574
+ return new_state, metrics, new_dropout_rng
575
+
576
+ # Create parallel version of the train step
577
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
578
+
579
+ # Define eval fn
580
+ def eval_step(params, batch):
581
+ labels = batch.pop("labels")
582
+
583
+ logits = model(**batch, params=params, train=False)[0]
584
+
585
+ # compute loss
586
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1]))
587
+
588
+ # compute accuracy
589
+ accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels)
590
+
591
+ # summarize metrics
592
+ metrics = {"loss": loss.mean(), "accuracy": accuracy.mean()}
593
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
594
+
595
+ return metrics
596
+
597
+ p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,))
598
+
599
+ # Replicate the train state on each device
600
+ state = jax_utils.replicate(state)
601
+
602
+ train_time = 0
603
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
604
+ for epoch in epochs:
605
+ # ======================== Training ================================
606
+ train_start = time.time()
607
+ train_metrics = []
608
+
609
+ # Create sampling rng
610
+ rng, input_rng = jax.random.split(rng)
611
+
612
+ # Generate an epoch by shuffling sampling indices from the train dataset
613
+ num_train_samples = len(tokenized_datasets["train"])
614
+ train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples))
615
+ train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
616
+
617
+ # Gather the indexes for creating the batch and do a training step
618
+ for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
619
+ samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
620
+ model_inputs = data_collator(samples)
621
+
622
+ # Model forward
623
+ model_inputs = shard(model_inputs.data)
624
+ state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
625
+ train_metrics.append(train_metric)
626
+
627
+ cur_step = epoch * (num_train_samples // train_batch_size) + step
628
+
629
+ if cur_step % training_args.logging_steps == 0 and cur_step > 0:
630
+ # Save metrics
631
+ train_metric = jax_utils.unreplicate(train_metric)
632
+ train_time += time.time() - train_start
633
+ if has_tensorboard and jax.process_index() == 0:
634
+ write_train_metric(summary_writer, train_metrics, train_time, cur_step)
635
+
636
+ epochs.write(
637
+ f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
638
+ )
639
+
640
+ train_metrics = []
641
+
642
+ if cur_step % training_args.eval_steps == 0 and cur_step > 0:
643
+ # ======================== Evaluating ==============================
644
+ num_eval_samples = len(tokenized_datasets["validation"])
645
+ eval_samples_idx = jnp.arange(num_eval_samples)
646
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
647
+
648
+ eval_metrics = []
649
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
650
+ samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
651
+ model_inputs = data_collator(samples)
652
+
653
+ # Model forward
654
+ model_inputs = shard(model_inputs.data)
655
+ metrics = p_eval_step(state.params, model_inputs)
656
+ eval_metrics.append(metrics)
657
+
658
+ # get eval metrics
659
+ eval_metrics = get_metrics(eval_metrics)
660
+ eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
661
+
662
+ # Update progress bar
663
+ epochs.write(f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})")
664
+
665
+ # Save metrics
666
+ if has_tensorboard and jax.process_index() == 0:
667
+ write_eval_metric(summary_writer, eval_metrics, cur_step)
668
+
669
+ if cur_step % training_args.save_steps == 0 and cur_step > 0:
670
+ # save checkpoint after each epoch and push checkpoint to the hub
671
+ if jax.process_index() == 0:
672
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
673
+ save_checkpoint(
674
+ ckpt_dir=training_args.output_dir,
675
+ target=jax_utils.unreplicate(state),
676
+ step=cur_step,
677
+ overwrite=True,
678
+ )
679
+ model.save_pretrained(
680
+ training_args.output_dir,
681
+ params=params,
682
+ push_to_hub=training_args.push_to_hub,
683
+ commit_message=f"Saving weights and logs of step {cur_step}",
684
+ )
t5_tokenizer_model.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import json
3
+ from typing import Iterator, List, Union
4
+
5
+ from tokenizers import AddedToken, Regex, Tokenizer, decoders, normalizers, pre_tokenizers, trainers
6
+ from tokenizers.implementations.base_tokenizer import BaseTokenizer
7
+ from tokenizers.models import Unigram
8
+ from tokenizers.processors import TemplateProcessing
9
+
10
+
11
+ class SentencePieceUnigramTokenizer(BaseTokenizer):
12
+ """
13
+ This class is a copy of `DeDLOC's tokenizer implementation <https://github.com/yandex-research/DeDLOC/blob/main/sahajbert/tokenizer/tokenizer_model.py>`__ .
14
+
15
+ Custom SentencePiece Unigram Tokenizer with NMT, NKFC, spaces and lower-casing characters normalization
16
+ Represents the Unigram algorithm, with the pretokenization used by SentencePiece
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ replacement: str = "▁",
22
+ add_prefix_space: bool = True,
23
+ unk_token: Union[str, AddedToken] = "<unk>",
24
+ eos_token: Union[str, AddedToken] = "</s>",
25
+ pad_token: Union[str, AddedToken] = "<pad>",
26
+ ):
27
+ self.special_tokens = {
28
+ "pad": {"id": 0, "token": pad_token},
29
+ "eos": {"id": 1, "token": eos_token},
30
+ "unk": {"id": 2, "token": unk_token},
31
+ }
32
+
33
+ self.special_tokens_list = [None] * len(self.special_tokens)
34
+ for token_dict in self.special_tokens.values():
35
+ self.special_tokens_list[token_dict["id"]] = token_dict["token"]
36
+
37
+ tokenizer = Tokenizer(Unigram())
38
+
39
+ # the following regexes are taken directly from https://github.com/aub-mind/arabert/blob/f92f06a29804f74878e2d1e39ea57fba8dcb0eac/preprocess.py
40
+ url = " [رابط] "
41
+ email = " [بريد] "
42
+ usr = " [مستخدم] "
43
+
44
+ url_regexes = [
45
+ r"(http(s)?:\/\/.)?(www\.)?[-a-zA-Z0-9@:%._\+~#=]{2,256}\.[a-z]{2,6}\b([-a-zA-Z0-9@:%_\+.~#?&//=]*)",
46
+ r"@(https?|ftp)://(-\.)?([^\s/?\.#-]+\.?)+(/[^\s]*)?$@iS",
47
+ r"http[s]?://[a-zA-Z0-9_\-./~\?=%&]+",
48
+ r"www[a-zA-Z0-9_\-?=%&/.~]+",
49
+ r"[a-zA-Z]+\.com",
50
+ r"(?=http)[^\s]+",
51
+ r"(?=www)[^\s]+",
52
+ r"://",
53
+ ]
54
+
55
+ email_regexes = [r"[\w-]+@([\w-]+\.)+[\w-]+", r"\S+@\S+"]
56
+
57
+ user_mention_regex = r"@[\w\d]+"
58
+
59
+ tokenizer.normalizer = normalizers.Sequence(
60
+ [
61
+ normalizers.Nmt(),
62
+ normalizers.NFKC(),
63
+ # remove links, emails, user mentions ans hashtags
64
+ *[normalizers.Replace(Regex(r), url) for r in url_regexes],
65
+ *[normalizers.Replace(Regex(r), email) for r in email_regexes],
66
+ normalizers.Replace(Regex(user_mention_regex), usr),
67
+ # remove html
68
+ normalizers.Replace(Regex("<br />"), " "),
69
+ normalizers.Replace(Regex("</?[^>]+>"), " "),
70
+ # remove extra white space
71
+ normalizers.Replace(Regex(" {2,}"), " "),
72
+ normalizers.Lowercase(),
73
+ ]
74
+ )
75
+ tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
76
+ [
77
+ pre_tokenizers.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space),
78
+ pre_tokenizers.Digits(individual_digits=True),
79
+ pre_tokenizers.Punctuation(),
80
+ ]
81
+ )
82
+ tokenizer.decoder = decoders.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space)
83
+
84
+ tokenizer.post_processor = TemplateProcessing(
85
+ single=f"$A {self.special_tokens['eos']['token']}",
86
+ special_tokens=[(self.special_tokens["eos"]["token"], self.special_tokens["eos"]["id"])],
87
+ )
88
+
89
+ parameters = {
90
+ "model": "SentencePieceUnigram",
91
+ "replacement": replacement,
92
+ "add_prefix_space": add_prefix_space,
93
+ }
94
+
95
+ super().__init__(tokenizer, parameters)
96
+
97
+ def train(
98
+ self,
99
+ files: Union[str, List[str]],
100
+ vocab_size: int = 8000,
101
+ show_progress: bool = True,
102
+ ):
103
+ """Train the model using the given files"""
104
+
105
+ trainer = trainers.UnigramTrainer(
106
+ vocab_size=vocab_size,
107
+ special_tokens=self.special_tokens_list,
108
+ show_progress=show_progress,
109
+ )
110
+
111
+ if isinstance(files, str):
112
+ files = [files]
113
+ self._tokenizer.train(files, trainer=trainer)
114
+
115
+ self.add_unk_id()
116
+
117
+ def train_from_iterator(
118
+ self,
119
+ iterator: Union[Iterator[str], Iterator[Iterator[str]]],
120
+ vocab_size: int = 8000,
121
+ show_progress: bool = True,
122
+ ):
123
+ """Train the model using the given iterator"""
124
+
125
+ trainer = trainers.UnigramTrainer(
126
+ vocab_size=vocab_size,
127
+ special_tokens=self.special_tokens_list,
128
+ show_progress=show_progress,
129
+ )
130
+
131
+ self._tokenizer.train_from_iterator(iterator, trainer=trainer)
132
+
133
+ self.add_unk_id()
134
+
135
+ def add_unk_id(self):
136
+ tokenizer_json = json.loads(self._tokenizer.to_str())
137
+
138
+ tokenizer_json["model"]["unk_id"] = self.special_tokens["unk"]["id"]
139
+
140
+ self._tokenizer = Tokenizer.from_str(json.dumps(tokenizer_json))