# static-embedding-japanese trainer.py # base: https://huggingface.co./blog/static-embeddings # MIT License import logging import os import random from pathlib import Path from sentence_transformers import ( SentenceTransformer, SentenceTransformerModelCardData, SentenceTransformerTrainer, SentenceTransformerTrainingArguments, ) from sentence_transformers.evaluation import NanoBEIREvaluator from sentence_transformers.losses import MatryoshkaLoss, MultipleNegativesRankingLoss from sentence_transformers.models.StaticEmbedding import StaticEmbedding from sentence_transformers.training_args import BatchSamplers, MultiDatasetBatchSamplers from transformers import AutoTokenizer from datasets import Dataset, DatasetDict, load_dataset EXP = "030" print("EXP:", EXP) PROJECT_ROOT = Path(__file__).resolve().parents[1] print(PROJECT_ROOT) EN_TARGET_DATASETS = [ # "gooaq", # non-commarical "msmarco", "squad", # "s2orc", # large "allnli", # "paq", # large "trivia_qa", # "msmarco_10m", "swim_ir", # "pubmedqa", "miracl", # "mldr", # non-commarical "mr_tydi", ] JA_TARGET_DATASETS = [ "hpprc_emb__auto-wiki-nli-triplet", "hpprc_emb__auto-wiki-qa", "hpprc_emb__auto-wiki-qa-nemotron", "hpprc_emb__auto-wiki-qa-pair", "hpprc_emb__baobab-wiki-retrieval", # "hpprc_emb__jagovfaqs", JMTEB task のtestに正解が含まれている "hpprc_emb__janli-triplet", "hpprc_emb__jaquad", "hpprc_emb__jqara", # JMTEB task のドメイン "hpprc_emb__jsnli-triplet", "hpprc_emb__jsquad", "hpprc_emb__miracl", # JMTEB task のドメイン "hpprc_emb__mkqa", "hpprc_emb__mkqa-triplet", # "hpprc_emb__mmarco", 文字化け等が含みノイジー "hpprc_emb__mr-tydi", # JMTEB task のドメイン "hpprc_emb__nu-mnli-triplet", "hpprc_emb__nu-snli-triplet", # "hpprc_emb__paws-x-triplet", JMTEB task のtestに含まれている? "hpprc_emb__quiz-no-mori", "hpprc_emb__quiz-works", "hpprc_emb__snow-triplet", "hpprc_llmjp-kaken", "hpprc_llmjp_warp_html", "hpprc_mqa_ja", "hpprc_msmarco_ja", ] AUG_FACTOR_DATASETS = { "hpprc_emb__miracl": 20, "hpprc_emb__mr-tydi": 20, "hpprc_emb__jqara": 10, "hpprc_emb__baobab-wiki-retrieval": 5, "hpprc_emb__mkqa": 5, "hpprc_emb__auto-wiki-qa-nemotron": 2, "hpprc_msmarco_ja": 2, } os.environ["TOKENIZERS_PARALLELISM"] = "false" logging.basicConfig( format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO ) random.seed(12) def _load_train_eval_datasets_en(): """ Either load the train and eval datasets from disk or load them from the datasets library & save them to disk. Upon saving to disk, we quit() to ensure that the datasets are not loaded into memory before training. """ en_train_dataset_dir = PROJECT_ROOT / "datasets" / "en_train_dataset" en_eval_dataset_dir = PROJECT_ROOT / "datasets" / "en_eval_dataset" try: train_dataset = DatasetDict.load_from_disk(en_train_dataset_dir) eval_dataset = DatasetDict.load_from_disk(en_eval_dataset_dir) return train_dataset, eval_dataset except FileNotFoundError: print("Loading gooaq dataset...") gooaq_dataset = load_dataset("sentence-transformers/gooaq", split="train") gooaq_dataset_dict = gooaq_dataset.train_test_split(test_size=10_000, seed=12) gooaq_train_dataset: Dataset = gooaq_dataset_dict["train"] gooaq_eval_dataset: Dataset = gooaq_dataset_dict["test"] print("Loaded gooaq dataset.") print("Loading msmarco dataset...") msmarco_dataset = load_dataset( "sentence-transformers/msmarco-co-condenser-margin-mse-sym-mnrl-mean-v1", "triplet", split="train", ) msmarco_dataset_dict = msmarco_dataset.train_test_split( test_size=10_000, seed=12 ) msmarco_train_dataset: Dataset = msmarco_dataset_dict["train"] msmarco_eval_dataset: Dataset = msmarco_dataset_dict["test"] print("Loaded msmarco dataset.") print("Loading squad dataset...") squad_dataset = load_dataset("sentence-transformers/squad", split="train") squad_dataset_dict = squad_dataset.train_test_split(test_size=10_000, seed=12) squad_train_dataset: Dataset = squad_dataset_dict["train"] squad_eval_dataset: Dataset = squad_dataset_dict["test"] print("Loaded squad dataset.") print("Loading s2orc dataset...") s2orc_dataset = load_dataset( "sentence-transformers/s2orc", "title-abstract-pair", split="train[:100000]" ) s2orc_dataset_dict = s2orc_dataset.train_test_split(test_size=10_000, seed=12) s2orc_train_dataset: Dataset = s2orc_dataset_dict["train"] s2orc_eval_dataset: Dataset = s2orc_dataset_dict["test"] print("Loaded s2orc dataset.") print("Loading allnli dataset...") allnli_train_dataset = load_dataset( "sentence-transformers/all-nli", "triplet", split="train" ) allnli_eval_dataset = load_dataset( "sentence-transformers/all-nli", "triplet", split="dev" ) print("Loaded allnli dataset.") print("Loading paq dataset...") paq_dataset = load_dataset("sentence-transformers/paq", split="train") paq_dataset_dict = paq_dataset.train_test_split(test_size=10_000, seed=12) paq_train_dataset: Dataset = paq_dataset_dict["train"] paq_eval_dataset: Dataset = paq_dataset_dict["test"] print("Loaded paq dataset.") print("Loading trivia_qa dataset...") trivia_qa = load_dataset("sentence-transformers/trivia-qa", split="train") trivia_qa_dataset_dict = trivia_qa.train_test_split(test_size=5_000, seed=12) trivia_qa_train_dataset: Dataset = trivia_qa_dataset_dict["train"] trivia_qa_eval_dataset: Dataset = trivia_qa_dataset_dict["test"] print("Loaded trivia_qa dataset.") print("Loading msmarco_10m dataset...") msmarco_10m_dataset = load_dataset( "bclavie/msmarco-10m-triplets", split="train" ) msmarco_10m_dataset_dict = msmarco_10m_dataset.train_test_split( test_size=10_000, seed=12 ) msmarco_10m_train_dataset: Dataset = msmarco_10m_dataset_dict["train"] msmarco_10m_eval_dataset: Dataset = msmarco_10m_dataset_dict["test"] print("Loaded msmarco_10m dataset.") print("Loading swim_ir dataset...") swim_ir_dataset = load_dataset( "nthakur/swim-ir-monolingual", "en", split="train" ).select_columns(["query", "text"]) swim_ir_dataset_dict = swim_ir_dataset.train_test_split( test_size=10_000, seed=12 ) swim_ir_train_dataset: Dataset = swim_ir_dataset_dict["train"] swim_ir_eval_dataset: Dataset = swim_ir_dataset_dict["test"] print("Loaded swim_ir dataset.") # NOTE: 20 negatives print("Loading pubmedqa dataset...") pubmedqa_dataset = load_dataset( "sentence-transformers/pubmedqa", "triplet-20", split="train" ) pubmedqa_dataset_dict = pubmedqa_dataset.train_test_split( test_size=100, seed=12 ) pubmedqa_train_dataset: Dataset = pubmedqa_dataset_dict["train"] pubmedqa_eval_dataset: Dataset = pubmedqa_dataset_dict["test"] print("Loaded pubmedqa dataset.") # NOTE: A lot of overlap with anchor/positives print("Loading miracl dataset...") miracl_dataset = load_dataset( "sentence-transformers/miracl", "en-triplet-all", split="train" ) miracl_dataset_dict = miracl_dataset.train_test_split(test_size=10_000, seed=12) miracl_train_dataset: Dataset = miracl_dataset_dict["train"] miracl_eval_dataset: Dataset = miracl_dataset_dict["test"] print("Loaded miracl dataset.") # NOTE: A lot of overlap with anchor/positives print("Loading mldr dataset...") mldr_dataset = load_dataset( "sentence-transformers/mldr", "en-triplet-all", split="train" ) mldr_dataset_dict = mldr_dataset.train_test_split(test_size=10_000, seed=12) mldr_train_dataset: Dataset = mldr_dataset_dict["train"] mldr_eval_dataset: Dataset = mldr_dataset_dict["test"] print("Loaded mldr dataset.") # NOTE: A lot of overlap with anchor/positives print("Loading mr_tydi dataset...") mr_tydi_dataset = load_dataset( "sentence-transformers/mr-tydi", "en-triplet-all", split="train" ) mr_tydi_dataset_dict = mr_tydi_dataset.train_test_split( test_size=10_000, seed=12 ) mr_tydi_train_dataset: Dataset = mr_tydi_dataset_dict["train"] mr_tydi_eval_dataset: Dataset = mr_tydi_dataset_dict["test"] print("Loaded mr_tydi dataset.") train_dataset = DatasetDict( { "gooaq": gooaq_train_dataset, "msmarco": msmarco_train_dataset, "squad": squad_train_dataset, "s2orc": s2orc_train_dataset, "allnli": allnli_train_dataset, "paq": paq_train_dataset, "trivia_qa": trivia_qa_train_dataset, "msmarco_10m": msmarco_10m_train_dataset, "swim_ir": swim_ir_train_dataset, "pubmedqa": pubmedqa_train_dataset, "miracl": miracl_train_dataset, "mldr": mldr_train_dataset, "mr_tydi": mr_tydi_train_dataset, } ) eval_dataset = DatasetDict( { "gooaq": gooaq_eval_dataset, "msmarco": msmarco_eval_dataset, "squad": squad_eval_dataset, "s2orc": s2orc_eval_dataset, "allnli": allnli_eval_dataset, "paq": paq_eval_dataset, "trivia_qa": trivia_qa_eval_dataset, "msmarco_10m": msmarco_10m_eval_dataset, "swim_ir": swim_ir_eval_dataset, "pubmedqa": pubmedqa_eval_dataset, "miracl": miracl_eval_dataset, "mldr": mldr_eval_dataset, "mr_tydi": mr_tydi_eval_dataset, } ) train_dataset.save_to_disk(str(en_train_dataset_dir)) eval_dataset.save_to_disk(str(en_eval_dataset_dir)) return train_dataset, eval_dataset def load_train_eval_datasets_en(target_dataset_names: list[str] = []): print("Loading train and eval datasets...") if len(target_dataset_names) == 0: return DatasetDict(), DatasetDict() train_dataset, eval_dataset = _load_train_eval_datasets_en() ds_names = list(train_dataset.keys()) for ds_name in ds_names: if ds_name not in target_dataset_names: del train_dataset[ds_name] del eval_dataset[ds_name] else: print( "target en ds", ds_name, len(train_dataset[ds_name]), len(eval_dataset[ds_name]), ) return train_dataset, eval_dataset def load_train_eval_datasets_jp(target_dataset_names: list[str] = []): print("Loading train and eval datasets...") jp_train_dataset_dir = PROJECT_ROOT / "datasets" / "jp_train_dataset" jp_eval_dataset_dir = PROJECT_ROOT / "datasets" / "jp_eval_dataset" train_dataset_dict = {} eval_dataset_dict = {} for ds_name in target_dataset_names: print("loading jp ds", ds_name) try: train_ds = Dataset.load_from_disk(f"{jp_train_dataset_dir}/{ds_name}") eval_ds = Dataset.load_from_disk(f"{jp_eval_dataset_dir}/{ds_name}") except FileNotFoundError: print(f"{ds_name} not found, loading from datasets library...") ds = load_dataset( "hotchpotch/sentence_transformer_japanese", ds_name, split="train" ) ds_size = len(ds) test_size = min(3000, ds_size // 100) splitted = ds.train_test_split(test_size=test_size, seed=12) train_ds = splitted["train"] eval_ds = splitted["test"] # save train_ds.save_to_disk(f"{jp_train_dataset_dir}/{ds_name}") eval_ds.save_to_disk(f"{jp_eval_dataset_dir}/{ds_name}") train_dataset_dict[ds_name] = train_ds eval_dataset_dict[ds_name] = eval_ds return DatasetDict(train_dataset_dict), DatasetDict(eval_dataset_dict) def main(): # 1. Load a model to finetune with 2. (Optional) model card data print("Loading model...") static_embedding = StaticEmbedding( AutoTokenizer.from_pretrained("hotchpotch/xlm-roberta-japanese-tokenizer"), embedding_dim=1024, ) model = SentenceTransformer( modules=[static_embedding], model_card_data=SentenceTransformerModelCardData( language="ja", license="mit", model_name="Static Embeddings with japanese tokenizer finetuned on various datasets", ), ) # 3. Set up training & evaluation datasets - each dataset is trained with MNRL (with MRL) print("Loading datasets...") train_dataset_en, eval_dataset_en = load_train_eval_datasets_en(EN_TARGET_DATASETS) train_dataset_jp, eval_dataset_jp = load_train_eval_datasets_jp(JA_TARGET_DATASETS) # merge print("Merging datasets...") train_dataset = DatasetDict({**train_dataset_en, **train_dataset_jp}) eval_dataset = DatasetDict({**eval_dataset_en, **eval_dataset_jp}) for ds_name, aug_factor in AUG_FACTOR_DATASETS.items(): columns = train_dataset[ds_name].column_names def data_aug(example): result = {} for col in columns: result[col] = example[col] * aug_factor return result before_len = len(train_dataset[ds_name]) train_dataset[ds_name] = train_dataset[ds_name].map( data_aug, batched=True, num_proc=11 ) print("data augmented", ds_name, before_len, len(train_dataset[ds_name])) for train_ds_name in train_dataset.keys(): print( "train ds", train_ds_name, len(train_dataset[train_ds_name]), len(eval_dataset[train_ds_name]), ) # 4. Define a loss function loss = MultipleNegativesRankingLoss(model) loss = MatryoshkaLoss(model, loss, matryoshka_dims=[32, 64, 128, 256, 512, 1024]) # 5. (Optional) Specify training arguments run_name = f"static-retrieval-mrl-jp-v1_{EXP}" args = SentenceTransformerTrainingArguments( # Required parameter: output_dir=f"models/{run_name}", # Optional training parameters: num_train_epochs=2, per_device_train_batch_size=2048 * 3, # gradient_accumulation_steps=4, per_device_eval_batch_size=2048, learning_rate=2e-1, lr_scheduler_type="cosine", # optim="adafactor", warmup_ratio=0.1, fp16=False, # Set to False if you get an error that your GPU can't run on FP16 bf16=True, # Set to True if you have a GPU that supports BF16 batch_sampler=BatchSamplers.NO_DUPLICATES, # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch multi_dataset_batch_sampler=MultiDatasetBatchSamplers.PROPORTIONAL, # Optional tracking/debugging parameters: eval_strategy="steps", eval_steps=200, save_strategy="steps", save_steps=200, save_total_limit=20, logging_steps=20, logging_first_step=True, dataloader_prefetch_factor=4, dataloader_num_workers=15, run_name=run_name, # Will be used in W&B if `wandb` is installed ) # 6. (Optional) Create an evaluator & evaluate the base model evaluator = NanoBEIREvaluator() evaluator(model) # 7. Create a trainer & train trainer = SentenceTransformerTrainer( model=model, args=args, train_dataset=train_dataset, eval_dataset=eval_dataset, loss=loss, evaluator=evaluator, ) trainer.train() # (Optional) Evaluate the trained model on the evaluator after training evaluator(model) # 8. Save the trained model model.save_pretrained(f"{PROJECT_ROOT}/models/{run_name}/final") # 9. (Optional) Push it to the Hugging Face Hub # model.push_to_hub(run_name, private=True) if __name__ == "__main__": main()