gbert-base-finetuned-twitter / fine_tune_mlm_twitter.py
JanSt's picture
modularized
0999ba2
import math
from datasets import Dataset
from huggingface_hub import login
from transformers import (
TrainingArguments,
DataCollatorForLanguageModeling,
AutoTokenizer,
AutoModelForMaskedLM,
Trainer,
default_data_collator
)
import torch
import collections
import numpy as np
def tokenize_function(examples):
result = tokenizer(examples["text"], padding=True, truncation=True)
if tokenizer.is_fast:
result["word_ids"] = [result.word_ids(i) for i in range(len(result["input_ids"]))]
print(f"tokenize function result: {result}")
return result
def group_texts(examples):
# Concatenate all texts
concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
# Compute length of concatenated texts
total_length = len(concatenated_examples[list(examples.keys())[0]])
# We drop the last chunk if it's smaller than chunk_size
total_length = (total_length // chunk_size) * chunk_size
# Split by chunks of max_len
result = {
k: [t[i : i + chunk_size] for i in range(0, total_length, chunk_size)]
for k, t in concatenated_examples.items()
}
# Create a new labels column
result["topic"] = result["input_ids"].copy()
print(f"group texts result: {result}")
return result
def whole_word_masking_data_collator(features):
# This means that if you’re using the whole word masking collator,
# you’ll also need to set remove_unused_columns=False to ensure
# we don’t lose the word_ids column during training.
for feature in features:
word_ids = feature.pop("word_ids")
# Create a map between words and corresponding token indices
mapping = collections.defaultdict(list)
current_word_index = -1
current_word = None
for idx, word_id in enumerate(word_ids):
if word_id is not None:
if word_id != current_word:
current_word = word_id
current_word_index += 1
mapping[current_word_index].append(idx)
# Randomly mask words
mask = np.random.binomial(1, wwm_probability, (len(mapping),))
input_ids = feature["input_ids"]
labels = feature["labels"]
new_labels = [-100] * len(labels)
for word_id in np.where(mask)[0]:
word_id = word_id.item()
for idx in mapping[word_id]:
new_labels[idx] = labels[idx]
input_ids[idx] = tokenizer.mask_token_id
feature["labels"] = new_labels
return default_data_collator(features)
def train_model():
batch_size = 64
# Show the training loss with every epoch
logging_steps = len(tw_dataset["train"]) // batch_size
model_name = model_checkpoint.split("/")[-1]
training_args = TrainingArguments(
output_dir=f"{model_name}-finetuned-twitter",
save_total_limit=3,
overwrite_output_dir=True,
evaluation_strategy="epoch",
learning_rate=2e-5,
weight_decay=0.01,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
push_to_hub=True,
fp16=False, #True if gpu and 16bit
logging_steps=logging_steps,
#remove_unused_columns=False,
)
tw_dataset["train"].set_format("torch", device="cuda")
tw_dataset["test"].set_format("torch", device="cuda")
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tw_dataset["train"],
eval_dataset=tw_dataset["test"],
data_collator=data_collator,
tokenizer=tokenizer,
)
eval_results = trainer.evaluate()
print(f">>> Perplexity: {math.exp(eval_results['eval_loss']):.2f}")
trainer.train()
eval_results = trainer.evaluate()
print(f">>> Perplexity: {math.exp(eval_results['eval_loss']):.2f}")
trainer.push_to_hub()
if __name__ == "__main__":
token = "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"
login(token=token)
model_checkpoint = "deepset/gbert-base"
model = AutoModelForMaskedLM.from_pretrained(model_checkpoint)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
tw_dataset = Dataset.from_file('../data/complete_sosec_dataset/data.arrow')
tw_dataset = tw_dataset.rename_column('topic', 'labels')
#sample dataset
#tw_dataset = tw_dataset.train_test_split(
# train_size=1000, test_size=10, seed=42
#)
print(f"tw_dataset sample: {tw_dataset}")
tokenized_datasets = tw_dataset.map(
tokenize_function, batched=True,
remove_columns=["text", "labels", 'id', 'sentiment', 'annotator', 'comment', 'topic_alt', 'lang',
'conversation_id', 'created_at', 'author_id', 'query', 'public_metrics.like_count',
'public_metrics.quote_count', 'public_metrics.reply_count', 'public_metrics.retweet_count',
'public_metrics.impression_count', '__index_level_0__']
)
print(f"tokenized_datsets: {tokenized_datasets}")
chunk_size = 128
lm_datasets = tokenized_datasets.map(group_texts, batched=True)
print(f"lm_datasets: {lm_datasets}")
tw_dataset = lm_datasets.train_test_split(
train_size=0.9, test_size=0.1, seed=42
)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)
print(f"data collator: {data_collator}")
wwm_probability = 0.2
train_model()