chatglm-6b / gptq_quantization.py
BigMaoGoGoGo's picture
fix gpu cache
ac792e3
raw
history blame
11.6 kB
import contextlib
import logging
import math
from typing import List, Optional
import torch
import transformers
from torch import nn
LOGGER = logging.getLogger(__name__)
QUANT_LAYERS = [nn.Linear, nn.Conv2d, transformers.Conv1D]
def is_transformer_conv1d(layer):
return isinstance(layer, transformers.Conv1D)
# These two functions only work on per-channel symmetric quantization for weight
def get_weight_scale(weight, weight_bit_width):
weight_scale = (weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half()
return weight_scale
def fake_quantize_weight(weight, weight_scale):
weight_scale = weight_scale[:, None]
fake_quantized_weight = torch.round(weight / weight_scale) * weight_scale
return fake_quantized_weight
class GPTQLayerWrapper:
def __init__(self, layer_name, layer, weight_bit_width):
super().__init__()
self.layer_name = layer_name
self.layer = layer
self.device = layer.weight.device
columns = layer.weight.shape[1]
self.columns = columns
self.H = torch.zeros((columns, columns), device=self.device)
self.nsamples = 0
self.is_record = True
self.weight_bit_width = weight_bit_width
self.weight_scale = None
def record_h(self, x):
if self.is_record:
x = x.detach().clone()
if len(x.shape) == 2:
x = x.unsqueeze(0)
batch = x.shape[0]
if isinstance(self.layer, nn.Linear) or is_transformer_conv1d(self.layer):
if len(x.shape) == 3:
x = x.reshape((-1, x.shape[-1]))
x = x.t()
if isinstance(self.layer, nn.Conv2d):
unfold = nn.Unfold(
self.layer.kernel_size,
dilation=self.layer.dilation,
padding=self.layer.padding,
stride=self.layer.stride
)
x = unfold(x)
x = x.permute([1, 0, 2])
x = x.flatten(1)
self.H *= self.nsamples / (self.nsamples + batch)
self.nsamples += batch
x = math.sqrt(2 / self.nsamples) * x.float()
self.H += x.matmul(x.t())
def quant_weight(self, blocksize=128, percdamp=.01, groupsize=-1):
if groupsize != -1:
raise RuntimeError("Group quantization of gptq quantizer is not supported for now")
weight = self.layer.weight.data.clone()
if isinstance(self.layer, nn.Conv2d):
weight = weight.flatten(1)
if is_transformer_conv1d(self.layer):
weight = weight.t()
weight = weight.float()
weight_scale = get_weight_scale(weight, self.weight_bit_width)
# todo: use buffer to store scale
self.weight_scale = weight_scale
H = self.H
dead = torch.diag(H) == 0
H[dead, dead] = 1
weight[:, dead] = 0
losses = torch.zeros_like(weight)
Q = torch.zeros_like(weight)
damp = percdamp * torch.mean(torch.diag(H))
diag = torch.arange(self.columns, device=self.device)
H[diag, diag] += damp
try:
H = torch.linalg.cholesky(H)
H = torch.cholesky_inverse(H)
H = torch.linalg.cholesky(H, upper=True)
except Exception:
logging.warning(f"Warning: cannot do compression on layer {self.layer_name} because of inverse error")
return
if H.isnan().any():
logging.warning(f"Warning: cannot do compression on layer {self.layer_name} because of inverse error")
return
hinv = H
for i1 in range(0, self.columns, blocksize):
i2 = min(i1 + blocksize, self.columns)
count = i2 - i1
w1 = weight[:, i1:i2].clone()
q1 = torch.zeros_like(w1)
total_err = torch.zeros_like(w1)
losses1 = torch.zeros_like(w1)
hinv1 = hinv[i1:i2, i1:i2]
for i in range(count):
w = w1[:, i]
d = hinv1[i, i]
q = fake_quantize_weight(w.unsqueeze(1), weight_scale).flatten()
q1[:, i] = q
losses1[:, i] = (w - q) ** 2 / d ** 2
err = (w - q) / d
w1[:, i:] -= err.unsqueeze(1).matmul(hinv1[i, i:].unsqueeze(0))
total_err[:, i] = err
Q[:, i1:i2] = q1
losses[:, i1:i2] = losses1 / 2
weight[:, i2:] -= total_err.matmul(hinv[i1:i2, i2:])
if torch.cuda.is_available():
torch.cuda.synchronize()
if is_transformer_conv1d(self.layer):
Q = Q.t()
shape = self.layer.weight.shape
dtype = self.layer.weight.data.dtype
del self.layer.weight
setattr(self.layer, "weight", nn.Parameter(Q.reshape(shape).to(dtype), requires_grad=False))
del self.H
class GPTQBlockWrapper:
def __init__(self, block_name: str, block: nn.Module, weight_bit_width=8):
self.layer_wrappers = {}
self.hook_handles = []
# block order in the whole network
self.order = 0
self.block_name = block_name
def get_hook(layer_name):
def record_hook(_, x):
self.layer_wrappers[layer_name].record_h(x[0])
return record_hook
for layer_name, layer in block.named_modules():
if isinstance(layer, tuple(QUANT_LAYERS)):
full_layer_name = f"{block_name}.{layer_name}" if layer_name else f"{block_name}"
self.layer_wrappers[full_layer_name] = GPTQLayerWrapper(full_layer_name, layer, weight_bit_width)
handle = layer.register_forward_pre_hook(get_hook(full_layer_name))
self.hook_handles.append(handle)
def quant_block(self):
for _, wrapper in self.layer_wrappers.items():
wrapper.quant_weight()
for h in self.hook_handles:
h.remove()
def set_order(self, idx):
self.order = idx
def get_order(self):
return self.order
def enable(self):
for n, l in self.layer_wrappers.items():
l.is_record = True
def disable(self):
for n, l in self.layer_wrappers.items():
l.is_record = False
class GPTQuantizer:
def __init__(self, block_type: Optional[List[type]] = None):
self.gptq_block_wrappers = {}
self.block_type = block_type
def wrap_model(self, model: nn.Module, weight_bit_width=8):
def wrap_block(m, prefix=""):
for name, child in m.named_children():
child_prefix = f"{prefix}.{name}" if prefix else name
if isinstance(child, tuple(self.block_type)):
self.gptq_block_wrappers[name] = GPTQBlockWrapper(child_prefix, child, weight_bit_width)
LOGGER.debug(f"Calibrate block {child_prefix} as a whole block in GPTQ")
else:
wrap_block(child, child_prefix)
wrap_block(model)
return model
@property
def calibration_iters(self):
return len(self.gptq_block_wrappers)
@contextlib.contextmanager
def record_order(self):
counter = 0
record_handles = []
orders = {}
try:
def get_record_order_hook(block_name):
def record_hook(*args, **kwargs):
nonlocal counter
if block_name not in orders:
orders[block_name] = counter
counter += 1
return record_hook
for block_name, block_wrapper in self.gptq_block_wrappers.items():
# disable the record
for _, layer_wrapper in block_wrapper.layer_wrappers.items():
layer_wrapper.is_record = False
one_layer_wrapper_in_block = list(block_wrapper.layer_wrappers.values())[0]
handles = one_layer_wrapper_in_block.layer.register_forward_pre_hook(get_record_order_hook(block_name))
record_handles.append(handles)
yield
except Exception as e:
logging.warning(e)
finally:
for block_name, order in orders.items():
self.gptq_block_wrappers[block_name].set_order(order)
for h in record_handles:
h.remove()
for _, block_wrapper in self.gptq_block_wrappers.items():
# disable the record
for _, layer_wrapper in block_wrapper.layer_wrappers.items():
layer_wrapper.is_record = True
@contextlib.contextmanager
def start_calib_iter(self, i):
assert i < len(self.gptq_block_wrappers)
target_block_wrapper = None
try:
for _, block_wrapper in self.gptq_block_wrappers.items():
if block_wrapper.get_order() == i:
block_wrapper.enable()
target_block_wrapper = block_wrapper
else:
block_wrapper.disable()
yield
finally:
target_block_wrapper.quant_block()
def release_reference(self):
# delete reference so that `torch.cuda.empty_cache()` can
# release all the gpu memory cache used during calibration
for _, block_wrapper in self.gptq_block_wrappers.items():
for _, layer_wrapper in block_wrapper.layer_wrappers.items():
del layer_wrapper.layer
torch.cuda.empty_cache()
def locate_parent(root: nn.Module, full_path: str):
parent = root
path = full_path.split('.')
for p in path[:-1]:
parent = getattr(parent, p)
return parent, path[-1]
@torch.no_grad()
def gptq_quantize(model, tokenizer, weight_bit_width, calib_data):
from .modeling_chatglm import GLMBlock
from .quantization import QuantizedLinear
quantizer = GPTQuantizer([GLMBlock])
calib_model = quantizer.wrap_model(model, weight_bit_width)
with quantizer.record_order():
calib_model.chat(tokenizer, calib_data[0], history=[])
logging.info("Start doing calibration using GPTQ ")
for i in range(quantizer.calibration_iters):
logging.info(f"Process: {i + 1}/{quantizer.calibration_iters}")
# todo: should add early return to speed up the calibration
# todo: add cpu offload to reduce the gpu memory requirements.
with quantizer.start_calib_iter(i):
for prompt in calib_data:
model.chat(tokenizer, prompt, history=[])
# replace the fp16 linear with quantized linear
for _, block_wrapper in quantizer.gptq_block_wrappers.items():
for layer_name, layer_wrapper in block_wrapper.layer_wrappers.items():
layer = layer_wrapper.layer
parent, name_in_parent = locate_parent(model, layer_name)
quantized_layer = QuantizedLinear(
weight_bit_width=weight_bit_width,
weight_tensor=layer.weight,
bias_tensor=layer.bias,
weight_scale=layer_wrapper.weight_scale,
in_features=layer.in_features,
out_features=layer.out_features,
bias=True,
dtype=torch.half,
device=layer_wrapper.device,
empty_init=False
)
parent.add_module(name_in_parent, quantized_layer)
# release the memory caache during calibration
quantizer.release_reference()
return