Pre-trained embedded value extraction
Thank you very much for your excellent work. I would like to ask how to extract the pre-trained model embedding values from the previous Hugging Face based tanseformers package for geneformer pre-training code.
Here is the code fragment.
model = BertForMaskedLM(config)
model = model.train()
define the training arguments
training_args = {
"learning_rate": max_lr,
"do_train": True,
"do_eval": False,
"group_by_length": True,
"length_column_name": "length",
"disable_tqdm": False,
"lr_scheduler_type": lr_schedule_fn,
"warmup_steps": warmup_steps,
"weight_decay": weight_decay,
"per_device_train_batch_size": geneformer_batch_size,
"num_train_epochs": epochs,
"save_strategy": "steps",
"save_steps": np.floor(num_examples / geneformer_batch_size / 8), # 8 saves per epoch
"logging_steps": 1000,
"output_dir": training_output_dir,
"logging_dir": logging_dir,
}
training_args = TrainingArguments(**training_args)
print("Starting training.")
define the trainer
trainer = GeneformerPretrainer(
model=model,
args=training_args,
train_dataset=load_from_disk("predigs_revise.dataset"),
example_lengths_file="predigs_revise.pkl",
token_dictionary=token_dictionary,
)
train
trainer.train()
save model
trainer.save_model(model_output_dir)
Thank you for your question. Please let us know if you were asking something else, but if you are asking how to extract gene or cell embeddings, you can use the embedding extractor we provide: https://geneformer.readthedocs.io/en/latest/geneformer.emb_extractor.html