File size: 5,026 Bytes
7dcb63b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import bz2
import torch
import base64
import ctypes
import os
import sys
import traceback
import math
from torch.nn.parameter import Parameter
from transformers.utils import logging
import ctypes
import pkg_resources
from typing import List
logger = logging.get_logger(__name__)
try:
import quant_cuda
except:
print('CUDA extension not installed.')
class QuantizedLinear(torch.nn.Module):
def __init__(self, weight_bit_width: int, weight, bias=None, device="cpu", dtype=None, empty_init=False, *args,
**kwargs):
super().__init__()
self.weight_bit_width = weight_bit_width
shape = weight.shape
self.shape = shape
self.group_size = 128
self.register_buffer('qzeros', torch.zeros((math.ceil(shape[1]/self.group_size),shape[0] // 256 * (weight_bit_width * 8)), dtype=torch.int))
self.register_buffer('scales', torch.zeros((math.ceil(shape[1]/self.group_size),shape[0]), dtype=torch.float))
self.register_buffer(
'qweight', torch.zeros((shape[1] // 256 * (weight_bit_width * 8), shape[0]), dtype=torch.int)
)
def forward(self, x):
intermediate_dtype = torch.float32
outshape = list(x.shape)
outshape[-1] = self.shape[0]
x = x.reshape(-1, x.shape[-1])
y = torch.zeros(x.shape[0], outshape[-1], dtype=intermediate_dtype, device=x.device)
output_dtype = x.dtype
x = x.to(intermediate_dtype)
if self.weight_bit_width == 2:
quant_cuda.vecquant2matmul(x, self.qweight, y, self.scales, self.qzeros, self.group_size)
elif self.weight_bit_width == 3:
quant_cuda.vecquant3matmul(x, self.qweight, y, self.scales, self.qzeros, self.group_size)
elif self.weight_bit_width == 4:
quant_cuda.vecquant4matmul(x, self.qweight, y, self.scales, self.qzeros, self.group_size)
elif self.weight_bit_width == 8:
quant_cuda.vecquant8matmul(x, self.qweight, y, self.scales, self.qzeros, self.group_size)
else:
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
y = y.to(output_dtype)
return y.reshape(outshape)
def quantize(model, weight_bit_width, empty_init=False, device=None):
for layer in model.layers:
layer.self_attn.q_proj = QuantizedLinear(
weight_bit_width=weight_bit_width,
weight=layer.self_attn.q_proj.weight,
bias=layer.self_attn.q_proj.bias,
dtype=layer.self_attn.q_proj.weight.dtype,
device=layer.self_attn.q_proj.weight.device if device is None else device,
empty_init=empty_init
)
layer.self_attn.k_proj = QuantizedLinear(
weight_bit_width=weight_bit_width,
weight=layer.self_attn.k_proj.weight,
bias=layer.self_attn.k_proj.bias,
dtype=layer.self_attn.k_proj.weight.dtype,
device=layer.self_attn.k_proj.weight.device if device is None else device,
empty_init=empty_init
)
layer.self_attn.v_proj = QuantizedLinear(
weight_bit_width=weight_bit_width,
weight=layer.self_attn.v_proj.weight,
bias=layer.self_attn.v_proj.bias,
dtype=layer.self_attn.v_proj.weight.dtype,
device=layer.self_attn.v_proj.weight.device if device is None else device,
empty_init=empty_init
)
layer.self_attn.o_proj = QuantizedLinear(
weight_bit_width=weight_bit_width,
weight=layer.self_attn.o_proj.weight,
bias=layer.self_attn.o_proj.bias,
dtype=layer.self_attn.o_proj.weight.dtype,
device=layer.self_attn.o_proj.weight.device if device is None else device,
empty_init=empty_init
)
layer.mlp.gate_proj = QuantizedLinear(
weight_bit_width=weight_bit_width,
weight=layer.mlp.gate_proj.weight,
bias=layer.mlp.gate_proj.bias,
dtype=layer.mlp.gate_proj.weight.dtype,
device=layer.mlp.gate_proj.weight.device if device is None else device,
empty_init=empty_init
)
layer.mlp.down_proj = QuantizedLinear(
weight_bit_width=weight_bit_width,
weight=layer.mlp.down_proj.weight,
bias=layer.mlp.down_proj.bias,
dtype=layer.mlp.down_proj.weight.dtype,
device=layer.mlp.down_proj.weight.device if device is None else device,
empty_init=empty_init
)
layer.mlp.up_proj = QuantizedLinear(
weight_bit_width=weight_bit_width,
weight=layer.mlp.up_proj.weight,
bias=layer.mlp.up_proj.bias,
dtype=layer.mlp.up_proj.weight.dtype,
device=layer.mlp.up_proj.weight.device if device is None else device,
empty_init=empty_init
)
return model
|