flan-t5-custom-handler / start_training.py
MjolnirThor's picture
Create start_training.py
ff2cfed verified
raw
history blame
1.54 kB
print("Starting training process...")
from datasets import load_dataset
from transformers import (
AutoModelForSeq2SeqLM,
AutoTokenizer,
Trainer,
DataCollatorForSeq2Seq
)
from training_config import training_args
# Load dataset
dataset = load_dataset("health360/Healix-Shot", split=f"train[:100000]")
# Initialize model and tokenizer
model_name = "google/flan-t5-large"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
def tokenize_function(examples):
return tokenizer(
examples['text'],
padding="max_length",
truncation=True,
max_length=512,
return_attention_mask=True
)
# Process dataset
train_test_split = dataset.train_test_split(test_size=0.1)
tokenized_train = train_test_split['train'].map(
tokenize_function,
batched=True,
remove_columns=dataset.column_names
)
tokenized_eval = train_test_split['test'].map(
tokenize_function,
batched=True,
remove_columns=dataset.column_names
)
# Initialize trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_train,
eval_dataset=tokenized_eval,
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)
)
# Train and save
print("Starting the training...")
trainer.train()
print("Training complete, saving model...")
model.push_to_hub("MjolnirThor/flan-t5-custom-handler")
tokenizer.push_to_hub("MjolnirThor/flan-t5-custom-handler")
print("Model saved successfully!")