gptq量化相关代码
#43
by
BigMaoGoGoGo
- opened
- gptq_quantization.py +324 -0
- modeling_chatglm.py +18 -5
- quantization.py +17 -3
gptq_quantization.py
ADDED
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import contextlib
|
3 |
+
import logging
|
4 |
+
import math
|
5 |
+
from typing import List, Optional
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import transformers
|
9 |
+
from torch import nn
|
10 |
+
|
11 |
+
LOGGER = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
QUANT_LAYERS = [nn.Linear, nn.Conv2d, transformers.Conv1D]
|
14 |
+
|
15 |
+
def is_transformer_conv1d(layer):
|
16 |
+
return isinstance(layer, transformers.Conv1D)
|
17 |
+
|
18 |
+
|
19 |
+
# These two functions only work on per-channel symmetric quantization for weight
|
20 |
+
def get_weight_scale(weight, weight_bit_width):
|
21 |
+
weight_scale = (weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half()
|
22 |
+
return weight_scale
|
23 |
+
|
24 |
+
def fake_quantize_weight(weight, weight_scale):
|
25 |
+
weight_scale = weight_scale[:, None]
|
26 |
+
fake_quantized_weight = torch.round(weight / weight_scale) * weight_scale
|
27 |
+
return fake_quantized_weight
|
28 |
+
|
29 |
+
|
30 |
+
class GPTQLayerWrapper:
|
31 |
+
def __init__(self, layer_name, layer, weight_bit_width):
|
32 |
+
super().__init__()
|
33 |
+
self.layer_name = layer_name
|
34 |
+
self.layer = layer
|
35 |
+
self.device = layer.weight.device
|
36 |
+
columns = layer.weight.shape[1]
|
37 |
+
self.columns = columns
|
38 |
+
self.H = torch.zeros((columns, columns), device=self.device)
|
39 |
+
self.nsamples = 0
|
40 |
+
self.is_record = True
|
41 |
+
self.weight_bit_width = weight_bit_width
|
42 |
+
self.weight_scale = None
|
43 |
+
|
44 |
+
def record_h(self, x):
|
45 |
+
if self.is_record:
|
46 |
+
x = x.detach().clone()
|
47 |
+
if len(x.shape) == 2:
|
48 |
+
x = x.unsqueeze(0)
|
49 |
+
batch = x.shape[0]
|
50 |
+
if isinstance(self.layer, nn.Linear) or is_transformer_conv1d(self.layer):
|
51 |
+
if len(x.shape) == 3:
|
52 |
+
x = x.reshape((-1, x.shape[-1]))
|
53 |
+
x = x.t()
|
54 |
+
|
55 |
+
if isinstance(self.layer, nn.Conv2d):
|
56 |
+
unfold = nn.Unfold(
|
57 |
+
self.layer.kernel_size,
|
58 |
+
dilation=self.layer.dilation,
|
59 |
+
padding=self.layer.padding,
|
60 |
+
stride=self.layer.stride
|
61 |
+
)
|
62 |
+
x = unfold(x)
|
63 |
+
x = x.permute([1, 0, 2])
|
64 |
+
x = x.flatten(1)
|
65 |
+
|
66 |
+
self.H *= self.nsamples / (self.nsamples + batch)
|
67 |
+
self.nsamples += batch
|
68 |
+
x = math.sqrt(2 / self.nsamples) * x.float()
|
69 |
+
self.H += x.matmul(x.t())
|
70 |
+
|
71 |
+
def quant_weight(self, blocksize=128, percdamp=.01, groupsize=-1):
|
72 |
+
if groupsize != -1:
|
73 |
+
raise RuntimeError("Group quantization of gptq quantizer is not supported for now")
|
74 |
+
weight = self.layer.weight.data.clone()
|
75 |
+
if isinstance(self.layer, nn.Conv2d):
|
76 |
+
weight = weight.flatten(1)
|
77 |
+
if is_transformer_conv1d(self.layer):
|
78 |
+
weight = weight.t()
|
79 |
+
weight = weight.float()
|
80 |
+
|
81 |
+
weight_scale = get_weight_scale(weight, self.weight_bit_width)
|
82 |
+
# todo: use buffer to store scale
|
83 |
+
self.weight_scale = weight_scale
|
84 |
+
H = self.H
|
85 |
+
dead = torch.diag(H) == 0
|
86 |
+
H[dead, dead] = 1
|
87 |
+
weight[:, dead] = 0
|
88 |
+
|
89 |
+
losses = torch.zeros_like(weight)
|
90 |
+
Q = torch.zeros_like(weight)
|
91 |
+
|
92 |
+
damp = percdamp * torch.mean(torch.diag(H))
|
93 |
+
diag = torch.arange(self.columns, device=self.device)
|
94 |
+
H[diag, diag] += damp
|
95 |
+
try:
|
96 |
+
H = torch.linalg.cholesky(H)
|
97 |
+
H = torch.cholesky_inverse(H)
|
98 |
+
H = torch.linalg.cholesky(H, upper=True)
|
99 |
+
except Exception:
|
100 |
+
logging.warning(f"Warning: cannot do compression on layer {self.layer_name} because of inverse error")
|
101 |
+
return
|
102 |
+
|
103 |
+
if H.isnan().any():
|
104 |
+
logging.warning(f"Warning: cannot do compression on layer {self.layer_name} because of inverse error")
|
105 |
+
return
|
106 |
+
|
107 |
+
hinv = H
|
108 |
+
|
109 |
+
for i1 in range(0, self.columns, blocksize):
|
110 |
+
i2 = min(i1 + blocksize, self.columns)
|
111 |
+
count = i2 - i1
|
112 |
+
|
113 |
+
w1 = weight[:, i1:i2].clone()
|
114 |
+
q1 = torch.zeros_like(w1)
|
115 |
+
total_err = torch.zeros_like(w1)
|
116 |
+
losses1 = torch.zeros_like(w1)
|
117 |
+
hinv1 = hinv[i1:i2, i1:i2]
|
118 |
+
|
119 |
+
for i in range(count):
|
120 |
+
w = w1[:, i]
|
121 |
+
d = hinv1[i, i]
|
122 |
+
|
123 |
+
q = fake_quantize_weight(w.unsqueeze(1), weight_scale).flatten()
|
124 |
+
|
125 |
+
q1[:, i] = q
|
126 |
+
losses1[:, i] = (w - q) ** 2 / d ** 2
|
127 |
+
err = (w - q) / d
|
128 |
+
w1[:, i:] -= err.unsqueeze(1).matmul(hinv1[i, i:].unsqueeze(0))
|
129 |
+
total_err[:, i] = err
|
130 |
+
|
131 |
+
Q[:, i1:i2] = q1
|
132 |
+
losses[:, i1:i2] = losses1 / 2
|
133 |
+
|
134 |
+
weight[:, i2:] -= total_err.matmul(hinv[i1:i2, i2:])
|
135 |
+
|
136 |
+
if torch.cuda.is_available():
|
137 |
+
torch.cuda.synchronize()
|
138 |
+
|
139 |
+
if is_transformer_conv1d(self.layer):
|
140 |
+
Q = Q.t()
|
141 |
+
shape = self.layer.weight.shape
|
142 |
+
dtype = self.layer.weight.data.dtype
|
143 |
+
del self.layer.weight
|
144 |
+
setattr(self.layer, "weight", nn.Parameter(Q.reshape(shape).to(dtype), requires_grad=False))
|
145 |
+
del self.H
|
146 |
+
|
147 |
+
|
148 |
+
class GPTQBlockWrapper:
|
149 |
+
def __init__(self, block_name: str, block: nn.Module, weight_bit_width=8):
|
150 |
+
self.layer_wrappers = {}
|
151 |
+
self.hook_handles = []
|
152 |
+
# block order in the whole network
|
153 |
+
self.order = 0
|
154 |
+
self.block_name = block_name
|
155 |
+
|
156 |
+
def get_hook(layer_name):
|
157 |
+
def record_hook(_, x):
|
158 |
+
self.layer_wrappers[layer_name].record_h(x[0])
|
159 |
+
return record_hook
|
160 |
+
|
161 |
+
for layer_name, layer in block.named_modules():
|
162 |
+
if isinstance(layer, tuple(QUANT_LAYERS)):
|
163 |
+
full_layer_name = f"{block_name}.{layer_name}" if layer_name else f"{block_name}"
|
164 |
+
self.layer_wrappers[full_layer_name] = GPTQLayerWrapper(full_layer_name, layer, weight_bit_width)
|
165 |
+
handle = layer.register_forward_pre_hook(get_hook(full_layer_name))
|
166 |
+
self.hook_handles.append(handle)
|
167 |
+
|
168 |
+
def quant_block(self):
|
169 |
+
for _, wrapper in self.layer_wrappers.items():
|
170 |
+
wrapper.quant_weight()
|
171 |
+
|
172 |
+
for h in self.hook_handles:
|
173 |
+
h.remove()
|
174 |
+
|
175 |
+
def set_order(self, idx):
|
176 |
+
self.order = idx
|
177 |
+
|
178 |
+
def get_order(self):
|
179 |
+
return self.order
|
180 |
+
|
181 |
+
def enable(self):
|
182 |
+
for n, l in self.layer_wrappers.items():
|
183 |
+
l.is_record = True
|
184 |
+
|
185 |
+
def disable(self):
|
186 |
+
for n, l in self.layer_wrappers.items():
|
187 |
+
l.is_record = False
|
188 |
+
|
189 |
+
|
190 |
+
class GPTQuantizer:
|
191 |
+
def __init__(self, block_type: Optional[List[type]] = None):
|
192 |
+
self.gptq_block_wrappers = {}
|
193 |
+
self.block_type = block_type
|
194 |
+
|
195 |
+
def wrap_model(self, model: nn.Module, weight_bit_width=8):
|
196 |
+
|
197 |
+
def wrap_block(m, prefix=""):
|
198 |
+
for name, child in m.named_children():
|
199 |
+
child_prefix = f"{prefix}.{name}" if prefix else name
|
200 |
+
if isinstance(child, tuple(self.block_type)):
|
201 |
+
self.gptq_block_wrappers[name] = GPTQBlockWrapper(child_prefix, child, weight_bit_width)
|
202 |
+
LOGGER.debug(f"Calibrate block {child_prefix} as a whole block in GPTQ")
|
203 |
+
else:
|
204 |
+
wrap_block(child, child_prefix)
|
205 |
+
|
206 |
+
wrap_block(model)
|
207 |
+
return model
|
208 |
+
|
209 |
+
@property
|
210 |
+
def calibration_iters(self):
|
211 |
+
return len(self.gptq_block_wrappers)
|
212 |
+
|
213 |
+
@contextlib.contextmanager
|
214 |
+
def record_order(self):
|
215 |
+
counter = 0
|
216 |
+
record_handles = []
|
217 |
+
orders = {}
|
218 |
+
try:
|
219 |
+
def get_record_order_hook(block_name):
|
220 |
+
def record_hook(*args, **kwargs):
|
221 |
+
nonlocal counter
|
222 |
+
if block_name not in orders:
|
223 |
+
orders[block_name] = counter
|
224 |
+
counter += 1
|
225 |
+
return record_hook
|
226 |
+
|
227 |
+
for block_name, block_wrapper in self.gptq_block_wrappers.items():
|
228 |
+
# disable the record
|
229 |
+
for _, layer_wrapper in block_wrapper.layer_wrappers.items():
|
230 |
+
layer_wrapper.is_record = False
|
231 |
+
|
232 |
+
one_layer_wrapper_in_block = list(block_wrapper.layer_wrappers.values())[0]
|
233 |
+
handles = one_layer_wrapper_in_block.layer.register_forward_pre_hook(get_record_order_hook(block_name))
|
234 |
+
record_handles.append(handles)
|
235 |
+
yield
|
236 |
+
except Exception as e:
|
237 |
+
logging.warning(e)
|
238 |
+
finally:
|
239 |
+
for block_name, order in orders.items():
|
240 |
+
self.gptq_block_wrappers[block_name].set_order(order)
|
241 |
+
|
242 |
+
for h in record_handles:
|
243 |
+
h.remove()
|
244 |
+
|
245 |
+
for _, block_wrapper in self.gptq_block_wrappers.items():
|
246 |
+
# disable the record
|
247 |
+
for _, layer_wrapper in block_wrapper.layer_wrappers.items():
|
248 |
+
layer_wrapper.is_record = True
|
249 |
+
|
250 |
+
|
251 |
+
@contextlib.contextmanager
|
252 |
+
def start_calib_iter(self, i):
|
253 |
+
assert i < len(self.gptq_block_wrappers)
|
254 |
+
target_block_wrapper = None
|
255 |
+
try:
|
256 |
+
for _, block_wrapper in self.gptq_block_wrappers.items():
|
257 |
+
if block_wrapper.get_order() == i:
|
258 |
+
block_wrapper.enable()
|
259 |
+
target_block_wrapper = block_wrapper
|
260 |
+
else:
|
261 |
+
block_wrapper.disable()
|
262 |
+
yield
|
263 |
+
finally:
|
264 |
+
target_block_wrapper.quant_block()
|
265 |
+
|
266 |
+
def release_reference(self):
|
267 |
+
# delete reference so that `torch.cuda.empty_cache()` can
|
268 |
+
# release all the gpu memory cache used during calibration
|
269 |
+
for _, block_wrapper in self.gptq_block_wrappers.items():
|
270 |
+
for _, layer_wrapper in block_wrapper.layer_wrappers.items():
|
271 |
+
del layer_wrapper.layer
|
272 |
+
|
273 |
+
torch.cuda.empty_cache()
|
274 |
+
|
275 |
+
|
276 |
+
def locate_parent(root: nn.Module, full_path: str):
|
277 |
+
parent = root
|
278 |
+
path = full_path.split('.')
|
279 |
+
for p in path[:-1]:
|
280 |
+
parent = getattr(parent, p)
|
281 |
+
return parent, path[-1]
|
282 |
+
|
283 |
+
|
284 |
+
@torch.no_grad()
|
285 |
+
def gptq_quantize(model, tokenizer, weight_bit_width, calib_data):
|
286 |
+
from .modeling_chatglm import GLMBlock
|
287 |
+
from .quantization import QuantizedLinear
|
288 |
+
|
289 |
+
quantizer = GPTQuantizer([GLMBlock])
|
290 |
+
calib_model = quantizer.wrap_model(model, weight_bit_width)
|
291 |
+
with quantizer.record_order():
|
292 |
+
calib_model.chat(tokenizer, calib_data[0], history=[])
|
293 |
+
|
294 |
+
logging.info("Start doing calibration using GPTQ ")
|
295 |
+
for i in range(quantizer.calibration_iters):
|
296 |
+
logging.info(f"Process: {i + 1}/{quantizer.calibration_iters}")
|
297 |
+
# todo: should add early return to speed up the calibration
|
298 |
+
# todo: add cpu offload to reduce the gpu memory requirements.
|
299 |
+
with quantizer.start_calib_iter(i):
|
300 |
+
for prompt in calib_data:
|
301 |
+
model.chat(tokenizer, prompt, history=[])
|
302 |
+
|
303 |
+
# replace the fp16 linear with quantized linear
|
304 |
+
for _, block_wrapper in quantizer.gptq_block_wrappers.items():
|
305 |
+
for layer_name, layer_wrapper in block_wrapper.layer_wrappers.items():
|
306 |
+
layer = layer_wrapper.layer
|
307 |
+
parent, name_in_parent = locate_parent(model, layer_name)
|
308 |
+
quantized_layer = QuantizedLinear(
|
309 |
+
weight_bit_width=weight_bit_width,
|
310 |
+
weight_tensor=layer.weight,
|
311 |
+
bias_tensor=layer.bias,
|
312 |
+
weight_scale=layer_wrapper.weight_scale,
|
313 |
+
in_features=layer.in_features,
|
314 |
+
out_features=layer.out_features,
|
315 |
+
bias=True,
|
316 |
+
dtype=torch.half,
|
317 |
+
device=layer_wrapper.device,
|
318 |
+
empty_init=False
|
319 |
+
)
|
320 |
+
parent.add_module(name_in_parent, quantized_layer)
|
321 |
+
|
322 |
+
# release the memory caache during calibration
|
323 |
+
quantizer.release_reference()
|
324 |
+
return
|
modeling_chatglm.py
CHANGED
@@ -1408,12 +1408,14 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1408 |
break
|
1409 |
yield input_ids
|
1410 |
|
1411 |
-
def quantize(
|
|
|
|
|
1412 |
if bits == 0:
|
1413 |
return
|
1414 |
|
1415 |
-
from .quantization import quantize
|
1416 |
-
|
1417 |
if self.quantized:
|
1418 |
logger.info("Already quantized.")
|
1419 |
return self
|
@@ -1421,6 +1423,17 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1421 |
self.quantized = True
|
1422 |
|
1423 |
self.config.quantization_bit = bits
|
1424 |
-
|
1425 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1426 |
return self
|
|
|
1408 |
break
|
1409 |
yield input_ids
|
1410 |
|
1411 |
+
def quantize(
|
1412 |
+
self, bits: int, empty_init=False, quant_algo_type: str="min_max",
|
1413 |
+
calib_data: Optional[List[str]]=None, tokenizer=None, **kwargs):
|
1414 |
if bits == 0:
|
1415 |
return
|
1416 |
|
1417 |
+
from .quantization import quantize, QuantAlgoType
|
1418 |
+
from .gptq_quantization import gptq_quantize
|
1419 |
if self.quantized:
|
1420 |
logger.info("Already quantized.")
|
1421 |
return self
|
|
|
1423 |
self.quantized = True
|
1424 |
|
1425 |
self.config.quantization_bit = bits
|
1426 |
+
quant_algo_type = QuantAlgoType(quant_algo_type)
|
1427 |
+
if quant_algo_type == QuantAlgoType.min_max:
|
1428 |
+
self.transformer = quantize(
|
1429 |
+
self.transformer, bits, empty_init=empty_init, algo_type=quant_algo_type, calib_data=calib_data, tokenizer=tokenizer, **kwargs)
|
1430 |
+
elif quant_algo_type == QuantAlgoType.gptq:
|
1431 |
+
if calib_data is None or tokenizer is None:
|
1432 |
+
raise RuntimeError("If using gptq to quantize the model, "
|
1433 |
+
"calibration data (e.g. some string prompts) and tokenizer should be provided")
|
1434 |
+
gptq_quantize(
|
1435 |
+
self, tokenizer, bits, calib_data
|
1436 |
+
)
|
1437 |
+
else:
|
1438 |
+
raise RuntimeError("Unsupported quantization algorithm type")
|
1439 |
return self
|
quantization.py
CHANGED
@@ -8,7 +8,7 @@ import ctypes
|
|
8 |
from transformers.utils import logging
|
9 |
|
10 |
from typing import List
|
11 |
-
from
|
12 |
|
13 |
logger = logging.get_logger(__name__)
|
14 |
|
@@ -41,6 +41,17 @@ except Exception as exception:
|
|
41 |
logger.warning("Failed to load cpm_kernels:" + str(exception))
|
42 |
|
43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
class W8A16Linear(torch.autograd.Function):
|
45 |
@staticmethod
|
46 |
def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width):
|
@@ -118,7 +129,7 @@ def extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tensor, sourc
|
|
118 |
|
119 |
|
120 |
class QuantizedLinear(Linear):
|
121 |
-
def __init__(self, weight_bit_width: int, weight_tensor=None, bias_tensor=None, empty_init=False, *args, **kwargs):
|
122 |
super(QuantizedLinear, self).__init__(*args, **kwargs)
|
123 |
self.weight_bit_width = weight_bit_width
|
124 |
|
@@ -131,7 +142,10 @@ class QuantizedLinear(Linear):
|
|
131 |
)
|
132 |
self.weight_scale = torch.empty(shape[0], dtype=kwargs["dtype"], device=kwargs["device"])
|
133 |
else:
|
134 |
-
|
|
|
|
|
|
|
135 |
self.weight = torch.round(weight_tensor / self.weight_scale[:, None]).to(torch.int8)
|
136 |
if weight_bit_width == 4:
|
137 |
self.weight = compress_int4_weight(self.weight)
|
|
|
8 |
from transformers.utils import logging
|
9 |
|
10 |
from typing import List
|
11 |
+
from enum import Enum
|
12 |
|
13 |
logger = logging.get_logger(__name__)
|
14 |
|
|
|
41 |
logger.warning("Failed to load cpm_kernels:" + str(exception))
|
42 |
|
43 |
|
44 |
+
class QuantAlgoType(Enum):
|
45 |
+
min_max = 'min_max'
|
46 |
+
gptq = 'gptq'
|
47 |
+
|
48 |
+
@classmethod
|
49 |
+
def _missing_(cls, value):
|
50 |
+
supported_types = [e.value for e in cls]
|
51 |
+
raise ValueError(f"Unsupported quantization algorithm type. Support list: "
|
52 |
+
f"{supported_types}. Got: '{value}'")
|
53 |
+
|
54 |
+
|
55 |
class W8A16Linear(torch.autograd.Function):
|
56 |
@staticmethod
|
57 |
def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width):
|
|
|
129 |
|
130 |
|
131 |
class QuantizedLinear(Linear):
|
132 |
+
def __init__(self, weight_bit_width: int, weight_tensor=None, bias_tensor=None, weight_scale=None, empty_init=False, *args, **kwargs):
|
133 |
super(QuantizedLinear, self).__init__(*args, **kwargs)
|
134 |
self.weight_bit_width = weight_bit_width
|
135 |
|
|
|
142 |
)
|
143 |
self.weight_scale = torch.empty(shape[0], dtype=kwargs["dtype"], device=kwargs["device"])
|
144 |
else:
|
145 |
+
if weight_scale is None:
|
146 |
+
self.weight_scale = (weight_tensor.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half()
|
147 |
+
else:
|
148 |
+
self.weight_scale = weight_scale
|
149 |
self.weight = torch.round(weight_tensor / self.weight_scale[:, None]).to(torch.int8)
|
150 |
if weight_bit_width == 4:
|
151 |
self.weight = compress_int4_weight(self.weight)
|