Spaces:
Sleeping
Sleeping
hoduyquocbao
commited on
Commit
•
3d31e68
1
Parent(s):
9c3f080
update feature check
Browse files- checkpoint.py +3 -4
checkpoint.py
CHANGED
@@ -7,9 +7,7 @@ from transformers import (
|
|
7 |
TrainingArguments,
|
8 |
Trainer,
|
9 |
DataCollatorForLanguageModeling,
|
10 |
-
TrainerCallback
|
11 |
-
TrainerState,
|
12 |
-
TrainerControl
|
13 |
)
|
14 |
from peft import LoraConfig, get_peft_model
|
15 |
import spaces
|
@@ -83,8 +81,9 @@ class SaveCheckpointCallback(TrainerCallback):
|
|
83 |
if state.global_step % args.save_steps == 0 and state.global_step != 0:
|
84 |
checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
|
85 |
print(f"Lưu checkpoint tại: {checkpoint_path}")
|
|
|
86 |
trainer.save_model(checkpoint_path)
|
87 |
-
return
|
88 |
|
89 |
# Định Nghĩa Hàm Huấn Luyện với Decorator @spaces.GPU
|
90 |
@spaces.GPU(duration=15, queue=False)
|
|
|
7 |
TrainingArguments,
|
8 |
Trainer,
|
9 |
DataCollatorForLanguageModeling,
|
10 |
+
TrainerCallback
|
|
|
|
|
11 |
)
|
12 |
from peft import LoraConfig, get_peft_model
|
13 |
import spaces
|
|
|
81 |
if state.global_step % args.save_steps == 0 and state.global_step != 0:
|
82 |
checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
|
83 |
print(f"Lưu checkpoint tại: {checkpoint_path}")
|
84 |
+
trainer = kwargs['trainer'] # Truy cập trainer từ kwargs
|
85 |
trainer.save_model(checkpoint_path)
|
86 |
+
return control # Trả về đối tượng control hiện tại
|
87 |
|
88 |
# Định Nghĩa Hàm Huấn Luyện với Decorator @spaces.GPU
|
89 |
@spaces.GPU(duration=15, queue=False)
|