Canstralian commited on
Commit
e2e74c5
·
verified ·
1 Parent(s): 6a1b091

Update fine_tuner.py

Browse files
Files changed (1) hide show
  1. fine_tuner.py +28 -2
fine_tuner.py CHANGED
@@ -1,9 +1,30 @@
1
  import torch
2
  from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments
3
  from datasets import load_dataset
 
4
 
5
- def fine_tune_model(dataset, model_name, epochs, batch_size, learning_rate):
6
- # Load the pre-trained model for sequence classification
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
8
 
9
  # Define the training arguments
@@ -14,6 +35,11 @@ def fine_tune_model(dataset, model_name, epochs, batch_size, learning_rate):
14
  learning_rate=learning_rate, # Learning rate for the optimizer
15
  logging_dir='./logs', # Directory for storing logs
16
  logging_steps=10, # Log every 10 steps
 
 
 
 
 
17
  )
18
 
19
  # Initialize the Trainer with the model, arguments, and dataset
 
1
  import torch
2
  from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments
3
  from datasets import load_dataset
4
+ from transformers import set_seed
5
 
6
+ # Set seed for reproducibility
7
+ set_seed(42)
8
+
9
+ def fine_tune_model(dataset_url, model_name, epochs, batch_size, learning_rate):
10
+ """
11
+ Fine-tunes a pre-trained transformer model on a custom dataset.
12
+
13
+ Parameters:
14
+ - dataset_url (str): URL or path to the dataset.
15
+ - model_name (str): Name of the pre-trained model.
16
+ - epochs (int): Number of training epochs.
17
+ - batch_size (int): Batch size for training.
18
+ - learning_rate (float): Learning rate for the optimizer.
19
+
20
+ Returns:
21
+ - dict: Status message containing training completion status.
22
+ """
23
+
24
+ # Load the dataset
25
+ dataset = load_dataset(dataset_url)
26
+
27
+ # Load the pre-trained model for sequence classification (2 labels for binary classification)
28
  model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
29
 
30
  # Define the training arguments
 
35
  learning_rate=learning_rate, # Learning rate for the optimizer
36
  logging_dir='./logs', # Directory for storing logs
37
  logging_steps=10, # Log every 10 steps
38
+ evaluation_strategy="epoch", # Evaluate every epoch
39
+ save_strategy="epoch", # Save checkpoint every epoch
40
+ load_best_model_at_end=True, # Load the best model at the end of training
41
+ metric_for_best_model="accuracy", # Metric to monitor for selecting the best model
42
+ greater_is_better=True, # Set to True if higher metric values are better
43
  )
44
 
45
  # Initialize the Trainer with the model, arguments, and dataset