robertgshaw2 commited on
Commit
cb5a068
·
1 Parent(s): 8d813c3

added conversion files

Browse files
Files changed (3) hide show
  1. README.md +0 -3
  2. convert.py +161 -0
  3. load.py +172 -0
README.md DELETED
@@ -1,3 +0,0 @@
1
- ---
2
- license: llama2
3
- ---
 
 
 
 
convert.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, argparse, copy
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ from auto_gptq.nn_modules.qlinear.qlinear_exllama import QuantLinear
4
+ from marlin import Layer as MarlinLayer
5
+ import gc
6
+
7
+ parser = argparse.ArgumentParser()
8
+ parser.add_argument("--model-id", type=str)
9
+ parser.add_argument("--save-path", type=str)
10
+ parser.add_argument("--do-generation", action="store_true")
11
+
12
+ def _validate_compatibility(model):
13
+ if not hasattr(model.config, "quantization_config"):
14
+ raise ValueError("Must be a quantized model to convert to Marlin Format")
15
+ quantization_config = model.config.quantization_config
16
+ if quantization_config.quant_method != "gptq":
17
+ raise ValueError(f"Only GPTQ models can be converted to Marlin format. You passed a model with quant_method={quantization_config.quant_method}")
18
+ if quantization_config.bits != 4:
19
+ raise ValueError(f"Only 4 bit quantized models can be converted to Marlin format. You passed a model with bits={quantization_config.bits}")
20
+ if quantization_config.group_size != 128:
21
+ raise ValueError(f"Only group size 128 models can be converted to Marlin format. You passed a model with group_size={quantization_config.group_size}")
22
+ if not quantization_config.sym:
23
+ raise ValueError(f"Only models with symmetric quantization can be converted to Marlin Format. You passed a model with sym={quantization_config.sym}")
24
+ if quantization_config.desc_act:
25
+ raise ValueError(f"Models with act order quantization cannot be converted to Marlin Format. You passed a model with desc_act={quantization_config.desc_act}")
26
+
27
+ @torch.no_grad()
28
+ def unpack_4bit_to_32bit_signed(qweight, qzeros):
29
+ # Unpack 4-bit values and interpret them as signed integers
30
+ unpacked_weights = torch.zeros((qweight.shape[0]*8, qweight.shape[1]), dtype=torch.int8, device=qweight.device, requires_grad=False)
31
+ unpacked_zeros = torch.zeros((qzeros.shape[0], qzeros.shape[1]*8), dtype=torch.int8, device=qzeros.device, requires_grad=False)
32
+
33
+ for row in range(unpacked_weights.shape[0]):
34
+ i = row % 8
35
+ unpacked_weights[row, :] = (qweight[row // 8, :] >> (4 * i)) & 0xF
36
+
37
+ for col in range(unpacked_zeros.shape[1]):
38
+ i = col % 8
39
+ unpacked_zeros[:, col] = (qzeros[:, col // 8] >> (4 * i)) & 0xF
40
+
41
+ return unpacked_weights, unpacked_zeros + 1
42
+
43
+ @torch.no_grad()
44
+ def dequantize_weight(layer):
45
+ qweight, qzeros, scales = layer.qweight, layer.qzeros, layer.scales
46
+ unpacked_qweight, unpacked_qzeros = unpack_4bit_to_32bit_signed(qweight, qzeros)
47
+ group_size = unpacked_qweight.shape[0] // scales.shape[0]
48
+ scales = scales.repeat_interleave(group_size, dim=0)
49
+ unpacked_qzeros = unpacked_qzeros.repeat_interleave(group_size, dim=0)
50
+ unpacked_qweight = (unpacked_qweight - unpacked_qzeros) * scales
51
+
52
+ return unpacked_qweight.T
53
+
54
+ @torch.no_grad()
55
+ def convert_model(model, verbose=True):
56
+ for name, module in model.named_modules():
57
+ if not isinstance(module, QuantLinear):
58
+ continue
59
+
60
+ if verbose:
61
+ print(f"--- Converting Module: {name}")
62
+ parent_name = ".".join(name.split(".")[:-1])
63
+ layer_name = name[len(parent_name) + 1:]
64
+
65
+ # Dequantize the weight.
66
+ dequantized_weight = dequantize_weight(module).to(torch.float16)
67
+ linear_module = torch.nn.Linear(
68
+ in_features=dequantized_weight.shape[1],
69
+ out_features=dequantized_weight.shape[0],
70
+ bias=False,
71
+ dtype=torch.float16,
72
+ device="cuda")
73
+ linear_module.weight.data.copy_(dequantized_weight)
74
+
75
+ # Create new linear method and copy to model.
76
+ new_module = MarlinLayer(
77
+ infeatures=linear_module.in_features,
78
+ outfeatures=linear_module.out_features,
79
+ groupsize=model.config.quantization_config.group_size)
80
+ new_module.pack(linear_module, scales=copy.deepcopy(module.scales.data.t()))
81
+
82
+ # Save to parent.
83
+ parent_module = model.get_submodule(parent_name)
84
+ setattr(parent_module, layer_name, new_module)
85
+
86
+ # Free cuda memory.
87
+ del dequantized_weight, module
88
+ torch.cuda.empty_cache()
89
+ gc.collect()
90
+
91
+ return model
92
+
93
+ @torch.no_grad()
94
+ def dequantize_model(model, verbose=True):
95
+ for name, module in model.named_modules():
96
+ if not isinstance(module, QuantLinear):
97
+ continue
98
+
99
+ if verbose:
100
+ print(f"--- Dequantizing Module: {name}")
101
+ parent_name = ".".join(name.split(".")[:-1])
102
+ layer_name = name[len(parent_name) + 1:]
103
+
104
+ # Dequantize the weight.
105
+ dequantized_weight = dequantize_weight(module)
106
+ dequantized_weight_cpu = dequantized_weight.to("cpu")
107
+
108
+ # Create new linear method and copy to model.
109
+ new_module = torch.nn.Linear(
110
+ in_features=dequantized_weight_cpu.shape[1],
111
+ out_features=dequantized_weight_cpu.shape[0],
112
+ bias=False,
113
+ dtype=torch.float16)
114
+ new_module.weight.data.copy_(dequantized_weight_cpu)
115
+ new_module.scales = torch.nn.Parameter(copy.deepcopy(module.scales.data))
116
+
117
+ # Save to parent.
118
+ parent_module = model.get_submodule(parent_name)
119
+ setattr(parent_module, layer_name, new_module)
120
+
121
+ # Free cuda memory.
122
+ del dequantized_weight, dequantized_weight_cpu, module
123
+ torch.cuda.empty_cache()
124
+
125
+ return model
126
+
127
+ if __name__ == "__main__":
128
+ args = parser.parse_args()
129
+ model_id = args.model_id
130
+ save_path = args.save_path
131
+ do_generation = args.do_generation
132
+
133
+ print("Loading gptq model...")
134
+ model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
135
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
136
+
137
+ # Validate that this model is compatible with Marlin.
138
+ print("Validating compatibility...")
139
+ _validate_compatibility(model)
140
+
141
+ # Dequantize the Model.
142
+ print("Converting model...")
143
+ model = convert_model(model).to("cpu")
144
+
145
+ # Save after updating quantization config.
146
+ print("Saving marlin model...")
147
+ model.config.quantization_config = {
148
+ "group_size": model.config.quantization_config.group_size,
149
+ "quant_method": "marlin"
150
+ }
151
+ model.save_pretrained(save_path)
152
+ tokenizer.save_pretrained(save_path)
153
+
154
+ if do_generation:
155
+ print("Generating sample text...")
156
+ model.to("cuda")
157
+ prompt = "My favorite song is"
158
+ inputs = tokenizer(prompt, return_tensors="pt")
159
+ inputs = {k: v.to("cuda") for k, v in inputs.items()}
160
+ outputs = model.generate(**inputs, max_new_tokens=50, do_sample=False)
161
+ print(tokenizer.batch_decode(outputs)[0])
load.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from huggingface_hub import snapshot_download
4
+ from safetensors.torch import safe_open
5
+ from typing import Optional, Tuple, List, Iterator
6
+ import os, filelock, json, glob
7
+ from accelerate import init_empty_weights
8
+ from transformers import AutoModelForCausalLM, AutoConfig
9
+ import marlin
10
+
11
+ # Adapted from https://github.com/vllm-project/vllm/blob/14cc317ba48229d93ee2417822d96ccb8db56abe/vllm/model_executor/weight_utils.py#L191
12
+
13
+ def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None):
14
+ lock_dir = cache_dir if cache_dir is not None else "/tmp"
15
+ lock_file_name = model_name_or_path.replace("/", "-") + ".lock"
16
+ lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name))
17
+ return lock
18
+
19
+ def prepare_hf_model_weights(
20
+ model_name_or_path: str,
21
+ cache_dir: Optional[str] = None,
22
+ load_format: str = "auto",
23
+ fall_back_to_pt: bool = True,
24
+ revision: Optional[str] = None,
25
+ ) -> Tuple[str, List[str], bool]:
26
+ # Download model weights from huggingface.
27
+ is_local = os.path.isdir(model_name_or_path)
28
+ use_safetensors = False
29
+ # Some quantized models use .pt files for storing the weights.
30
+ if load_format == "auto":
31
+ allow_patterns = ["*.safetensors", "*.bin"]
32
+ elif load_format == "safetensors":
33
+ use_safetensors = True
34
+ allow_patterns = ["*.safetensors"]
35
+ elif load_format == "pt":
36
+ allow_patterns = ["*.pt"]
37
+ elif load_format == "npcache":
38
+ allow_patterns = ["*.bin"]
39
+ else:
40
+ raise ValueError(f"Unknown load_format: {load_format}")
41
+
42
+ if fall_back_to_pt:
43
+ allow_patterns += ["*.pt"]
44
+
45
+ if not is_local:
46
+ # Use file lock to prevent multiple processes from
47
+ # downloading the same model weights at the same time.
48
+ with get_lock(model_name_or_path, cache_dir):
49
+ hf_folder = snapshot_download(model_name_or_path,
50
+ allow_patterns=allow_patterns,
51
+ cache_dir=cache_dir,
52
+ revision=revision)
53
+ else:
54
+ hf_folder = model_name_or_path
55
+ hf_weights_files: List[str] = []
56
+ for pattern in allow_patterns:
57
+ hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
58
+ if len(hf_weights_files) > 0:
59
+ if pattern == "*.safetensors":
60
+ use_safetensors = True
61
+ break
62
+ if not use_safetensors:
63
+ # Exclude files that are not needed for inference.
64
+ # https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
65
+ blacklist = [
66
+ "training_args.bin",
67
+ "optimizer.bin",
68
+ "optimizer.pt",
69
+ "scheduler.pt",
70
+ "scaler.pt",
71
+ ]
72
+ hf_weights_files = [
73
+ f for f in hf_weights_files
74
+ if not any(f.endswith(x) for x in blacklist)
75
+ ]
76
+
77
+ if len(hf_weights_files) == 0:
78
+ raise RuntimeError(
79
+ f"Cannot find any model weights with `{model_name_or_path}`")
80
+
81
+ return hf_folder, hf_weights_files, use_safetensors
82
+
83
+ def hf_model_weights_iterator(
84
+ model_name_or_path: str,
85
+ cache_dir: Optional[str] = None,
86
+ load_format: str = "auto",
87
+ revision: Optional[str] = None,
88
+ fall_back_to_pt: Optional[bool] = True,
89
+ ) -> Iterator[Tuple[str, torch.Tensor]]:
90
+ hf_folder, hf_weights_files, use_safetensors = prepare_hf_model_weights(
91
+ model_name_or_path,
92
+ cache_dir=cache_dir,
93
+ load_format=load_format,
94
+ fall_back_to_pt=fall_back_to_pt,
95
+ revision=revision)
96
+
97
+ if load_format == "npcache":
98
+ # Currently np_cache only support *.bin checkpoints
99
+ assert use_safetensors is False
100
+
101
+ # Convert the model weights from torch tensors to numpy arrays for
102
+ # faster loading.
103
+ np_folder = os.path.join(hf_folder, "np")
104
+ os.makedirs(np_folder, exist_ok=True)
105
+ weight_names_file = os.path.join(np_folder, "weight_names.json")
106
+ # Use file lock to prevent multiple processes from
107
+ # dumping the same model weights to numpy at the same time.
108
+ with get_lock(model_name_or_path, cache_dir):
109
+ if not os.path.exists(weight_names_file):
110
+ weight_names = []
111
+ for bin_file in hf_weights_files:
112
+ state = torch.load(bin_file, map_location="cpu")
113
+ for name, param in state.items():
114
+ param_path = os.path.join(np_folder, name)
115
+ with open(param_path, "wb") as f:
116
+ np.save(f, param.cpu().detach().numpy())
117
+ weight_names.append(name)
118
+ with open(weight_names_file, "w") as f:
119
+ json.dump(weight_names, f)
120
+
121
+ with open(weight_names_file, "r") as f:
122
+ weight_names = json.load(f)
123
+
124
+ for name in weight_names:
125
+ param_path = os.path.join(np_folder, name)
126
+ with open(param_path, "rb") as f:
127
+ param = np.load(f)
128
+ yield name, torch.from_numpy(param)
129
+ elif use_safetensors:
130
+ for st_file in hf_weights_files:
131
+ with safe_open(st_file, framework="pt") as f:
132
+ for name in f.keys(): # noqa: SIM118
133
+ param = f.get_tensor(name)
134
+ yield name, param
135
+ else:
136
+ for bin_file in hf_weights_files:
137
+ state = torch.load(bin_file, map_location="cpu")
138
+ for name, param in state.items():
139
+ yield name, param
140
+ del state
141
+ torch.cuda.empty_cache()
142
+
143
+ @torch.no_grad()
144
+ def load_model(model_path):
145
+ with init_empty_weights():
146
+ config = AutoConfig.from_pretrained(model_path)
147
+
148
+ if not hasattr(config, "quantization_config"):
149
+ raise ValueError("Must be a Marlin quantized model, but your config has no quantization config.")
150
+ if "quant_method" not in config.quantization_config:
151
+ raise ValueError("Must be a Marlin quantized model, but your quantization config has no quant_method.")
152
+ if config.quantization_config["quant_method"] != "marlin":
153
+ raise ValueError(f"Must be a Marlin model, but you passed a model with quant_method = {config.quant_method}")
154
+
155
+ model = AutoModelForCausalLM.from_config(config)
156
+ marlin.replace_linear(
157
+ model.model,
158
+ groupsize=config.quantization_config["group_size"]
159
+ )
160
+
161
+ module_dict = dict(model.named_modules())
162
+ for name, loaded_weight in hf_model_weights_iterator(model_path):
163
+ module_name = ".".join(name.split(".")[:-1])
164
+ param_name = name[len(module_name) + 1:]
165
+ module = module_dict[module_name]
166
+
167
+ if not hasattr(module, param_name):
168
+ raise ValueError("Key mismatch.")
169
+
170
+ setattr(module, param_name, torch.nn.Parameter(loaded_weight, requires_grad=False))
171
+
172
+ return model