File size: 6,921 Bytes
0d17543
 
 
 
 
 
 
 
9c3f080
 
 
 
0d17543
 
 
 
 
 
 
 
 
 
9c3f080
0d17543
 
9c3f080
0d17543
 
 
 
9c3f080
0d17543
 
 
 
 
 
 
9c3f080
0d17543
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c3f080
0d17543
 
 
 
 
 
 
9c3f080
0d17543
 
 
9c3f080
0d17543
 
 
9c3f080
0d17543
 
 
 
 
9c3f080
 
 
 
 
 
 
 
0d17543
 
9c3f080
0d17543
9c3f080
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d17543
 
 
 
 
 
 
 
 
 
 
 
 
9c3f080
 
 
 
 
 
 
 
 
 
 
 
0d17543
 
9c3f080
 
0d17543
9c3f080
 
 
0d17543
 
 
9c3f080
0d17543
9c3f080
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import os
import torch
from datasets import load_dataset
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM, 
    TrainingArguments, 
    Trainer, 
    DataCollatorForLanguageModeling, 
    TrainerCallback, 
    TrainerState, 
    TrainerControl
)
from peft import LoraConfig, get_peft_model
import spaces
import time

# Đường dẫn lưu checkpoint
CHECKPOINT_DIR = "./checkpoints"
if not os.path.exists(CHECKPOINT_DIR):
    os.makedirs(CHECKPOINT_DIR)

# Tải Dataset (CPU)
dataset = load_dataset('vntc/wiki-mini-corpus')

# Chia Dataset thành train và validation (CPU)
split_dataset = dataset['train'].train_test_split(test_size=0.1, seed=42)
train_dataset = split_dataset['train']
validation_dataset = split_dataset['test']

# Tiền Xử Lý Văn Bản (CPU)
def preprocess_function(examples):
    passages = [passage.lower().strip() for passage in examples['passage']]
    return {'passage': passages}

processed_train = train_dataset.map(preprocess_function, batched=True, remove_columns=['id', 'metadata'])
processed_validation = validation_dataset.map(preprocess_function, batched=True, remove_columns=['id', 'metadata'])

# Tokenization (CPU)
model_name = "meta-llama/Llama-3.2-3B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Đảm bảo tokenizer có pad_token
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

def tokenize_function(examples):
    return tokenizer(
        examples['passage'],
        padding='max_length',
        truncation=True,
        max_length=512,
    )

tokenized_train = processed_train.map(tokenize_function, batched=True)
tokenized_validation = processed_validation.map(tokenize_function, batched=True)

# Thêm trường 'labels' (CPU)
def add_labels(examples):
    examples['labels'] = examples['input_ids'].copy()
    return examples

tokenized_train = tokenized_train.map(add_labels, batched=True)
tokenized_validation = tokenized_validation.map(add_labels, batched=True)

# Loại bỏ các cột không cần thiết (CPU)
tokenized_train = tokenized_train.remove_columns(['passage'])
tokenized_validation = tokenized_validation.remove_columns(['passage'])

# Định dạng dữ liệu cho PyTorch (CPU)
tokenized_train.set_format('torch')
tokenized_validation.set_format('torch')

# Tạo DatasetDict (CPU)
final_dataset = {
    'train': tokenized_train,
    'validation': tokenized_validation
}

# Định Nghĩa TrainerCallback để Lưu Checkpoint Nhanh Hơn
class SaveCheckpointCallback(TrainerCallback):
    def on_step_end(self, args, state, control, **kwargs):
        if state.global_step % args.save_steps == 0 and state.global_step != 0:
            checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
            print(f"Lưu checkpoint tại: {checkpoint_path}")
            trainer.save_model(checkpoint_path)
        return TrainerControl.CONTINUE

# Định Nghĩa Hàm Huấn Luyện với Decorator @spaces.GPU
@spaces.GPU(duration=15, queue=False)
def run_training():
    """
    Hàm huấn luyện mô hình sử dụng GPU với thời gian hạn chế.
    """
    # Tải và Cấu Hình Mô Hình với LoRA (GPU)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map="auto",
        torch_dtype=torch.float16,
        load_in_8bit=False
    )
    
    lora_config = LoraConfig(
        r=8,
        lora_alpha=32,
        target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
        lora_dropout=0.1,
        bias="none",
    )
    
    model = get_peft_model(model, lora_config)
    print(model)
    
    # Cấu Hình TrainingArguments (GPU)
    training_args = TrainingArguments(
        output_dir=CHECKPOINT_DIR,
        per_device_train_batch_size=4,
        per_device_eval_batch_size=4,
        gradient_accumulation_steps=8,
        num_train_epochs=3,
        max_steps=50,  # Đặt max_steps tại đây
        learning_rate=3e-4,
        weight_decay=0.01,
        logging_steps=10,  # Giảm số bước logging để theo dõi thường xuyên hơn
        eval_strategy="steps",  # Đánh giá sau mỗi vài bước
        eval_steps=50,  # Đánh giá sau mỗi 50 bước
        save_strategy="steps",  # Lưu checkpoint sau mỗi vài bước
        save_steps=50,  # Lưu checkpoint sau mỗi 50 bước
        save_total_limit=5,  # Giới hạn số lượng checkpoint lưu trữ
        fp16=True,
        report_to="none",
        load_best_model_at_end=True,
    )
    
    # Data Collator (GPU)
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer, 
        mlm=False,  # Vì bạn đang thực hiện Causal LM
        pad_to_multiple_of=8
    )
    
    # Tạo Trainer (GPU)
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=final_dataset['train'],
        eval_dataset=final_dataset['validation'],
        tokenizer=tokenizer,
        data_collator=data_collator,
        callbacks=[SaveCheckpointCallback()],  # Thêm callback
    )
    
    # Kiểm tra nếu có checkpoint
    checkpoints = [os.path.join(CHECKPOINT_DIR, d) for d in os.listdir(CHECKPOINT_DIR) if d.startswith('checkpoint')]
    if checkpoints:
        latest_checkpoint = max(checkpoints, key=os.path.getctime)
        print(f"Đang tiếp tục huấn luyện từ checkpoint: {latest_checkpoint}")
        trainer.train(resume_from_checkpoint=latest_checkpoint)
    else:
        trainer.train()
    
    # Lưu checkpoint sau khi huấn luyện
    trainer.save_model(CHECKPOINT_DIR)
    return "Huấn luyện hoàn tất hoặc đã tiếp tục từ checkpoint."

# Hàm Tự Động Hóa Việc Gọi Lặp Lại Hàm Huấn Luyện
def continuous_training(total_steps=300, steps_per_call=50):
    """
    Hàm tự động gọi lại `run_training` để hoàn thành quá trình huấn luyện.
    
    Args:
        total_steps (int): Tổng số bước huấn luyện mong muốn.
        steps_per_call (int): Số bước huấn luyện mỗi lần gọi hàm.
    """
    steps_done = 0
    while steps_done < total_steps:
        print(f"Bắt đầu huấn luyện cho {steps_per_call} bước.")
        result = run_training()
        print(result)
        steps_done += steps_per_call
        print(f"Đã huấn luyện {steps_done} / {total_steps} bước.")
        
        # Kiểm tra nếu đã đạt số bước mong muốn
        if steps_done >= total_steps:
            print("Đã hoàn thành toàn bộ quá trình huấn luyện.")
            break
        
        # Chờ một khoảng thời gian trước khi gọi lại (tùy thuộc vào yêu cầu của hệ thống)
        time.sleep(2)  # Thời gian chờ có thể điều chỉnh

# Gọi hàm huấn luyện liên tục
if __name__ == "__main__":
    continuous_training(total_steps=300, steps_per_call=50)