hotchpotch commited on
Commit
2b3a7e0
·
verified ·
1 Parent(s): 656fb4f

Upload trainer.py

Browse files

add trainer script

Files changed (1) hide show
  1. trainer.py +432 -0
trainer.py ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # static-embedding-japanese trainer.py
2
+ # base: https://huggingface.co/blog/static-embeddings
3
+ # MIT License
4
+
5
+ import logging
6
+ import os
7
+ import random
8
+ from pathlib import Path
9
+
10
+ from sentence_transformers import (
11
+ SentenceTransformer,
12
+ SentenceTransformerModelCardData,
13
+ SentenceTransformerTrainer,
14
+ SentenceTransformerTrainingArguments,
15
+ )
16
+ from sentence_transformers.evaluation import NanoBEIREvaluator
17
+ from sentence_transformers.losses import MatryoshkaLoss, MultipleNegativesRankingLoss
18
+ from sentence_transformers.models.StaticEmbedding import StaticEmbedding
19
+ from sentence_transformers.training_args import BatchSamplers, MultiDatasetBatchSamplers
20
+ from transformers import AutoTokenizer
21
+
22
+ from datasets import Dataset, DatasetDict, load_dataset
23
+
24
+ EXP = "030"
25
+ print("EXP:", EXP)
26
+
27
+ PROJECT_ROOT = Path(__file__).resolve().parents[1]
28
+ print(PROJECT_ROOT)
29
+
30
+ EN_TARGET_DATASETS = [
31
+ # "gooaq", # non-commarical
32
+ "msmarco",
33
+ "squad",
34
+ # "s2orc", # large
35
+ "allnli",
36
+ # "paq", # large
37
+ "trivia_qa",
38
+ # "msmarco_10m",
39
+ "swim_ir",
40
+ # "pubmedqa",
41
+ "miracl",
42
+ # "mldr", # non-commarical
43
+ "mr_tydi",
44
+ ]
45
+
46
+ JA_TARGET_DATASETS = [
47
+ "hpprc_emb__auto-wiki-nli-triplet",
48
+ "hpprc_emb__auto-wiki-qa",
49
+ "hpprc_emb__auto-wiki-qa-nemotron",
50
+ "hpprc_emb__auto-wiki-qa-pair",
51
+ "hpprc_emb__baobab-wiki-retrieval",
52
+ # "hpprc_emb__jagovfaqs", JMTEB task のtestに正解が含まれている
53
+ "hpprc_emb__janli-triplet",
54
+ "hpprc_emb__jaquad",
55
+ "hpprc_emb__jqara", # JMTEB task のドメイン
56
+ "hpprc_emb__jsnli-triplet",
57
+ "hpprc_emb__jsquad",
58
+ "hpprc_emb__miracl", # JMTEB task のドメイン
59
+ "hpprc_emb__mkqa",
60
+ "hpprc_emb__mkqa-triplet",
61
+ # "hpprc_emb__mmarco", 文字化け等が含みノイジー
62
+ "hpprc_emb__mr-tydi", # JMTEB task のドメイン
63
+ "hpprc_emb__nu-mnli-triplet",
64
+ "hpprc_emb__nu-snli-triplet",
65
+ # "hpprc_emb__paws-x-triplet", JMTEB task のtestに含まれている?
66
+ "hpprc_emb__quiz-no-mori",
67
+ "hpprc_emb__quiz-works",
68
+ "hpprc_emb__snow-triplet",
69
+ "hpprc_llmjp-kaken",
70
+ "hpprc_llmjp_warp_html",
71
+ "hpprc_mqa_ja",
72
+ "hpprc_msmarco_ja",
73
+ ]
74
+
75
+ AUG_FACTOR_DATASETS = {
76
+ "hpprc_emb__miracl": 20,
77
+ "hpprc_emb__mr-tydi": 20,
78
+ "hpprc_emb__jqara": 10,
79
+ "hpprc_emb__baobab-wiki-retrieval": 5,
80
+ "hpprc_emb__mkqa": 5,
81
+ "hpprc_emb__auto-wiki-qa-nemotron": 2,
82
+ "hpprc_msmarco_ja": 2,
83
+ }
84
+
85
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
86
+
87
+
88
+ logging.basicConfig(
89
+ format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO
90
+ )
91
+ random.seed(12)
92
+
93
+
94
+ def _load_train_eval_datasets_en():
95
+ """
96
+ Either load the train and eval datasets from disk or load them from the datasets library & save them to disk.
97
+
98
+ Upon saving to disk, we quit() to ensure that the datasets are not loaded into memory before training.
99
+ """
100
+ en_train_dataset_dir = PROJECT_ROOT / "datasets" / "en_train_dataset"
101
+ en_eval_dataset_dir = PROJECT_ROOT / "datasets" / "en_eval_dataset"
102
+ try:
103
+ train_dataset = DatasetDict.load_from_disk(en_train_dataset_dir)
104
+ eval_dataset = DatasetDict.load_from_disk(en_eval_dataset_dir)
105
+ return train_dataset, eval_dataset
106
+ except FileNotFoundError:
107
+ print("Loading gooaq dataset...")
108
+ gooaq_dataset = load_dataset("sentence-transformers/gooaq", split="train")
109
+ gooaq_dataset_dict = gooaq_dataset.train_test_split(test_size=10_000, seed=12)
110
+ gooaq_train_dataset: Dataset = gooaq_dataset_dict["train"]
111
+ gooaq_eval_dataset: Dataset = gooaq_dataset_dict["test"]
112
+ print("Loaded gooaq dataset.")
113
+
114
+ print("Loading msmarco dataset...")
115
+ msmarco_dataset = load_dataset(
116
+ "sentence-transformers/msmarco-co-condenser-margin-mse-sym-mnrl-mean-v1",
117
+ "triplet",
118
+ split="train",
119
+ )
120
+ msmarco_dataset_dict = msmarco_dataset.train_test_split(
121
+ test_size=10_000, seed=12
122
+ )
123
+ msmarco_train_dataset: Dataset = msmarco_dataset_dict["train"]
124
+ msmarco_eval_dataset: Dataset = msmarco_dataset_dict["test"]
125
+ print("Loaded msmarco dataset.")
126
+
127
+ print("Loading squad dataset...")
128
+ squad_dataset = load_dataset("sentence-transformers/squad", split="train")
129
+ squad_dataset_dict = squad_dataset.train_test_split(test_size=10_000, seed=12)
130
+ squad_train_dataset: Dataset = squad_dataset_dict["train"]
131
+ squad_eval_dataset: Dataset = squad_dataset_dict["test"]
132
+ print("Loaded squad dataset.")
133
+
134
+ print("Loading s2orc dataset...")
135
+ s2orc_dataset = load_dataset(
136
+ "sentence-transformers/s2orc", "title-abstract-pair", split="train[:100000]"
137
+ )
138
+ s2orc_dataset_dict = s2orc_dataset.train_test_split(test_size=10_000, seed=12)
139
+ s2orc_train_dataset: Dataset = s2orc_dataset_dict["train"]
140
+ s2orc_eval_dataset: Dataset = s2orc_dataset_dict["test"]
141
+ print("Loaded s2orc dataset.")
142
+
143
+ print("Loading allnli dataset...")
144
+ allnli_train_dataset = load_dataset(
145
+ "sentence-transformers/all-nli", "triplet", split="train"
146
+ )
147
+ allnli_eval_dataset = load_dataset(
148
+ "sentence-transformers/all-nli", "triplet", split="dev"
149
+ )
150
+ print("Loaded allnli dataset.")
151
+
152
+ print("Loading paq dataset...")
153
+ paq_dataset = load_dataset("sentence-transformers/paq", split="train")
154
+ paq_dataset_dict = paq_dataset.train_test_split(test_size=10_000, seed=12)
155
+ paq_train_dataset: Dataset = paq_dataset_dict["train"]
156
+ paq_eval_dataset: Dataset = paq_dataset_dict["test"]
157
+ print("Loaded paq dataset.")
158
+
159
+ print("Loading trivia_qa dataset...")
160
+ trivia_qa = load_dataset("sentence-transformers/trivia-qa", split="train")
161
+ trivia_qa_dataset_dict = trivia_qa.train_test_split(test_size=5_000, seed=12)
162
+ trivia_qa_train_dataset: Dataset = trivia_qa_dataset_dict["train"]
163
+ trivia_qa_eval_dataset: Dataset = trivia_qa_dataset_dict["test"]
164
+ print("Loaded trivia_qa dataset.")
165
+
166
+ print("Loading msmarco_10m dataset...")
167
+ msmarco_10m_dataset = load_dataset(
168
+ "bclavie/msmarco-10m-triplets", split="train"
169
+ )
170
+ msmarco_10m_dataset_dict = msmarco_10m_dataset.train_test_split(
171
+ test_size=10_000, seed=12
172
+ )
173
+ msmarco_10m_train_dataset: Dataset = msmarco_10m_dataset_dict["train"]
174
+ msmarco_10m_eval_dataset: Dataset = msmarco_10m_dataset_dict["test"]
175
+ print("Loaded msmarco_10m dataset.")
176
+
177
+ print("Loading swim_ir dataset...")
178
+ swim_ir_dataset = load_dataset(
179
+ "nthakur/swim-ir-monolingual", "en", split="train"
180
+ ).select_columns(["query", "text"])
181
+ swim_ir_dataset_dict = swim_ir_dataset.train_test_split(
182
+ test_size=10_000, seed=12
183
+ )
184
+ swim_ir_train_dataset: Dataset = swim_ir_dataset_dict["train"]
185
+ swim_ir_eval_dataset: Dataset = swim_ir_dataset_dict["test"]
186
+ print("Loaded swim_ir dataset.")
187
+
188
+ # NOTE: 20 negatives
189
+ print("Loading pubmedqa dataset...")
190
+ pubmedqa_dataset = load_dataset(
191
+ "sentence-transformers/pubmedqa", "triplet-20", split="train"
192
+ )
193
+ pubmedqa_dataset_dict = pubmedqa_dataset.train_test_split(
194
+ test_size=100, seed=12
195
+ )
196
+ pubmedqa_train_dataset: Dataset = pubmedqa_dataset_dict["train"]
197
+ pubmedqa_eval_dataset: Dataset = pubmedqa_dataset_dict["test"]
198
+ print("Loaded pubmedqa dataset.")
199
+
200
+ # NOTE: A lot of overlap with anchor/positives
201
+ print("Loading miracl dataset...")
202
+ miracl_dataset = load_dataset(
203
+ "sentence-transformers/miracl", "en-triplet-all", split="train"
204
+ )
205
+ miracl_dataset_dict = miracl_dataset.train_test_split(test_size=10_000, seed=12)
206
+ miracl_train_dataset: Dataset = miracl_dataset_dict["train"]
207
+ miracl_eval_dataset: Dataset = miracl_dataset_dict["test"]
208
+ print("Loaded miracl dataset.")
209
+
210
+ # NOTE: A lot of overlap with anchor/positives
211
+ print("Loading mldr dataset...")
212
+ mldr_dataset = load_dataset(
213
+ "sentence-transformers/mldr", "en-triplet-all", split="train"
214
+ )
215
+ mldr_dataset_dict = mldr_dataset.train_test_split(test_size=10_000, seed=12)
216
+ mldr_train_dataset: Dataset = mldr_dataset_dict["train"]
217
+ mldr_eval_dataset: Dataset = mldr_dataset_dict["test"]
218
+ print("Loaded mldr dataset.")
219
+
220
+ # NOTE: A lot of overlap with anchor/positives
221
+ print("Loading mr_tydi dataset...")
222
+ mr_tydi_dataset = load_dataset(
223
+ "sentence-transformers/mr-tydi", "en-triplet-all", split="train"
224
+ )
225
+ mr_tydi_dataset_dict = mr_tydi_dataset.train_test_split(
226
+ test_size=10_000, seed=12
227
+ )
228
+ mr_tydi_train_dataset: Dataset = mr_tydi_dataset_dict["train"]
229
+ mr_tydi_eval_dataset: Dataset = mr_tydi_dataset_dict["test"]
230
+ print("Loaded mr_tydi dataset.")
231
+
232
+ train_dataset = DatasetDict(
233
+ {
234
+ "gooaq": gooaq_train_dataset,
235
+ "msmarco": msmarco_train_dataset,
236
+ "squad": squad_train_dataset,
237
+ "s2orc": s2orc_train_dataset,
238
+ "allnli": allnli_train_dataset,
239
+ "paq": paq_train_dataset,
240
+ "trivia_qa": trivia_qa_train_dataset,
241
+ "msmarco_10m": msmarco_10m_train_dataset,
242
+ "swim_ir": swim_ir_train_dataset,
243
+ "pubmedqa": pubmedqa_train_dataset,
244
+ "miracl": miracl_train_dataset,
245
+ "mldr": mldr_train_dataset,
246
+ "mr_tydi": mr_tydi_train_dataset,
247
+ }
248
+ )
249
+ eval_dataset = DatasetDict(
250
+ {
251
+ "gooaq": gooaq_eval_dataset,
252
+ "msmarco": msmarco_eval_dataset,
253
+ "squad": squad_eval_dataset,
254
+ "s2orc": s2orc_eval_dataset,
255
+ "allnli": allnli_eval_dataset,
256
+ "paq": paq_eval_dataset,
257
+ "trivia_qa": trivia_qa_eval_dataset,
258
+ "msmarco_10m": msmarco_10m_eval_dataset,
259
+ "swim_ir": swim_ir_eval_dataset,
260
+ "pubmedqa": pubmedqa_eval_dataset,
261
+ "miracl": miracl_eval_dataset,
262
+ "mldr": mldr_eval_dataset,
263
+ "mr_tydi": mr_tydi_eval_dataset,
264
+ }
265
+ )
266
+
267
+ train_dataset.save_to_disk(str(en_train_dataset_dir))
268
+ eval_dataset.save_to_disk(str(en_eval_dataset_dir))
269
+ return train_dataset, eval_dataset
270
+
271
+
272
+ def load_train_eval_datasets_en(target_dataset_names: list[str] = []):
273
+ print("Loading train and eval datasets...")
274
+ if len(target_dataset_names) == 0:
275
+ return DatasetDict(), DatasetDict()
276
+ train_dataset, eval_dataset = _load_train_eval_datasets_en()
277
+ ds_names = list(train_dataset.keys())
278
+ for ds_name in ds_names:
279
+ if ds_name not in target_dataset_names:
280
+ del train_dataset[ds_name]
281
+ del eval_dataset[ds_name]
282
+ else:
283
+ print(
284
+ "target en ds",
285
+ ds_name,
286
+ len(train_dataset[ds_name]),
287
+ len(eval_dataset[ds_name]),
288
+ )
289
+ return train_dataset, eval_dataset
290
+
291
+
292
+ def load_train_eval_datasets_jp(target_dataset_names: list[str] = []):
293
+ print("Loading train and eval datasets...")
294
+ jp_train_dataset_dir = PROJECT_ROOT / "datasets" / "jp_train_dataset"
295
+ jp_eval_dataset_dir = PROJECT_ROOT / "datasets" / "jp_eval_dataset"
296
+
297
+ train_dataset_dict = {}
298
+ eval_dataset_dict = {}
299
+
300
+ for ds_name in target_dataset_names:
301
+ print("loading jp ds", ds_name)
302
+ try:
303
+ train_ds = Dataset.load_from_disk(f"{jp_train_dataset_dir}/{ds_name}")
304
+ eval_ds = Dataset.load_from_disk(f"{jp_eval_dataset_dir}/{ds_name}")
305
+
306
+ except FileNotFoundError:
307
+ print(f"{ds_name} not found, loading from datasets library...")
308
+ ds = load_dataset(
309
+ "hotchpotch/sentence_transformer_japanese", ds_name, split="train"
310
+ )
311
+ ds_size = len(ds)
312
+ test_size = min(3000, ds_size // 100)
313
+ splitted = ds.train_test_split(test_size=test_size, seed=12)
314
+ train_ds = splitted["train"]
315
+ eval_ds = splitted["test"]
316
+ # save
317
+ train_ds.save_to_disk(f"{jp_train_dataset_dir}/{ds_name}")
318
+ eval_ds.save_to_disk(f"{jp_eval_dataset_dir}/{ds_name}")
319
+ train_dataset_dict[ds_name] = train_ds
320
+ eval_dataset_dict[ds_name] = eval_ds
321
+ return DatasetDict(train_dataset_dict), DatasetDict(eval_dataset_dict)
322
+
323
+
324
+ def main():
325
+ # 1. Load a model to finetune with 2. (Optional) model card data
326
+ print("Loading model...")
327
+ static_embedding = StaticEmbedding(
328
+ AutoTokenizer.from_pretrained("hotchpotch/xlm-roberta-japanese-tokenizer"),
329
+ embedding_dim=1024,
330
+ )
331
+ model = SentenceTransformer(
332
+ modules=[static_embedding],
333
+ model_card_data=SentenceTransformerModelCardData(
334
+ language="ja",
335
+ license="mit",
336
+ model_name="Static Embeddings with japanese tokenizer finetuned on various datasets",
337
+ ),
338
+ )
339
+
340
+ # 3. Set up training & evaluation datasets - each dataset is trained with MNRL (with MRL)
341
+ print("Loading datasets...")
342
+ train_dataset_en, eval_dataset_en = load_train_eval_datasets_en(EN_TARGET_DATASETS)
343
+ train_dataset_jp, eval_dataset_jp = load_train_eval_datasets_jp(JA_TARGET_DATASETS)
344
+ # merge
345
+ print("Merging datasets...")
346
+ train_dataset = DatasetDict({**train_dataset_en, **train_dataset_jp})
347
+ eval_dataset = DatasetDict({**eval_dataset_en, **eval_dataset_jp})
348
+ for ds_name, aug_factor in AUG_FACTOR_DATASETS.items():
349
+ columns = train_dataset[ds_name].column_names
350
+
351
+ def data_aug(example):
352
+ result = {}
353
+ for col in columns:
354
+ result[col] = example[col] * aug_factor
355
+ return result
356
+
357
+ before_len = len(train_dataset[ds_name])
358
+ train_dataset[ds_name] = train_dataset[ds_name].map(
359
+ data_aug, batched=True, num_proc=11
360
+ )
361
+ print("data augmented", ds_name, before_len, len(train_dataset[ds_name]))
362
+ for train_ds_name in train_dataset.keys():
363
+ print(
364
+ "train ds",
365
+ train_ds_name,
366
+ len(train_dataset[train_ds_name]),
367
+ len(eval_dataset[train_ds_name]),
368
+ )
369
+
370
+ # 4. Define a loss function
371
+ loss = MultipleNegativesRankingLoss(model)
372
+ loss = MatryoshkaLoss(model, loss, matryoshka_dims=[32, 64, 128, 256, 512, 1024])
373
+
374
+ # 5. (Optional) Specify training arguments
375
+ run_name = f"static-retrieval-mrl-jp-v1_{EXP}"
376
+ args = SentenceTransformerTrainingArguments(
377
+ # Required parameter:
378
+ output_dir=f"models/{run_name}",
379
+ # Optional training parameters:
380
+ num_train_epochs=2,
381
+ per_device_train_batch_size=2048 * 3,
382
+ # gradient_accumulation_steps=4,
383
+ per_device_eval_batch_size=2048,
384
+ learning_rate=2e-1,
385
+ lr_scheduler_type="cosine",
386
+ # optim="adafactor",
387
+ warmup_ratio=0.1,
388
+ fp16=False, # Set to False if you get an error that your GPU can't run on FP16
389
+ bf16=True, # Set to True if you have a GPU that supports BF16
390
+ batch_sampler=BatchSamplers.NO_DUPLICATES, # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
391
+ multi_dataset_batch_sampler=MultiDatasetBatchSamplers.PROPORTIONAL,
392
+ # Optional tracking/debugging parameters:
393
+ eval_strategy="steps",
394
+ eval_steps=200,
395
+ save_strategy="steps",
396
+ save_steps=200,
397
+ save_total_limit=20,
398
+ logging_steps=20,
399
+ logging_first_step=True,
400
+ dataloader_prefetch_factor=4,
401
+ dataloader_num_workers=15,
402
+ run_name=run_name, # Will be used in W&B if `wandb` is installed
403
+ )
404
+
405
+ # 6. (Optional) Create an evaluator & evaluate the base model
406
+ evaluator = NanoBEIREvaluator()
407
+ evaluator(model)
408
+
409
+ # 7. Create a trainer & train
410
+ trainer = SentenceTransformerTrainer(
411
+ model=model,
412
+ args=args,
413
+ train_dataset=train_dataset,
414
+ eval_dataset=eval_dataset,
415
+ loss=loss,
416
+ evaluator=evaluator,
417
+ )
418
+ trainer.train()
419
+
420
+ # (Optional) Evaluate the trained model on the evaluator after training
421
+ evaluator(model)
422
+
423
+ # 8. Save the trained model
424
+ model.save_pretrained(f"{PROJECT_ROOT}/models/{run_name}/final")
425
+
426
+ # 9. (Optional) Push it to the Hugging Face Hub
427
+ # model.push_to_hub(run_name, private=True)
428
+
429
+
430
+ if __name__ == "__main__":
431
+ main()
432
+