BlueLM-7B-Chat-4bits / quantization.py
jeffreygao's picture
First model version
7dcb63b
raw
history blame
5.03 kB
import bz2
import torch
import base64
import ctypes
import os
import sys
import traceback
import math
from torch.nn.parameter import Parameter
from transformers.utils import logging
import ctypes
import pkg_resources
from typing import List
logger = logging.get_logger(__name__)
try:
import quant_cuda
except:
print('CUDA extension not installed.')
class QuantizedLinear(torch.nn.Module):
def __init__(self, weight_bit_width: int, weight, bias=None, device="cpu", dtype=None, empty_init=False, *args,
**kwargs):
super().__init__()
self.weight_bit_width = weight_bit_width
shape = weight.shape
self.shape = shape
self.group_size = 128
self.register_buffer('qzeros', torch.zeros((math.ceil(shape[1]/self.group_size),shape[0] // 256 * (weight_bit_width * 8)), dtype=torch.int))
self.register_buffer('scales', torch.zeros((math.ceil(shape[1]/self.group_size),shape[0]), dtype=torch.float))
self.register_buffer(
'qweight', torch.zeros((shape[1] // 256 * (weight_bit_width * 8), shape[0]), dtype=torch.int)
)
def forward(self, x):
intermediate_dtype = torch.float32
outshape = list(x.shape)
outshape[-1] = self.shape[0]
x = x.reshape(-1, x.shape[-1])
y = torch.zeros(x.shape[0], outshape[-1], dtype=intermediate_dtype, device=x.device)
output_dtype = x.dtype
x = x.to(intermediate_dtype)
if self.weight_bit_width == 2:
quant_cuda.vecquant2matmul(x, self.qweight, y, self.scales, self.qzeros, self.group_size)
elif self.weight_bit_width == 3:
quant_cuda.vecquant3matmul(x, self.qweight, y, self.scales, self.qzeros, self.group_size)
elif self.weight_bit_width == 4:
quant_cuda.vecquant4matmul(x, self.qweight, y, self.scales, self.qzeros, self.group_size)
elif self.weight_bit_width == 8:
quant_cuda.vecquant8matmul(x, self.qweight, y, self.scales, self.qzeros, self.group_size)
else:
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
y = y.to(output_dtype)
return y.reshape(outshape)
def quantize(model, weight_bit_width, empty_init=False, device=None):
for layer in model.layers:
layer.self_attn.q_proj = QuantizedLinear(
weight_bit_width=weight_bit_width,
weight=layer.self_attn.q_proj.weight,
bias=layer.self_attn.q_proj.bias,
dtype=layer.self_attn.q_proj.weight.dtype,
device=layer.self_attn.q_proj.weight.device if device is None else device,
empty_init=empty_init
)
layer.self_attn.k_proj = QuantizedLinear(
weight_bit_width=weight_bit_width,
weight=layer.self_attn.k_proj.weight,
bias=layer.self_attn.k_proj.bias,
dtype=layer.self_attn.k_proj.weight.dtype,
device=layer.self_attn.k_proj.weight.device if device is None else device,
empty_init=empty_init
)
layer.self_attn.v_proj = QuantizedLinear(
weight_bit_width=weight_bit_width,
weight=layer.self_attn.v_proj.weight,
bias=layer.self_attn.v_proj.bias,
dtype=layer.self_attn.v_proj.weight.dtype,
device=layer.self_attn.v_proj.weight.device if device is None else device,
empty_init=empty_init
)
layer.self_attn.o_proj = QuantizedLinear(
weight_bit_width=weight_bit_width,
weight=layer.self_attn.o_proj.weight,
bias=layer.self_attn.o_proj.bias,
dtype=layer.self_attn.o_proj.weight.dtype,
device=layer.self_attn.o_proj.weight.device if device is None else device,
empty_init=empty_init
)
layer.mlp.gate_proj = QuantizedLinear(
weight_bit_width=weight_bit_width,
weight=layer.mlp.gate_proj.weight,
bias=layer.mlp.gate_proj.bias,
dtype=layer.mlp.gate_proj.weight.dtype,
device=layer.mlp.gate_proj.weight.device if device is None else device,
empty_init=empty_init
)
layer.mlp.down_proj = QuantizedLinear(
weight_bit_width=weight_bit_width,
weight=layer.mlp.down_proj.weight,
bias=layer.mlp.down_proj.bias,
dtype=layer.mlp.down_proj.weight.dtype,
device=layer.mlp.down_proj.weight.device if device is None else device,
empty_init=empty_init
)
layer.mlp.up_proj = QuantizedLinear(
weight_bit_width=weight_bit_width,
weight=layer.mlp.up_proj.weight,
bias=layer.mlp.up_proj.bias,
dtype=layer.mlp.up_proj.weight.dtype,
device=layer.mlp.up_proj.weight.device if device is None else device,
empty_init=empty_init
)
return model