The results of the cell classification task were repeatable

#431
by babykai - opened

In replicating your previous version of the cell classification script, I found that running the following script multiple times without making any changes produced slightly different accuracy rates, F1 values, etc. How to set up to get the same result in the case of multiple runs?
Here is some of the code

max_input_size = 2 ** 11 # 2048

set training hyperparameters

max learning rate

max_lr = 6e-5

how many pretrained layers to freeze

freeze_layers = 0

number gpus

num_gpus = 1

number cpu cores

num_proc = 16

batch size for training and eval

geneformer_batch_size = 16

learning schedule

lr_schedule_fn = "linear"

warmup steps

warmup_steps = 500

number of epochs

epochs = 12

optimizer

optimizer = "adamw"

for organ in organ_list:
print(organ)
organ_trainset = trainset_dict[organ]
organ_evalset = evalset_dict[organ]
organ_label_dict = traintargetdict_dict[organ]

# set logging steps
logging_steps = round(len(organ_trainset)/geneformer_batch_size/10)

# reload pretrained model
model = BertForSequenceClassification.from_pretrained("all_cellxgene_geo_predigs_revise.dataset",
                                                  num_labels=len(organ_label_dict.keys()),
                                                  output_attentions = False,
                                                  output_hidden_states = False)

# define output directory path
current_date = datetime.datetime.now()
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
output_dir = f"./{datestamp}_geneformer_CellClassifier_{organ}_L{max_input_size}_B{geneformer_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_F{freeze_layers}/"

# ensure not overwriting previously saved model
saved_model_test = os.path.join(output_dir, f"pytorch_model.bin")
if os.path.isfile(saved_model_test) == True:
    raise Exception("Model already saved to this directory.")

# make output directory
subprocess.call(f'mkdir {output_dir}', shell=True)

# set training arguments
training_args = {
    "learning_rate": max_lr,
    "do_train": True,
    "do_eval": True,
    "evaluation_strategy": "epoch",
    "save_strategy": "epoch",
    "logging_steps": logging_steps,
    "group_by_length": True,
    "length_column_name": "length",
    "disable_tqdm": False,
    "lr_scheduler_type": lr_schedule_fn,
    "warmup_steps": warmup_steps,
    "weight_decay": 0.001,
    "per_device_train_batch_size": geneformer_batch_size,
    "per_device_eval_batch_size": geneformer_batch_size,
    "num_train_epochs": epochs,
    "load_best_model_at_end": True,
    "output_dir": output_dir,
    "seed": 42,
    "data_seed": 42
}

training_args_init = TrainingArguments(**training_args)

# create the trainer
trainer = Trainer(
    model=model,
    args=training_args_init,
    data_collator=DataCollatorForCellClassification(),
    train_dataset=organ_trainset,
    eval_dataset=organ_evalset,
    compute_metrics=compute_metrics
)
# train the cell type classifier
trainer.train()
predictions = trainer.predict(organ_evalset)
with open(f"{output_dir}predictions.pickle", "wb") as fp:
    pickle.dump(predictions, fp)
trainer.save_metrics("eval",predictions.metrics)
trainer.save_model(output_dir)

Thank you for your question! This code uses the Hugging Face trainer to fine-tune the model. You are already setting the seed and data_seed, so the results should be more or less comparable between runs. If the slight difference in F1 is minimal and nothing at all is changing in the code between runs (especially that the model being initiated each time is the pretrained one that is not already fine-tuned), then probably the results are going to be sufficiently comparable. If the difference is large, then you could consider opening an issue with the Hugging Face transformers repository if you think there is an issue with the trainer. In our experience though, results are sufficiently consistent between runs.

ctheodoris changed discussion status to closed

Sign up or log in to comment