kurogane commited on
Commit
1c786f8
1 Parent(s): cd5210b

Upload Merger.py

Browse files
Files changed (1) hide show
  1. Merger.py +102 -0
Merger.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #coding:utf-8
2
+
3
+ import os
4
+ import torch
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
6
+ from safetensors.torch import save_file, load_file
7
+
8
+ DIR_CACHE = r"E:\llm_baack\cache"
9
+ DIR_OFFLOAD = r"E:\llm_baack\offload"
10
+ DIR_SAVE = r"E:\llm_baack\safetensors"
11
+
12
+ for _dir in [DIR_CACHE, DIR_OFFLOAD, DIR_SAVE]:
13
+ if not os.path.exists(_dir):
14
+ os.makedirs(_dir)
15
+
16
+ MODEL_SUBJ = "aaditya/Llama3-OpenBioLLM-8B"
17
+ MODEL_VECTOR = "aixsatoshi/Llama-3-youko-8b-instruct-chatvector"
18
+ MODEL_BASE = "NousResearch/Meta-Llama-3-8B"
19
+
20
+
21
+ def download_model(model_name):
22
+ s_name_offload = model_name.replace("/", "-")
23
+ dir_offload = os.path.join(DIR_OFFLOAD, s_name_offload)
24
+ if not os.path.exists(dir_offload):
25
+ os.makedirs(dir_offload)
26
+
27
+ model = AutoModelForCausalLM.from_pretrained(
28
+ model_name,
29
+ cache_dir=DIR_CACHE,
30
+ torch_dtype=torch.bfloat16,
31
+ device_map="cpu",
32
+ offload_folder=dir_offload,
33
+ offload_state_dict=True,
34
+ trust_remote_code=True,
35
+ )
36
+ model.eval()
37
+ model.hf_device_map
38
+
39
+ model_state_dict = model.state_dict().copy()
40
+
41
+ for key in model_state_dict.keys():
42
+ model_value = model_state_dict[key].clone().to("cpu")
43
+ print(key, model_value.dtype, model_value.shape, model_value)
44
+ break
45
+
46
+ s_name = model_name.replace("/", "-")
47
+ dir_save_safe = os.path.join(DIR_SAVE, f"{s_name}.safetensors")
48
+ save_file(model_state_dict, dir_save_safe)
49
+
50
+ # modelを解放
51
+ del model
52
+ del model_state_dict
53
+
54
+ return dir_save_safe, s_name
55
+
56
+
57
+ DIR_MODEL_SUBJ, s_name_subj = download_model(MODEL_SUBJ)
58
+ DIR_MODEL_VECTOR, s_name_vect = download_model(MODEL_VECTOR)
59
+ DIR_MODEL_BASE, s_name_base = download_model(MODEL_BASE)
60
+
61
+
62
+ d_state_subj = load_file(DIR_MODEL_SUBJ, device="cpu")
63
+ d_state_vector = load_file(DIR_MODEL_VECTOR, device="cpu")
64
+ new_state_dict = d_state_subj
65
+
66
+ with torch.no_grad():
67
+ for key in d_state_subj.keys():
68
+ print(key)
69
+
70
+ new_state_dict[key] = (
71
+ new_state_dict[key].to("cuda") + d_state_vector[key].to("cuda")
72
+ ).to("cpu")
73
+
74
+ new_state_dict
75
+ del d_state_subj, d_state_vector
76
+ torch.cuda.empty_cache()
77
+ dir_save_subjpvect = os.path.join(DIR_SAVE, f"{s_name_subj}+{s_name_vect}.safetensors")
78
+ save_file(new_state_dict, dir_save_subjpvect)
79
+
80
+ # モデルの読み込み
81
+ d_state_subj_subjpvect = load_file(dir_save_subjpvect, device="cpu")
82
+ d_state_base = load_file(DIR_MODEL_BASE, device="cpu")
83
+
84
+ # キー名が同じことを確認
85
+ for key_subjpvect, key_base in zip(
86
+ d_state_subj_subjpvect.keys(), d_state_base.keys()
87
+ ):
88
+ assert key_subjpvect == key_base
89
+
90
+ new_state_dict = d_state_subj_subjpvect
91
+
92
+ with torch.no_grad():
93
+ for key in new_state_dict.keys():
94
+ print(key)
95
+
96
+ new_state_dict[key] = (
97
+ new_state_dict[key].to("cuda") - d_state_base[key].to("cuda")
98
+ ).to("cpu")
99
+
100
+ new_state_dict
101
+ save_file(new_state_dict, os.path.join(DIR_SAVE, f"{s_name_subj}+{s_name_vect}-{s_name_base}.safetensors"))
102
+