good-vibes / train.py
dejanseo's picture
Upload 6 files
8edab9a verified
raw
history blame
No virus
3.02 kB
import torch
from transformers import AlbertTokenizer, AlbertForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset
import evaluate
import wandb
import numpy as np
# Initialize WandB
wandb.init(entity="dejan", project="good-vibes")
# Adjustable parameters
model_name = "albert-base-v2"
batch_size = 32
epochs = 10
learning_rate = 2e-5
gradient_clip_value = 1.0
warmup_steps = 500
# Load tokenizer and model
tokenizer = AlbertTokenizer.from_pretrained(model_name)
model = AlbertForSequenceClassification.from_pretrained(model_name, num_labels=3)
# Load dataset
dataset = load_dataset('csv', data_files={'train': 'sentences.csv'})
dataset = dataset['train'].train_test_split(test_size=0.1)
# Preprocess the data
def preprocess_function(examples):
return tokenizer(examples['text'], padding='max_length', truncation=True)
encoded_dataset = dataset.map(preprocess_function, batched=True)
encoded_dataset = encoded_dataset.rename_column("label", "labels")
# Define metrics
accuracy_metric = evaluate.load("accuracy")
f1_metric = evaluate.load("f1")
precision_metric = evaluate.load("precision")
recall_metric = evaluate.load("recall")
def compute_metrics(eval_pred):
logits, labels = eval_pred
predictions = np.argmax(logits, axis=-1)
accuracy = accuracy_metric.compute(predictions=predictions, references=labels)
f1 = f1_metric.compute(predictions=predictions, references=labels, average='weighted')
precision = precision_metric.compute(predictions=predictions, references=labels, average='weighted')
recall = recall_metric.compute(predictions=predictions, references=labels, average='weighted')
return {
"accuracy": accuracy["accuracy"],
"f1": f1["f1"],
"precision": precision["precision"],
"recall": recall["recall"]
}
# Training arguments
training_args = TrainingArguments(
output_dir="./results",
evaluation_strategy="epoch",
save_strategy="epoch",
learning_rate=learning_rate,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
num_train_epochs=epochs,
weight_decay=0.01,
logging_dir="./logs",
logging_steps=10,
load_best_model_at_end=True,
metric_for_best_model="accuracy", # Use accuracy to define the best model
greater_is_better=True, # Set to True if higher metric value is better
gradient_accumulation_steps=2,
fp16=True,
report_to="wandb",
run_name="albert-finetuning",
warmup_steps=warmup_steps,
max_grad_norm=gradient_clip_value # Correct parameter for gradient clipping
)
# Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=encoded_dataset['train'],
eval_dataset=encoded_dataset['test'],
compute_metrics=compute_metrics
)
# Train the model
trainer.train()
# Save the model
trainer.save_model("fine-tuned-albert-base-v2")
# Finish WandB run
wandb.finish()