Pre-trained embedded value extraction

#440
by babykai - opened

Thank you very much for your excellent work. I would like to ask how to extract the pre-trained model embedding values from the previous Hugging Face based tanseformers package for geneformer pre-training code.
Here is the code fragment.

model = BertForMaskedLM(config)
model = model.train()

define the training arguments

training_args = {
"learning_rate": max_lr,
"do_train": True,
"do_eval": False,
"group_by_length": True,
"length_column_name": "length",
"disable_tqdm": False,
"lr_scheduler_type": lr_schedule_fn,
"warmup_steps": warmup_steps,
"weight_decay": weight_decay,
"per_device_train_batch_size": geneformer_batch_size,
"num_train_epochs": epochs,
"save_strategy": "steps",
"save_steps": np.floor(num_examples / geneformer_batch_size / 8), # 8 saves per epoch
"logging_steps": 1000,
"output_dir": training_output_dir,
"logging_dir": logging_dir,
}
training_args = TrainingArguments(**training_args)

print("Starting training.")

define the trainer

trainer = GeneformerPretrainer(
model=model,
args=training_args,
train_dataset=load_from_disk("predigs_revise.dataset"),
example_lengths_file="predigs_revise.pkl",
token_dictionary=token_dictionary,
)

train

trainer.train()

save model

trainer.save_model(model_output_dir)

Thank you for your question. Please let us know if you were asking something else, but if you are asking how to extract gene or cell embeddings, you can use the embedding extractor we provide: https://geneformer.readthedocs.io/en/latest/geneformer.emb_extractor.html

ctheodoris changed discussion status to closed

Sign up or log in to comment