Spaces:
Sleeping
Sleeping
Staticaliza
commited on
Upload 3 files
Browse files- modules/gpt_fast/generate.py +436 -0
- modules/gpt_fast/model.py +356 -0
- modules/gpt_fast/quantize.py +622 -0
modules/gpt_fast/generate.py
ADDED
@@ -0,0 +1,436 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
import itertools
|
7 |
+
import sys
|
8 |
+
import time
|
9 |
+
from pathlib import Path
|
10 |
+
from typing import Optional, Tuple
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch._dynamo.config
|
14 |
+
import torch._inductor.config
|
15 |
+
|
16 |
+
def device_sync(device):
|
17 |
+
if "cuda" in device:
|
18 |
+
torch.cuda.synchronize(device)
|
19 |
+
elif ("cpu" in device) or ("mps" in device):
|
20 |
+
pass
|
21 |
+
else:
|
22 |
+
print(f"device={device} is not yet suppported")
|
23 |
+
|
24 |
+
|
25 |
+
torch._inductor.config.coordinate_descent_tuning = True
|
26 |
+
torch._inductor.config.triton.unique_kernel_names = True
|
27 |
+
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
|
28 |
+
|
29 |
+
default_device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
30 |
+
|
31 |
+
# support running without installing as a package
|
32 |
+
wd = Path(__file__).parent.parent.resolve()
|
33 |
+
sys.path.append(str(wd))
|
34 |
+
|
35 |
+
from model import Transformer
|
36 |
+
from tokenizer import get_tokenizer
|
37 |
+
|
38 |
+
def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization
|
39 |
+
q = torch.empty_like(probs_sort).exponential_(1)
|
40 |
+
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
41 |
+
|
42 |
+
def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None):
|
43 |
+
logits = logits / max(temperature, 1e-5)
|
44 |
+
|
45 |
+
if top_k is not None:
|
46 |
+
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
47 |
+
pivot = v.select(-1, -1).unsqueeze(-1)
|
48 |
+
logits = torch.where(logits < pivot, -float("Inf"), logits)
|
49 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
50 |
+
return probs
|
51 |
+
|
52 |
+
def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
|
53 |
+
probs = logits_to_probs(logits[0, -1], temperature, top_k)
|
54 |
+
idx_next = multinomial_sample_one_no_sync(probs)
|
55 |
+
return idx_next, probs
|
56 |
+
|
57 |
+
def prefill(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> torch.Tensor:
|
58 |
+
# input_pos: [B, S]
|
59 |
+
logits = model(x, input_pos)
|
60 |
+
return sample(logits, **sampling_kwargs)[0]
|
61 |
+
|
62 |
+
def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
|
63 |
+
# input_pos: [B, 1]
|
64 |
+
assert input_pos.shape[-1] == 1
|
65 |
+
logits = model(x, input_pos)
|
66 |
+
return sample(logits, **sampling_kwargs)
|
67 |
+
|
68 |
+
def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, callback=lambda _: _, **sampling_kwargs):
|
69 |
+
new_tokens, new_probs = [], []
|
70 |
+
for i in range(num_new_tokens):
|
71 |
+
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here
|
72 |
+
next_token, next_prob = decode_one_token(
|
73 |
+
model, cur_token, input_pos, **sampling_kwargs
|
74 |
+
)
|
75 |
+
input_pos += 1
|
76 |
+
new_tokens.append(next_token.clone())
|
77 |
+
callback(new_tokens[-1])
|
78 |
+
new_probs.append(next_prob.clone())
|
79 |
+
cur_token = next_token.view(1, -1)
|
80 |
+
|
81 |
+
return new_tokens, new_probs
|
82 |
+
|
83 |
+
|
84 |
+
def model_forward(model, x, input_pos):
|
85 |
+
return model(x, input_pos)
|
86 |
+
|
87 |
+
def speculative_decode(
|
88 |
+
model: Transformer,
|
89 |
+
draft_model: Transformer,
|
90 |
+
cur_token: torch.Tensor,
|
91 |
+
input_pos: int,
|
92 |
+
speculate_k: int,
|
93 |
+
**sampling_kwargs
|
94 |
+
) -> torch.Tensor:
|
95 |
+
# draft model inference sequentially
|
96 |
+
device = cur_token.device
|
97 |
+
orig_input_pos = torch.tensor([input_pos], dtype=torch.int64, device=cur_token.device)
|
98 |
+
draft_tokens, draft_probs = decode_n_tokens(draft_model, cur_token.view(1, -1), orig_input_pos.clone(), speculate_k, **sampling_kwargs)
|
99 |
+
|
100 |
+
draft_tokens = torch.cat(draft_tokens)
|
101 |
+
# parallel inference on target model using draft tokens
|
102 |
+
target_logits = model_forward(
|
103 |
+
model,
|
104 |
+
torch.cat([cur_token.view(1), draft_tokens]).view(1, -1),
|
105 |
+
torch.arange(input_pos, input_pos + speculate_k + 1, device=cur_token.device)
|
106 |
+
)
|
107 |
+
target_probs = logits_to_probs(target_logits[0], **sampling_kwargs)
|
108 |
+
draft_probs = torch.stack(draft_probs)
|
109 |
+
# q: target prob, p: draft prob
|
110 |
+
# q >= p: always accept draft token
|
111 |
+
# q < p: q/p prob to accept draft token
|
112 |
+
p = draft_probs[torch.arange(0, speculate_k, device=device), draft_tokens]
|
113 |
+
q = target_probs[torch.arange(0, speculate_k, device=device), draft_tokens]
|
114 |
+
accept_draft_prob = torch.minimum(torch.ones(()), q[:speculate_k]/ p)
|
115 |
+
rejected_locations = (torch.rand_like(accept_draft_prob) > accept_draft_prob).nonzero()
|
116 |
+
|
117 |
+
if rejected_locations.shape[0] == 0: # All draft tokens have been accepted
|
118 |
+
accept_length = speculate_k + 1
|
119 |
+
last_token = multinomial_sample_one_no_sync(target_probs[-1])
|
120 |
+
# fill last token into draft model
|
121 |
+
model_forward(
|
122 |
+
draft_model,
|
123 |
+
draft_tokens[-1].view(1, -1),
|
124 |
+
orig_input_pos + speculate_k,
|
125 |
+
)
|
126 |
+
return torch.cat([draft_tokens, last_token])
|
127 |
+
else:
|
128 |
+
accept_length = rejected_locations[0].item()
|
129 |
+
p = draft_probs[accept_length]
|
130 |
+
q = target_probs[accept_length]
|
131 |
+
new = q - p
|
132 |
+
new = torch.where(new > 0, new, 0.0)
|
133 |
+
new = new / new.sum()
|
134 |
+
next_token = multinomial_sample_one_no_sync(new)
|
135 |
+
return torch.cat([draft_tokens[:accept_length], next_token])
|
136 |
+
|
137 |
+
@torch.no_grad()
|
138 |
+
def generate(
|
139 |
+
model: Transformer,
|
140 |
+
prompt: torch.Tensor,
|
141 |
+
max_new_tokens: int,
|
142 |
+
*,
|
143 |
+
interactive: bool,
|
144 |
+
draft_model: Transformer,
|
145 |
+
speculate_k: Optional[int] = 8,
|
146 |
+
callback = lambda x: x,
|
147 |
+
**sampling_kwargs
|
148 |
+
) -> torch.Tensor:
|
149 |
+
"""
|
150 |
+
Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
|
151 |
+
"""
|
152 |
+
|
153 |
+
is_speculative = draft_model is not None
|
154 |
+
# create an empty tensor of the expected final shape and fill in the current tokens
|
155 |
+
T = prompt.size(0)
|
156 |
+
T_new = T + max_new_tokens
|
157 |
+
if interactive:
|
158 |
+
max_seq_length = 350
|
159 |
+
else:
|
160 |
+
max_seq_length = min(T_new, model.config.block_size)
|
161 |
+
|
162 |
+
device, dtype = prompt.device, prompt.dtype
|
163 |
+
max_seq_length = max_seq_length + speculate_k + 1 if is_speculative else max_seq_length
|
164 |
+
with torch.device(device):
|
165 |
+
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
|
166 |
+
if is_speculative and draft_model is not model:
|
167 |
+
draft_model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
|
168 |
+
|
169 |
+
# create an empty tensor of the expected final shape and fill in the current tokens
|
170 |
+
empty = torch.empty(T_new, dtype=dtype, device=device)
|
171 |
+
empty[:T] = prompt
|
172 |
+
seq = empty
|
173 |
+
input_pos = torch.arange(0, T, device=device)
|
174 |
+
|
175 |
+
next_token = prefill(model, prompt.view(1, -1), input_pos, **sampling_kwargs).clone()
|
176 |
+
if is_speculative:
|
177 |
+
prefill(draft_model, prompt.view(1, -1), input_pos, **sampling_kwargs)
|
178 |
+
seq[T] = next_token
|
179 |
+
|
180 |
+
input_pos = torch.tensor([T], device=device, dtype=torch.int)
|
181 |
+
accept_counts = [0] * (speculate_k + 1)
|
182 |
+
|
183 |
+
if is_speculative:
|
184 |
+
input_pos = input_pos.item() # for speculative decoding easier to keep on host
|
185 |
+
while input_pos < T_new - 1:
|
186 |
+
cur_token = next_token.view(())
|
187 |
+
|
188 |
+
next_tokens = speculative_decode(
|
189 |
+
model, draft_model, cur_token, input_pos, speculate_k, **sampling_kwargs
|
190 |
+
)
|
191 |
+
|
192 |
+
accept_counts[len(next_tokens) - 1] += 1
|
193 |
+
num_added = min(T_new - input_pos - 1, len(next_tokens))
|
194 |
+
seq[input_pos + 1 : input_pos + num_added + 1] = next_tokens[: num_added]
|
195 |
+
for i in next_tokens[: num_added,]:
|
196 |
+
callback(i)
|
197 |
+
input_pos = input_pos + num_added
|
198 |
+
next_token = next_tokens[-1]
|
199 |
+
else:
|
200 |
+
generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs)
|
201 |
+
seq[T + 1:] = torch.cat(generated_tokens)
|
202 |
+
|
203 |
+
generate_stats = {
|
204 |
+
'accept_counts': accept_counts
|
205 |
+
}
|
206 |
+
return seq, generate_stats
|
207 |
+
|
208 |
+
def encode_tokens(tokenizer, string, bos=True, device=default_device):
|
209 |
+
tokens = tokenizer.encode(string)
|
210 |
+
if bos:
|
211 |
+
tokens = [tokenizer.bos_id()] + tokens
|
212 |
+
return torch.tensor(tokens, dtype=torch.int, device=device)
|
213 |
+
|
214 |
+
def _load_model(checkpoint_path, device, precision, use_tp):
|
215 |
+
use_cuda = 'cuda' in device
|
216 |
+
with torch.device('meta'):
|
217 |
+
model = Transformer.from_name(checkpoint_path.parent.name)
|
218 |
+
|
219 |
+
if "int8" in str(checkpoint_path):
|
220 |
+
print("Using int8 weight-only quantization!")
|
221 |
+
from quantize import WeightOnlyInt8QuantHandler
|
222 |
+
simple_quantizer = WeightOnlyInt8QuantHandler(model)
|
223 |
+
model = simple_quantizer.convert_for_runtime()
|
224 |
+
|
225 |
+
if "int4" in str(checkpoint_path):
|
226 |
+
print("Using int4 weight-only quantization!")
|
227 |
+
path_comps = checkpoint_path.name.split(".")
|
228 |
+
groupsize = int(path_comps[-2][1:])
|
229 |
+
from quantize import WeightOnlyInt4QuantHandler
|
230 |
+
simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
|
231 |
+
model = simple_quantizer.convert_for_runtime()
|
232 |
+
|
233 |
+
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
|
234 |
+
if "model" in checkpoint and "stories" in str(checkpoint_path):
|
235 |
+
checkpoint = checkpoint["model"]
|
236 |
+
model.load_state_dict(checkpoint, assign=True)
|
237 |
+
|
238 |
+
if use_tp:
|
239 |
+
from tp import apply_tp
|
240 |
+
print("Applying tensor parallel to model ...")
|
241 |
+
apply_tp(model)
|
242 |
+
|
243 |
+
model = model.to(device=device, dtype=precision)
|
244 |
+
return model.eval()
|
245 |
+
|
246 |
+
def _get_model_size(model):
|
247 |
+
model_size = 0
|
248 |
+
for name, child in model.named_children():
|
249 |
+
if not isinstance(child, torch.nn.Embedding):
|
250 |
+
model_size += sum(
|
251 |
+
[
|
252 |
+
p.numel() * p.dtype.itemsize
|
253 |
+
for p in itertools.chain(child.parameters(), child.buffers())
|
254 |
+
]
|
255 |
+
)
|
256 |
+
return model_size
|
257 |
+
|
258 |
+
B_INST, E_INST = "[INST]", "[/INST]"
|
259 |
+
|
260 |
+
def main(
|
261 |
+
prompt: str = "Hello, my name is",
|
262 |
+
interactive: bool = False,
|
263 |
+
num_samples: int = 5,
|
264 |
+
max_new_tokens: int = 100,
|
265 |
+
top_k: int = 200,
|
266 |
+
temperature: float = 0.8,
|
267 |
+
checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"),
|
268 |
+
compile: bool = True,
|
269 |
+
compile_prefill: bool = False,
|
270 |
+
profile: Optional[Path] = None,
|
271 |
+
draft_checkpoint_path: Optional[Path] = None,
|
272 |
+
speculate_k: int = 5,
|
273 |
+
device=default_device,
|
274 |
+
) -> None:
|
275 |
+
"""Generates text samples based on a pre-trained Transformer model and tokenizer.
|
276 |
+
"""
|
277 |
+
assert checkpoint_path.is_file(), checkpoint_path
|
278 |
+
|
279 |
+
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
|
280 |
+
assert tokenizer_path.is_file(), str(tokenizer_path)
|
281 |
+
|
282 |
+
global print
|
283 |
+
from tp import maybe_init_dist
|
284 |
+
rank = maybe_init_dist()
|
285 |
+
use_tp = rank is not None
|
286 |
+
if use_tp:
|
287 |
+
if rank != 0:
|
288 |
+
# only print on rank 0
|
289 |
+
print = lambda *args, **kwargs: None
|
290 |
+
|
291 |
+
print(f"Using device={device}")
|
292 |
+
precision = torch.bfloat16
|
293 |
+
is_speculative = draft_checkpoint_path is not None
|
294 |
+
is_chat = "chat" in str(checkpoint_path)
|
295 |
+
|
296 |
+
print("Loading model ...")
|
297 |
+
t0 = time.time()
|
298 |
+
model = _load_model(checkpoint_path, device, precision, use_tp)
|
299 |
+
|
300 |
+
if is_speculative:
|
301 |
+
draft_model = _load_model(draft_checkpoint_path, device, precision, use_tp)
|
302 |
+
else:
|
303 |
+
draft_model = None
|
304 |
+
|
305 |
+
device_sync(device=device) # MKG
|
306 |
+
print(f"Time to load model: {time.time() - t0:.02f} seconds")
|
307 |
+
|
308 |
+
tokenizer = get_tokenizer(tokenizer_path, checkpoint_path)
|
309 |
+
|
310 |
+
encoded = encode_tokens(tokenizer, prompt, bos=True, device=device)
|
311 |
+
prompt_length = encoded.size(0)
|
312 |
+
|
313 |
+
torch.manual_seed(1234)
|
314 |
+
model_size = _get_model_size(model)
|
315 |
+
if compile:
|
316 |
+
if is_speculative and use_tp: # and ("cuda" in device):
|
317 |
+
torch._inductor.config.triton.cudagraph_trees = False # Bug with cudagraph trees in this case
|
318 |
+
|
319 |
+
if is_speculative:
|
320 |
+
global model_forward, logits_to_prob
|
321 |
+
model_forward = torch.compile(model_forward, mode="reduce-overhead", fullgraph=True)
|
322 |
+
|
323 |
+
global decode_one_token, prefill
|
324 |
+
decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True)
|
325 |
+
|
326 |
+
# Uncomment to squeeze more perf out of prefill
|
327 |
+
if compile_prefill:
|
328 |
+
prefill = torch.compile(prefill, fullgraph=True, dynamic=True)
|
329 |
+
|
330 |
+
|
331 |
+
aggregate_metrics = {
|
332 |
+
'tokens_per_sec': [],
|
333 |
+
'accept_counts': [],
|
334 |
+
}
|
335 |
+
start = -1 if compile else 0
|
336 |
+
|
337 |
+
for i in range(start, num_samples):
|
338 |
+
device_sync(device=device) # MKG
|
339 |
+
if i >= 0 and interactive:
|
340 |
+
prompt = input("What is your prompt? ")
|
341 |
+
if is_chat:
|
342 |
+
prompt = f"{B_INST} {prompt.strip()} {E_INST}"
|
343 |
+
encoded = encode_tokens(tokenizer, prompt, bos=True, device=device)
|
344 |
+
|
345 |
+
if interactive and i >= 0:
|
346 |
+
buffer = []
|
347 |
+
period_id = tokenizer.encode('.')[0]
|
348 |
+
done_generating = False
|
349 |
+
def callback(x):
|
350 |
+
nonlocal done_generating
|
351 |
+
if done_generating:
|
352 |
+
return
|
353 |
+
buffer.append(tokenizer.decode([period_id] + x.tolist())[1:])
|
354 |
+
if x.item() == tokenizer.eos_id():
|
355 |
+
done_generating = True
|
356 |
+
if len(buffer) == 4 or done_generating:
|
357 |
+
print(''.join(buffer), end='', flush=True)
|
358 |
+
buffer.clear()
|
359 |
+
# print(, end='', flush=True)
|
360 |
+
else:
|
361 |
+
callback = lambda x : x
|
362 |
+
t0 = time.perf_counter()
|
363 |
+
import contextlib
|
364 |
+
if (i != num_samples - 1 or not profile) or (use_tp and rank != 0):
|
365 |
+
prof = contextlib.nullcontext()
|
366 |
+
else:
|
367 |
+
torch.profiler._utils._init_for_cuda_graphs()
|
368 |
+
prof = torch.profiler.profile()
|
369 |
+
with prof:
|
370 |
+
y, metrics = generate(
|
371 |
+
model,
|
372 |
+
encoded,
|
373 |
+
max_new_tokens,
|
374 |
+
draft_model=draft_model,
|
375 |
+
speculate_k=speculate_k,
|
376 |
+
interactive=interactive,
|
377 |
+
callback=callback,
|
378 |
+
temperature=temperature,
|
379 |
+
top_k=top_k,
|
380 |
+
)
|
381 |
+
aggregate_metrics['accept_counts'].append(metrics['accept_counts'])
|
382 |
+
if i == -1:
|
383 |
+
print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
|
384 |
+
continue
|
385 |
+
if hasattr(prof, "export_chrome_trace"):
|
386 |
+
if use_tp:
|
387 |
+
prof.export_chrome_trace(f"{profile}_rank_{rank}.json")
|
388 |
+
else:
|
389 |
+
prof.export_chrome_trace(f"{profile}.json")
|
390 |
+
device_sync(device=device) # MKG
|
391 |
+
t = time.perf_counter() - t0
|
392 |
+
|
393 |
+
if not interactive:
|
394 |
+
print(tokenizer.decode(y.tolist()))
|
395 |
+
else:
|
396 |
+
print()
|
397 |
+
tokens_generated = y.size(0) - prompt_length
|
398 |
+
tokens_sec = tokens_generated / t
|
399 |
+
aggregate_metrics['tokens_per_sec'].append(tokens_sec)
|
400 |
+
print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec")
|
401 |
+
print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s")
|
402 |
+
print("==========")
|
403 |
+
if is_speculative:
|
404 |
+
counts_aggregated = [sum(i) for i in zip(*aggregate_metrics['accept_counts'])]
|
405 |
+
acceptance_probs = [i/sum(counts_aggregated) for i in counts_aggregated]
|
406 |
+
print(f"Acceptance probs: {acceptance_probs}")
|
407 |
+
print(f"Mean Accepted: {sum([idx * i for idx, i in enumerate(counts_aggregated)])/sum(counts_aggregated)}")
|
408 |
+
|
409 |
+
print(f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}")
|
410 |
+
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
|
411 |
+
|
412 |
+
|
413 |
+
if __name__ == '__main__':
|
414 |
+
import argparse
|
415 |
+
parser = argparse.ArgumentParser(description='Your CLI description.')
|
416 |
+
|
417 |
+
parser.add_argument('--prompt', type=str, default="Hello, my name is", help='Input prompt.')
|
418 |
+
parser.add_argument('--interactive', action='store_true', help='Whether to launch in interactive mode')
|
419 |
+
parser.add_argument('--num_samples', type=int, default=5, help='Number of samples.')
|
420 |
+
parser.add_argument('--max_new_tokens', type=int, default=200, help='Maximum number of new tokens.')
|
421 |
+
parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.')
|
422 |
+
parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.')
|
423 |
+
parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), help='Model checkpoint path.')
|
424 |
+
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
|
425 |
+
parser.add_argument('--compile_prefill', action='store_true', help='Whether to compile the prefill (improves prefill perf, but higher compile times)')
|
426 |
+
parser.add_argument('--profile', type=Path, default=None, help='Profile path.')
|
427 |
+
parser.add_argument('--speculate_k', type=int, default=5, help='Speculative execution depth.')
|
428 |
+
parser.add_argument('--draft_checkpoint_path', type=Path, default=None, help='Draft checkpoint path.')
|
429 |
+
parser.add_argument('--device', type=str, default=default_device, help='Device to use')
|
430 |
+
|
431 |
+
args = parser.parse_args()
|
432 |
+
main(
|
433 |
+
args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k,
|
434 |
+
args.temperature, args.checkpoint_path, args.compile, args.compile_prefill, args.profile, args.draft_checkpoint_path,
|
435 |
+
args.speculate_k, args.device
|
436 |
+
)
|
modules/gpt_fast/model.py
ADDED
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
from dataclasses import dataclass
|
7 |
+
from typing import Optional
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from torch import Tensor
|
12 |
+
from torch.nn import functional as F
|
13 |
+
|
14 |
+
|
15 |
+
def find_multiple(n: int, k: int) -> int:
|
16 |
+
if n % k == 0:
|
17 |
+
return n
|
18 |
+
return n + k - (n % k)
|
19 |
+
|
20 |
+
class AdaptiveLayerNorm(nn.Module):
|
21 |
+
r"""Adaptive Layer Normalization"""
|
22 |
+
|
23 |
+
def __init__(self, d_model, norm) -> None:
|
24 |
+
super(AdaptiveLayerNorm, self).__init__()
|
25 |
+
self.project_layer = nn.Linear(d_model, 2 * d_model)
|
26 |
+
self.norm = norm
|
27 |
+
self.d_model = d_model
|
28 |
+
self.eps = self.norm.eps
|
29 |
+
|
30 |
+
def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
|
31 |
+
if embedding is None:
|
32 |
+
return self.norm(input)
|
33 |
+
weight, bias = torch.split(
|
34 |
+
self.project_layer(embedding),
|
35 |
+
split_size_or_sections=self.d_model,
|
36 |
+
dim=-1,
|
37 |
+
)
|
38 |
+
return weight * self.norm(input) + bias
|
39 |
+
|
40 |
+
|
41 |
+
@dataclass
|
42 |
+
class ModelArgs:
|
43 |
+
block_size: int = 2048
|
44 |
+
vocab_size: int = 32000
|
45 |
+
n_layer: int = 32
|
46 |
+
n_head: int = 32
|
47 |
+
dim: int = 4096
|
48 |
+
intermediate_size: int = None
|
49 |
+
n_local_heads: int = -1
|
50 |
+
head_dim: int = 64
|
51 |
+
rope_base: float = 10000
|
52 |
+
norm_eps: float = 1e-5
|
53 |
+
has_cross_attention: bool = False
|
54 |
+
context_dim: int = 0
|
55 |
+
uvit_skip_connection: bool = False
|
56 |
+
|
57 |
+
def __post_init__(self):
|
58 |
+
if self.n_local_heads == -1:
|
59 |
+
self.n_local_heads = self.n_head
|
60 |
+
if self.intermediate_size is None:
|
61 |
+
hidden_dim = 4 * self.dim
|
62 |
+
n_hidden = int(2 * hidden_dim / 3)
|
63 |
+
self.intermediate_size = find_multiple(n_hidden, 256)
|
64 |
+
# self.head_dim = self.dim // self.n_head
|
65 |
+
|
66 |
+
@classmethod
|
67 |
+
def from_name(cls, name: str):
|
68 |
+
if name in transformer_configs:
|
69 |
+
return cls(**transformer_configs[name])
|
70 |
+
# fuzzy search
|
71 |
+
config = [config for config in transformer_configs if config.lower() in str(name).lower()]
|
72 |
+
|
73 |
+
# We may have two or more configs matched (e.g. "7B" and "Mistral-7B"). Find the best config match,
|
74 |
+
# take longer name (as it have more symbols matched)
|
75 |
+
if len(config) > 1:
|
76 |
+
config.sort(key=len, reverse=True)
|
77 |
+
assert len(config[0]) != len(config[1]), name # make sure only one 'best' match
|
78 |
+
|
79 |
+
return cls(**transformer_configs[config[0]])
|
80 |
+
|
81 |
+
|
82 |
+
transformer_configs = {
|
83 |
+
"CodeLlama-7b-Python-hf": dict(block_size=16384, vocab_size=32000, n_layer=32, dim=4096, rope_base=1000000),
|
84 |
+
"7B": dict(n_layer=32, n_head=32, dim=4096),
|
85 |
+
"13B": dict(n_layer=40, n_head=40, dim=5120),
|
86 |
+
"30B": dict(n_layer=60, n_head=52, dim=6656),
|
87 |
+
"34B": dict(n_layer=48, n_head=64, dim=8192, vocab_size=32000, n_local_heads=8, intermediate_size=22016,
|
88 |
+
rope_base=1000000), # CodeLlama-34B-Python-hf
|
89 |
+
"70B": dict(n_layer=80, n_head=64, dim=8192, n_local_heads=8, intermediate_size=28672),
|
90 |
+
"Mistral-7B": dict(n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=32000),
|
91 |
+
"stories15M": dict(n_layer=6, n_head=6, dim=288),
|
92 |
+
"stories110M": dict(n_layer=12, n_head=12, dim=768),
|
93 |
+
|
94 |
+
"llama-3-8b": dict(block_size=8192, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336,
|
95 |
+
vocab_size=128256, rope_base=500000),
|
96 |
+
"llama-3-70b": dict(block_size=8192, n_layer=80, n_head=64, n_local_heads=8, dim=8192, intermediate_size=28672,
|
97 |
+
vocab_size=128256, rope_base=500000),
|
98 |
+
}
|
99 |
+
|
100 |
+
|
101 |
+
class KVCache(nn.Module):
|
102 |
+
def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16):
|
103 |
+
super().__init__()
|
104 |
+
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
|
105 |
+
self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype))
|
106 |
+
self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype))
|
107 |
+
|
108 |
+
def update(self, input_pos, k_val, v_val):
|
109 |
+
# input_pos: [S], k_val: [B, H, S, D]
|
110 |
+
assert input_pos.shape[0] == k_val.shape[2]
|
111 |
+
|
112 |
+
k_out = self.k_cache
|
113 |
+
v_out = self.v_cache
|
114 |
+
k_out[:, :, input_pos] = k_val
|
115 |
+
v_out[:, :, input_pos] = v_val
|
116 |
+
|
117 |
+
return k_out, v_out
|
118 |
+
|
119 |
+
|
120 |
+
class Transformer(nn.Module):
|
121 |
+
def __init__(self, config: ModelArgs) -> None:
|
122 |
+
super().__init__()
|
123 |
+
self.config = config
|
124 |
+
|
125 |
+
self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer))
|
126 |
+
self.norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
|
127 |
+
|
128 |
+
self.freqs_cis: Optional[Tensor] = None
|
129 |
+
self.mask_cache: Optional[Tensor] = None
|
130 |
+
self.max_batch_size = -1
|
131 |
+
self.max_seq_length = -1
|
132 |
+
|
133 |
+
def setup_caches(self, max_batch_size, max_seq_length, use_kv_cache=True):
|
134 |
+
if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
|
135 |
+
return
|
136 |
+
head_dim = self.config.dim // self.config.n_head
|
137 |
+
max_seq_length = find_multiple(max_seq_length, 8)
|
138 |
+
self.max_seq_length = max_seq_length
|
139 |
+
self.max_batch_size = max_batch_size
|
140 |
+
dtype = self.norm.project_layer.weight.dtype
|
141 |
+
device = self.norm.project_layer.weight.device
|
142 |
+
|
143 |
+
if not self.training and use_kv_cache:
|
144 |
+
for b in self.layers:
|
145 |
+
b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype).to(device)
|
146 |
+
|
147 |
+
self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.head_dim,
|
148 |
+
self.config.rope_base, dtype).to(device)
|
149 |
+
self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)).to(device)
|
150 |
+
self.use_kv_cache = use_kv_cache
|
151 |
+
self.uvit_skip_connection = self.config.uvit_skip_connection
|
152 |
+
if self.uvit_skip_connection:
|
153 |
+
self.layers_emit_skip = [i for i in range(self.config.n_layer) if i < self.config.n_layer // 2]
|
154 |
+
self.layers_receive_skip = [i for i in range(self.config.n_layer) if i > self.config.n_layer // 2]
|
155 |
+
else:
|
156 |
+
self.layers_emit_skip = []
|
157 |
+
self.layers_receive_skip = []
|
158 |
+
|
159 |
+
def forward(self,
|
160 |
+
x: Tensor,
|
161 |
+
c: Tensor,
|
162 |
+
input_pos: Optional[Tensor] = None,
|
163 |
+
mask: Optional[Tensor] = None,
|
164 |
+
context: Optional[Tensor] = None,
|
165 |
+
context_input_pos: Optional[Tensor] = None,
|
166 |
+
cross_attention_mask: Optional[Tensor] = None,
|
167 |
+
) -> Tensor:
|
168 |
+
assert self.freqs_cis is not None, "Caches must be initialized first"
|
169 |
+
if mask is None: # in case of non-causal model
|
170 |
+
if not self.training and self.use_kv_cache:
|
171 |
+
mask = self.causal_mask[None, None, input_pos]
|
172 |
+
else:
|
173 |
+
mask = self.causal_mask[None, None, input_pos]
|
174 |
+
mask = mask[..., input_pos]
|
175 |
+
freqs_cis = self.freqs_cis[input_pos]
|
176 |
+
if context is not None:
|
177 |
+
context_freqs_cis = self.freqs_cis[context_input_pos]
|
178 |
+
else:
|
179 |
+
context_freqs_cis = None
|
180 |
+
skip_in_x_list = []
|
181 |
+
for i, layer in enumerate(self.layers):
|
182 |
+
if self.uvit_skip_connection and i in self.layers_receive_skip:
|
183 |
+
skip_in_x = skip_in_x_list.pop(-1)
|
184 |
+
else:
|
185 |
+
skip_in_x = None
|
186 |
+
x = layer(x, c, input_pos, freqs_cis, mask, context, context_freqs_cis, cross_attention_mask, skip_in_x)
|
187 |
+
if self.uvit_skip_connection and i in self.layers_emit_skip:
|
188 |
+
skip_in_x_list.append(x)
|
189 |
+
x = self.norm(x, c)
|
190 |
+
return x
|
191 |
+
|
192 |
+
@classmethod
|
193 |
+
def from_name(cls, name: str):
|
194 |
+
return cls(ModelArgs.from_name(name))
|
195 |
+
|
196 |
+
|
197 |
+
class TransformerBlock(nn.Module):
|
198 |
+
def __init__(self, config: ModelArgs) -> None:
|
199 |
+
super().__init__()
|
200 |
+
self.attention = Attention(config)
|
201 |
+
self.feed_forward = FeedForward(config)
|
202 |
+
self.ffn_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
|
203 |
+
self.attention_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
|
204 |
+
|
205 |
+
if config.has_cross_attention:
|
206 |
+
self.has_cross_attention = True
|
207 |
+
self.cross_attention = Attention(config, is_cross_attention=True)
|
208 |
+
self.cross_attention_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
|
209 |
+
else:
|
210 |
+
self.has_cross_attention = False
|
211 |
+
|
212 |
+
if config.uvit_skip_connection:
|
213 |
+
self.skip_in_linear = nn.Linear(config.dim * 2, config.dim)
|
214 |
+
self.uvit_skip_connection = True
|
215 |
+
else:
|
216 |
+
self.uvit_skip_connection = False
|
217 |
+
|
218 |
+
def forward(self,
|
219 |
+
x: Tensor,
|
220 |
+
c: Tensor,
|
221 |
+
input_pos: Tensor,
|
222 |
+
freqs_cis: Tensor,
|
223 |
+
mask: Tensor,
|
224 |
+
context: Optional[Tensor] = None,
|
225 |
+
context_freqs_cis: Optional[Tensor] = None,
|
226 |
+
cross_attention_mask: Optional[Tensor] = None,
|
227 |
+
skip_in_x: Optional[Tensor] = None,
|
228 |
+
) -> Tensor:
|
229 |
+
if self.uvit_skip_connection and skip_in_x is not None:
|
230 |
+
x = self.skip_in_linear(torch.cat([x, skip_in_x], dim=-1))
|
231 |
+
h = x + self.attention(self.attention_norm(x, c), freqs_cis, mask, input_pos)
|
232 |
+
if self.has_cross_attention:
|
233 |
+
h = h + self.cross_attention(self.cross_attention_norm(h, c), freqs_cis, cross_attention_mask, input_pos, context, context_freqs_cis)
|
234 |
+
out = h + self.feed_forward(self.ffn_norm(h, c))
|
235 |
+
return out
|
236 |
+
|
237 |
+
|
238 |
+
class Attention(nn.Module):
|
239 |
+
def __init__(self, config: ModelArgs, is_cross_attention: bool = False):
|
240 |
+
super().__init__()
|
241 |
+
assert config.dim % config.n_head == 0
|
242 |
+
|
243 |
+
total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
|
244 |
+
# key, query, value projections for all heads, but in a batch
|
245 |
+
if is_cross_attention:
|
246 |
+
self.wq = nn.Linear(config.dim, config.n_head * config.head_dim, bias=False)
|
247 |
+
self.wkv = nn.Linear(config.context_dim, 2 * config.n_local_heads * config.head_dim, bias=False)
|
248 |
+
else:
|
249 |
+
self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
|
250 |
+
self.wo = nn.Linear(config.head_dim * config.n_head, config.dim, bias=False)
|
251 |
+
self.kv_cache = None
|
252 |
+
|
253 |
+
self.n_head = config.n_head
|
254 |
+
self.head_dim = config.head_dim
|
255 |
+
self.n_local_heads = config.n_local_heads
|
256 |
+
self.dim = config.dim
|
257 |
+
# self._register_load_state_dict_pre_hook(self.load_hook)
|
258 |
+
|
259 |
+
# def load_hook(self, state_dict, prefix, *args):
|
260 |
+
# if prefix + "wq.weight" in state_dict:
|
261 |
+
# wq = state_dict.pop(prefix + "wq.weight")
|
262 |
+
# wk = state_dict.pop(prefix + "wk.weight")
|
263 |
+
# wv = state_dict.pop(prefix + "wv.weight")
|
264 |
+
# state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
|
265 |
+
|
266 |
+
def forward(self,
|
267 |
+
x: Tensor,
|
268 |
+
freqs_cis: Tensor,
|
269 |
+
mask: Tensor,
|
270 |
+
input_pos: Optional[Tensor] = None,
|
271 |
+
context: Optional[Tensor] = None,
|
272 |
+
context_freqs_cis: Optional[Tensor] = None,
|
273 |
+
) -> Tensor:
|
274 |
+
bsz, seqlen, _ = x.shape
|
275 |
+
|
276 |
+
kv_size = self.n_local_heads * self.head_dim
|
277 |
+
if context is None:
|
278 |
+
q, k, v = self.wqkv(x).split([kv_size, kv_size, kv_size], dim=-1)
|
279 |
+
context_seqlen = seqlen
|
280 |
+
else:
|
281 |
+
q = self.wq(x)
|
282 |
+
k, v = self.wkv(context).split([kv_size, kv_size], dim=-1)
|
283 |
+
context_seqlen = context.shape[1]
|
284 |
+
|
285 |
+
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
|
286 |
+
k = k.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
|
287 |
+
v = v.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
|
288 |
+
|
289 |
+
q = apply_rotary_emb(q, freqs_cis)
|
290 |
+
k = apply_rotary_emb(k, context_freqs_cis if context_freqs_cis is not None else freqs_cis)
|
291 |
+
|
292 |
+
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
|
293 |
+
|
294 |
+
if self.kv_cache is not None:
|
295 |
+
k, v = self.kv_cache.update(input_pos, k, v)
|
296 |
+
|
297 |
+
k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
298 |
+
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
299 |
+
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
|
300 |
+
|
301 |
+
y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.head_dim * self.n_head)
|
302 |
+
|
303 |
+
y = self.wo(y)
|
304 |
+
return y
|
305 |
+
|
306 |
+
|
307 |
+
class FeedForward(nn.Module):
|
308 |
+
def __init__(self, config: ModelArgs) -> None:
|
309 |
+
super().__init__()
|
310 |
+
self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
|
311 |
+
self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
|
312 |
+
self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
|
313 |
+
|
314 |
+
def forward(self, x: Tensor) -> Tensor:
|
315 |
+
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
316 |
+
|
317 |
+
|
318 |
+
class RMSNorm(nn.Module):
|
319 |
+
def __init__(self, dim: int, eps: float = 1e-5):
|
320 |
+
super().__init__()
|
321 |
+
self.eps = eps
|
322 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
323 |
+
|
324 |
+
def _norm(self, x):
|
325 |
+
return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
|
326 |
+
|
327 |
+
def forward(self, x: Tensor) -> Tensor:
|
328 |
+
output = self._norm(x.float()).type_as(x)
|
329 |
+
return output * self.weight
|
330 |
+
|
331 |
+
|
332 |
+
def precompute_freqs_cis(
|
333 |
+
seq_len: int, n_elem: int, base: int = 10000,
|
334 |
+
dtype: torch.dtype = torch.bfloat16
|
335 |
+
) -> Tensor:
|
336 |
+
freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
|
337 |
+
t = torch.arange(seq_len, device=freqs.device)
|
338 |
+
freqs = torch.outer(t, freqs)
|
339 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
340 |
+
cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
|
341 |
+
return cache.to(dtype=dtype)
|
342 |
+
|
343 |
+
|
344 |
+
def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
|
345 |
+
xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
|
346 |
+
freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
|
347 |
+
x_out2 = torch.stack(
|
348 |
+
[
|
349 |
+
xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
|
350 |
+
xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
|
351 |
+
],
|
352 |
+
-1,
|
353 |
+
)
|
354 |
+
|
355 |
+
x_out2 = x_out2.flatten(3)
|
356 |
+
return x_out2.type_as(x)
|
modules/gpt_fast/quantize.py
ADDED
@@ -0,0 +1,622 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
import time
|
7 |
+
from pathlib import Path
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from tokenizer import get_tokenizer
|
13 |
+
|
14 |
+
try:
|
15 |
+
from GPTQ import GenericGPTQRunner, InputRecorder
|
16 |
+
from eval import get_task_dict, evaluate, lm_eval
|
17 |
+
except:
|
18 |
+
pass
|
19 |
+
|
20 |
+
from model import Transformer
|
21 |
+
|
22 |
+
##### Quantization Primitives ######
|
23 |
+
|
24 |
+
def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):
|
25 |
+
# assumes symmetric quantization
|
26 |
+
# assumes axis == 0
|
27 |
+
# assumes dense memory format
|
28 |
+
# TODO(future): relax ^ as needed
|
29 |
+
|
30 |
+
# default setup for affine quantization of activations
|
31 |
+
eps = torch.finfo(torch.float32).eps
|
32 |
+
|
33 |
+
# get min and max
|
34 |
+
min_val, max_val = torch.aminmax(x, dim=1)
|
35 |
+
|
36 |
+
# calculate scales and zero_points based on min and max
|
37 |
+
# reference: https://fburl.com/code/srbiybme
|
38 |
+
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
|
39 |
+
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
|
40 |
+
device = min_val_neg.device
|
41 |
+
|
42 |
+
# reference: https://fburl.com/code/4wll53rk
|
43 |
+
max_val_pos = torch.max(-min_val_neg, max_val_pos)
|
44 |
+
scales = max_val_pos / (float(quant_max - quant_min) / 2)
|
45 |
+
# ensure scales is the same dtype as the original tensor
|
46 |
+
scales = torch.clamp(scales, min=eps).to(x.dtype)
|
47 |
+
zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
|
48 |
+
|
49 |
+
# quantize based on qmin/qmax/scales/zp
|
50 |
+
# reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63
|
51 |
+
x_div = x / scales.unsqueeze(-1)
|
52 |
+
x_round = torch.round(x_div)
|
53 |
+
x_zp = x_round + zero_points.unsqueeze(-1)
|
54 |
+
quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype)
|
55 |
+
|
56 |
+
return quant, scales, zero_points
|
57 |
+
|
58 |
+
def get_group_qparams(w, n_bit=4, groupsize=128):
|
59 |
+
# needed for GPTQ with padding
|
60 |
+
if groupsize > w.shape[-1]:
|
61 |
+
groupsize = w.shape[-1]
|
62 |
+
assert groupsize > 1
|
63 |
+
assert w.shape[-1] % groupsize == 0
|
64 |
+
assert w.dim() == 2
|
65 |
+
|
66 |
+
to_quant = w.reshape(-1, groupsize)
|
67 |
+
assert torch.isnan(to_quant).sum() == 0
|
68 |
+
|
69 |
+
max_val = to_quant.amax(dim=1, keepdim=True)
|
70 |
+
min_val = to_quant.amin(dim=1, keepdim=True)
|
71 |
+
max_int = 2**n_bit - 1
|
72 |
+
scales = (max_val - min_val).clamp(min=1e-6) / max_int
|
73 |
+
zeros = min_val + scales * (2 ** (n_bit - 1))
|
74 |
+
return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to(
|
75 |
+
torch.bfloat16
|
76 |
+
).reshape(w.shape[0], -1)
|
77 |
+
|
78 |
+
|
79 |
+
def pack_scales_and_zeros(scales, zeros):
|
80 |
+
assert scales.shape == zeros.shape
|
81 |
+
assert scales.dtype == torch.bfloat16
|
82 |
+
assert zeros.dtype == torch.bfloat16
|
83 |
+
return (
|
84 |
+
torch.cat(
|
85 |
+
[
|
86 |
+
scales.reshape(scales.size(0), scales.size(1), 1),
|
87 |
+
zeros.reshape(zeros.size(0), zeros.size(1), 1),
|
88 |
+
],
|
89 |
+
2,
|
90 |
+
)
|
91 |
+
.transpose(0, 1)
|
92 |
+
.contiguous()
|
93 |
+
)
|
94 |
+
|
95 |
+
|
96 |
+
def unpack_scales_and_zeros(scales_and_zeros):
|
97 |
+
assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2
|
98 |
+
assert scales_and_zeros.dtype == torch.float
|
99 |
+
return torch.split(scales_and_zeros.transpose(0, 1), 1, 2)
|
100 |
+
|
101 |
+
|
102 |
+
def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128):
|
103 |
+
assert groupsize > 1
|
104 |
+
# needed for GPTQ single column quantize
|
105 |
+
if groupsize > w.shape[-1] and scales.shape[-1] == 1:
|
106 |
+
groupsize = w.shape[-1]
|
107 |
+
|
108 |
+
assert w.shape[-1] % groupsize == 0
|
109 |
+
assert w.dim() == 2
|
110 |
+
|
111 |
+
to_quant = w.reshape(-1, groupsize)
|
112 |
+
assert torch.isnan(to_quant).sum() == 0
|
113 |
+
|
114 |
+
scales = scales.reshape(-1, 1)
|
115 |
+
zeros = zeros.reshape(-1, 1)
|
116 |
+
min_val = zeros - scales * (2 ** (n_bit - 1))
|
117 |
+
max_int = 2**n_bit - 1
|
118 |
+
min_int = 0
|
119 |
+
w_int32 = (
|
120 |
+
to_quant.sub(min_val)
|
121 |
+
.div(scales)
|
122 |
+
.round()
|
123 |
+
.clamp_(min_int, max_int)
|
124 |
+
.to(torch.int32)
|
125 |
+
.reshape_as(w)
|
126 |
+
)
|
127 |
+
|
128 |
+
return w_int32
|
129 |
+
|
130 |
+
|
131 |
+
def group_quantize_tensor(w, n_bit=4, groupsize=128):
|
132 |
+
scales, zeros = get_group_qparams(w, n_bit, groupsize)
|
133 |
+
w_int32 = group_quantize_tensor_from_qparams(w, scales, zeros, n_bit, groupsize)
|
134 |
+
scales_and_zeros = pack_scales_and_zeros(scales, zeros)
|
135 |
+
return w_int32, scales_and_zeros
|
136 |
+
|
137 |
+
|
138 |
+
def group_dequantize_tensor_from_qparams(
|
139 |
+
w_int32, scales, zeros, n_bit=4, groupsize=128
|
140 |
+
):
|
141 |
+
assert groupsize > 1
|
142 |
+
# needed for GPTQ single column dequantize
|
143 |
+
if groupsize > w_int32.shape[-1] and scales.shape[-1] == 1:
|
144 |
+
groupsize = w_int32.shape[-1]
|
145 |
+
assert w_int32.shape[-1] % groupsize == 0
|
146 |
+
assert w_int32.dim() == 2
|
147 |
+
|
148 |
+
w_int32_grouped = w_int32.reshape(-1, groupsize)
|
149 |
+
scales = scales.reshape(-1, 1)
|
150 |
+
zeros = zeros.reshape(-1, 1)
|
151 |
+
|
152 |
+
w_dq = (
|
153 |
+
w_int32_grouped.sub(2 ** (n_bit - 1)).mul(scales).add(zeros).reshape_as(w_int32)
|
154 |
+
)
|
155 |
+
return w_dq
|
156 |
+
|
157 |
+
|
158 |
+
def group_dequantize_tensor(w_int32, scales_and_zeros, n_bit=4, groupsize=128):
|
159 |
+
scales, zeros = unpack_scales_and_zeros(scales_and_zeros)
|
160 |
+
return group_dequantize_tensor_from_qparams(
|
161 |
+
w_int32, scales, zeros, n_bit, groupsize
|
162 |
+
)
|
163 |
+
|
164 |
+
class QuantHandler:
|
165 |
+
def __init__(self, mod):
|
166 |
+
self.mod = mod
|
167 |
+
|
168 |
+
def create_quantized_state_dict(self) -> "StateDict":
|
169 |
+
pass
|
170 |
+
|
171 |
+
def convert_for_runtime(self) -> "nn.Module":
|
172 |
+
pass
|
173 |
+
|
174 |
+
class GPTQQuantHandler(QuantHandler):
|
175 |
+
"""
|
176 |
+
This class implements a GPTQ QuantHandler that can be used to apply GPTQ to a model in concert with the GenericGPTQRunner class.
|
177 |
+
Unlike the base QuantHandler class, the user does not need to implement the create_quantized_state_dict, instead they have to reimplement
|
178 |
+
__init__ such that it defines the functions for the quantization mode. User is expected to reimplement convert_for_runtime.
|
179 |
+
|
180 |
+
The following functions (which must be defined in __init__) are used to define the quantization mode for both GPTQ and
|
181 |
+
create_quantized_state_dict. Here is a description of each function.
|
182 |
+
|
183 |
+
get_qparams_func:
|
184 |
+
A function that calculates the quantization qparams for an input tensor.
|
185 |
+
Args:
|
186 |
+
weight: A 2d weight tensor with non-integer dtype.
|
187 |
+
Returns:
|
188 |
+
qparams: it can have any format but will need to be handled by the other defined functions below.
|
189 |
+
|
190 |
+
quantize_func:
|
191 |
+
A function that applies quantization to an input tensor. It should be noted
|
192 |
+
that this function needs to be able to handle quantizing the entire weight tensor, a single group,
|
193 |
+
or a single column.
|
194 |
+
Args:
|
195 |
+
weight: A 2d weight tensor with non-integer dtype.
|
196 |
+
qparams: the output from get_qparams_func
|
197 |
+
Returns:
|
198 |
+
quantized_weight: A 2d quantized weight tensor (generally with an integer dtype)
|
199 |
+
|
200 |
+
|
201 |
+
dequantize_func:
|
202 |
+
A function that dequantizes an input quantized weight tensor. It should be noted
|
203 |
+
that this function needs to be able to handle dequantizing the entire weight tensor, a single group,
|
204 |
+
or a single column.
|
205 |
+
Args:
|
206 |
+
quantized_weight: A 2d quantized weight tensor (generally with an integer dtype)
|
207 |
+
qparams: the output from get_qparams_func
|
208 |
+
Returns:
|
209 |
+
weight: A 2d weight tensor with non-integer dtype.
|
210 |
+
|
211 |
+
combine_qparams_list_func:
|
212 |
+
A function that combines several qparams into one qparam.
|
213 |
+
Args:
|
214 |
+
qparams_list: a list of qparams objects, each obtained by calling get_qparams_func
|
215 |
+
on a single group from a weight tensor
|
216 |
+
Returns:
|
217 |
+
qparams: an object of the same format as the qparams above.
|
218 |
+
|
219 |
+
skip_layer_func:
|
220 |
+
A function that determines which linear layers should be skipped during GPTQ
|
221 |
+
Args:
|
222 |
+
weight: A 2d weight tensor with non-integer dtype.
|
223 |
+
Returns:
|
224 |
+
skip: boolean indicating whether layer should be skipped
|
225 |
+
|
226 |
+
make_names_and_values_dict_func:
|
227 |
+
A function that prepares the qparams and quantized_weight and creates a dictionary indicating how they
|
228 |
+
should be inserted into the state_dict. Generally any packing of the weight and qparams should be done here.
|
229 |
+
Args:
|
230 |
+
quantized_weight: A 2d quantized weight tensor (generally with an integer dtype)
|
231 |
+
qparams: the output from get_qparams_func
|
232 |
+
Returns:
|
233 |
+
names_and_values_dict: a dictionary mapping the name of the parameters of the quantized module to the
|
234 |
+
corresponding quantized weights and qparams.
|
235 |
+
"""
|
236 |
+
def __init__(self):
|
237 |
+
assert self.mod is not None
|
238 |
+
assert self.get_qparams_func is not None
|
239 |
+
assert self.quantize_func is not None
|
240 |
+
assert self.dequantize_func is not None
|
241 |
+
assert self.combine_qparams_list_func is not None
|
242 |
+
assert self.make_names_and_values_dict_func is not None
|
243 |
+
|
244 |
+
@staticmethod
|
245 |
+
def get_inputs(model, tokenizer, calibration_tasks, calibration_limit, calibration_seq_length, pad_calibration_inputs) -> "MultiInput":
|
246 |
+
input_recorder = InputRecorder(
|
247 |
+
model,
|
248 |
+
tokenizer,
|
249 |
+
calibration_seq_length,
|
250 |
+
pad_calibration_inputs,
|
251 |
+
)
|
252 |
+
|
253 |
+
try:
|
254 |
+
lm_eval.tasks.initialize_tasks()
|
255 |
+
except:
|
256 |
+
pass
|
257 |
+
task_dict = get_task_dict(calibration_tasks)
|
258 |
+
print("Obtaining GPTQ calibration inputs on: ", calibration_tasks)
|
259 |
+
|
260 |
+
evaluate(
|
261 |
+
input_recorder,
|
262 |
+
task_dict,
|
263 |
+
limit=calibration_limit,
|
264 |
+
)
|
265 |
+
inputs = input_recorder.get_recorded_inputs()
|
266 |
+
assert inputs is not None, (
|
267 |
+
f"No inputs were collected, use a task other than {calibration_tasks}, "+
|
268 |
+
f"use option pad_calibration_inputs, or decrease calibration_sequence_length (currently "+
|
269 |
+
f"{calibration_seq_length})"
|
270 |
+
)
|
271 |
+
print(f"Obtained {len(inputs[0].values)} calibration samples")
|
272 |
+
return inputs
|
273 |
+
|
274 |
+
@torch.no_grad()
|
275 |
+
def create_quantized_state_dict(
|
276 |
+
self,
|
277 |
+
tokenizer,
|
278 |
+
blocksize,
|
279 |
+
percdamp,
|
280 |
+
groupsize,
|
281 |
+
calibration_tasks,
|
282 |
+
calibration_limit,
|
283 |
+
calibration_seq_length,
|
284 |
+
pad_calibration_inputs,
|
285 |
+
) -> "StateDict":
|
286 |
+
inputs = GPTQQuantHandler.get_inputs(self.mod, tokenizer, calibration_tasks, calibration_limit, calibration_seq_length, pad_calibration_inputs)
|
287 |
+
print("Tracing model for GPTQ")
|
288 |
+
GPTQ_runner = GenericGPTQRunner(
|
289 |
+
self.mod,
|
290 |
+
inputs,
|
291 |
+
blocksize,
|
292 |
+
percdamp,
|
293 |
+
groupsize,
|
294 |
+
).configure_quantization_mode(
|
295 |
+
self.get_qparams_func,
|
296 |
+
self.quantize_func,
|
297 |
+
self.dequantize_func,
|
298 |
+
self.combine_qparams_list_func,
|
299 |
+
self.make_names_and_values_dict_func,
|
300 |
+
self.skip_layer_func
|
301 |
+
)
|
302 |
+
|
303 |
+
print("Applying GPTQ to weights")
|
304 |
+
GPTQ_runner.run()
|
305 |
+
return GPTQ_runner.get_quantized_state_dict()
|
306 |
+
|
307 |
+
def convert_for_runtime(self) -> "nn.Module":
|
308 |
+
pass
|
309 |
+
|
310 |
+
##### Weight-only int8 per-channel quantized code ######
|
311 |
+
|
312 |
+
def replace_linear_weight_only_int8_per_channel(module):
|
313 |
+
for name, child in module.named_children():
|
314 |
+
if isinstance(child, nn.Linear):
|
315 |
+
setattr(module, name, WeightOnlyInt8Linear(child.in_features, child.out_features))
|
316 |
+
else:
|
317 |
+
replace_linear_weight_only_int8_per_channel(child)
|
318 |
+
|
319 |
+
class WeightOnlyInt8QuantHandler:
|
320 |
+
def __init__(self, mod):
|
321 |
+
self.mod = mod
|
322 |
+
|
323 |
+
@torch.no_grad()
|
324 |
+
def create_quantized_state_dict(self):
|
325 |
+
cur_state_dict = self.mod.state_dict()
|
326 |
+
for fqn, mod in self.mod.named_modules():
|
327 |
+
if isinstance(mod, torch.nn.Linear):
|
328 |
+
int8_weight, scales, _ = dynamically_quantize_per_channel(mod.weight.float(), -128, 127, torch.int8)
|
329 |
+
cur_state_dict[f"{fqn}.weight"] = int8_weight
|
330 |
+
cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype)
|
331 |
+
|
332 |
+
return cur_state_dict
|
333 |
+
|
334 |
+
def convert_for_runtime(self):
|
335 |
+
replace_linear_weight_only_int8_per_channel(self.mod)
|
336 |
+
return self.mod
|
337 |
+
|
338 |
+
|
339 |
+
class WeightOnlyInt8Linear(torch.nn.Module):
|
340 |
+
__constants__ = ['in_features', 'out_features']
|
341 |
+
in_features: int
|
342 |
+
out_features: int
|
343 |
+
weight: torch.Tensor
|
344 |
+
|
345 |
+
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
346 |
+
device=None, dtype=None) -> None:
|
347 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
348 |
+
super().__init__()
|
349 |
+
self.in_features = in_features
|
350 |
+
self.out_features = out_features
|
351 |
+
self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.int8))
|
352 |
+
self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16))
|
353 |
+
|
354 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
355 |
+
return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales
|
356 |
+
|
357 |
+
##### weight only int4 per channel groupwise quantized code ######
|
358 |
+
|
359 |
+
def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles):
|
360 |
+
weight_int32, scales_and_zeros = group_quantize_tensor(
|
361 |
+
weight_bf16, n_bit=4, groupsize=groupsize
|
362 |
+
)
|
363 |
+
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight_int32, inner_k_tiles)
|
364 |
+
return weight_int4pack, scales_and_zeros
|
365 |
+
|
366 |
+
|
367 |
+
def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize):
|
368 |
+
origin_x_size = x.size()
|
369 |
+
x = x.reshape(-1, origin_x_size[-1])
|
370 |
+
c = torch.ops.aten._weight_int4pack_mm(x, weight_int4pack, groupsize, scales_and_zeros)
|
371 |
+
new_shape = origin_x_size[:-1] + (out_features,)
|
372 |
+
c = c.reshape(new_shape)
|
373 |
+
return c
|
374 |
+
|
375 |
+
|
376 |
+
def _check_linear_int4_k(k, groupsize = 1, inner_k_tiles = 1):
|
377 |
+
return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0
|
378 |
+
|
379 |
+
def replace_linear_int4(module, groupsize, inner_k_tiles, padding):
|
380 |
+
for name, child in module.named_children():
|
381 |
+
if isinstance(child, nn.Linear):
|
382 |
+
if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles):
|
383 |
+
setattr(module, name, WeightOnlyInt4Linear(
|
384 |
+
child.in_features, child.out_features, bias=False,
|
385 |
+
groupsize=groupsize, inner_k_tiles=inner_k_tiles, padding=False,
|
386 |
+
))
|
387 |
+
elif padding:
|
388 |
+
setattr(module, name, WeightOnlyInt4Linear(
|
389 |
+
child.in_features, child.out_features, bias=False,
|
390 |
+
groupsize=groupsize, inner_k_tiles=inner_k_tiles, padding=True,
|
391 |
+
))
|
392 |
+
else:
|
393 |
+
replace_linear_int4(child, groupsize, inner_k_tiles, padding)
|
394 |
+
|
395 |
+
|
396 |
+
class WeightOnlyInt4QuantHandler:
|
397 |
+
def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
|
398 |
+
self.mod = mod
|
399 |
+
self.groupsize = groupsize
|
400 |
+
self.inner_k_tiles = inner_k_tiles
|
401 |
+
self.padding = padding
|
402 |
+
assert groupsize in [32, 64, 128, 256]
|
403 |
+
assert inner_k_tiles in [2, 4, 8]
|
404 |
+
|
405 |
+
@torch.no_grad()
|
406 |
+
def create_quantized_state_dict(self, use_cuda = True):
|
407 |
+
if use_cuda:
|
408 |
+
device="cuda"
|
409 |
+
else:
|
410 |
+
device="cpu"
|
411 |
+
|
412 |
+
cur_state_dict = self.mod.state_dict()
|
413 |
+
for fqn, mod in self.mod.named_modules():
|
414 |
+
if isinstance(mod, torch.nn.Linear):
|
415 |
+
assert not mod.bias
|
416 |
+
out_features = mod.out_features
|
417 |
+
in_features = mod.in_features
|
418 |
+
assert out_features % 8 == 0, "require out_features % 8 == 0"
|
419 |
+
print(f"linear: {fqn}, in={in_features}, out={out_features}")
|
420 |
+
|
421 |
+
weight = mod.weight.data
|
422 |
+
if not _check_linear_int4_k(in_features, self.groupsize, self.inner_k_tiles):
|
423 |
+
if self.padding:
|
424 |
+
from model import find_multiple
|
425 |
+
import torch.nn.functional as F
|
426 |
+
print(f"warning: {fqn} is padded to satisfy in_features % 1024 == 0")
|
427 |
+
padded_in_features = find_multiple(in_features, 1024)
|
428 |
+
weight = F.pad(weight, pad=(0, padded_in_features - in_features))
|
429 |
+
else:
|
430 |
+
print(f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " +
|
431 |
+
"and that groupsize and inner_k_tiles*16 evenly divide into it")
|
432 |
+
continue
|
433 |
+
weight_int4pack, scales_and_zeros = prepare_int4_weight_and_scales_and_zeros(
|
434 |
+
weight.to(torch.bfloat16).to(device=device), self.groupsize, self.inner_k_tiles
|
435 |
+
)
|
436 |
+
cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to('cpu')
|
437 |
+
cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to('cpu')
|
438 |
+
|
439 |
+
return cur_state_dict
|
440 |
+
|
441 |
+
def convert_for_runtime(self):
|
442 |
+
replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding)
|
443 |
+
return self.mod
|
444 |
+
|
445 |
+
class WeightOnlyInt4GPTQQuantHandler(GPTQQuantHandler):
|
446 |
+
def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
|
447 |
+
from model import find_multiple
|
448 |
+
self.mod = mod
|
449 |
+
self.groupsize = groupsize
|
450 |
+
self.inner_k_tiles = inner_k_tiles
|
451 |
+
self.padding = padding
|
452 |
+
self.get_qparams_func = lambda w: get_group_qparams(w, 4, groupsize)
|
453 |
+
self.quantize_func = lambda w, qparams: \
|
454 |
+
group_quantize_tensor_from_qparams(w, qparams[0], qparams[1], 4, groupsize)
|
455 |
+
self.dequantize_func = lambda q, qparams: \
|
456 |
+
group_dequantize_tensor_from_qparams(q, qparams[0], qparams[1], 4, groupsize).float()
|
457 |
+
self.combine_qparams_list_func = lambda qparams_list: \
|
458 |
+
[torch.cat(x, dim=1) for x in zip(*qparams_list)]
|
459 |
+
# skip unless padding=True or its correctly sized
|
460 |
+
self.skip_layer_func = lambda linear_weight: not (
|
461 |
+
_check_linear_int4_k(linear_weight.shape[-1], groupsize, inner_k_tiles) or padding
|
462 |
+
)
|
463 |
+
# we need to do the padding here, both for q and the qparams if necessary
|
464 |
+
def make_names_and_values_dict_func(q, qparams):
|
465 |
+
k = q.shape[1]
|
466 |
+
new_k = find_multiple(k, 1024)
|
467 |
+
# how much we need to pad the weight
|
468 |
+
delta_k = new_k - q.shape[1]
|
469 |
+
final_q = torch.ops.aten._convert_weight_to_int4pack(F.pad(q, pad=(0, delta_k)), inner_k_tiles)
|
470 |
+
scales_and_zeros = pack_scales_and_zeros(*qparams)
|
471 |
+
# how many new groups we need for padded weight
|
472 |
+
delta_groups = new_k // groupsize - scales_and_zeros.shape[0]
|
473 |
+
final_s_and_z = F.pad(scales_and_zeros, pad=(0,0,0,0,0, delta_groups), value=1)
|
474 |
+
return {"weight": final_q, "scales_and_zeros": final_s_and_z}
|
475 |
+
self.make_names_and_values_dict_func = make_names_and_values_dict_func
|
476 |
+
super().__init__()
|
477 |
+
|
478 |
+
|
479 |
+
def convert_for_runtime(self):
|
480 |
+
replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding)
|
481 |
+
return self.mod
|
482 |
+
|
483 |
+
class WeightOnlyInt4Linear(torch.nn.Module):
|
484 |
+
__constants__ = ['in_features', 'out_features']
|
485 |
+
in_features: int
|
486 |
+
out_features: int
|
487 |
+
weight: torch.Tensor
|
488 |
+
|
489 |
+
def __init__(
|
490 |
+
self, in_features: int, out_features: int,
|
491 |
+
bias=True, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8, padding: bool = True,
|
492 |
+
) -> None:
|
493 |
+
super().__init__()
|
494 |
+
self.padding = padding
|
495 |
+
if padding:
|
496 |
+
from model import find_multiple
|
497 |
+
self.origin_in_features = in_features
|
498 |
+
in_features = find_multiple(in_features, 1024)
|
499 |
+
|
500 |
+
self.in_features = in_features
|
501 |
+
self.out_features = out_features
|
502 |
+
assert not bias, "require bias=False"
|
503 |
+
self.groupsize = groupsize
|
504 |
+
self.inner_k_tiles = inner_k_tiles
|
505 |
+
|
506 |
+
assert out_features % 8 == 0, "require out_features % 8 == 0"
|
507 |
+
assert in_features % (inner_k_tiles * 16) == 0, "require in_features % (innerKTiles * 16) == 0"
|
508 |
+
self.register_buffer(
|
509 |
+
"weight",
|
510 |
+
torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32)
|
511 |
+
)
|
512 |
+
self.register_buffer(
|
513 |
+
"scales_and_zeros",
|
514 |
+
torch.empty((in_features // groupsize, out_features, 2), dtype=torch.bfloat16)
|
515 |
+
)
|
516 |
+
|
517 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
518 |
+
input = input.to(torch.bfloat16)
|
519 |
+
if self.padding:
|
520 |
+
import torch.nn.functional as F
|
521 |
+
input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
|
522 |
+
return linear_forward_int4(
|
523 |
+
input,
|
524 |
+
self.weight, self.scales_and_zeros, self.out_features, self.groupsize
|
525 |
+
)
|
526 |
+
|
527 |
+
|
528 |
+
def quantize(
|
529 |
+
checkpoint_path: Path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"),
|
530 |
+
mode: str = 'int8',
|
531 |
+
# following arguments only available when setting int4 quantization.
|
532 |
+
groupsize: int = 128,
|
533 |
+
# following arguments only used for GPTQ
|
534 |
+
calibration_tasks: list = ["hellaswag"],
|
535 |
+
calibration_limit: int = 1000,
|
536 |
+
calibration_seq_length: int = 100,
|
537 |
+
pad_calibration_inputs: bool = False,
|
538 |
+
percdamp: float = .01,
|
539 |
+
blocksize: int = 128,
|
540 |
+
label: str = '',
|
541 |
+
) -> None:
|
542 |
+
assert checkpoint_path.is_file(), checkpoint_path
|
543 |
+
|
544 |
+
device = 'cpu'
|
545 |
+
precision = torch.bfloat16
|
546 |
+
|
547 |
+
print("Loading model ...")
|
548 |
+
t0 = time.time()
|
549 |
+
|
550 |
+
with torch.device('meta'):
|
551 |
+
model = Transformer.from_name(checkpoint_path.parent.name)
|
552 |
+
|
553 |
+
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
|
554 |
+
model.load_state_dict(checkpoint, assign=True)
|
555 |
+
model = model.to(dtype=precision, device=device)
|
556 |
+
|
557 |
+
if mode == 'int8':
|
558 |
+
print("Quantizing model weights for int8 weight-only symmetric per-channel quantization")
|
559 |
+
quant_handler = WeightOnlyInt8QuantHandler(model)
|
560 |
+
quantized_state_dict = quant_handler.create_quantized_state_dict()
|
561 |
+
|
562 |
+
dir_name = checkpoint_path.parent
|
563 |
+
base_name = checkpoint_path.name
|
564 |
+
new_base_name = base_name.replace('.pth', f'{label}int8.pth')
|
565 |
+
|
566 |
+
elif mode == 'int4':
|
567 |
+
print("Quantizing model weights for int4 weight-only affine per-channel groupwise quantization")
|
568 |
+
quant_handler = WeightOnlyInt4QuantHandler(model, groupsize)
|
569 |
+
quantized_state_dict = quant_handler.create_quantized_state_dict()
|
570 |
+
|
571 |
+
dir_name = checkpoint_path.parent
|
572 |
+
base_name = checkpoint_path.name
|
573 |
+
new_base_name = base_name.replace('.pth', f"{label}int4.g{groupsize}.pth")
|
574 |
+
|
575 |
+
elif mode == 'int4-gptq':
|
576 |
+
print("Quantizing model weights for int4 weight-only affine per-channel groupwise quantization using GPTQ...")
|
577 |
+
quant_handler = WeightOnlyInt4GPTQQuantHandler(model, groupsize)
|
578 |
+
|
579 |
+
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
|
580 |
+
assert tokenizer_path.is_file(), str(tokenizer_path)
|
581 |
+
tokenizer = get_tokenizer(tokenizer_path, checkpoint_path)
|
582 |
+
|
583 |
+
quantized_state_dict = quant_handler.create_quantized_state_dict(
|
584 |
+
tokenizer,
|
585 |
+
blocksize,
|
586 |
+
percdamp,
|
587 |
+
groupsize,
|
588 |
+
calibration_tasks,
|
589 |
+
calibration_limit,
|
590 |
+
calibration_seq_length,
|
591 |
+
pad_calibration_inputs
|
592 |
+
)
|
593 |
+
|
594 |
+
dir_name = checkpoint_path.parent
|
595 |
+
base_name = checkpoint_path.name
|
596 |
+
new_base_name = base_name.replace('.pth', f"{label}int4-gptq.g{groupsize}.pth")
|
597 |
+
else:
|
598 |
+
raise ValueError(f"Invalid quantization mode {mode} needs to be one of [int8, int4, int4-gpptq]")
|
599 |
+
|
600 |
+
quantize_path = dir_name / new_base_name
|
601 |
+
print(f"Writing quantized weights to {quantize_path}")
|
602 |
+
quantize_path.unlink(missing_ok=True) # remove existing file if one already there
|
603 |
+
torch.save(quantized_state_dict, quantize_path)
|
604 |
+
print(f"Quantization complete took {time.time() - t0:.02f} seconds")
|
605 |
+
return
|
606 |
+
|
607 |
+
if __name__ == '__main__':
|
608 |
+
import argparse
|
609 |
+
parser = argparse.ArgumentParser(description='Quantize a model.')
|
610 |
+
parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Path to the model checkpoint to be quantized.')
|
611 |
+
parser.add_argument('--mode', '-q', type=str, default='int8', choices=['int8', 'int4', 'int4-gptq'], help='type of quantization to perform')
|
612 |
+
parser.add_argument('--groupsize', type=int, default=32, help='Group size for int4 quantization.')
|
613 |
+
parser.add_argument('--calibration_tasks', type=str, nargs='+', default=['wikitext'], help='tasks to do gptq calibration on, if doing gptq')
|
614 |
+
parser.add_argument('--calibration_limit', type=int, default=1000, help='number of samples to use for gptq calibration')
|
615 |
+
parser.add_argument('--calibration_seq_length', type=int, default=100, help='length of sequences to use for gptq calibration')
|
616 |
+
parser.add_argument('--pad_calibration_inputs', type=bool, default=False, help='pads sequences shorter than calibration_seq_length to that length, yielding more calibration inputs but running much slower')
|
617 |
+
parser.add_argument('--percdamp', type=float, default=.01, help='gptq percentage dampening')
|
618 |
+
parser.add_argument('--blocksize', type=int, default=128, help='blocksize for gptq')
|
619 |
+
parser.add_argument('--label', type=str, default='_', help='label to add to output filename')
|
620 |
+
|
621 |
+
args = parser.parse_args()
|
622 |
+
quantize(args.checkpoint_path, args.mode, args.groupsize, args.calibration_tasks, args.calibration_limit, args.calibration_seq_length, args.pad_calibration_inputs, args.percdamp, args.blocksize, args.label)
|