gptq量化相关代码

#43
Files changed (3) hide show
  1. gptq_quantization.py +324 -0
  2. modeling_chatglm.py +18 -5
  3. quantization.py +17 -3
gptq_quantization.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import contextlib
3
+ import logging
4
+ import math
5
+ from typing import List, Optional
6
+
7
+ import torch
8
+ import transformers
9
+ from torch import nn
10
+
11
+ LOGGER = logging.getLogger(__name__)
12
+
13
+ QUANT_LAYERS = [nn.Linear, nn.Conv2d, transformers.Conv1D]
14
+
15
+ def is_transformer_conv1d(layer):
16
+ return isinstance(layer, transformers.Conv1D)
17
+
18
+
19
+ # These two functions only work on per-channel symmetric quantization for weight
20
+ def get_weight_scale(weight, weight_bit_width):
21
+ weight_scale = (weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half()
22
+ return weight_scale
23
+
24
+ def fake_quantize_weight(weight, weight_scale):
25
+ weight_scale = weight_scale[:, None]
26
+ fake_quantized_weight = torch.round(weight / weight_scale) * weight_scale
27
+ return fake_quantized_weight
28
+
29
+
30
+ class GPTQLayerWrapper:
31
+ def __init__(self, layer_name, layer, weight_bit_width):
32
+ super().__init__()
33
+ self.layer_name = layer_name
34
+ self.layer = layer
35
+ self.device = layer.weight.device
36
+ columns = layer.weight.shape[1]
37
+ self.columns = columns
38
+ self.H = torch.zeros((columns, columns), device=self.device)
39
+ self.nsamples = 0
40
+ self.is_record = True
41
+ self.weight_bit_width = weight_bit_width
42
+ self.weight_scale = None
43
+
44
+ def record_h(self, x):
45
+ if self.is_record:
46
+ x = x.detach().clone()
47
+ if len(x.shape) == 2:
48
+ x = x.unsqueeze(0)
49
+ batch = x.shape[0]
50
+ if isinstance(self.layer, nn.Linear) or is_transformer_conv1d(self.layer):
51
+ if len(x.shape) == 3:
52
+ x = x.reshape((-1, x.shape[-1]))
53
+ x = x.t()
54
+
55
+ if isinstance(self.layer, nn.Conv2d):
56
+ unfold = nn.Unfold(
57
+ self.layer.kernel_size,
58
+ dilation=self.layer.dilation,
59
+ padding=self.layer.padding,
60
+ stride=self.layer.stride
61
+ )
62
+ x = unfold(x)
63
+ x = x.permute([1, 0, 2])
64
+ x = x.flatten(1)
65
+
66
+ self.H *= self.nsamples / (self.nsamples + batch)
67
+ self.nsamples += batch
68
+ x = math.sqrt(2 / self.nsamples) * x.float()
69
+ self.H += x.matmul(x.t())
70
+
71
+ def quant_weight(self, blocksize=128, percdamp=.01, groupsize=-1):
72
+ if groupsize != -1:
73
+ raise RuntimeError("Group quantization of gptq quantizer is not supported for now")
74
+ weight = self.layer.weight.data.clone()
75
+ if isinstance(self.layer, nn.Conv2d):
76
+ weight = weight.flatten(1)
77
+ if is_transformer_conv1d(self.layer):
78
+ weight = weight.t()
79
+ weight = weight.float()
80
+
81
+ weight_scale = get_weight_scale(weight, self.weight_bit_width)
82
+ # todo: use buffer to store scale
83
+ self.weight_scale = weight_scale
84
+ H = self.H
85
+ dead = torch.diag(H) == 0
86
+ H[dead, dead] = 1
87
+ weight[:, dead] = 0
88
+
89
+ losses = torch.zeros_like(weight)
90
+ Q = torch.zeros_like(weight)
91
+
92
+ damp = percdamp * torch.mean(torch.diag(H))
93
+ diag = torch.arange(self.columns, device=self.device)
94
+ H[diag, diag] += damp
95
+ try:
96
+ H = torch.linalg.cholesky(H)
97
+ H = torch.cholesky_inverse(H)
98
+ H = torch.linalg.cholesky(H, upper=True)
99
+ except Exception:
100
+ logging.warning(f"Warning: cannot do compression on layer {self.layer_name} because of inverse error")
101
+ return
102
+
103
+ if H.isnan().any():
104
+ logging.warning(f"Warning: cannot do compression on layer {self.layer_name} because of inverse error")
105
+ return
106
+
107
+ hinv = H
108
+
109
+ for i1 in range(0, self.columns, blocksize):
110
+ i2 = min(i1 + blocksize, self.columns)
111
+ count = i2 - i1
112
+
113
+ w1 = weight[:, i1:i2].clone()
114
+ q1 = torch.zeros_like(w1)
115
+ total_err = torch.zeros_like(w1)
116
+ losses1 = torch.zeros_like(w1)
117
+ hinv1 = hinv[i1:i2, i1:i2]
118
+
119
+ for i in range(count):
120
+ w = w1[:, i]
121
+ d = hinv1[i, i]
122
+
123
+ q = fake_quantize_weight(w.unsqueeze(1), weight_scale).flatten()
124
+
125
+ q1[:, i] = q
126
+ losses1[:, i] = (w - q) ** 2 / d ** 2
127
+ err = (w - q) / d
128
+ w1[:, i:] -= err.unsqueeze(1).matmul(hinv1[i, i:].unsqueeze(0))
129
+ total_err[:, i] = err
130
+
131
+ Q[:, i1:i2] = q1
132
+ losses[:, i1:i2] = losses1 / 2
133
+
134
+ weight[:, i2:] -= total_err.matmul(hinv[i1:i2, i2:])
135
+
136
+ if torch.cuda.is_available():
137
+ torch.cuda.synchronize()
138
+
139
+ if is_transformer_conv1d(self.layer):
140
+ Q = Q.t()
141
+ shape = self.layer.weight.shape
142
+ dtype = self.layer.weight.data.dtype
143
+ del self.layer.weight
144
+ setattr(self.layer, "weight", nn.Parameter(Q.reshape(shape).to(dtype), requires_grad=False))
145
+ del self.H
146
+
147
+
148
+ class GPTQBlockWrapper:
149
+ def __init__(self, block_name: str, block: nn.Module, weight_bit_width=8):
150
+ self.layer_wrappers = {}
151
+ self.hook_handles = []
152
+ # block order in the whole network
153
+ self.order = 0
154
+ self.block_name = block_name
155
+
156
+ def get_hook(layer_name):
157
+ def record_hook(_, x):
158
+ self.layer_wrappers[layer_name].record_h(x[0])
159
+ return record_hook
160
+
161
+ for layer_name, layer in block.named_modules():
162
+ if isinstance(layer, tuple(QUANT_LAYERS)):
163
+ full_layer_name = f"{block_name}.{layer_name}" if layer_name else f"{block_name}"
164
+ self.layer_wrappers[full_layer_name] = GPTQLayerWrapper(full_layer_name, layer, weight_bit_width)
165
+ handle = layer.register_forward_pre_hook(get_hook(full_layer_name))
166
+ self.hook_handles.append(handle)
167
+
168
+ def quant_block(self):
169
+ for _, wrapper in self.layer_wrappers.items():
170
+ wrapper.quant_weight()
171
+
172
+ for h in self.hook_handles:
173
+ h.remove()
174
+
175
+ def set_order(self, idx):
176
+ self.order = idx
177
+
178
+ def get_order(self):
179
+ return self.order
180
+
181
+ def enable(self):
182
+ for n, l in self.layer_wrappers.items():
183
+ l.is_record = True
184
+
185
+ def disable(self):
186
+ for n, l in self.layer_wrappers.items():
187
+ l.is_record = False
188
+
189
+
190
+ class GPTQuantizer:
191
+ def __init__(self, block_type: Optional[List[type]] = None):
192
+ self.gptq_block_wrappers = {}
193
+ self.block_type = block_type
194
+
195
+ def wrap_model(self, model: nn.Module, weight_bit_width=8):
196
+
197
+ def wrap_block(m, prefix=""):
198
+ for name, child in m.named_children():
199
+ child_prefix = f"{prefix}.{name}" if prefix else name
200
+ if isinstance(child, tuple(self.block_type)):
201
+ self.gptq_block_wrappers[name] = GPTQBlockWrapper(child_prefix, child, weight_bit_width)
202
+ LOGGER.debug(f"Calibrate block {child_prefix} as a whole block in GPTQ")
203
+ else:
204
+ wrap_block(child, child_prefix)
205
+
206
+ wrap_block(model)
207
+ return model
208
+
209
+ @property
210
+ def calibration_iters(self):
211
+ return len(self.gptq_block_wrappers)
212
+
213
+ @contextlib.contextmanager
214
+ def record_order(self):
215
+ counter = 0
216
+ record_handles = []
217
+ orders = {}
218
+ try:
219
+ def get_record_order_hook(block_name):
220
+ def record_hook(*args, **kwargs):
221
+ nonlocal counter
222
+ if block_name not in orders:
223
+ orders[block_name] = counter
224
+ counter += 1
225
+ return record_hook
226
+
227
+ for block_name, block_wrapper in self.gptq_block_wrappers.items():
228
+ # disable the record
229
+ for _, layer_wrapper in block_wrapper.layer_wrappers.items():
230
+ layer_wrapper.is_record = False
231
+
232
+ one_layer_wrapper_in_block = list(block_wrapper.layer_wrappers.values())[0]
233
+ handles = one_layer_wrapper_in_block.layer.register_forward_pre_hook(get_record_order_hook(block_name))
234
+ record_handles.append(handles)
235
+ yield
236
+ except Exception as e:
237
+ logging.warning(e)
238
+ finally:
239
+ for block_name, order in orders.items():
240
+ self.gptq_block_wrappers[block_name].set_order(order)
241
+
242
+ for h in record_handles:
243
+ h.remove()
244
+
245
+ for _, block_wrapper in self.gptq_block_wrappers.items():
246
+ # disable the record
247
+ for _, layer_wrapper in block_wrapper.layer_wrappers.items():
248
+ layer_wrapper.is_record = True
249
+
250
+
251
+ @contextlib.contextmanager
252
+ def start_calib_iter(self, i):
253
+ assert i < len(self.gptq_block_wrappers)
254
+ target_block_wrapper = None
255
+ try:
256
+ for _, block_wrapper in self.gptq_block_wrappers.items():
257
+ if block_wrapper.get_order() == i:
258
+ block_wrapper.enable()
259
+ target_block_wrapper = block_wrapper
260
+ else:
261
+ block_wrapper.disable()
262
+ yield
263
+ finally:
264
+ target_block_wrapper.quant_block()
265
+
266
+ def release_reference(self):
267
+ # delete reference so that `torch.cuda.empty_cache()` can
268
+ # release all the gpu memory cache used during calibration
269
+ for _, block_wrapper in self.gptq_block_wrappers.items():
270
+ for _, layer_wrapper in block_wrapper.layer_wrappers.items():
271
+ del layer_wrapper.layer
272
+
273
+ torch.cuda.empty_cache()
274
+
275
+
276
+ def locate_parent(root: nn.Module, full_path: str):
277
+ parent = root
278
+ path = full_path.split('.')
279
+ for p in path[:-1]:
280
+ parent = getattr(parent, p)
281
+ return parent, path[-1]
282
+
283
+
284
+ @torch.no_grad()
285
+ def gptq_quantize(model, tokenizer, weight_bit_width, calib_data):
286
+ from .modeling_chatglm import GLMBlock
287
+ from .quantization import QuantizedLinear
288
+
289
+ quantizer = GPTQuantizer([GLMBlock])
290
+ calib_model = quantizer.wrap_model(model, weight_bit_width)
291
+ with quantizer.record_order():
292
+ calib_model.chat(tokenizer, calib_data[0], history=[])
293
+
294
+ logging.info("Start doing calibration using GPTQ ")
295
+ for i in range(quantizer.calibration_iters):
296
+ logging.info(f"Process: {i + 1}/{quantizer.calibration_iters}")
297
+ # todo: should add early return to speed up the calibration
298
+ # todo: add cpu offload to reduce the gpu memory requirements.
299
+ with quantizer.start_calib_iter(i):
300
+ for prompt in calib_data:
301
+ model.chat(tokenizer, prompt, history=[])
302
+
303
+ # replace the fp16 linear with quantized linear
304
+ for _, block_wrapper in quantizer.gptq_block_wrappers.items():
305
+ for layer_name, layer_wrapper in block_wrapper.layer_wrappers.items():
306
+ layer = layer_wrapper.layer
307
+ parent, name_in_parent = locate_parent(model, layer_name)
308
+ quantized_layer = QuantizedLinear(
309
+ weight_bit_width=weight_bit_width,
310
+ weight_tensor=layer.weight,
311
+ bias_tensor=layer.bias,
312
+ weight_scale=layer_wrapper.weight_scale,
313
+ in_features=layer.in_features,
314
+ out_features=layer.out_features,
315
+ bias=True,
316
+ dtype=torch.half,
317
+ device=layer_wrapper.device,
318
+ empty_init=False
319
+ )
320
+ parent.add_module(name_in_parent, quantized_layer)
321
+
322
+ # release the memory caache during calibration
323
+ quantizer.release_reference()
324
+ return
modeling_chatglm.py CHANGED
@@ -1408,12 +1408,14 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1408
  break
1409
  yield input_ids
1410
 
1411
- def quantize(self, bits: int, empty_init=False, **kwargs):
 
 
1412
  if bits == 0:
1413
  return
1414
 
1415
- from .quantization import quantize
1416
-
1417
  if self.quantized:
1418
  logger.info("Already quantized.")
1419
  return self
@@ -1421,6 +1423,17 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1421
  self.quantized = True
1422
 
1423
  self.config.quantization_bit = bits
1424
-
1425
- self.transformer = quantize(self.transformer, bits, empty_init=empty_init, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
1426
  return self
 
1408
  break
1409
  yield input_ids
1410
 
1411
+ def quantize(
1412
+ self, bits: int, empty_init=False, quant_algo_type: str="min_max",
1413
+ calib_data: Optional[List[str]]=None, tokenizer=None, **kwargs):
1414
  if bits == 0:
1415
  return
1416
 
1417
+ from .quantization import quantize, QuantAlgoType
1418
+ from .gptq_quantization import gptq_quantize
1419
  if self.quantized:
1420
  logger.info("Already quantized.")
1421
  return self
 
1423
  self.quantized = True
1424
 
1425
  self.config.quantization_bit = bits
1426
+ quant_algo_type = QuantAlgoType(quant_algo_type)
1427
+ if quant_algo_type == QuantAlgoType.min_max:
1428
+ self.transformer = quantize(
1429
+ self.transformer, bits, empty_init=empty_init, algo_type=quant_algo_type, calib_data=calib_data, tokenizer=tokenizer, **kwargs)
1430
+ elif quant_algo_type == QuantAlgoType.gptq:
1431
+ if calib_data is None or tokenizer is None:
1432
+ raise RuntimeError("If using gptq to quantize the model, "
1433
+ "calibration data (e.g. some string prompts) and tokenizer should be provided")
1434
+ gptq_quantize(
1435
+ self, tokenizer, bits, calib_data
1436
+ )
1437
+ else:
1438
+ raise RuntimeError("Unsupported quantization algorithm type")
1439
  return self
quantization.py CHANGED
@@ -8,7 +8,7 @@ import ctypes
8
  from transformers.utils import logging
9
 
10
  from typing import List
11
- from functools import partial
12
 
13
  logger = logging.get_logger(__name__)
14
 
@@ -41,6 +41,17 @@ except Exception as exception:
41
  logger.warning("Failed to load cpm_kernels:" + str(exception))
42
 
43
 
 
 
 
 
 
 
 
 
 
 
 
44
  class W8A16Linear(torch.autograd.Function):
45
  @staticmethod
46
  def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width):
@@ -118,7 +129,7 @@ def extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tensor, sourc
118
 
119
 
120
  class QuantizedLinear(Linear):
121
- def __init__(self, weight_bit_width: int, weight_tensor=None, bias_tensor=None, empty_init=False, *args, **kwargs):
122
  super(QuantizedLinear, self).__init__(*args, **kwargs)
123
  self.weight_bit_width = weight_bit_width
124
 
@@ -131,7 +142,10 @@ class QuantizedLinear(Linear):
131
  )
132
  self.weight_scale = torch.empty(shape[0], dtype=kwargs["dtype"], device=kwargs["device"])
133
  else:
134
- self.weight_scale = (weight_tensor.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half()
 
 
 
135
  self.weight = torch.round(weight_tensor / self.weight_scale[:, None]).to(torch.int8)
136
  if weight_bit_width == 4:
137
  self.weight = compress_int4_weight(self.weight)
 
8
  from transformers.utils import logging
9
 
10
  from typing import List
11
+ from enum import Enum
12
 
13
  logger = logging.get_logger(__name__)
14
 
 
41
  logger.warning("Failed to load cpm_kernels:" + str(exception))
42
 
43
 
44
+ class QuantAlgoType(Enum):
45
+ min_max = 'min_max'
46
+ gptq = 'gptq'
47
+
48
+ @classmethod
49
+ def _missing_(cls, value):
50
+ supported_types = [e.value for e in cls]
51
+ raise ValueError(f"Unsupported quantization algorithm type. Support list: "
52
+ f"{supported_types}. Got: '{value}'")
53
+
54
+
55
  class W8A16Linear(torch.autograd.Function):
56
  @staticmethod
57
  def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width):
 
129
 
130
 
131
  class QuantizedLinear(Linear):
132
+ def __init__(self, weight_bit_width: int, weight_tensor=None, bias_tensor=None, weight_scale=None, empty_init=False, *args, **kwargs):
133
  super(QuantizedLinear, self).__init__(*args, **kwargs)
134
  self.weight_bit_width = weight_bit_width
135
 
 
142
  )
143
  self.weight_scale = torch.empty(shape[0], dtype=kwargs["dtype"], device=kwargs["device"])
144
  else:
145
+ if weight_scale is None:
146
+ self.weight_scale = (weight_tensor.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half()
147
+ else:
148
+ self.weight_scale = weight_scale
149
  self.weight = torch.round(weight_tensor / self.weight_scale[:, None]).to(torch.int8)
150
  if weight_bit_width == 4:
151
  self.weight = compress_int4_weight(self.weight)