tomaarsen HF staff commited on
Commit
416e115
·
1 Parent(s): 4b8af05

Upload train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train.py +85 -0
train.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import shutil
3
+ from typing import Any, Dict
4
+ from datasets import load_dataset
5
+ from transformers import TrainingArguments
6
+ from span_marker import SpanMarkerModel, Trainer
7
+ from span_marker.model_card import SpanMarkerModelCardData
8
+
9
+ import os
10
+
11
+ os.environ["CODECARBON_LOG_LEVEL"] = "error"
12
+
13
+
14
+ def main() -> None:
15
+ # Load the dataset, ensure "tokens" and "ner_tags" columns, and get a list of labels
16
+ dataset_id = "EMBO/SourceData"
17
+ dataset_name = "SourceData"
18
+ dataset = load_dataset(dataset_id, version="1.0.1").rename_columns({"labels": "ner_tags", "words": "tokens"})
19
+ labels = dataset["train"].features["ner_tags"].feature.names
20
+
21
+ # Initialize a SpanMarker model using a pretrained BERT-style encoder
22
+ encoder_id = "bert-base-uncased"
23
+ model_id = f"tomaarsen/span-marker-{encoder_id}-sourcedata"
24
+ model = SpanMarkerModel.from_pretrained(
25
+ encoder_id,
26
+ labels=labels,
27
+ # SpanMarker hyperparameters:
28
+ model_max_length=256,
29
+ marker_max_length=128,
30
+ entity_max_length=8,
31
+ # Model card variables
32
+ model_card_data=SpanMarkerModelCardData(
33
+ model_id=model_id,
34
+ encoder_id=encoder_id,
35
+ dataset_name=dataset_name,
36
+ dataset_id=dataset_id,
37
+ license="cc-by-4.0",
38
+ language="en",
39
+ ),
40
+ )
41
+
42
+ # Prepare the 🤗 transformers training arguments
43
+ output_dir = Path("models") / model_id
44
+ args = TrainingArguments(
45
+ output_dir=output_dir,
46
+ run_name=model_id,
47
+ # Training Hyperparameters:
48
+ learning_rate=5e-5,
49
+ per_device_train_batch_size=32,
50
+ per_device_eval_batch_size=32,
51
+ num_train_epochs=3,
52
+ weight_decay=0.01,
53
+ warmup_ratio=0.1,
54
+ bf16=True, # Replace `bf16` with `fp16` if your hardware can't use bf16.
55
+ # Other Training parameters
56
+ logging_first_step=True,
57
+ logging_steps=50,
58
+ evaluation_strategy="steps",
59
+ save_strategy="steps",
60
+ eval_steps=3000,
61
+ save_total_limit=2,
62
+ dataloader_num_workers=2,
63
+ )
64
+
65
+ # Initialize the trainer using our model, training args & dataset, and train
66
+ trainer = Trainer(
67
+ model=model,
68
+ args=args,
69
+ train_dataset=dataset["train"],
70
+ eval_dataset=dataset["validation"],
71
+ )
72
+ trainer.train()
73
+
74
+ # Compute & save the metrics on the test set
75
+ metrics = trainer.evaluate(dataset["test"], metric_key_prefix="test")
76
+ trainer.save_metrics("test", metrics)
77
+
78
+ trainer.save_model(output_dir / "checkpoint-final")
79
+ shutil.copy2(__file__, output_dir / "checkpoint-final" / "train.py")
80
+
81
+
82
+ if __name__ == "__main__":
83
+ main()
84
+
85
+ # TODO: Embedding resizing & SpanMarkerModelCardData docs