File size: 5,026 Bytes
7dcb63b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import bz2
import torch
import base64
import ctypes
import os
import sys
import traceback
import math
from torch.nn.parameter import Parameter
from transformers.utils import logging
import ctypes
import pkg_resources
from typing import List
logger = logging.get_logger(__name__)

try:
        import quant_cuda
except:
        print('CUDA extension not installed.')

class QuantizedLinear(torch.nn.Module):
    def __init__(self, weight_bit_width: int, weight, bias=None, device="cpu", dtype=None, empty_init=False, *args,

                 **kwargs):
        super().__init__()
        self.weight_bit_width = weight_bit_width

        shape = weight.shape
        self.shape = shape
        self.group_size = 128
        self.register_buffer('qzeros', torch.zeros((math.ceil(shape[1]/self.group_size),shape[0] // 256 * (weight_bit_width * 8)), dtype=torch.int))
        self.register_buffer('scales', torch.zeros((math.ceil(shape[1]/self.group_size),shape[0]), dtype=torch.float))
        self.register_buffer(
            'qweight', torch.zeros((shape[1] // 256 * (weight_bit_width * 8), shape[0]), dtype=torch.int)
        )

    def forward(self, x):
        intermediate_dtype = torch.float32
        outshape = list(x.shape)
        outshape[-1] = self.shape[0]
        x = x.reshape(-1, x.shape[-1])
        y = torch.zeros(x.shape[0], outshape[-1], dtype=intermediate_dtype, device=x.device)
        output_dtype = x.dtype
        x = x.to(intermediate_dtype)
        if self.weight_bit_width == 2:
            quant_cuda.vecquant2matmul(x, self.qweight, y, self.scales, self.qzeros, self.group_size)
        elif self.weight_bit_width == 3:
            quant_cuda.vecquant3matmul(x, self.qweight, y, self.scales, self.qzeros, self.group_size)
        elif self.weight_bit_width == 4:
            quant_cuda.vecquant4matmul(x, self.qweight, y, self.scales, self.qzeros, self.group_size)
        elif self.weight_bit_width == 8:
            quant_cuda.vecquant8matmul(x, self.qweight, y, self.scales, self.qzeros, self.group_size)
        else:
            raise NotImplementedError("Only 2,3,4,8 bits are supported.")
        y = y.to(output_dtype)
        return y.reshape(outshape)

def quantize(model, weight_bit_width, empty_init=False, device=None):
    for layer in model.layers:
        layer.self_attn.q_proj = QuantizedLinear(
            weight_bit_width=weight_bit_width,
            weight=layer.self_attn.q_proj.weight,
            bias=layer.self_attn.q_proj.bias,
            dtype=layer.self_attn.q_proj.weight.dtype,
            device=layer.self_attn.q_proj.weight.device if device is None else device,
            empty_init=empty_init
        )
        layer.self_attn.k_proj = QuantizedLinear(
            weight_bit_width=weight_bit_width,
            weight=layer.self_attn.k_proj.weight,
            bias=layer.self_attn.k_proj.bias,
            dtype=layer.self_attn.k_proj.weight.dtype,
            device=layer.self_attn.k_proj.weight.device if device is None else device,
            empty_init=empty_init
        )
        layer.self_attn.v_proj = QuantizedLinear(
            weight_bit_width=weight_bit_width,
            weight=layer.self_attn.v_proj.weight,
            bias=layer.self_attn.v_proj.bias,
            dtype=layer.self_attn.v_proj.weight.dtype,
            device=layer.self_attn.v_proj.weight.device if device is None else device,
            empty_init=empty_init
        )
        layer.self_attn.o_proj = QuantizedLinear(
            weight_bit_width=weight_bit_width,
            weight=layer.self_attn.o_proj.weight,
            bias=layer.self_attn.o_proj.bias,
            dtype=layer.self_attn.o_proj.weight.dtype,
            device=layer.self_attn.o_proj.weight.device if device is None else device,
            empty_init=empty_init
        )
        layer.mlp.gate_proj = QuantizedLinear(
            weight_bit_width=weight_bit_width,
            weight=layer.mlp.gate_proj.weight,
            bias=layer.mlp.gate_proj.bias,
            dtype=layer.mlp.gate_proj.weight.dtype,
            device=layer.mlp.gate_proj.weight.device if device is None else device,
            empty_init=empty_init
        )
        layer.mlp.down_proj = QuantizedLinear(
            weight_bit_width=weight_bit_width,
            weight=layer.mlp.down_proj.weight,
            bias=layer.mlp.down_proj.bias,
            dtype=layer.mlp.down_proj.weight.dtype,
            device=layer.mlp.down_proj.weight.device if device is None else device,
            empty_init=empty_init
        )
        layer.mlp.up_proj = QuantizedLinear(
            weight_bit_width=weight_bit_width,
            weight=layer.mlp.up_proj.weight,
            bias=layer.mlp.up_proj.bias,
            dtype=layer.mlp.up_proj.weight.dtype,
            device=layer.mlp.up_proj.weight.device if device is None else device,
            empty_init=empty_init
        )

    return model