modularized
Browse files- fine_tune_mlm_twitter.py +152 -0
fine_tune_mlm_twitter.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from datasets import Dataset
|
3 |
+
from huggingface_hub import login
|
4 |
+
from transformers import (
|
5 |
+
TrainingArguments,
|
6 |
+
DataCollatorForLanguageModeling,
|
7 |
+
AutoTokenizer,
|
8 |
+
AutoModelForMaskedLM,
|
9 |
+
Trainer,
|
10 |
+
default_data_collator
|
11 |
+
)
|
12 |
+
import torch
|
13 |
+
import collections
|
14 |
+
import numpy as np
|
15 |
+
|
16 |
+
def tokenize_function(examples):
|
17 |
+
result = tokenizer(examples["text"], padding=True, truncation=True)
|
18 |
+
if tokenizer.is_fast:
|
19 |
+
result["word_ids"] = [result.word_ids(i) for i in range(len(result["input_ids"]))]
|
20 |
+
print(f"tokenize function result: {result}")
|
21 |
+
return result
|
22 |
+
|
23 |
+
|
24 |
+
def group_texts(examples):
|
25 |
+
# Concatenate all texts
|
26 |
+
concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
|
27 |
+
# Compute length of concatenated texts
|
28 |
+
total_length = len(concatenated_examples[list(examples.keys())[0]])
|
29 |
+
# We drop the last chunk if it's smaller than chunk_size
|
30 |
+
total_length = (total_length // chunk_size) * chunk_size
|
31 |
+
# Split by chunks of max_len
|
32 |
+
result = {
|
33 |
+
k: [t[i : i + chunk_size] for i in range(0, total_length, chunk_size)]
|
34 |
+
for k, t in concatenated_examples.items()
|
35 |
+
}
|
36 |
+
# Create a new labels column
|
37 |
+
result["topic"] = result["input_ids"].copy()
|
38 |
+
print(f"group texts result: {result}")
|
39 |
+
return result
|
40 |
+
|
41 |
+
|
42 |
+
def whole_word_masking_data_collator(features):
|
43 |
+
# This means that if you’re using the whole word masking collator,
|
44 |
+
# you’ll also need to set remove_unused_columns=False to ensure
|
45 |
+
# we don’t lose the word_ids column during training.
|
46 |
+
for feature in features:
|
47 |
+
word_ids = feature.pop("word_ids")
|
48 |
+
|
49 |
+
# Create a map between words and corresponding token indices
|
50 |
+
mapping = collections.defaultdict(list)
|
51 |
+
current_word_index = -1
|
52 |
+
current_word = None
|
53 |
+
for idx, word_id in enumerate(word_ids):
|
54 |
+
if word_id is not None:
|
55 |
+
if word_id != current_word:
|
56 |
+
current_word = word_id
|
57 |
+
current_word_index += 1
|
58 |
+
mapping[current_word_index].append(idx)
|
59 |
+
|
60 |
+
# Randomly mask words
|
61 |
+
mask = np.random.binomial(1, wwm_probability, (len(mapping),))
|
62 |
+
input_ids = feature["input_ids"]
|
63 |
+
labels = feature["labels"]
|
64 |
+
new_labels = [-100] * len(labels)
|
65 |
+
for word_id in np.where(mask)[0]:
|
66 |
+
word_id = word_id.item()
|
67 |
+
for idx in mapping[word_id]:
|
68 |
+
new_labels[idx] = labels[idx]
|
69 |
+
input_ids[idx] = tokenizer.mask_token_id
|
70 |
+
feature["labels"] = new_labels
|
71 |
+
|
72 |
+
return default_data_collator(features)
|
73 |
+
|
74 |
+
def train_model():
|
75 |
+
|
76 |
+
batch_size = 64
|
77 |
+
# Show the training loss with every epoch
|
78 |
+
logging_steps = len(tw_dataset["train"]) // batch_size
|
79 |
+
model_name = model_checkpoint.split("/")[-1]
|
80 |
+
|
81 |
+
training_args = TrainingArguments(
|
82 |
+
output_dir=f"{model_name}-finetuned-twitter",
|
83 |
+
save_total_limit=3,
|
84 |
+
overwrite_output_dir=True,
|
85 |
+
evaluation_strategy="epoch",
|
86 |
+
learning_rate=2e-5,
|
87 |
+
weight_decay=0.01,
|
88 |
+
per_device_train_batch_size=batch_size,
|
89 |
+
per_device_eval_batch_size=batch_size,
|
90 |
+
push_to_hub=True,
|
91 |
+
fp16=False, #True if gpu and 16bit
|
92 |
+
logging_steps=logging_steps,
|
93 |
+
#remove_unused_columns=False,
|
94 |
+
)
|
95 |
+
tw_dataset["train"].set_format("torch", device="cuda")
|
96 |
+
tw_dataset["test"].set_format("torch", device="cuda")
|
97 |
+
trainer = Trainer(
|
98 |
+
model=model,
|
99 |
+
args=training_args,
|
100 |
+
train_dataset=tw_dataset["train"],
|
101 |
+
eval_dataset=tw_dataset["test"],
|
102 |
+
data_collator=data_collator,
|
103 |
+
tokenizer=tokenizer,
|
104 |
+
)
|
105 |
+
|
106 |
+
|
107 |
+
eval_results = trainer.evaluate()
|
108 |
+
print(f">>> Perplexity: {math.exp(eval_results['eval_loss']):.2f}")
|
109 |
+
trainer.train()
|
110 |
+
eval_results = trainer.evaluate()
|
111 |
+
print(f">>> Perplexity: {math.exp(eval_results['eval_loss']):.2f}")
|
112 |
+
|
113 |
+
trainer.push_to_hub()
|
114 |
+
|
115 |
+
|
116 |
+
if __name__ == "__main__":
|
117 |
+
token = "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"
|
118 |
+
login(token=token)
|
119 |
+
model_checkpoint = "deepset/gbert-base"
|
120 |
+
model = AutoModelForMaskedLM.from_pretrained(model_checkpoint)
|
121 |
+
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
|
122 |
+
tw_dataset = Dataset.from_file('../data/complete_sosec_dataset/data.arrow')
|
123 |
+
tw_dataset = tw_dataset.rename_column('topic', 'labels')
|
124 |
+
#sample dataset
|
125 |
+
#tw_dataset = tw_dataset.train_test_split(
|
126 |
+
# train_size=1000, test_size=10, seed=42
|
127 |
+
#)
|
128 |
+
print(f"tw_dataset sample: {tw_dataset}")
|
129 |
+
tokenized_datasets = tw_dataset.map(
|
130 |
+
tokenize_function, batched=True,
|
131 |
+
remove_columns=["text", "labels", 'id', 'sentiment', 'annotator', 'comment', 'topic_alt', 'lang',
|
132 |
+
'conversation_id', 'created_at', 'author_id', 'query', 'public_metrics.like_count',
|
133 |
+
'public_metrics.quote_count', 'public_metrics.reply_count', 'public_metrics.retweet_count',
|
134 |
+
'public_metrics.impression_count', '__index_level_0__']
|
135 |
+
)
|
136 |
+
print(f"tokenized_datsets: {tokenized_datasets}")
|
137 |
+
chunk_size = 128
|
138 |
+
|
139 |
+
lm_datasets = tokenized_datasets.map(group_texts, batched=True)
|
140 |
+
print(f"lm_datasets: {lm_datasets}")
|
141 |
+
|
142 |
+
tw_dataset = lm_datasets.train_test_split(
|
143 |
+
train_size=0.9, test_size=0.1, seed=42
|
144 |
+
)
|
145 |
+
|
146 |
+
|
147 |
+
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)
|
148 |
+
print(f"data collator: {data_collator}")
|
149 |
+
wwm_probability = 0.2
|
150 |
+
|
151 |
+
train_model()
|
152 |
+
|