JanSt commited on
Commit
0999ba2
·
1 Parent(s): 0b8c9c9

modularized

Browse files
Files changed (1) hide show
  1. fine_tune_mlm_twitter.py +152 -0
fine_tune_mlm_twitter.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from datasets import Dataset
3
+ from huggingface_hub import login
4
+ from transformers import (
5
+ TrainingArguments,
6
+ DataCollatorForLanguageModeling,
7
+ AutoTokenizer,
8
+ AutoModelForMaskedLM,
9
+ Trainer,
10
+ default_data_collator
11
+ )
12
+ import torch
13
+ import collections
14
+ import numpy as np
15
+
16
+ def tokenize_function(examples):
17
+ result = tokenizer(examples["text"], padding=True, truncation=True)
18
+ if tokenizer.is_fast:
19
+ result["word_ids"] = [result.word_ids(i) for i in range(len(result["input_ids"]))]
20
+ print(f"tokenize function result: {result}")
21
+ return result
22
+
23
+
24
+ def group_texts(examples):
25
+ # Concatenate all texts
26
+ concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
27
+ # Compute length of concatenated texts
28
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
29
+ # We drop the last chunk if it's smaller than chunk_size
30
+ total_length = (total_length // chunk_size) * chunk_size
31
+ # Split by chunks of max_len
32
+ result = {
33
+ k: [t[i : i + chunk_size] for i in range(0, total_length, chunk_size)]
34
+ for k, t in concatenated_examples.items()
35
+ }
36
+ # Create a new labels column
37
+ result["topic"] = result["input_ids"].copy()
38
+ print(f"group texts result: {result}")
39
+ return result
40
+
41
+
42
+ def whole_word_masking_data_collator(features):
43
+ # This means that if you’re using the whole word masking collator,
44
+ # you’ll also need to set remove_unused_columns=False to ensure
45
+ # we don’t lose the word_ids column during training.
46
+ for feature in features:
47
+ word_ids = feature.pop("word_ids")
48
+
49
+ # Create a map between words and corresponding token indices
50
+ mapping = collections.defaultdict(list)
51
+ current_word_index = -1
52
+ current_word = None
53
+ for idx, word_id in enumerate(word_ids):
54
+ if word_id is not None:
55
+ if word_id != current_word:
56
+ current_word = word_id
57
+ current_word_index += 1
58
+ mapping[current_word_index].append(idx)
59
+
60
+ # Randomly mask words
61
+ mask = np.random.binomial(1, wwm_probability, (len(mapping),))
62
+ input_ids = feature["input_ids"]
63
+ labels = feature["labels"]
64
+ new_labels = [-100] * len(labels)
65
+ for word_id in np.where(mask)[0]:
66
+ word_id = word_id.item()
67
+ for idx in mapping[word_id]:
68
+ new_labels[idx] = labels[idx]
69
+ input_ids[idx] = tokenizer.mask_token_id
70
+ feature["labels"] = new_labels
71
+
72
+ return default_data_collator(features)
73
+
74
+ def train_model():
75
+
76
+ batch_size = 64
77
+ # Show the training loss with every epoch
78
+ logging_steps = len(tw_dataset["train"]) // batch_size
79
+ model_name = model_checkpoint.split("/")[-1]
80
+
81
+ training_args = TrainingArguments(
82
+ output_dir=f"{model_name}-finetuned-twitter",
83
+ save_total_limit=3,
84
+ overwrite_output_dir=True,
85
+ evaluation_strategy="epoch",
86
+ learning_rate=2e-5,
87
+ weight_decay=0.01,
88
+ per_device_train_batch_size=batch_size,
89
+ per_device_eval_batch_size=batch_size,
90
+ push_to_hub=True,
91
+ fp16=False, #True if gpu and 16bit
92
+ logging_steps=logging_steps,
93
+ #remove_unused_columns=False,
94
+ )
95
+ tw_dataset["train"].set_format("torch", device="cuda")
96
+ tw_dataset["test"].set_format("torch", device="cuda")
97
+ trainer = Trainer(
98
+ model=model,
99
+ args=training_args,
100
+ train_dataset=tw_dataset["train"],
101
+ eval_dataset=tw_dataset["test"],
102
+ data_collator=data_collator,
103
+ tokenizer=tokenizer,
104
+ )
105
+
106
+
107
+ eval_results = trainer.evaluate()
108
+ print(f">>> Perplexity: {math.exp(eval_results['eval_loss']):.2f}")
109
+ trainer.train()
110
+ eval_results = trainer.evaluate()
111
+ print(f">>> Perplexity: {math.exp(eval_results['eval_loss']):.2f}")
112
+
113
+ trainer.push_to_hub()
114
+
115
+
116
+ if __name__ == "__main__":
117
+ token = "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"
118
+ login(token=token)
119
+ model_checkpoint = "deepset/gbert-base"
120
+ model = AutoModelForMaskedLM.from_pretrained(model_checkpoint)
121
+ tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
122
+ tw_dataset = Dataset.from_file('../data/complete_sosec_dataset/data.arrow')
123
+ tw_dataset = tw_dataset.rename_column('topic', 'labels')
124
+ #sample dataset
125
+ #tw_dataset = tw_dataset.train_test_split(
126
+ # train_size=1000, test_size=10, seed=42
127
+ #)
128
+ print(f"tw_dataset sample: {tw_dataset}")
129
+ tokenized_datasets = tw_dataset.map(
130
+ tokenize_function, batched=True,
131
+ remove_columns=["text", "labels", 'id', 'sentiment', 'annotator', 'comment', 'topic_alt', 'lang',
132
+ 'conversation_id', 'created_at', 'author_id', 'query', 'public_metrics.like_count',
133
+ 'public_metrics.quote_count', 'public_metrics.reply_count', 'public_metrics.retweet_count',
134
+ 'public_metrics.impression_count', '__index_level_0__']
135
+ )
136
+ print(f"tokenized_datsets: {tokenized_datasets}")
137
+ chunk_size = 128
138
+
139
+ lm_datasets = tokenized_datasets.map(group_texts, batched=True)
140
+ print(f"lm_datasets: {lm_datasets}")
141
+
142
+ tw_dataset = lm_datasets.train_test_split(
143
+ train_size=0.9, test_size=0.1, seed=42
144
+ )
145
+
146
+
147
+ data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)
148
+ print(f"data collator: {data_collator}")
149
+ wwm_probability = 0.2
150
+
151
+ train_model()
152
+