|
import bz2
|
|
import torch
|
|
import base64
|
|
import ctypes
|
|
import os
|
|
import sys
|
|
import traceback
|
|
import math
|
|
from torch.nn.parameter import Parameter
|
|
from transformers.utils import logging
|
|
import ctypes
|
|
import pkg_resources
|
|
from typing import List
|
|
logger = logging.get_logger(__name__)
|
|
|
|
try:
|
|
import quant_cuda
|
|
except:
|
|
print('CUDA extension not installed.')
|
|
|
|
class QuantizedLinear(torch.nn.Module):
|
|
def __init__(self, weight_bit_width: int, weight, bias=None, device="cpu", dtype=None, empty_init=False, *args,
|
|
**kwargs):
|
|
super().__init__()
|
|
self.weight_bit_width = weight_bit_width
|
|
|
|
shape = weight.shape
|
|
self.shape = shape
|
|
self.group_size = 128
|
|
self.register_buffer('qzeros', torch.zeros((math.ceil(shape[1]/self.group_size),shape[0] // 256 * (weight_bit_width * 8)), dtype=torch.int))
|
|
self.register_buffer('scales', torch.zeros((math.ceil(shape[1]/self.group_size),shape[0]), dtype=torch.float))
|
|
self.register_buffer(
|
|
'qweight', torch.zeros((shape[1] // 256 * (weight_bit_width * 8), shape[0]), dtype=torch.int)
|
|
)
|
|
|
|
def forward(self, x):
|
|
intermediate_dtype = torch.float32
|
|
outshape = list(x.shape)
|
|
outshape[-1] = self.shape[0]
|
|
x = x.reshape(-1, x.shape[-1])
|
|
y = torch.zeros(x.shape[0], outshape[-1], dtype=intermediate_dtype, device=x.device)
|
|
output_dtype = x.dtype
|
|
x = x.to(intermediate_dtype)
|
|
if self.weight_bit_width == 2:
|
|
quant_cuda.vecquant2matmul(x, self.qweight, y, self.scales, self.qzeros, self.group_size)
|
|
elif self.weight_bit_width == 3:
|
|
quant_cuda.vecquant3matmul(x, self.qweight, y, self.scales, self.qzeros, self.group_size)
|
|
elif self.weight_bit_width == 4:
|
|
quant_cuda.vecquant4matmul(x, self.qweight, y, self.scales, self.qzeros, self.group_size)
|
|
elif self.weight_bit_width == 8:
|
|
quant_cuda.vecquant8matmul(x, self.qweight, y, self.scales, self.qzeros, self.group_size)
|
|
else:
|
|
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
|
|
y = y.to(output_dtype)
|
|
return y.reshape(outshape)
|
|
|
|
def quantize(model, weight_bit_width, empty_init=False, device=None):
|
|
for layer in model.layers:
|
|
layer.self_attn.q_proj = QuantizedLinear(
|
|
weight_bit_width=weight_bit_width,
|
|
weight=layer.self_attn.q_proj.weight,
|
|
bias=layer.self_attn.q_proj.bias,
|
|
dtype=layer.self_attn.q_proj.weight.dtype,
|
|
device=layer.self_attn.q_proj.weight.device if device is None else device,
|
|
empty_init=empty_init
|
|
)
|
|
layer.self_attn.k_proj = QuantizedLinear(
|
|
weight_bit_width=weight_bit_width,
|
|
weight=layer.self_attn.k_proj.weight,
|
|
bias=layer.self_attn.k_proj.bias,
|
|
dtype=layer.self_attn.k_proj.weight.dtype,
|
|
device=layer.self_attn.k_proj.weight.device if device is None else device,
|
|
empty_init=empty_init
|
|
)
|
|
layer.self_attn.v_proj = QuantizedLinear(
|
|
weight_bit_width=weight_bit_width,
|
|
weight=layer.self_attn.v_proj.weight,
|
|
bias=layer.self_attn.v_proj.bias,
|
|
dtype=layer.self_attn.v_proj.weight.dtype,
|
|
device=layer.self_attn.v_proj.weight.device if device is None else device,
|
|
empty_init=empty_init
|
|
)
|
|
layer.self_attn.o_proj = QuantizedLinear(
|
|
weight_bit_width=weight_bit_width,
|
|
weight=layer.self_attn.o_proj.weight,
|
|
bias=layer.self_attn.o_proj.bias,
|
|
dtype=layer.self_attn.o_proj.weight.dtype,
|
|
device=layer.self_attn.o_proj.weight.device if device is None else device,
|
|
empty_init=empty_init
|
|
)
|
|
layer.mlp.gate_proj = QuantizedLinear(
|
|
weight_bit_width=weight_bit_width,
|
|
weight=layer.mlp.gate_proj.weight,
|
|
bias=layer.mlp.gate_proj.bias,
|
|
dtype=layer.mlp.gate_proj.weight.dtype,
|
|
device=layer.mlp.gate_proj.weight.device if device is None else device,
|
|
empty_init=empty_init
|
|
)
|
|
layer.mlp.down_proj = QuantizedLinear(
|
|
weight_bit_width=weight_bit_width,
|
|
weight=layer.mlp.down_proj.weight,
|
|
bias=layer.mlp.down_proj.bias,
|
|
dtype=layer.mlp.down_proj.weight.dtype,
|
|
device=layer.mlp.down_proj.weight.device if device is None else device,
|
|
empty_init=empty_init
|
|
)
|
|
layer.mlp.up_proj = QuantizedLinear(
|
|
weight_bit_width=weight_bit_width,
|
|
weight=layer.mlp.up_proj.weight,
|
|
bias=layer.mlp.up_proj.bias,
|
|
dtype=layer.mlp.up_proj.weight.dtype,
|
|
device=layer.mlp.up_proj.weight.device if device is None else device,
|
|
empty_init=empty_init
|
|
)
|
|
|
|
return model
|
|
|