MjolnirThor commited on
Commit
ff2cfed
·
verified ·
1 Parent(s): 76939ee

Create start_training.py

Browse files
Files changed (1) hide show
  1. start_training.py +56 -0
start_training.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ print("Starting training process...")
2
+ from datasets import load_dataset
3
+ from transformers import (
4
+ AutoModelForSeq2SeqLM,
5
+ AutoTokenizer,
6
+ Trainer,
7
+ DataCollatorForSeq2Seq
8
+ )
9
+ from training_config import training_args
10
+
11
+ # Load dataset
12
+ dataset = load_dataset("health360/Healix-Shot", split=f"train[:100000]")
13
+
14
+ # Initialize model and tokenizer
15
+ model_name = "google/flan-t5-large"
16
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
17
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
18
+
19
+ def tokenize_function(examples):
20
+ return tokenizer(
21
+ examples['text'],
22
+ padding="max_length",
23
+ truncation=True,
24
+ max_length=512,
25
+ return_attention_mask=True
26
+ )
27
+
28
+ # Process dataset
29
+ train_test_split = dataset.train_test_split(test_size=0.1)
30
+ tokenized_train = train_test_split['train'].map(
31
+ tokenize_function,
32
+ batched=True,
33
+ remove_columns=dataset.column_names
34
+ )
35
+ tokenized_eval = train_test_split['test'].map(
36
+ tokenize_function,
37
+ batched=True,
38
+ remove_columns=dataset.column_names
39
+ )
40
+
41
+ # Initialize trainer
42
+ trainer = Trainer(
43
+ model=model,
44
+ args=training_args,
45
+ train_dataset=tokenized_train,
46
+ eval_dataset=tokenized_eval,
47
+ data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)
48
+ )
49
+
50
+ # Train and save
51
+ print("Starting the training...")
52
+ trainer.train()
53
+ print("Training complete, saving model...")
54
+ model.push_to_hub("MjolnirThor/flan-t5-custom-handler")
55
+ tokenizer.push_to_hub("MjolnirThor/flan-t5-custom-handler")
56
+ print("Model saved successfully!")