tomaarsen HF staff commited on
Commit
0d63046
·
1 Parent(s): 07f6697

Upload train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train.py +29 -9
train.py CHANGED
@@ -1,28 +1,47 @@
 
 
1
  from datasets import load_dataset
2
  from transformers import TrainingArguments
3
- from span_marker import SpanMarkerModel, Trainer
 
 
 
 
4
 
5
 
6
  def main() -> None:
7
  # Load the dataset, ensure "tokens" and "ner_tags" columns, and get a list of labels
8
- dataset = load_dataset("acronym_identification").rename_column("labels", "ner_tags")
 
 
9
  labels = dataset["train"].features["ner_tags"].feature.names
10
 
11
  # Initialize a SpanMarker model using a pretrained BERT-style encoder
12
- model_name = "bert-base-cased"
 
13
  model = SpanMarkerModel.from_pretrained(
14
- model_name,
15
  labels=labels,
16
  # SpanMarker hyperparameters:
17
  model_max_length=256,
18
  marker_max_length=128,
19
  entity_max_length=8,
 
 
 
 
 
 
 
 
 
20
  )
21
 
22
  # Prepare the 🤗 transformers training arguments
 
23
  args = TrainingArguments(
24
- output_dir=f"models/span_marker_bert_base_acronyms",
25
- run_name=f"bb_acronyms",
26
  # Training Hyperparameters:
27
  learning_rate=5e-5,
28
  per_device_train_batch_size=32,
@@ -49,12 +68,13 @@ def main() -> None:
49
  eval_dataset=dataset["validation"],
50
  )
51
  trainer.train()
52
- trainer.save_model(f"models/span_marker_bert_base_acronyms/checkpoint-final")
53
 
54
  # Compute & save the metrics on the test set
55
- metrics = trainer.evaluate()
56
  trainer.save_metrics("validation", metrics)
57
- trainer.create_model_card()
 
 
58
 
59
 
60
  if __name__ == "__main__":
 
1
+ from pathlib import Path
2
+ import shutil
3
  from datasets import load_dataset
4
  from transformers import TrainingArguments
5
+ from span_marker import SpanMarkerModel, Trainer, SpanMarkerModelCardData
6
+
7
+ import os
8
+
9
+ os.environ["CODECARBON_LOG_LEVEL"] = "error"
10
 
11
 
12
  def main() -> None:
13
  # Load the dataset, ensure "tokens" and "ner_tags" columns, and get a list of labels
14
+ dataset_name = "Acronym Identification"
15
+ dataset_id = "acronym_identification"
16
+ dataset = load_dataset(dataset_id).rename_column("labels", "ner_tags")
17
  labels = dataset["train"].features["ner_tags"].feature.names
18
 
19
  # Initialize a SpanMarker model using a pretrained BERT-style encoder
20
+ encoder_id = "bert-base-cased"
21
+ model_id = "tomaarsen/span-marker-bert-base-acronyms"
22
  model = SpanMarkerModel.from_pretrained(
23
+ encoder_id,
24
  labels=labels,
25
  # SpanMarker hyperparameters:
26
  model_max_length=256,
27
  marker_max_length=128,
28
  entity_max_length=8,
29
+ # Model card variables
30
+ model_card_data=SpanMarkerModelCardData(
31
+ model_id=model_id,
32
+ encoder_id=encoder_id,
33
+ dataset_name=dataset_name,
34
+ dataset_id=dataset_id,
35
+ license="apache-2.0",
36
+ language="en",
37
+ ),
38
  )
39
 
40
  # Prepare the 🤗 transformers training arguments
41
+ output_dir = Path("models") / model_id
42
  args = TrainingArguments(
43
+ output_dir=output_dir,
44
+ run_name=model_id,
45
  # Training Hyperparameters:
46
  learning_rate=5e-5,
47
  per_device_train_batch_size=32,
 
68
  eval_dataset=dataset["validation"],
69
  )
70
  trainer.train()
 
71
 
72
  # Compute & save the metrics on the test set
73
+ metrics = trainer.evaluate(metric_key_prefix="validation")
74
  trainer.save_metrics("validation", metrics)
75
+
76
+ trainer.save_model(output_dir / "checkpoint-final")
77
+ shutil.copy2(__file__, output_dir / "checkpoint-final" / "train.py")
78
 
79
 
80
  if __name__ == "__main__":