File size: 3,778 Bytes
f7009b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
import re
import sys
import time
import torch
import shutil
from _thread import start_new_thread
cuda_visible_devices = os.environ["CUDA_VISIBLE_DEVICES"]
assert len(cuda_visible_devices) == 1, "Only support train on one GPU."
RANK = 64




checkpoint_path = os.path.join(os.path.dirname(__file__), "checkpoint")
if not os.path.exists(checkpoint_path):
    os.makedirs(checkpoint_path, exist_ok=False)
generated_path = os.path.join(os.path.dirname(__file__), "generated")
if not os.path.exists(generated_path):
    os.makedirs(generated_path, exist_ok=False)
adapter_config_path = os.path.join(os.path.dirname(__file__), "adapter_config.json")


# change root
import json
config_file = os.path.join(os.path.dirname(os.path.dirname(__file__)), "config.json")
with open(config_file, "r") as f:
    additional_config = json.load(f)
root = additional_config["dora_root"]
sys.path.append(root)
os.chdir(root)
print(f"\033[91mWe are working under: {root}\033[0m")
if os.path.exists(f"./finetuned_result/dora_r{RANK}"):
    print(f"\033[91mWARNING: ./finetuned_result/dora_r{RANK} existed!\033[0m")
    input("\033[91mPress ENTER to clear this dir...\033[0m")
    os.system(f"rm ./finetuned_result/dora_r{RANK}/* -rf")




exit_flag = False

def move_to_checkpoint():
    global exit_flag
    index = 1
    finished_list = []
    while exit_flag is False:
        father = f"./finetuned_result/dora_r{RANK}"
        if not os.path.exists(father):
            time.sleep(1)
            continue
        item_list = os.listdir(father)
        for item in item_list:
            src = os.path.join(father, item)
            if not os.path.isdir(src):
                continue  # is file saved in the end
            if item[:4] == "tmp-":
                continue  # is a tmp file
            if src in finished_list:
                continue  # have been processed
            finished_list.append(src)
            try:  # deleted before loaded
                shutil.copy(os.path.join(src, "adapter_config.json"), adapter_config_path)
                src = os.path.join(src, "adapter_model.bin")
                diction = torch.load(src, map_location="cpu", weights_only=False)
                dst = os.path.join(checkpoint_path, f"{str(index).zfill(7)}.pth")
                torch.save(diction, dst)
            except Exception as e:
                print(f"\033[91mWARNING: encountered {e} and ignored.\033[0m")
                continue
            print(f"Moved {src} to {dst}.")
            index += 1
        time.sleep(1)
start_new_thread(move_to_checkpoint, ())


def remove_early_checkpoint():
    global exit_flag
    while exit_flag is False:
        item_list = [item for item in os.listdir(checkpoint_path) if item.endswith('.pth')]
        if len(item_list) <= 50:
            time.sleep(10)
            continue
        def extract_number(filename):
            match = re.search(r'(\d+).pth', filename)
            return int(match.group(1)) if match else -1
        sorted_items = sorted(item_list, key=extract_number)
        num_to_remove = len(sorted_items) - 50
        for i in range(num_to_remove):
            file_to_remove = os.path.join(checkpoint_path, sorted_items[i])
            os.remove(file_to_remove)
            print(f"\033[91mRemoved: {file_to_remove}\033[0m")
        time.sleep(10)
start_new_thread(remove_early_checkpoint, ())




# start training
activate_path = shutil.which('conda')[:-5] + "activate"
env_path = shutil.which('conda')[:-9] + f"envs/{additional_config['dora_env_name']}"
os.system(
    f"bash -c \"source {activate_path} {env_path} && " +
    f"sh llama_7B_Dora.sh {RANK} {RANK*2} ./finetuned_result/dora_r{RANK} {cuda_visible_devices}\""
)
# noinspection PyRedeclaration
time.sleep(5)
exit_flag = True
time.sleep(20)