|
|
|
|
|
|
|
|
|
import logging |
|
import os |
|
import random |
|
from pathlib import Path |
|
|
|
from sentence_transformers import ( |
|
SentenceTransformer, |
|
SentenceTransformerModelCardData, |
|
SentenceTransformerTrainer, |
|
SentenceTransformerTrainingArguments, |
|
) |
|
from sentence_transformers.evaluation import NanoBEIREvaluator |
|
from sentence_transformers.losses import MatryoshkaLoss, MultipleNegativesRankingLoss |
|
from sentence_transformers.models.StaticEmbedding import StaticEmbedding |
|
from sentence_transformers.training_args import BatchSamplers, MultiDatasetBatchSamplers |
|
from transformers import AutoTokenizer |
|
|
|
from datasets import Dataset, DatasetDict, load_dataset |
|
|
|
EXP = "030" |
|
print("EXP:", EXP) |
|
|
|
PROJECT_ROOT = Path(__file__).resolve().parents[1] |
|
print(PROJECT_ROOT) |
|
|
|
EN_TARGET_DATASETS = [ |
|
|
|
"msmarco", |
|
"squad", |
|
|
|
"allnli", |
|
|
|
"trivia_qa", |
|
|
|
"swim_ir", |
|
|
|
"miracl", |
|
|
|
"mr_tydi", |
|
] |
|
|
|
JA_TARGET_DATASETS = [ |
|
"hpprc_emb__auto-wiki-nli-triplet", |
|
"hpprc_emb__auto-wiki-qa", |
|
"hpprc_emb__auto-wiki-qa-nemotron", |
|
"hpprc_emb__auto-wiki-qa-pair", |
|
"hpprc_emb__baobab-wiki-retrieval", |
|
|
|
"hpprc_emb__janli-triplet", |
|
"hpprc_emb__jaquad", |
|
"hpprc_emb__jqara", |
|
"hpprc_emb__jsnli-triplet", |
|
"hpprc_emb__jsquad", |
|
"hpprc_emb__miracl", |
|
"hpprc_emb__mkqa", |
|
"hpprc_emb__mkqa-triplet", |
|
|
|
"hpprc_emb__mr-tydi", |
|
"hpprc_emb__nu-mnli-triplet", |
|
"hpprc_emb__nu-snli-triplet", |
|
|
|
"hpprc_emb__quiz-no-mori", |
|
"hpprc_emb__quiz-works", |
|
"hpprc_emb__snow-triplet", |
|
"hpprc_llmjp-kaken", |
|
"hpprc_llmjp_warp_html", |
|
"hpprc_mqa_ja", |
|
"hpprc_msmarco_ja", |
|
] |
|
|
|
AUG_FACTOR_DATASETS = { |
|
"hpprc_emb__miracl": 20, |
|
"hpprc_emb__mr-tydi": 20, |
|
"hpprc_emb__jqara": 10, |
|
"hpprc_emb__baobab-wiki-retrieval": 5, |
|
"hpprc_emb__mkqa": 5, |
|
"hpprc_emb__auto-wiki-qa-nemotron": 2, |
|
"hpprc_msmarco_ja": 2, |
|
} |
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
logging.basicConfig( |
|
format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO |
|
) |
|
random.seed(12) |
|
|
|
|
|
def _load_train_eval_datasets_en(): |
|
""" |
|
Either load the train and eval datasets from disk or load them from the datasets library & save them to disk. |
|
|
|
Upon saving to disk, we quit() to ensure that the datasets are not loaded into memory before training. |
|
""" |
|
en_train_dataset_dir = PROJECT_ROOT / "datasets" / "en_train_dataset" |
|
en_eval_dataset_dir = PROJECT_ROOT / "datasets" / "en_eval_dataset" |
|
try: |
|
train_dataset = DatasetDict.load_from_disk(en_train_dataset_dir) |
|
eval_dataset = DatasetDict.load_from_disk(en_eval_dataset_dir) |
|
return train_dataset, eval_dataset |
|
except FileNotFoundError: |
|
print("Loading gooaq dataset...") |
|
gooaq_dataset = load_dataset("sentence-transformers/gooaq", split="train") |
|
gooaq_dataset_dict = gooaq_dataset.train_test_split(test_size=10_000, seed=12) |
|
gooaq_train_dataset: Dataset = gooaq_dataset_dict["train"] |
|
gooaq_eval_dataset: Dataset = gooaq_dataset_dict["test"] |
|
print("Loaded gooaq dataset.") |
|
|
|
print("Loading msmarco dataset...") |
|
msmarco_dataset = load_dataset( |
|
"sentence-transformers/msmarco-co-condenser-margin-mse-sym-mnrl-mean-v1", |
|
"triplet", |
|
split="train", |
|
) |
|
msmarco_dataset_dict = msmarco_dataset.train_test_split( |
|
test_size=10_000, seed=12 |
|
) |
|
msmarco_train_dataset: Dataset = msmarco_dataset_dict["train"] |
|
msmarco_eval_dataset: Dataset = msmarco_dataset_dict["test"] |
|
print("Loaded msmarco dataset.") |
|
|
|
print("Loading squad dataset...") |
|
squad_dataset = load_dataset("sentence-transformers/squad", split="train") |
|
squad_dataset_dict = squad_dataset.train_test_split(test_size=10_000, seed=12) |
|
squad_train_dataset: Dataset = squad_dataset_dict["train"] |
|
squad_eval_dataset: Dataset = squad_dataset_dict["test"] |
|
print("Loaded squad dataset.") |
|
|
|
print("Loading s2orc dataset...") |
|
s2orc_dataset = load_dataset( |
|
"sentence-transformers/s2orc", "title-abstract-pair", split="train[:100000]" |
|
) |
|
s2orc_dataset_dict = s2orc_dataset.train_test_split(test_size=10_000, seed=12) |
|
s2orc_train_dataset: Dataset = s2orc_dataset_dict["train"] |
|
s2orc_eval_dataset: Dataset = s2orc_dataset_dict["test"] |
|
print("Loaded s2orc dataset.") |
|
|
|
print("Loading allnli dataset...") |
|
allnli_train_dataset = load_dataset( |
|
"sentence-transformers/all-nli", "triplet", split="train" |
|
) |
|
allnli_eval_dataset = load_dataset( |
|
"sentence-transformers/all-nli", "triplet", split="dev" |
|
) |
|
print("Loaded allnli dataset.") |
|
|
|
print("Loading paq dataset...") |
|
paq_dataset = load_dataset("sentence-transformers/paq", split="train") |
|
paq_dataset_dict = paq_dataset.train_test_split(test_size=10_000, seed=12) |
|
paq_train_dataset: Dataset = paq_dataset_dict["train"] |
|
paq_eval_dataset: Dataset = paq_dataset_dict["test"] |
|
print("Loaded paq dataset.") |
|
|
|
print("Loading trivia_qa dataset...") |
|
trivia_qa = load_dataset("sentence-transformers/trivia-qa", split="train") |
|
trivia_qa_dataset_dict = trivia_qa.train_test_split(test_size=5_000, seed=12) |
|
trivia_qa_train_dataset: Dataset = trivia_qa_dataset_dict["train"] |
|
trivia_qa_eval_dataset: Dataset = trivia_qa_dataset_dict["test"] |
|
print("Loaded trivia_qa dataset.") |
|
|
|
print("Loading msmarco_10m dataset...") |
|
msmarco_10m_dataset = load_dataset( |
|
"bclavie/msmarco-10m-triplets", split="train" |
|
) |
|
msmarco_10m_dataset_dict = msmarco_10m_dataset.train_test_split( |
|
test_size=10_000, seed=12 |
|
) |
|
msmarco_10m_train_dataset: Dataset = msmarco_10m_dataset_dict["train"] |
|
msmarco_10m_eval_dataset: Dataset = msmarco_10m_dataset_dict["test"] |
|
print("Loaded msmarco_10m dataset.") |
|
|
|
print("Loading swim_ir dataset...") |
|
swim_ir_dataset = load_dataset( |
|
"nthakur/swim-ir-monolingual", "en", split="train" |
|
).select_columns(["query", "text"]) |
|
swim_ir_dataset_dict = swim_ir_dataset.train_test_split( |
|
test_size=10_000, seed=12 |
|
) |
|
swim_ir_train_dataset: Dataset = swim_ir_dataset_dict["train"] |
|
swim_ir_eval_dataset: Dataset = swim_ir_dataset_dict["test"] |
|
print("Loaded swim_ir dataset.") |
|
|
|
|
|
print("Loading pubmedqa dataset...") |
|
pubmedqa_dataset = load_dataset( |
|
"sentence-transformers/pubmedqa", "triplet-20", split="train" |
|
) |
|
pubmedqa_dataset_dict = pubmedqa_dataset.train_test_split( |
|
test_size=100, seed=12 |
|
) |
|
pubmedqa_train_dataset: Dataset = pubmedqa_dataset_dict["train"] |
|
pubmedqa_eval_dataset: Dataset = pubmedqa_dataset_dict["test"] |
|
print("Loaded pubmedqa dataset.") |
|
|
|
|
|
print("Loading miracl dataset...") |
|
miracl_dataset = load_dataset( |
|
"sentence-transformers/miracl", "en-triplet-all", split="train" |
|
) |
|
miracl_dataset_dict = miracl_dataset.train_test_split(test_size=10_000, seed=12) |
|
miracl_train_dataset: Dataset = miracl_dataset_dict["train"] |
|
miracl_eval_dataset: Dataset = miracl_dataset_dict["test"] |
|
print("Loaded miracl dataset.") |
|
|
|
|
|
print("Loading mldr dataset...") |
|
mldr_dataset = load_dataset( |
|
"sentence-transformers/mldr", "en-triplet-all", split="train" |
|
) |
|
mldr_dataset_dict = mldr_dataset.train_test_split(test_size=10_000, seed=12) |
|
mldr_train_dataset: Dataset = mldr_dataset_dict["train"] |
|
mldr_eval_dataset: Dataset = mldr_dataset_dict["test"] |
|
print("Loaded mldr dataset.") |
|
|
|
|
|
print("Loading mr_tydi dataset...") |
|
mr_tydi_dataset = load_dataset( |
|
"sentence-transformers/mr-tydi", "en-triplet-all", split="train" |
|
) |
|
mr_tydi_dataset_dict = mr_tydi_dataset.train_test_split( |
|
test_size=10_000, seed=12 |
|
) |
|
mr_tydi_train_dataset: Dataset = mr_tydi_dataset_dict["train"] |
|
mr_tydi_eval_dataset: Dataset = mr_tydi_dataset_dict["test"] |
|
print("Loaded mr_tydi dataset.") |
|
|
|
train_dataset = DatasetDict( |
|
{ |
|
"gooaq": gooaq_train_dataset, |
|
"msmarco": msmarco_train_dataset, |
|
"squad": squad_train_dataset, |
|
"s2orc": s2orc_train_dataset, |
|
"allnli": allnli_train_dataset, |
|
"paq": paq_train_dataset, |
|
"trivia_qa": trivia_qa_train_dataset, |
|
"msmarco_10m": msmarco_10m_train_dataset, |
|
"swim_ir": swim_ir_train_dataset, |
|
"pubmedqa": pubmedqa_train_dataset, |
|
"miracl": miracl_train_dataset, |
|
"mldr": mldr_train_dataset, |
|
"mr_tydi": mr_tydi_train_dataset, |
|
} |
|
) |
|
eval_dataset = DatasetDict( |
|
{ |
|
"gooaq": gooaq_eval_dataset, |
|
"msmarco": msmarco_eval_dataset, |
|
"squad": squad_eval_dataset, |
|
"s2orc": s2orc_eval_dataset, |
|
"allnli": allnli_eval_dataset, |
|
"paq": paq_eval_dataset, |
|
"trivia_qa": trivia_qa_eval_dataset, |
|
"msmarco_10m": msmarco_10m_eval_dataset, |
|
"swim_ir": swim_ir_eval_dataset, |
|
"pubmedqa": pubmedqa_eval_dataset, |
|
"miracl": miracl_eval_dataset, |
|
"mldr": mldr_eval_dataset, |
|
"mr_tydi": mr_tydi_eval_dataset, |
|
} |
|
) |
|
|
|
train_dataset.save_to_disk(str(en_train_dataset_dir)) |
|
eval_dataset.save_to_disk(str(en_eval_dataset_dir)) |
|
return train_dataset, eval_dataset |
|
|
|
|
|
def load_train_eval_datasets_en(target_dataset_names: list[str] = []): |
|
print("Loading train and eval datasets...") |
|
if len(target_dataset_names) == 0: |
|
return DatasetDict(), DatasetDict() |
|
train_dataset, eval_dataset = _load_train_eval_datasets_en() |
|
ds_names = list(train_dataset.keys()) |
|
for ds_name in ds_names: |
|
if ds_name not in target_dataset_names: |
|
del train_dataset[ds_name] |
|
del eval_dataset[ds_name] |
|
else: |
|
print( |
|
"target en ds", |
|
ds_name, |
|
len(train_dataset[ds_name]), |
|
len(eval_dataset[ds_name]), |
|
) |
|
return train_dataset, eval_dataset |
|
|
|
|
|
def load_train_eval_datasets_jp(target_dataset_names: list[str] = []): |
|
print("Loading train and eval datasets...") |
|
jp_train_dataset_dir = PROJECT_ROOT / "datasets" / "jp_train_dataset" |
|
jp_eval_dataset_dir = PROJECT_ROOT / "datasets" / "jp_eval_dataset" |
|
|
|
train_dataset_dict = {} |
|
eval_dataset_dict = {} |
|
|
|
for ds_name in target_dataset_names: |
|
print("loading jp ds", ds_name) |
|
try: |
|
train_ds = Dataset.load_from_disk(f"{jp_train_dataset_dir}/{ds_name}") |
|
eval_ds = Dataset.load_from_disk(f"{jp_eval_dataset_dir}/{ds_name}") |
|
|
|
except FileNotFoundError: |
|
print(f"{ds_name} not found, loading from datasets library...") |
|
ds = load_dataset( |
|
"hotchpotch/sentence_transformer_japanese", ds_name, split="train" |
|
) |
|
ds_size = len(ds) |
|
test_size = min(3000, ds_size // 100) |
|
splitted = ds.train_test_split(test_size=test_size, seed=12) |
|
train_ds = splitted["train"] |
|
eval_ds = splitted["test"] |
|
|
|
train_ds.save_to_disk(f"{jp_train_dataset_dir}/{ds_name}") |
|
eval_ds.save_to_disk(f"{jp_eval_dataset_dir}/{ds_name}") |
|
train_dataset_dict[ds_name] = train_ds |
|
eval_dataset_dict[ds_name] = eval_ds |
|
return DatasetDict(train_dataset_dict), DatasetDict(eval_dataset_dict) |
|
|
|
|
|
def main(): |
|
|
|
print("Loading model...") |
|
static_embedding = StaticEmbedding( |
|
AutoTokenizer.from_pretrained("hotchpotch/xlm-roberta-japanese-tokenizer"), |
|
embedding_dim=1024, |
|
) |
|
model = SentenceTransformer( |
|
modules=[static_embedding], |
|
model_card_data=SentenceTransformerModelCardData( |
|
language="ja", |
|
license="mit", |
|
model_name="Static Embeddings with japanese tokenizer finetuned on various datasets", |
|
), |
|
) |
|
|
|
|
|
print("Loading datasets...") |
|
train_dataset_en, eval_dataset_en = load_train_eval_datasets_en(EN_TARGET_DATASETS) |
|
train_dataset_jp, eval_dataset_jp = load_train_eval_datasets_jp(JA_TARGET_DATASETS) |
|
|
|
print("Merging datasets...") |
|
train_dataset = DatasetDict({**train_dataset_en, **train_dataset_jp}) |
|
eval_dataset = DatasetDict({**eval_dataset_en, **eval_dataset_jp}) |
|
for ds_name, aug_factor in AUG_FACTOR_DATASETS.items(): |
|
columns = train_dataset[ds_name].column_names |
|
|
|
def data_aug(example): |
|
result = {} |
|
for col in columns: |
|
result[col] = example[col] * aug_factor |
|
return result |
|
|
|
before_len = len(train_dataset[ds_name]) |
|
train_dataset[ds_name] = train_dataset[ds_name].map( |
|
data_aug, batched=True, num_proc=11 |
|
) |
|
print("data augmented", ds_name, before_len, len(train_dataset[ds_name])) |
|
for train_ds_name in train_dataset.keys(): |
|
print( |
|
"train ds", |
|
train_ds_name, |
|
len(train_dataset[train_ds_name]), |
|
len(eval_dataset[train_ds_name]), |
|
) |
|
|
|
|
|
loss = MultipleNegativesRankingLoss(model) |
|
loss = MatryoshkaLoss(model, loss, matryoshka_dims=[32, 64, 128, 256, 512, 1024]) |
|
|
|
|
|
run_name = f"static-retrieval-mrl-jp-v1_{EXP}" |
|
args = SentenceTransformerTrainingArguments( |
|
|
|
output_dir=f"models/{run_name}", |
|
|
|
num_train_epochs=2, |
|
per_device_train_batch_size=2048 * 3, |
|
|
|
per_device_eval_batch_size=2048, |
|
learning_rate=2e-1, |
|
lr_scheduler_type="cosine", |
|
|
|
warmup_ratio=0.1, |
|
fp16=False, |
|
bf16=True, |
|
batch_sampler=BatchSamplers.NO_DUPLICATES, |
|
multi_dataset_batch_sampler=MultiDatasetBatchSamplers.PROPORTIONAL, |
|
|
|
eval_strategy="steps", |
|
eval_steps=200, |
|
save_strategy="steps", |
|
save_steps=200, |
|
save_total_limit=20, |
|
logging_steps=20, |
|
logging_first_step=True, |
|
dataloader_prefetch_factor=4, |
|
dataloader_num_workers=15, |
|
run_name=run_name, |
|
) |
|
|
|
|
|
evaluator = NanoBEIREvaluator() |
|
evaluator(model) |
|
|
|
|
|
trainer = SentenceTransformerTrainer( |
|
model=model, |
|
args=args, |
|
train_dataset=train_dataset, |
|
eval_dataset=eval_dataset, |
|
loss=loss, |
|
evaluator=evaluator, |
|
) |
|
trainer.train() |
|
|
|
|
|
evaluator(model) |
|
|
|
|
|
model.save_pretrained(f"{PROJECT_ROOT}/models/{run_name}/final") |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|
|
|