alvinwatner commited on
Commit
1d7fe78
·
1 Parent(s): 42a738c

Saving weights and logs of epoch 0

Browse files
config.json CHANGED
@@ -1,5 +1,5 @@
1
  {
2
- "_name_or_path": "facebook/bart-base",
3
  "activation_dropout": 0.1,
4
  "activation_function": "gelu",
5
  "add_bias_logits": false,
 
1
  {
2
+ "_name_or_path": "/home/alvinwatner/bart-qg-alpha-interro",
3
  "activation_dropout": 0.1,
4
  "activation_function": "gelu",
5
  "add_bias_logits": false,
events.out.tfevents.1639719034.t1v-n-22127d47-w-0.139681.0.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:43113075790298ba2536dcf4db82634b5932fd18d0f39203900967eaeafa8628
3
+ size 219397
flax_model.msgpack CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:cae9a29397548b360c4df3cecb1060626f6abe2c6592712e8fee68c9fcce592e
3
  size 557891914
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4698d1df8c200105890e3cd40d5b397d67b00f416dcc4778c31551f77c0474c1
3
  size 557891914
run_finetuning.sh ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export MODEL_DIR="$(pwd)"
2
+ export DATA_PATH=/home/$USER/dataset
3
+
4
+ python3 run_summarization_flax.py \
5
+ --output_dir ${MODEL_DIR} \
6
+ --model_name_or_path ${MODEL_DIR}/flax_model.msgpack \
7
+ --config_name ${MODEL_DIR} \
8
+ --tokenizer_name ${MODEL_DIR} \
9
+ --train_file ${DATA_PATH}/train_sen_jsonlines.json \
10
+ --validation_file ${DATA_PATH}/val_sen_jsonlines.json \
11
+ --test_file ${DATA_PATH}/test_sen_jsonlines.json \
12
+ --do_train --do_eval --do_predict --predict_with_generate \
13
+ --num_train_epochs 3 \
14
+ --learning_rate 5e-5 --warmup_steps 0 \
15
+ --per_device_train_batch_size 8 \
16
+ --per_device_eval_batch_size 8 \
17
+ --overwrite_output_dir \
18
+ --max_source_length 256 \
19
+ --max_target_length 64 \
20
+ --text_column src \
21
+ --summary_column tgt \
22
+ --hub_model_id alvinwatner/bart-qg-alpha-interro \
23
+ --hub_token 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 \
24
+ --push_to_hub
25
+
26
+
run_summarization_flax.py ADDED
@@ -0,0 +1,895 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the library models for summarization.
18
+ """
19
+ # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
20
+
21
+ import json
22
+ import logging
23
+ import os
24
+ import sys
25
+ import time
26
+ from dataclasses import asdict, dataclass, field
27
+ from enum import Enum
28
+ from functools import partial
29
+ from pathlib import Path
30
+ from typing import Callable, Optional
31
+
32
+ import datasets
33
+ import nltk # Here to have a nice missing dependency error message early on
34
+ import numpy as np
35
+ from datasets import Dataset, load_dataset, load_metric
36
+ from tqdm import tqdm
37
+
38
+ import jax
39
+ import jax.numpy as jnp
40
+ import optax
41
+ import transformers
42
+ from filelock import FileLock
43
+ from flax import jax_utils, traverse_util
44
+ from flax.jax_utils import unreplicate
45
+ from flax.training import train_state
46
+ from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
47
+ from huggingface_hub import Repository
48
+ from transformers import (
49
+ CONFIG_MAPPING,
50
+ FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
51
+ AutoConfig,
52
+ AutoTokenizer,
53
+ FlaxAutoModelForSeq2SeqLM,
54
+ HfArgumentParser,
55
+ is_tensorboard_available,
56
+ )
57
+ from transformers.file_utils import get_full_repo_name, is_offline_mode
58
+
59
+
60
+ logger = logging.getLogger(__name__)
61
+
62
+ try:
63
+ nltk.data.find("tokenizers/punkt")
64
+ except (LookupError, OSError):
65
+ if is_offline_mode():
66
+ raise LookupError(
67
+ "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files"
68
+ )
69
+ with FileLock(".lock") as lock:
70
+ nltk.download("punkt", quiet=True)
71
+
72
+
73
+ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys())
74
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
75
+
76
+
77
+ @dataclass
78
+ class TrainingArguments:
79
+ output_dir: str = field(
80
+ metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
81
+ )
82
+ overwrite_output_dir: bool = field(
83
+ default=False,
84
+ metadata={
85
+ "help": (
86
+ "Overwrite the content of the output directory. "
87
+ "Use this to continue training if output_dir points to a checkpoint directory."
88
+ )
89
+ },
90
+ )
91
+ do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
92
+ do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
93
+ do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."})
94
+ per_device_train_batch_size: int = field(
95
+ default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."}
96
+ )
97
+ per_device_eval_batch_size: int = field(
98
+ default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
99
+ )
100
+ learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
101
+ weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
102
+ adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
103
+ adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"})
104
+ adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."})
105
+ label_smoothing_factor: float = field(
106
+ default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."}
107
+ )
108
+ adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."})
109
+ num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."})
110
+ warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
111
+ logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
112
+ save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."})
113
+ eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."})
114
+ seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
115
+ push_to_hub: bool = field(
116
+ default=False, metadata={"help": "Whether or not to upload the trained model to the model hub after training."}
117
+ )
118
+ hub_model_id: str = field(
119
+ default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
120
+ )
121
+ hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
122
+
123
+ def __post_init__(self):
124
+ if self.output_dir is not None:
125
+ self.output_dir = os.path.expanduser(self.output_dir)
126
+
127
+ def to_dict(self):
128
+ """
129
+ Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates
130
+ the token values by removing their value.
131
+ """
132
+ d = asdict(self)
133
+ for k, v in d.items():
134
+ if isinstance(v, Enum):
135
+ d[k] = v.value
136
+ if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum):
137
+ d[k] = [x.value for x in v]
138
+ if k.endswith("_token"):
139
+ d[k] = f"<{k.upper()}>"
140
+ return d
141
+
142
+
143
+ @dataclass
144
+ class ModelArguments:
145
+ """
146
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
147
+ """
148
+
149
+ model_name_or_path: Optional[str] = field(
150
+ default=None,
151
+ metadata={
152
+ "help": "The model checkpoint for weights initialization."
153
+ "Don't set if you want to train a model from scratch."
154
+ },
155
+ )
156
+ model_type: Optional[str] = field(
157
+ default=None,
158
+ metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
159
+ )
160
+ config_name: Optional[str] = field(
161
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
162
+ )
163
+ tokenizer_name: Optional[str] = field(
164
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
165
+ )
166
+ cache_dir: Optional[str] = field(
167
+ default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
168
+ )
169
+ use_fast_tokenizer: bool = field(
170
+ default=True,
171
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
172
+ )
173
+ dtype: Optional[str] = field(
174
+ default="float32",
175
+ metadata={
176
+ "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
177
+ },
178
+ )
179
+
180
+
181
+ @dataclass
182
+ class DataTrainingArguments:
183
+ """
184
+ Arguments pertaining to what data we are going to input our model for training and eval.
185
+ """
186
+
187
+ dataset_name: Optional[str] = field(
188
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
189
+ )
190
+ dataset_config_name: Optional[str] = field(
191
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
192
+ )
193
+ text_column: Optional[str] = field(
194
+ default=None,
195
+ metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
196
+ )
197
+ summary_column: Optional[str] = field(
198
+ default=None,
199
+ metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."},
200
+ )
201
+ train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
202
+ validation_file: Optional[str] = field(
203
+ default=None,
204
+ metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
205
+ )
206
+ test_file: Optional[str] = field(
207
+ default=None,
208
+ metadata={"help": "An optional input predict data file to do prediction on (a text file)."},
209
+ )
210
+ max_source_length: Optional[int] = field(
211
+ default=1024,
212
+ metadata={
213
+ "help": "The maximum total input sequence length after tokenization. Sequences longer "
214
+ "than this will be truncated, sequences shorter will be padded."
215
+ },
216
+ )
217
+ max_target_length: Optional[int] = field(
218
+ default=128,
219
+ metadata={
220
+ "help": "The maximum total sequence length for target text after tokenization. Sequences longer "
221
+ "than this will be truncated, sequences shorter will be padded."
222
+ },
223
+ )
224
+ val_max_target_length: Optional[int] = field(
225
+ default=None,
226
+ metadata={
227
+ "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
228
+ "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
229
+ "This argument is also used to override the `max_length` param of `model.generate`, which is used "
230
+ "during evaluation."
231
+ },
232
+ )
233
+ max_train_samples: Optional[int] = field(
234
+ default=None,
235
+ metadata={
236
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
237
+ "value if set."
238
+ },
239
+ )
240
+ max_eval_samples: Optional[int] = field(
241
+ default=None,
242
+ metadata={
243
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
244
+ "value if set."
245
+ },
246
+ )
247
+ max_predict_samples: Optional[int] = field(
248
+ default=None,
249
+ metadata={
250
+ "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
251
+ "value if set."
252
+ },
253
+ )
254
+ preprocessing_num_workers: Optional[int] = field(
255
+ default=None,
256
+ metadata={"help": "The number of processes to use for the preprocessing."},
257
+ )
258
+ source_prefix: Optional[str] = field(
259
+ default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
260
+ )
261
+ predict_with_generate: bool = field(
262
+ default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."}
263
+ )
264
+ num_beams: Optional[int] = field(
265
+ default=None,
266
+ metadata={
267
+ "help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`, "
268
+ "which is used during evaluation."
269
+ },
270
+ )
271
+ overwrite_cache: bool = field(
272
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
273
+ )
274
+
275
+ def __post_init__(self):
276
+ if self.dataset_name is None and self.train_file is None and self.validation_file is None:
277
+ raise ValueError("Need either a dataset name or a training/validation file.")
278
+ else:
279
+ if self.train_file is not None:
280
+ extension = self.train_file.split(".")[-1]
281
+ assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
282
+ if self.validation_file is not None:
283
+ extension = self.validation_file.split(".")[-1]
284
+ assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
285
+ if self.val_max_target_length is None:
286
+ self.val_max_target_length = self.max_target_length
287
+
288
+
289
+ summarization_name_mapping = {
290
+ "amazon_reviews_multi": ("review_body", "review_title"),
291
+ "big_patent": ("description", "abstract"),
292
+ "cnn_dailymail": ("article", "highlights"),
293
+ "orange_sum": ("text", "summary"),
294
+ "pn_summary": ("article", "summary"),
295
+ "psc": ("extract_text", "summary_text"),
296
+ "samsum": ("dialogue", "summary"),
297
+ "thaisum": ("body", "summary"),
298
+ "xglue": ("news_body", "news_title"),
299
+ "xsum": ("document", "summary"),
300
+ "wiki_summary": ("article", "highlights"),
301
+ }
302
+
303
+
304
+ class TrainState(train_state.TrainState):
305
+ dropout_rng: jnp.ndarray
306
+
307
+ def replicate(self):
308
+ return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
309
+
310
+
311
+ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False):
312
+ """
313
+ Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
314
+ Shuffle batches if `shuffle` is `True`.
315
+ """
316
+ steps_per_epoch = len(dataset) // batch_size
317
+
318
+ if shuffle:
319
+ batch_idx = jax.random.permutation(rng, len(dataset))
320
+ else:
321
+ batch_idx = jnp.arange(len(dataset))
322
+
323
+ batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
324
+ batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
325
+
326
+ for idx in batch_idx:
327
+ batch = dataset[idx]
328
+ batch = {k: jnp.array(v) for k, v in batch.items()}
329
+
330
+ batch = shard(batch)
331
+
332
+ yield batch
333
+
334
+
335
+ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
336
+ summary_writer.scalar("train_time", train_time, step)
337
+
338
+ train_metrics = get_metrics(train_metrics)
339
+ for key, vals in train_metrics.items():
340
+ tag = f"train_{key}"
341
+ for i, val in enumerate(vals):
342
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
343
+
344
+ for metric_name, value in eval_metrics.items():
345
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
346
+
347
+
348
+ def create_learning_rate_fn(
349
+ train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
350
+ ) -> Callable[[int], jnp.array]:
351
+ """Returns a linear warmup, linear_decay learning rate function."""
352
+ steps_per_epoch = train_ds_size // train_batch_size
353
+ num_train_steps = steps_per_epoch * num_train_epochs
354
+ warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
355
+ decay_fn = optax.linear_schedule(
356
+ init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
357
+ )
358
+ schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
359
+ return schedule_fn
360
+
361
+
362
+ def main():
363
+ # See all possible arguments in src/transformers/training_args.py
364
+ # or by passing the --help flag to this script.
365
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
366
+
367
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
368
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
369
+ # If we pass only one argument to the script and it's the path to a json file,
370
+ # let's parse it to get our arguments.
371
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
372
+ else:
373
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
374
+
375
+ if (
376
+ os.path.exists(training_args.output_dir)
377
+ and os.listdir(training_args.output_dir)
378
+ and training_args.do_train
379
+ and not training_args.overwrite_output_dir
380
+ ):
381
+ raise ValueError(
382
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
383
+ "Use --overwrite_output_dir to overcome."
384
+ )
385
+
386
+ # Make one log on every process with the configuration for debugging.
387
+ logging.basicConfig(
388
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
389
+ datefmt="%m/%d/%Y %H:%M:%S",
390
+ level=logging.INFO,
391
+ )
392
+ # Setup logging, we only want one process per machine to log things on the screen.
393
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
394
+ if jax.process_index() == 0:
395
+ datasets.utils.logging.set_verbosity_warning()
396
+ transformers.utils.logging.set_verbosity_info()
397
+ else:
398
+ datasets.utils.logging.set_verbosity_error()
399
+ transformers.utils.logging.set_verbosity_error()
400
+
401
+ # Set the verbosity to info of the Transformers logger (on main process only):
402
+ logger.info(f"Training/evaluation parameters {training_args}")
403
+
404
+ # Handle the repository creation
405
+ if training_args.push_to_hub:
406
+ if training_args.hub_model_id is None:
407
+ repo_name = get_full_repo_name(
408
+ Path(training_args.output_dir).absolute().name, token=training_args.hub_token
409
+ )
410
+ else:
411
+ repo_name = training_args.hub_model_id
412
+ repo = Repository(training_args.output_dir, clone_from=repo_name)
413
+
414
+ # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
415
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
416
+ # (the dataset will be downloaded automatically from the datasets Hub).
417
+ #
418
+ # For CSV/JSON files this script will use the first column for the full texts and the second column for the
419
+ # summaries (unless you specify column names for this with the `text_column` and `summary_column` arguments).
420
+ #
421
+ if data_args.dataset_name is not None:
422
+ # Downloading and loading a dataset from the hub.
423
+ dataset = load_dataset(
424
+ data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, keep_in_memory=False
425
+ )
426
+ else:
427
+ data_files = {}
428
+ if data_args.train_file is not None:
429
+ data_files["train"] = data_args.train_file
430
+ extension = data_args.train_file.split(".")[-1]
431
+ if data_args.validation_file is not None:
432
+ data_files["validation"] = data_args.validation_file
433
+ extension = data_args.validation_file.split(".")[-1]
434
+ if data_args.test_file is not None:
435
+ data_files["test"] = data_args.test_file
436
+ extension = data_args.test_file.split(".")[-1]
437
+ dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
438
+ # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
439
+ # https://huggingface.co/docs/datasets/loading_datasets.html.
440
+
441
+ # Load pretrained model and tokenizer
442
+
443
+ if model_args.config_name:
444
+ config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
445
+ elif model_args.model_name_or_path:
446
+ config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
447
+ else:
448
+ config = CONFIG_MAPPING[model_args.model_type]()
449
+ logger.warning("You are instantiating a new config instance from scratch.")
450
+
451
+ if model_args.tokenizer_name:
452
+ tokenizer = AutoTokenizer.from_pretrained(
453
+ model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
454
+ )
455
+ elif model_args.model_name_or_path:
456
+ tokenizer = AutoTokenizer.from_pretrained(
457
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
458
+ )
459
+ else:
460
+ raise ValueError(
461
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
462
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
463
+ )
464
+
465
+ if model_args.model_name_or_path:
466
+ model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
467
+ model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
468
+ )
469
+ else:
470
+ model = FlaxAutoModelForSeq2SeqLM.from_config(
471
+ config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
472
+ )
473
+
474
+ if model.config.decoder_start_token_id is None:
475
+ raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
476
+
477
+ prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
478
+
479
+ # Preprocessing the datasets.
480
+ # We need to tokenize inputs and targets.
481
+ if training_args.do_train:
482
+ column_names = dataset["train"].column_names
483
+ elif training_args.do_eval:
484
+ column_names = dataset["validation"].column_names
485
+ elif training_args.do_predict:
486
+ column_names = dataset["test"].column_names
487
+ else:
488
+ logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
489
+ return
490
+
491
+ # Get the column names for input/target.
492
+ dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None)
493
+ if data_args.text_column is None:
494
+ text_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
495
+ else:
496
+ text_column = data_args.text_column
497
+ if text_column not in column_names:
498
+ raise ValueError(
499
+ f"--text_column' value '{data_args.text_column}' needs to be one of: {', '.join(column_names)}"
500
+ )
501
+ if data_args.summary_column is None:
502
+ summary_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
503
+ else:
504
+ summary_column = data_args.summary_column
505
+ if summary_column not in column_names:
506
+ raise ValueError(
507
+ f"--summary_column' value '{data_args.summary_column}' needs to be one of: {', '.join(column_names)}"
508
+ )
509
+
510
+ # Temporarily set max_target_length for training.
511
+ max_target_length = data_args.max_target_length
512
+
513
+ # In Flax, for seq2seq models we need to pass `decoder_input_ids`
514
+ # as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here
515
+ # for that dynamically import the `shift_tokens_right` function from the model file
516
+ model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"])
517
+ shift_tokens_right_fn = getattr(model_module, "shift_tokens_right")
518
+
519
+ # Setting padding="max_length" as we need fixed length inputs for jitted functions
520
+ def preprocess_function(examples):
521
+ inputs = examples[text_column]
522
+ targets = examples[summary_column]
523
+ inputs = [prefix + inp for inp in inputs]
524
+ model_inputs = tokenizer(
525
+ inputs, max_length=data_args.max_source_length, padding="max_length", truncation=True, return_tensors="np"
526
+ )
527
+
528
+ # Setup the tokenizer for targets
529
+ with tokenizer.as_target_tokenizer():
530
+ labels = tokenizer(
531
+ targets, max_length=max_target_length, padding="max_length", truncation=True, return_tensors="np"
532
+ )
533
+
534
+ model_inputs["labels"] = labels["input_ids"]
535
+ decoder_input_ids = shift_tokens_right_fn(
536
+ labels["input_ids"], config.pad_token_id, config.decoder_start_token_id
537
+ )
538
+ model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
539
+
540
+ # We need decoder_attention_mask so we can ignore pad tokens from loss
541
+ model_inputs["decoder_attention_mask"] = labels["attention_mask"]
542
+
543
+ return model_inputs
544
+
545
+ if training_args.do_train:
546
+ if "train" not in dataset:
547
+ raise ValueError("--do_train requires a train dataset")
548
+ train_dataset = dataset["train"]
549
+ if data_args.max_train_samples is not None:
550
+ train_dataset = train_dataset.select(range(data_args.max_train_samples))
551
+ train_dataset = train_dataset.map(
552
+ preprocess_function,
553
+ batched=True,
554
+ num_proc=data_args.preprocessing_num_workers,
555
+ remove_columns=column_names,
556
+ load_from_cache_file=not data_args.overwrite_cache,
557
+ desc="Running tokenizer on train dataset",
558
+ )
559
+
560
+ if training_args.do_eval:
561
+ max_target_length = data_args.val_max_target_length
562
+ if "validation" not in dataset:
563
+ raise ValueError("--do_eval requires a validation dataset")
564
+ eval_dataset = dataset["validation"]
565
+ if data_args.max_eval_samples is not None:
566
+ eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
567
+ eval_dataset = eval_dataset.map(
568
+ preprocess_function,
569
+ batched=True,
570
+ num_proc=data_args.preprocessing_num_workers,
571
+ remove_columns=column_names,
572
+ load_from_cache_file=not data_args.overwrite_cache,
573
+ desc="Running tokenizer on validation dataset",
574
+ )
575
+
576
+ if training_args.do_predict:
577
+ max_target_length = data_args.val_max_target_length
578
+ if "test" not in dataset:
579
+ raise ValueError("--do_predict requires a test dataset")
580
+ predict_dataset = dataset["test"]
581
+ if data_args.max_predict_samples is not None:
582
+ predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
583
+ predict_dataset = predict_dataset.map(
584
+ preprocess_function,
585
+ batched=True,
586
+ num_proc=data_args.preprocessing_num_workers,
587
+ remove_columns=column_names,
588
+ load_from_cache_file=not data_args.overwrite_cache,
589
+ desc="Running tokenizer on prediction dataset",
590
+ )
591
+
592
+ # Metric
593
+ metric = load_metric("rouge")
594
+
595
+ def postprocess_text(preds, labels):
596
+ preds = [pred.strip() for pred in preds]
597
+ labels = [label.strip() for label in labels]
598
+
599
+ # rougeLSum expects newline after each sentence
600
+ preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
601
+ labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
602
+
603
+ return preds, labels
604
+
605
+ def compute_metrics(preds, labels):
606
+ decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
607
+ decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
608
+
609
+ # Some simple post-processing
610
+ decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
611
+
612
+ result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
613
+ # Extract a few results from ROUGE
614
+ result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
615
+
616
+ prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
617
+ result["gen_len"] = np.mean(prediction_lens)
618
+ result = {k: round(v, 4) for k, v in result.items()}
619
+ return result
620
+
621
+ # Enable tensorboard only on the master node
622
+ has_tensorboard = is_tensorboard_available()
623
+ if has_tensorboard and jax.process_index() == 0:
624
+ try:
625
+ from flax.metrics.tensorboard import SummaryWriter
626
+
627
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
628
+ except ImportError as ie:
629
+ has_tensorboard = False
630
+ logger.warning(
631
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
632
+ )
633
+ else:
634
+ logger.warning(
635
+ "Unable to display metrics through TensorBoard because the package is not installed: "
636
+ "Please run pip install tensorboard to enable."
637
+ )
638
+
639
+ # Initialize our training
640
+ rng = jax.random.PRNGKey(training_args.seed)
641
+ rng, dropout_rng = jax.random.split(rng)
642
+
643
+ # Store some constant
644
+ num_epochs = int(training_args.num_train_epochs)
645
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
646
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
647
+ steps_per_epoch = len(train_dataset) // train_batch_size
648
+ total_train_steps = steps_per_epoch * num_epochs
649
+
650
+ # Create learning rate schedule
651
+ linear_decay_lr_schedule_fn = create_learning_rate_fn(
652
+ len(train_dataset),
653
+ train_batch_size,
654
+ training_args.num_train_epochs,
655
+ training_args.warmup_steps,
656
+ training_args.learning_rate,
657
+ )
658
+
659
+ # We use Optax's "masking" functionality to not apply weight decay
660
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
661
+ # mask boolean with the same structure as the parameters.
662
+ # The mask is True for parameters that should be decayed.
663
+ # Note that this mask is specifically adapted for FlaxBart.
664
+ # For FlaxT5, one should correct the layer norm parameter naming
665
+ # accordingly - see `run_t5_mlm_flax.py` e.g.
666
+ def decay_mask_fn(params):
667
+ flat_params = traverse_util.flatten_dict(params)
668
+ layer_norm_params = [
669
+ (name, "scale") for name in ["self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"]
670
+ ]
671
+ flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params}
672
+ return traverse_util.unflatten_dict(flat_mask)
673
+
674
+ # create adam optimizer
675
+ adamw = optax.adamw(
676
+ learning_rate=linear_decay_lr_schedule_fn,
677
+ b1=training_args.adam_beta1,
678
+ b2=training_args.adam_beta2,
679
+ eps=training_args.adam_epsilon,
680
+ weight_decay=training_args.weight_decay,
681
+ mask=decay_mask_fn,
682
+ )
683
+
684
+ # Setup train state
685
+ state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
686
+
687
+ # label smoothed cross entropy
688
+ def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
689
+ """
690
+ The label smoothing implementation is adapted from Flax's official example:
691
+ https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
692
+ """
693
+ vocab_size = logits.shape[-1]
694
+ confidence = 1.0 - label_smoothing_factor
695
+ low_confidence = (1.0 - confidence) / (vocab_size - 1)
696
+ normalizing_constant = -(
697
+ confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
698
+ )
699
+ soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
700
+
701
+ loss = optax.softmax_cross_entropy(logits, soft_labels)
702
+ loss = loss - normalizing_constant
703
+
704
+ # ignore padded tokens from loss
705
+ loss = loss * padding_mask
706
+ loss = loss.sum() / padding_mask.sum()
707
+ return loss
708
+
709
+ # Define gradient update step fn
710
+ def train_step(state, batch, label_smoothing_factor=0.0):
711
+ dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
712
+
713
+ def compute_loss(params):
714
+ labels = batch.pop("labels")
715
+ logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
716
+ loss = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
717
+ return loss
718
+
719
+ grad_fn = jax.value_and_grad(compute_loss)
720
+ loss, grad = grad_fn(state.params)
721
+ grad = jax.lax.pmean(grad, "batch")
722
+
723
+ new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
724
+
725
+ metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
726
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
727
+
728
+ return new_state, metrics
729
+
730
+ # Define eval fn
731
+ def eval_step(params, batch, label_smoothing_factor=0.0):
732
+ labels = batch.pop("labels")
733
+ logits = model(**batch, params=params, train=False)[0]
734
+ loss = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
735
+
736
+ # summarize metrics
737
+ metrics = {"loss": loss}
738
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
739
+ return metrics
740
+
741
+ # Define generation function
742
+ max_length = (
743
+ data_args.val_max_target_length if data_args.val_max_target_length is not None else model.config.max_length
744
+ )
745
+ num_beams = data_args.num_beams if data_args.num_beams is not None else model.config.num_beams
746
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
747
+
748
+ def generate_step(params, batch):
749
+ model.params = params
750
+ output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **gen_kwargs)
751
+ return output_ids.sequences
752
+
753
+ # Create parallel version of the train and eval step
754
+ p_train_step = jax.pmap(
755
+ partial(train_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch", donate_argnums=(0,)
756
+ )
757
+ p_eval_step = jax.pmap(partial(eval_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch")
758
+ p_generate_step = jax.pmap(generate_step, "batch")
759
+
760
+ # Replicate the train state on each device
761
+ state = state.replicate()
762
+
763
+ logger.info("***** Running training *****")
764
+ logger.info(f" Num examples = {len(train_dataset)}")
765
+ logger.info(f" Num Epochs = {num_epochs}")
766
+ logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
767
+ logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
768
+ logger.info(f" Total optimization steps = {total_train_steps}")
769
+
770
+ train_time = 0
771
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
772
+ for epoch in epochs:
773
+ # ======================== Training ================================
774
+ train_start = time.time()
775
+
776
+ # Create sampling rng
777
+ rng, input_rng = jax.random.split(rng)
778
+ train_metrics = []
779
+
780
+ # Generate an epoch by shuffling sampling indices from the train dataset
781
+ train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
782
+ steps_per_epoch = len(train_dataset) // train_batch_size
783
+ # train
784
+ for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
785
+ batch = next(train_loader)
786
+ state, train_metric = p_train_step(state, batch)
787
+ train_metrics.append(train_metric)
788
+
789
+ train_time += time.time() - train_start
790
+
791
+ train_metric = unreplicate(train_metric)
792
+
793
+ epochs.write(
794
+ f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
795
+ )
796
+
797
+ # ======================== Evaluating ==============================
798
+ eval_metrics = []
799
+ eval_preds = []
800
+ eval_labels = []
801
+
802
+ eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
803
+ eval_steps = len(eval_dataset) // eval_batch_size
804
+ for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
805
+ # Model forward
806
+ batch = next(eval_loader)
807
+ labels = batch["labels"]
808
+
809
+ metrics = p_eval_step(state.params, batch)
810
+ eval_metrics.append(metrics)
811
+
812
+ # generation
813
+ if data_args.predict_with_generate:
814
+ generated_ids = p_generate_step(state.params, batch)
815
+ eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
816
+ eval_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1])))
817
+
818
+ # normalize eval metrics
819
+ eval_metrics = get_metrics(eval_metrics)
820
+ eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
821
+
822
+ # compute ROUGE metrics
823
+ rouge_desc = ""
824
+ if data_args.predict_with_generate:
825
+ rouge_metrics = compute_metrics(eval_preds, eval_labels)
826
+ eval_metrics.update(rouge_metrics)
827
+ rouge_desc = " ".join([f"Eval {key}: {value} |" for key, value in rouge_metrics.items()])
828
+
829
+ # Print metrics and update progress bar
830
+ desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})"
831
+ epochs.write(desc)
832
+ epochs.desc = desc
833
+
834
+ # Save metrics
835
+ if has_tensorboard and jax.process_index() == 0:
836
+ cur_step = epoch * (len(train_dataset) // train_batch_size)
837
+ write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
838
+
839
+ # save checkpoint after each epoch and push checkpoint to the hub
840
+ if jax.process_index() == 0:
841
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
842
+ model.save_pretrained(training_args.output_dir, params=params)
843
+ tokenizer.save_pretrained(training_args.output_dir)
844
+ if training_args.push_to_hub:
845
+ repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)
846
+
847
+ # ======================== Prediction loop ==============================
848
+ if training_args.do_predict:
849
+ logger.info("*** Predict ***")
850
+
851
+ pred_metrics = []
852
+ pred_generations = []
853
+ pred_labels = []
854
+
855
+ pred_loader = data_loader(input_rng, predict_dataset, eval_batch_size)
856
+ pred_steps = len(predict_dataset) // eval_batch_size
857
+ for _ in tqdm(range(pred_steps), desc="Predicting...", position=2, leave=False):
858
+ # Model forward
859
+ batch = next(pred_loader)
860
+ labels = batch["labels"]
861
+
862
+ metrics = p_eval_step(state.params, batch)
863
+ pred_metrics.append(metrics)
864
+
865
+ # generation
866
+ if data_args.predict_with_generate:
867
+ generated_ids = p_generate_step(state.params, batch)
868
+ pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
869
+ pred_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1])))
870
+
871
+ # normalize prediction metrics
872
+ pred_metrics = get_metrics(pred_metrics)
873
+ pred_metrics = jax.tree_map(jnp.mean, pred_metrics)
874
+
875
+ # compute ROUGE metrics
876
+ rouge_desc = ""
877
+ if data_args.predict_with_generate:
878
+ rouge_metrics = compute_metrics(pred_generations, pred_labels)
879
+ pred_metrics.update(rouge_metrics)
880
+ rouge_desc = " ".join([f"Predict {key}: {value} |" for key, value in rouge_metrics.items()])
881
+
882
+ # Print metrics
883
+ desc = f"Predict Loss: {pred_metrics['loss']} | {rouge_desc})"
884
+ logger.info(desc)
885
+
886
+ # save final metrics in json
887
+ if jax.process_index() == 0:
888
+ rouge_metrics = {f"test_{metric_name}": value for metric_name, value in rouge_metrics.items()}
889
+ path = os.path.join(training_args.output_dir, "test_results.json")
890
+ with open(path, "w") as f:
891
+ json.dump(rouge_metrics, f, indent=4, sort_keys=True)
892
+
893
+
894
+ if __name__ == "__main__":
895
+ main()
test_results.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "test_gen_len": 12.9904,
3
+ "test_rouge1": 66.2214,
4
+ "test_rouge2": 42.4956,
5
+ "test_rougeL": 60.8554,
6
+ "test_rougeLsum": 60.8631
7
+ }
testing.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ export MODEL_DIR="$(pwd)"
2
+ ${MODEL_DIR}
tokenizer_config.json CHANGED
@@ -1 +1 @@
1
- {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>", "add_prefix_space": false, "errors": "replace", "sep_token": "</s>", "cls_token": "<s>", "pad_token": "<pad>", "mask_token": "<mask>", "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "facebook/bart-base", "tokenizer_class": "BartTokenizer"}
 
1
+ {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>", "add_prefix_space": false, "errors": "replace", "sep_token": "</s>", "cls_token": "<s>", "pad_token": "<pad>", "mask_token": "<mask>", "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "/home/alvinwatner/bart-qg-alpha-interro", "tokenizer_class": "BartTokenizer"}