|
import os |
|
import re |
|
import sys |
|
import time |
|
import torch |
|
import shutil |
|
from _thread import start_new_thread |
|
cuda_visible_devices = os.environ["CUDA_VISIBLE_DEVICES"] |
|
assert len(cuda_visible_devices) == 1, "Only support train on one GPU." |
|
RANK = 64 |
|
|
|
|
|
|
|
|
|
checkpoint_path = os.path.join(os.path.dirname(__file__), "checkpoint") |
|
if not os.path.exists(checkpoint_path): |
|
os.makedirs(checkpoint_path, exist_ok=False) |
|
generated_path = os.path.join(os.path.dirname(__file__), "generated") |
|
if not os.path.exists(generated_path): |
|
os.makedirs(generated_path, exist_ok=False) |
|
adapter_config_path = os.path.join(os.path.dirname(__file__), "adapter_config.json") |
|
|
|
|
|
|
|
import json |
|
config_file = os.path.join(os.path.dirname(os.path.dirname(__file__)), "config.json") |
|
with open(config_file, "r") as f: |
|
additional_config = json.load(f) |
|
root = additional_config["dora_root"] |
|
sys.path.append(root) |
|
os.chdir(root) |
|
print(f"\033[91mWe are working under: {root}\033[0m") |
|
if os.path.exists(f"./finetuned_result/dora_r{RANK}"): |
|
print(f"\033[91mWARNING: ./finetuned_result/dora_r{RANK} existed!\033[0m") |
|
input("\033[91mPress ENTER to clear this dir...\033[0m") |
|
os.system(f"rm ./finetuned_result/dora_r{RANK}/* -rf") |
|
|
|
|
|
|
|
|
|
exit_flag = False |
|
|
|
def move_to_checkpoint(): |
|
global exit_flag |
|
index = 1 |
|
finished_list = [] |
|
while exit_flag is False: |
|
father = f"./finetuned_result/dora_r{RANK}" |
|
if not os.path.exists(father): |
|
time.sleep(1) |
|
continue |
|
item_list = os.listdir(father) |
|
for item in item_list: |
|
src = os.path.join(father, item) |
|
if not os.path.isdir(src): |
|
continue |
|
if item[:4] == "tmp-": |
|
continue |
|
if src in finished_list: |
|
continue |
|
finished_list.append(src) |
|
try: |
|
shutil.copy(os.path.join(src, "adapter_config.json"), adapter_config_path) |
|
src = os.path.join(src, "adapter_model.bin") |
|
diction = torch.load(src, map_location="cpu", weights_only=False) |
|
dst = os.path.join(checkpoint_path, f"{str(index).zfill(7)}.pth") |
|
torch.save(diction, dst) |
|
except Exception as e: |
|
print(f"\033[91mWARNING: encountered {e} and ignored.\033[0m") |
|
continue |
|
print(f"Moved {src} to {dst}.") |
|
index += 1 |
|
time.sleep(1) |
|
start_new_thread(move_to_checkpoint, ()) |
|
|
|
|
|
def remove_early_checkpoint(): |
|
global exit_flag |
|
while exit_flag is False: |
|
item_list = [item for item in os.listdir(checkpoint_path) if item.endswith('.pth')] |
|
if len(item_list) <= 50: |
|
time.sleep(10) |
|
continue |
|
def extract_number(filename): |
|
match = re.search(r'(\d+).pth', filename) |
|
return int(match.group(1)) if match else -1 |
|
sorted_items = sorted(item_list, key=extract_number) |
|
num_to_remove = len(sorted_items) - 50 |
|
for i in range(num_to_remove): |
|
file_to_remove = os.path.join(checkpoint_path, sorted_items[i]) |
|
os.remove(file_to_remove) |
|
print(f"\033[91mRemoved: {file_to_remove}\033[0m") |
|
time.sleep(10) |
|
start_new_thread(remove_early_checkpoint, ()) |
|
|
|
|
|
|
|
|
|
|
|
activate_path = shutil.which('conda')[:-5] + "activate" |
|
env_path = shutil.which('conda')[:-9] + f"envs/{additional_config['dora_env_name']}" |
|
os.system( |
|
f"bash -c \"source {activate_path} {env_path} && " + |
|
f"sh llama_7B_Dora.sh {RANK} {RANK*2} ./finetuned_result/dora_r{RANK} {cuda_visible_devices}\"" |
|
) |
|
|
|
time.sleep(5) |
|
exit_flag = True |
|
time.sleep(20) |