Commit
•
a5ba550
1
Parent(s):
fa5a1b0
修改 quantization.py 中待量化权重的移动逻辑 (#47)
Browse files- 修改 quantization.py 中待量化权重的移动逻辑 (4e17ef37342ef49abe1bfcc54cd43ded06b51347)
Co-authored-by: Miku <[email protected]>
- quantization.py +6 -5
quantization.py
CHANGED
@@ -125,8 +125,9 @@ class QuantizedLinear(torch.nn.Module):
|
|
125 |
def __init__(self, weight_bit_width: int, weight, bias=None, device="cpu", dtype=None, empty_init=False, *args,
|
126 |
**kwargs):
|
127 |
super().__init__()
|
|
|
|
|
128 |
self.weight_bit_width = weight_bit_width
|
129 |
-
|
130 |
shape = weight.shape
|
131 |
|
132 |
if weight is None or empty_init:
|
@@ -154,7 +155,7 @@ def quantize(model, weight_bit_width, empty_init=False, device=None):
|
|
154 |
for layer in model.layers:
|
155 |
layer.self_attention.query_key_value = QuantizedLinear(
|
156 |
weight_bit_width=weight_bit_width,
|
157 |
-
weight=layer.self_attention.query_key_value.weight
|
158 |
bias=layer.self_attention.query_key_value.bias,
|
159 |
dtype=layer.self_attention.query_key_value.weight.dtype,
|
160 |
device=layer.self_attention.query_key_value.weight.device if device is None else device,
|
@@ -162,7 +163,7 @@ def quantize(model, weight_bit_width, empty_init=False, device=None):
|
|
162 |
)
|
163 |
layer.self_attention.dense = QuantizedLinear(
|
164 |
weight_bit_width=weight_bit_width,
|
165 |
-
weight=layer.self_attention.dense.weight
|
166 |
bias=layer.self_attention.dense.bias,
|
167 |
dtype=layer.self_attention.dense.weight.dtype,
|
168 |
device=layer.self_attention.dense.weight.device if device is None else device,
|
@@ -170,7 +171,7 @@ def quantize(model, weight_bit_width, empty_init=False, device=None):
|
|
170 |
)
|
171 |
layer.mlp.dense_h_to_4h = QuantizedLinear(
|
172 |
weight_bit_width=weight_bit_width,
|
173 |
-
weight=layer.mlp.dense_h_to_4h.weight
|
174 |
bias=layer.mlp.dense_h_to_4h.bias,
|
175 |
dtype=layer.mlp.dense_h_to_4h.weight.dtype,
|
176 |
device=layer.mlp.dense_h_to_4h.weight.device if device is None else device,
|
@@ -178,7 +179,7 @@ def quantize(model, weight_bit_width, empty_init=False, device=None):
|
|
178 |
)
|
179 |
layer.mlp.dense_4h_to_h = QuantizedLinear(
|
180 |
weight_bit_width=weight_bit_width,
|
181 |
-
weight=layer.mlp.dense_4h_to_h.weight
|
182 |
bias=layer.mlp.dense_4h_to_h.bias,
|
183 |
dtype=layer.mlp.dense_4h_to_h.weight.dtype,
|
184 |
device=layer.mlp.dense_4h_to_h.weight.device if device is None else device,
|
|
|
125 |
def __init__(self, weight_bit_width: int, weight, bias=None, device="cpu", dtype=None, empty_init=False, *args,
|
126 |
**kwargs):
|
127 |
super().__init__()
|
128 |
+
assert str(weight.device).startswith('cuda'), 'The weights that need to be quantified should be on the CUDA device'
|
129 |
+
|
130 |
self.weight_bit_width = weight_bit_width
|
|
|
131 |
shape = weight.shape
|
132 |
|
133 |
if weight is None or empty_init:
|
|
|
155 |
for layer in model.layers:
|
156 |
layer.self_attention.query_key_value = QuantizedLinear(
|
157 |
weight_bit_width=weight_bit_width,
|
158 |
+
weight=layer.self_attention.query_key_value.weight,
|
159 |
bias=layer.self_attention.query_key_value.bias,
|
160 |
dtype=layer.self_attention.query_key_value.weight.dtype,
|
161 |
device=layer.self_attention.query_key_value.weight.device if device is None else device,
|
|
|
163 |
)
|
164 |
layer.self_attention.dense = QuantizedLinear(
|
165 |
weight_bit_width=weight_bit_width,
|
166 |
+
weight=layer.self_attention.dense.weight,
|
167 |
bias=layer.self_attention.dense.bias,
|
168 |
dtype=layer.self_attention.dense.weight.dtype,
|
169 |
device=layer.self_attention.dense.weight.device if device is None else device,
|
|
|
171 |
)
|
172 |
layer.mlp.dense_h_to_4h = QuantizedLinear(
|
173 |
weight_bit_width=weight_bit_width,
|
174 |
+
weight=layer.mlp.dense_h_to_4h.weight,
|
175 |
bias=layer.mlp.dense_h_to_4h.bias,
|
176 |
dtype=layer.mlp.dense_h_to_4h.weight.dtype,
|
177 |
device=layer.mlp.dense_h_to_4h.weight.device if device is None else device,
|
|
|
179 |
)
|
180 |
layer.mlp.dense_4h_to_h = QuantizedLinear(
|
181 |
weight_bit_width=weight_bit_width,
|
182 |
+
weight=layer.mlp.dense_4h_to_h.weight,
|
183 |
bias=layer.mlp.dense_4h_to_h.bias,
|
184 |
dtype=layer.mlp.dense_4h_to_h.weight.dtype,
|
185 |
device=layer.mlp.dense_4h_to_h.weight.device if device is None else device,
|