|
|
|
import contextlib |
|
import logging |
|
import math |
|
from typing import List, Optional |
|
|
|
import torch |
|
import transformers |
|
from torch import nn |
|
|
|
LOGGER = logging.getLogger(__name__) |
|
|
|
QUANT_LAYERS = [nn.Linear, nn.Conv2d, transformers.Conv1D] |
|
|
|
def is_transformer_conv1d(layer): |
|
return isinstance(layer, transformers.Conv1D) |
|
|
|
|
|
|
|
def get_weight_scale(weight, weight_bit_width): |
|
weight_scale = (weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half() |
|
return weight_scale |
|
|
|
def fake_quantize_weight(weight, weight_scale): |
|
weight_scale = weight_scale[:, None] |
|
fake_quantized_weight = torch.round(weight / weight_scale) * weight_scale |
|
return fake_quantized_weight |
|
|
|
|
|
class GPTQLayerWrapper: |
|
def __init__(self, layer_name, layer, weight_bit_width): |
|
super().__init__() |
|
self.layer_name = layer_name |
|
self.layer = layer |
|
self.device = layer.weight.device |
|
columns = layer.weight.shape[1] |
|
self.columns = columns |
|
self.H = torch.zeros((columns, columns), device=self.device) |
|
self.nsamples = 0 |
|
self.is_record = True |
|
self.weight_bit_width = weight_bit_width |
|
self.weight_scale = None |
|
|
|
def record_h(self, x): |
|
if self.is_record: |
|
x = x.detach().clone() |
|
if len(x.shape) == 2: |
|
x = x.unsqueeze(0) |
|
batch = x.shape[0] |
|
if isinstance(self.layer, nn.Linear) or is_transformer_conv1d(self.layer): |
|
if len(x.shape) == 3: |
|
x = x.reshape((-1, x.shape[-1])) |
|
x = x.t() |
|
|
|
if isinstance(self.layer, nn.Conv2d): |
|
unfold = nn.Unfold( |
|
self.layer.kernel_size, |
|
dilation=self.layer.dilation, |
|
padding=self.layer.padding, |
|
stride=self.layer.stride |
|
) |
|
x = unfold(x) |
|
x = x.permute([1, 0, 2]) |
|
x = x.flatten(1) |
|
|
|
self.H *= self.nsamples / (self.nsamples + batch) |
|
self.nsamples += batch |
|
x = math.sqrt(2 / self.nsamples) * x.float() |
|
self.H += x.matmul(x.t()) |
|
|
|
def quant_weight(self, blocksize=128, percdamp=.01, groupsize=-1): |
|
if groupsize != -1: |
|
raise RuntimeError("Group quantization of gptq quantizer is not supported for now") |
|
weight = self.layer.weight.data.clone() |
|
if isinstance(self.layer, nn.Conv2d): |
|
weight = weight.flatten(1) |
|
if is_transformer_conv1d(self.layer): |
|
weight = weight.t() |
|
weight = weight.float() |
|
|
|
weight_scale = get_weight_scale(weight, self.weight_bit_width) |
|
|
|
self.weight_scale = weight_scale |
|
H = self.H |
|
dead = torch.diag(H) == 0 |
|
H[dead, dead] = 1 |
|
weight[:, dead] = 0 |
|
|
|
losses = torch.zeros_like(weight) |
|
Q = torch.zeros_like(weight) |
|
|
|
damp = percdamp * torch.mean(torch.diag(H)) |
|
diag = torch.arange(self.columns, device=self.device) |
|
H[diag, diag] += damp |
|
try: |
|
H = torch.linalg.cholesky(H) |
|
H = torch.cholesky_inverse(H) |
|
H = torch.linalg.cholesky(H, upper=True) |
|
except Exception: |
|
logging.warning(f"Warning: cannot do compression on layer {self.layer_name} because of inverse error") |
|
return |
|
|
|
if H.isnan().any(): |
|
logging.warning(f"Warning: cannot do compression on layer {self.layer_name} because of inverse error") |
|
return |
|
|
|
hinv = H |
|
|
|
for i1 in range(0, self.columns, blocksize): |
|
i2 = min(i1 + blocksize, self.columns) |
|
count = i2 - i1 |
|
|
|
w1 = weight[:, i1:i2].clone() |
|
q1 = torch.zeros_like(w1) |
|
total_err = torch.zeros_like(w1) |
|
losses1 = torch.zeros_like(w1) |
|
hinv1 = hinv[i1:i2, i1:i2] |
|
|
|
for i in range(count): |
|
w = w1[:, i] |
|
d = hinv1[i, i] |
|
|
|
q = fake_quantize_weight(w.unsqueeze(1), weight_scale).flatten() |
|
|
|
q1[:, i] = q |
|
losses1[:, i] = (w - q) ** 2 / d ** 2 |
|
err = (w - q) / d |
|
w1[:, i:] -= err.unsqueeze(1).matmul(hinv1[i, i:].unsqueeze(0)) |
|
total_err[:, i] = err |
|
|
|
Q[:, i1:i2] = q1 |
|
losses[:, i1:i2] = losses1 / 2 |
|
|
|
weight[:, i2:] -= total_err.matmul(hinv[i1:i2, i2:]) |
|
|
|
if torch.cuda.is_available(): |
|
torch.cuda.synchronize() |
|
|
|
if is_transformer_conv1d(self.layer): |
|
Q = Q.t() |
|
shape = self.layer.weight.shape |
|
dtype = self.layer.weight.data.dtype |
|
del self.layer.weight |
|
setattr(self.layer, "weight", nn.Parameter(Q.reshape(shape).to(dtype), requires_grad=False)) |
|
del self.H |
|
|
|
|
|
class GPTQBlockWrapper: |
|
def __init__(self, block_name: str, block: nn.Module, weight_bit_width=8): |
|
self.layer_wrappers = {} |
|
self.hook_handles = [] |
|
|
|
self.order = 0 |
|
self.block_name = block_name |
|
|
|
def get_hook(layer_name): |
|
def record_hook(_, x): |
|
self.layer_wrappers[layer_name].record_h(x[0]) |
|
return record_hook |
|
|
|
for layer_name, layer in block.named_modules(): |
|
if isinstance(layer, tuple(QUANT_LAYERS)): |
|
full_layer_name = f"{block_name}.{layer_name}" if layer_name else f"{block_name}" |
|
self.layer_wrappers[full_layer_name] = GPTQLayerWrapper(full_layer_name, layer, weight_bit_width) |
|
handle = layer.register_forward_pre_hook(get_hook(full_layer_name)) |
|
self.hook_handles.append(handle) |
|
|
|
def quant_block(self): |
|
for _, wrapper in self.layer_wrappers.items(): |
|
wrapper.quant_weight() |
|
|
|
for h in self.hook_handles: |
|
h.remove() |
|
|
|
def set_order(self, idx): |
|
self.order = idx |
|
|
|
def get_order(self): |
|
return self.order |
|
|
|
def enable(self): |
|
for n, l in self.layer_wrappers.items(): |
|
l.is_record = True |
|
|
|
def disable(self): |
|
for n, l in self.layer_wrappers.items(): |
|
l.is_record = False |
|
|
|
|
|
class GPTQuantizer: |
|
def __init__(self, block_type: Optional[List[type]] = None): |
|
self.gptq_block_wrappers = {} |
|
self.block_type = block_type |
|
|
|
def wrap_model(self, model: nn.Module, weight_bit_width=8): |
|
|
|
def wrap_block(m, prefix=""): |
|
for name, child in m.named_children(): |
|
child_prefix = f"{prefix}.{name}" if prefix else name |
|
if isinstance(child, tuple(self.block_type)): |
|
self.gptq_block_wrappers[name] = GPTQBlockWrapper(child_prefix, child, weight_bit_width) |
|
LOGGER.debug(f"Calibrate block {child_prefix} as a whole block in GPTQ") |
|
else: |
|
wrap_block(child, child_prefix) |
|
|
|
wrap_block(model) |
|
return model |
|
|
|
@property |
|
def calibration_iters(self): |
|
return len(self.gptq_block_wrappers) |
|
|
|
@contextlib.contextmanager |
|
def record_order(self): |
|
counter = 0 |
|
record_handles = [] |
|
orders = {} |
|
try: |
|
def get_record_order_hook(block_name): |
|
def record_hook(*args, **kwargs): |
|
nonlocal counter |
|
if block_name not in orders: |
|
orders[block_name] = counter |
|
counter += 1 |
|
return record_hook |
|
|
|
for block_name, block_wrapper in self.gptq_block_wrappers.items(): |
|
|
|
for _, layer_wrapper in block_wrapper.layer_wrappers.items(): |
|
layer_wrapper.is_record = False |
|
|
|
one_layer_wrapper_in_block = list(block_wrapper.layer_wrappers.values())[0] |
|
handles = one_layer_wrapper_in_block.layer.register_forward_pre_hook(get_record_order_hook(block_name)) |
|
record_handles.append(handles) |
|
yield |
|
except Exception as e: |
|
logging.warning(e) |
|
finally: |
|
for block_name, order in orders.items(): |
|
self.gptq_block_wrappers[block_name].set_order(order) |
|
|
|
for h in record_handles: |
|
h.remove() |
|
|
|
for _, block_wrapper in self.gptq_block_wrappers.items(): |
|
|
|
for _, layer_wrapper in block_wrapper.layer_wrappers.items(): |
|
layer_wrapper.is_record = True |
|
|
|
|
|
@contextlib.contextmanager |
|
def start_calib_iter(self, i): |
|
assert i < len(self.gptq_block_wrappers) |
|
target_block_wrapper = None |
|
try: |
|
for _, block_wrapper in self.gptq_block_wrappers.items(): |
|
if block_wrapper.get_order() == i: |
|
block_wrapper.enable() |
|
target_block_wrapper = block_wrapper |
|
else: |
|
block_wrapper.disable() |
|
yield |
|
finally: |
|
target_block_wrapper.quant_block() |
|
|
|
def release_reference(self): |
|
|
|
|
|
for _, block_wrapper in self.gptq_block_wrappers.items(): |
|
for _, layer_wrapper in block_wrapper.layer_wrappers.items(): |
|
del layer_wrapper.layer |
|
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
def locate_parent(root: nn.Module, full_path: str): |
|
parent = root |
|
path = full_path.split('.') |
|
for p in path[:-1]: |
|
parent = getattr(parent, p) |
|
return parent, path[-1] |
|
|
|
|
|
@torch.no_grad() |
|
def gptq_quantize(model, tokenizer, weight_bit_width, calib_data): |
|
from .modeling_chatglm import GLMBlock |
|
from .quantization import QuantizedLinear |
|
|
|
quantizer = GPTQuantizer([GLMBlock]) |
|
calib_model = quantizer.wrap_model(model, weight_bit_width) |
|
with quantizer.record_order(): |
|
calib_model.chat(tokenizer, calib_data[0], history=[]) |
|
|
|
logging.info("Start doing calibration using GPTQ ") |
|
for i in range(quantizer.calibration_iters): |
|
logging.info(f"Process: {i + 1}/{quantizer.calibration_iters}") |
|
|
|
|
|
with quantizer.start_calib_iter(i): |
|
for prompt in calib_data: |
|
model.chat(tokenizer, prompt, history=[]) |
|
|
|
|
|
for _, block_wrapper in quantizer.gptq_block_wrappers.items(): |
|
for layer_name, layer_wrapper in block_wrapper.layer_wrappers.items(): |
|
layer = layer_wrapper.layer |
|
parent, name_in_parent = locate_parent(model, layer_name) |
|
quantized_layer = QuantizedLinear( |
|
weight_bit_width=weight_bit_width, |
|
weight_tensor=layer.weight, |
|
bias_tensor=layer.bias, |
|
weight_scale=layer_wrapper.weight_scale, |
|
in_features=layer.in_features, |
|
out_features=layer.out_features, |
|
bias=True, |
|
dtype=torch.half, |
|
device=layer_wrapper.device, |
|
empty_init=False |
|
) |
|
parent.add_module(name_in_parent, quantized_layer) |
|
|
|
|
|
quantizer.release_reference() |
|
return |
|
|