Staticaliza commited on
Commit
be0c908
·
verified ·
1 Parent(s): e9923fc

Upload 3 files

Browse files
modules/gpt_fast/generate.py ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ import itertools
7
+ import sys
8
+ import time
9
+ from pathlib import Path
10
+ from typing import Optional, Tuple
11
+
12
+ import torch
13
+ import torch._dynamo.config
14
+ import torch._inductor.config
15
+
16
+ def device_sync(device):
17
+ if "cuda" in device:
18
+ torch.cuda.synchronize(device)
19
+ elif ("cpu" in device) or ("mps" in device):
20
+ pass
21
+ else:
22
+ print(f"device={device} is not yet suppported")
23
+
24
+
25
+ torch._inductor.config.coordinate_descent_tuning = True
26
+ torch._inductor.config.triton.unique_kernel_names = True
27
+ torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
28
+
29
+ default_device = 'cuda' if torch.cuda.is_available() else 'cpu'
30
+
31
+ # support running without installing as a package
32
+ wd = Path(__file__).parent.parent.resolve()
33
+ sys.path.append(str(wd))
34
+
35
+ from model import Transformer
36
+ from tokenizer import get_tokenizer
37
+
38
+ def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization
39
+ q = torch.empty_like(probs_sort).exponential_(1)
40
+ return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
41
+
42
+ def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None):
43
+ logits = logits / max(temperature, 1e-5)
44
+
45
+ if top_k is not None:
46
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
47
+ pivot = v.select(-1, -1).unsqueeze(-1)
48
+ logits = torch.where(logits < pivot, -float("Inf"), logits)
49
+ probs = torch.nn.functional.softmax(logits, dim=-1)
50
+ return probs
51
+
52
+ def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
53
+ probs = logits_to_probs(logits[0, -1], temperature, top_k)
54
+ idx_next = multinomial_sample_one_no_sync(probs)
55
+ return idx_next, probs
56
+
57
+ def prefill(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> torch.Tensor:
58
+ # input_pos: [B, S]
59
+ logits = model(x, input_pos)
60
+ return sample(logits, **sampling_kwargs)[0]
61
+
62
+ def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
63
+ # input_pos: [B, 1]
64
+ assert input_pos.shape[-1] == 1
65
+ logits = model(x, input_pos)
66
+ return sample(logits, **sampling_kwargs)
67
+
68
+ def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, callback=lambda _: _, **sampling_kwargs):
69
+ new_tokens, new_probs = [], []
70
+ for i in range(num_new_tokens):
71
+ with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here
72
+ next_token, next_prob = decode_one_token(
73
+ model, cur_token, input_pos, **sampling_kwargs
74
+ )
75
+ input_pos += 1
76
+ new_tokens.append(next_token.clone())
77
+ callback(new_tokens[-1])
78
+ new_probs.append(next_prob.clone())
79
+ cur_token = next_token.view(1, -1)
80
+
81
+ return new_tokens, new_probs
82
+
83
+
84
+ def model_forward(model, x, input_pos):
85
+ return model(x, input_pos)
86
+
87
+ def speculative_decode(
88
+ model: Transformer,
89
+ draft_model: Transformer,
90
+ cur_token: torch.Tensor,
91
+ input_pos: int,
92
+ speculate_k: int,
93
+ **sampling_kwargs
94
+ ) -> torch.Tensor:
95
+ # draft model inference sequentially
96
+ device = cur_token.device
97
+ orig_input_pos = torch.tensor([input_pos], dtype=torch.int64, device=cur_token.device)
98
+ draft_tokens, draft_probs = decode_n_tokens(draft_model, cur_token.view(1, -1), orig_input_pos.clone(), speculate_k, **sampling_kwargs)
99
+
100
+ draft_tokens = torch.cat(draft_tokens)
101
+ # parallel inference on target model using draft tokens
102
+ target_logits = model_forward(
103
+ model,
104
+ torch.cat([cur_token.view(1), draft_tokens]).view(1, -1),
105
+ torch.arange(input_pos, input_pos + speculate_k + 1, device=cur_token.device)
106
+ )
107
+ target_probs = logits_to_probs(target_logits[0], **sampling_kwargs)
108
+ draft_probs = torch.stack(draft_probs)
109
+ # q: target prob, p: draft prob
110
+ # q >= p: always accept draft token
111
+ # q < p: q/p prob to accept draft token
112
+ p = draft_probs[torch.arange(0, speculate_k, device=device), draft_tokens]
113
+ q = target_probs[torch.arange(0, speculate_k, device=device), draft_tokens]
114
+ accept_draft_prob = torch.minimum(torch.ones(()), q[:speculate_k]/ p)
115
+ rejected_locations = (torch.rand_like(accept_draft_prob) > accept_draft_prob).nonzero()
116
+
117
+ if rejected_locations.shape[0] == 0: # All draft tokens have been accepted
118
+ accept_length = speculate_k + 1
119
+ last_token = multinomial_sample_one_no_sync(target_probs[-1])
120
+ # fill last token into draft model
121
+ model_forward(
122
+ draft_model,
123
+ draft_tokens[-1].view(1, -1),
124
+ orig_input_pos + speculate_k,
125
+ )
126
+ return torch.cat([draft_tokens, last_token])
127
+ else:
128
+ accept_length = rejected_locations[0].item()
129
+ p = draft_probs[accept_length]
130
+ q = target_probs[accept_length]
131
+ new = q - p
132
+ new = torch.where(new > 0, new, 0.0)
133
+ new = new / new.sum()
134
+ next_token = multinomial_sample_one_no_sync(new)
135
+ return torch.cat([draft_tokens[:accept_length], next_token])
136
+
137
+ @torch.no_grad()
138
+ def generate(
139
+ model: Transformer,
140
+ prompt: torch.Tensor,
141
+ max_new_tokens: int,
142
+ *,
143
+ interactive: bool,
144
+ draft_model: Transformer,
145
+ speculate_k: Optional[int] = 8,
146
+ callback = lambda x: x,
147
+ **sampling_kwargs
148
+ ) -> torch.Tensor:
149
+ """
150
+ Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
151
+ """
152
+
153
+ is_speculative = draft_model is not None
154
+ # create an empty tensor of the expected final shape and fill in the current tokens
155
+ T = prompt.size(0)
156
+ T_new = T + max_new_tokens
157
+ if interactive:
158
+ max_seq_length = 350
159
+ else:
160
+ max_seq_length = min(T_new, model.config.block_size)
161
+
162
+ device, dtype = prompt.device, prompt.dtype
163
+ max_seq_length = max_seq_length + speculate_k + 1 if is_speculative else max_seq_length
164
+ with torch.device(device):
165
+ model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
166
+ if is_speculative and draft_model is not model:
167
+ draft_model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
168
+
169
+ # create an empty tensor of the expected final shape and fill in the current tokens
170
+ empty = torch.empty(T_new, dtype=dtype, device=device)
171
+ empty[:T] = prompt
172
+ seq = empty
173
+ input_pos = torch.arange(0, T, device=device)
174
+
175
+ next_token = prefill(model, prompt.view(1, -1), input_pos, **sampling_kwargs).clone()
176
+ if is_speculative:
177
+ prefill(draft_model, prompt.view(1, -1), input_pos, **sampling_kwargs)
178
+ seq[T] = next_token
179
+
180
+ input_pos = torch.tensor([T], device=device, dtype=torch.int)
181
+ accept_counts = [0] * (speculate_k + 1)
182
+
183
+ if is_speculative:
184
+ input_pos = input_pos.item() # for speculative decoding easier to keep on host
185
+ while input_pos < T_new - 1:
186
+ cur_token = next_token.view(())
187
+
188
+ next_tokens = speculative_decode(
189
+ model, draft_model, cur_token, input_pos, speculate_k, **sampling_kwargs
190
+ )
191
+
192
+ accept_counts[len(next_tokens) - 1] += 1
193
+ num_added = min(T_new - input_pos - 1, len(next_tokens))
194
+ seq[input_pos + 1 : input_pos + num_added + 1] = next_tokens[: num_added]
195
+ for i in next_tokens[: num_added,]:
196
+ callback(i)
197
+ input_pos = input_pos + num_added
198
+ next_token = next_tokens[-1]
199
+ else:
200
+ generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs)
201
+ seq[T + 1:] = torch.cat(generated_tokens)
202
+
203
+ generate_stats = {
204
+ 'accept_counts': accept_counts
205
+ }
206
+ return seq, generate_stats
207
+
208
+ def encode_tokens(tokenizer, string, bos=True, device=default_device):
209
+ tokens = tokenizer.encode(string)
210
+ if bos:
211
+ tokens = [tokenizer.bos_id()] + tokens
212
+ return torch.tensor(tokens, dtype=torch.int, device=device)
213
+
214
+ def _load_model(checkpoint_path, device, precision, use_tp):
215
+ use_cuda = 'cuda' in device
216
+ with torch.device('meta'):
217
+ model = Transformer.from_name(checkpoint_path.parent.name)
218
+
219
+ if "int8" in str(checkpoint_path):
220
+ print("Using int8 weight-only quantization!")
221
+ from quantize import WeightOnlyInt8QuantHandler
222
+ simple_quantizer = WeightOnlyInt8QuantHandler(model)
223
+ model = simple_quantizer.convert_for_runtime()
224
+
225
+ if "int4" in str(checkpoint_path):
226
+ print("Using int4 weight-only quantization!")
227
+ path_comps = checkpoint_path.name.split(".")
228
+ groupsize = int(path_comps[-2][1:])
229
+ from quantize import WeightOnlyInt4QuantHandler
230
+ simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
231
+ model = simple_quantizer.convert_for_runtime()
232
+
233
+ checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
234
+ if "model" in checkpoint and "stories" in str(checkpoint_path):
235
+ checkpoint = checkpoint["model"]
236
+ model.load_state_dict(checkpoint, assign=True)
237
+
238
+ if use_tp:
239
+ from tp import apply_tp
240
+ print("Applying tensor parallel to model ...")
241
+ apply_tp(model)
242
+
243
+ model = model.to(device=device, dtype=precision)
244
+ return model.eval()
245
+
246
+ def _get_model_size(model):
247
+ model_size = 0
248
+ for name, child in model.named_children():
249
+ if not isinstance(child, torch.nn.Embedding):
250
+ model_size += sum(
251
+ [
252
+ p.numel() * p.dtype.itemsize
253
+ for p in itertools.chain(child.parameters(), child.buffers())
254
+ ]
255
+ )
256
+ return model_size
257
+
258
+ B_INST, E_INST = "[INST]", "[/INST]"
259
+
260
+ def main(
261
+ prompt: str = "Hello, my name is",
262
+ interactive: bool = False,
263
+ num_samples: int = 5,
264
+ max_new_tokens: int = 100,
265
+ top_k: int = 200,
266
+ temperature: float = 0.8,
267
+ checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"),
268
+ compile: bool = True,
269
+ compile_prefill: bool = False,
270
+ profile: Optional[Path] = None,
271
+ draft_checkpoint_path: Optional[Path] = None,
272
+ speculate_k: int = 5,
273
+ device=default_device,
274
+ ) -> None:
275
+ """Generates text samples based on a pre-trained Transformer model and tokenizer.
276
+ """
277
+ assert checkpoint_path.is_file(), checkpoint_path
278
+
279
+ tokenizer_path = checkpoint_path.parent / "tokenizer.model"
280
+ assert tokenizer_path.is_file(), str(tokenizer_path)
281
+
282
+ global print
283
+ from tp import maybe_init_dist
284
+ rank = maybe_init_dist()
285
+ use_tp = rank is not None
286
+ if use_tp:
287
+ if rank != 0:
288
+ # only print on rank 0
289
+ print = lambda *args, **kwargs: None
290
+
291
+ print(f"Using device={device}")
292
+ precision = torch.bfloat16
293
+ is_speculative = draft_checkpoint_path is not None
294
+ is_chat = "chat" in str(checkpoint_path)
295
+
296
+ print("Loading model ...")
297
+ t0 = time.time()
298
+ model = _load_model(checkpoint_path, device, precision, use_tp)
299
+
300
+ if is_speculative:
301
+ draft_model = _load_model(draft_checkpoint_path, device, precision, use_tp)
302
+ else:
303
+ draft_model = None
304
+
305
+ device_sync(device=device) # MKG
306
+ print(f"Time to load model: {time.time() - t0:.02f} seconds")
307
+
308
+ tokenizer = get_tokenizer(tokenizer_path, checkpoint_path)
309
+
310
+ encoded = encode_tokens(tokenizer, prompt, bos=True, device=device)
311
+ prompt_length = encoded.size(0)
312
+
313
+ torch.manual_seed(1234)
314
+ model_size = _get_model_size(model)
315
+ if compile:
316
+ if is_speculative and use_tp: # and ("cuda" in device):
317
+ torch._inductor.config.triton.cudagraph_trees = False # Bug with cudagraph trees in this case
318
+
319
+ if is_speculative:
320
+ global model_forward, logits_to_prob
321
+ model_forward = torch.compile(model_forward, mode="reduce-overhead", fullgraph=True)
322
+
323
+ global decode_one_token, prefill
324
+ decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True)
325
+
326
+ # Uncomment to squeeze more perf out of prefill
327
+ if compile_prefill:
328
+ prefill = torch.compile(prefill, fullgraph=True, dynamic=True)
329
+
330
+
331
+ aggregate_metrics = {
332
+ 'tokens_per_sec': [],
333
+ 'accept_counts': [],
334
+ }
335
+ start = -1 if compile else 0
336
+
337
+ for i in range(start, num_samples):
338
+ device_sync(device=device) # MKG
339
+ if i >= 0 and interactive:
340
+ prompt = input("What is your prompt? ")
341
+ if is_chat:
342
+ prompt = f"{B_INST} {prompt.strip()} {E_INST}"
343
+ encoded = encode_tokens(tokenizer, prompt, bos=True, device=device)
344
+
345
+ if interactive and i >= 0:
346
+ buffer = []
347
+ period_id = tokenizer.encode('.')[0]
348
+ done_generating = False
349
+ def callback(x):
350
+ nonlocal done_generating
351
+ if done_generating:
352
+ return
353
+ buffer.append(tokenizer.decode([period_id] + x.tolist())[1:])
354
+ if x.item() == tokenizer.eos_id():
355
+ done_generating = True
356
+ if len(buffer) == 4 or done_generating:
357
+ print(''.join(buffer), end='', flush=True)
358
+ buffer.clear()
359
+ # print(, end='', flush=True)
360
+ else:
361
+ callback = lambda x : x
362
+ t0 = time.perf_counter()
363
+ import contextlib
364
+ if (i != num_samples - 1 or not profile) or (use_tp and rank != 0):
365
+ prof = contextlib.nullcontext()
366
+ else:
367
+ torch.profiler._utils._init_for_cuda_graphs()
368
+ prof = torch.profiler.profile()
369
+ with prof:
370
+ y, metrics = generate(
371
+ model,
372
+ encoded,
373
+ max_new_tokens,
374
+ draft_model=draft_model,
375
+ speculate_k=speculate_k,
376
+ interactive=interactive,
377
+ callback=callback,
378
+ temperature=temperature,
379
+ top_k=top_k,
380
+ )
381
+ aggregate_metrics['accept_counts'].append(metrics['accept_counts'])
382
+ if i == -1:
383
+ print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
384
+ continue
385
+ if hasattr(prof, "export_chrome_trace"):
386
+ if use_tp:
387
+ prof.export_chrome_trace(f"{profile}_rank_{rank}.json")
388
+ else:
389
+ prof.export_chrome_trace(f"{profile}.json")
390
+ device_sync(device=device) # MKG
391
+ t = time.perf_counter() - t0
392
+
393
+ if not interactive:
394
+ print(tokenizer.decode(y.tolist()))
395
+ else:
396
+ print()
397
+ tokens_generated = y.size(0) - prompt_length
398
+ tokens_sec = tokens_generated / t
399
+ aggregate_metrics['tokens_per_sec'].append(tokens_sec)
400
+ print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec")
401
+ print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s")
402
+ print("==========")
403
+ if is_speculative:
404
+ counts_aggregated = [sum(i) for i in zip(*aggregate_metrics['accept_counts'])]
405
+ acceptance_probs = [i/sum(counts_aggregated) for i in counts_aggregated]
406
+ print(f"Acceptance probs: {acceptance_probs}")
407
+ print(f"Mean Accepted: {sum([idx * i for idx, i in enumerate(counts_aggregated)])/sum(counts_aggregated)}")
408
+
409
+ print(f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}")
410
+ print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
411
+
412
+
413
+ if __name__ == '__main__':
414
+ import argparse
415
+ parser = argparse.ArgumentParser(description='Your CLI description.')
416
+
417
+ parser.add_argument('--prompt', type=str, default="Hello, my name is", help='Input prompt.')
418
+ parser.add_argument('--interactive', action='store_true', help='Whether to launch in interactive mode')
419
+ parser.add_argument('--num_samples', type=int, default=5, help='Number of samples.')
420
+ parser.add_argument('--max_new_tokens', type=int, default=200, help='Maximum number of new tokens.')
421
+ parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.')
422
+ parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.')
423
+ parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), help='Model checkpoint path.')
424
+ parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
425
+ parser.add_argument('--compile_prefill', action='store_true', help='Whether to compile the prefill (improves prefill perf, but higher compile times)')
426
+ parser.add_argument('--profile', type=Path, default=None, help='Profile path.')
427
+ parser.add_argument('--speculate_k', type=int, default=5, help='Speculative execution depth.')
428
+ parser.add_argument('--draft_checkpoint_path', type=Path, default=None, help='Draft checkpoint path.')
429
+ parser.add_argument('--device', type=str, default=default_device, help='Device to use')
430
+
431
+ args = parser.parse_args()
432
+ main(
433
+ args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k,
434
+ args.temperature, args.checkpoint_path, args.compile, args.compile_prefill, args.profile, args.draft_checkpoint_path,
435
+ args.speculate_k, args.device
436
+ )
modules/gpt_fast/model.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ from dataclasses import dataclass
7
+ from typing import Optional
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch import Tensor
12
+ from torch.nn import functional as F
13
+
14
+
15
+ def find_multiple(n: int, k: int) -> int:
16
+ if n % k == 0:
17
+ return n
18
+ return n + k - (n % k)
19
+
20
+ class AdaptiveLayerNorm(nn.Module):
21
+ r"""Adaptive Layer Normalization"""
22
+
23
+ def __init__(self, d_model, norm) -> None:
24
+ super(AdaptiveLayerNorm, self).__init__()
25
+ self.project_layer = nn.Linear(d_model, 2 * d_model)
26
+ self.norm = norm
27
+ self.d_model = d_model
28
+ self.eps = self.norm.eps
29
+
30
+ def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
31
+ if embedding is None:
32
+ return self.norm(input)
33
+ weight, bias = torch.split(
34
+ self.project_layer(embedding),
35
+ split_size_or_sections=self.d_model,
36
+ dim=-1,
37
+ )
38
+ return weight * self.norm(input) + bias
39
+
40
+
41
+ @dataclass
42
+ class ModelArgs:
43
+ block_size: int = 2048
44
+ vocab_size: int = 32000
45
+ n_layer: int = 32
46
+ n_head: int = 32
47
+ dim: int = 4096
48
+ intermediate_size: int = None
49
+ n_local_heads: int = -1
50
+ head_dim: int = 64
51
+ rope_base: float = 10000
52
+ norm_eps: float = 1e-5
53
+ has_cross_attention: bool = False
54
+ context_dim: int = 0
55
+ uvit_skip_connection: bool = False
56
+
57
+ def __post_init__(self):
58
+ if self.n_local_heads == -1:
59
+ self.n_local_heads = self.n_head
60
+ if self.intermediate_size is None:
61
+ hidden_dim = 4 * self.dim
62
+ n_hidden = int(2 * hidden_dim / 3)
63
+ self.intermediate_size = find_multiple(n_hidden, 256)
64
+ # self.head_dim = self.dim // self.n_head
65
+
66
+ @classmethod
67
+ def from_name(cls, name: str):
68
+ if name in transformer_configs:
69
+ return cls(**transformer_configs[name])
70
+ # fuzzy search
71
+ config = [config for config in transformer_configs if config.lower() in str(name).lower()]
72
+
73
+ # We may have two or more configs matched (e.g. "7B" and "Mistral-7B"). Find the best config match,
74
+ # take longer name (as it have more symbols matched)
75
+ if len(config) > 1:
76
+ config.sort(key=len, reverse=True)
77
+ assert len(config[0]) != len(config[1]), name # make sure only one 'best' match
78
+
79
+ return cls(**transformer_configs[config[0]])
80
+
81
+
82
+ transformer_configs = {
83
+ "CodeLlama-7b-Python-hf": dict(block_size=16384, vocab_size=32000, n_layer=32, dim=4096, rope_base=1000000),
84
+ "7B": dict(n_layer=32, n_head=32, dim=4096),
85
+ "13B": dict(n_layer=40, n_head=40, dim=5120),
86
+ "30B": dict(n_layer=60, n_head=52, dim=6656),
87
+ "34B": dict(n_layer=48, n_head=64, dim=8192, vocab_size=32000, n_local_heads=8, intermediate_size=22016,
88
+ rope_base=1000000), # CodeLlama-34B-Python-hf
89
+ "70B": dict(n_layer=80, n_head=64, dim=8192, n_local_heads=8, intermediate_size=28672),
90
+ "Mistral-7B": dict(n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=32000),
91
+ "stories15M": dict(n_layer=6, n_head=6, dim=288),
92
+ "stories110M": dict(n_layer=12, n_head=12, dim=768),
93
+
94
+ "llama-3-8b": dict(block_size=8192, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336,
95
+ vocab_size=128256, rope_base=500000),
96
+ "llama-3-70b": dict(block_size=8192, n_layer=80, n_head=64, n_local_heads=8, dim=8192, intermediate_size=28672,
97
+ vocab_size=128256, rope_base=500000),
98
+ }
99
+
100
+
101
+ class KVCache(nn.Module):
102
+ def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16):
103
+ super().__init__()
104
+ cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
105
+ self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype))
106
+ self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype))
107
+
108
+ def update(self, input_pos, k_val, v_val):
109
+ # input_pos: [S], k_val: [B, H, S, D]
110
+ assert input_pos.shape[0] == k_val.shape[2]
111
+
112
+ k_out = self.k_cache
113
+ v_out = self.v_cache
114
+ k_out[:, :, input_pos] = k_val
115
+ v_out[:, :, input_pos] = v_val
116
+
117
+ return k_out, v_out
118
+
119
+
120
+ class Transformer(nn.Module):
121
+ def __init__(self, config: ModelArgs) -> None:
122
+ super().__init__()
123
+ self.config = config
124
+
125
+ self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer))
126
+ self.norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
127
+
128
+ self.freqs_cis: Optional[Tensor] = None
129
+ self.mask_cache: Optional[Tensor] = None
130
+ self.max_batch_size = -1
131
+ self.max_seq_length = -1
132
+
133
+ def setup_caches(self, max_batch_size, max_seq_length, use_kv_cache=True):
134
+ if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
135
+ return
136
+ head_dim = self.config.dim // self.config.n_head
137
+ max_seq_length = find_multiple(max_seq_length, 8)
138
+ self.max_seq_length = max_seq_length
139
+ self.max_batch_size = max_batch_size
140
+ dtype = self.norm.project_layer.weight.dtype
141
+ device = self.norm.project_layer.weight.device
142
+
143
+ if not self.training and use_kv_cache:
144
+ for b in self.layers:
145
+ b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype).to(device)
146
+
147
+ self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.head_dim,
148
+ self.config.rope_base, dtype).to(device)
149
+ self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)).to(device)
150
+ self.use_kv_cache = use_kv_cache
151
+ self.uvit_skip_connection = self.config.uvit_skip_connection
152
+ if self.uvit_skip_connection:
153
+ self.layers_emit_skip = [i for i in range(self.config.n_layer) if i < self.config.n_layer // 2]
154
+ self.layers_receive_skip = [i for i in range(self.config.n_layer) if i > self.config.n_layer // 2]
155
+ else:
156
+ self.layers_emit_skip = []
157
+ self.layers_receive_skip = []
158
+
159
+ def forward(self,
160
+ x: Tensor,
161
+ c: Tensor,
162
+ input_pos: Optional[Tensor] = None,
163
+ mask: Optional[Tensor] = None,
164
+ context: Optional[Tensor] = None,
165
+ context_input_pos: Optional[Tensor] = None,
166
+ cross_attention_mask: Optional[Tensor] = None,
167
+ ) -> Tensor:
168
+ assert self.freqs_cis is not None, "Caches must be initialized first"
169
+ if mask is None: # in case of non-causal model
170
+ if not self.training and self.use_kv_cache:
171
+ mask = self.causal_mask[None, None, input_pos]
172
+ else:
173
+ mask = self.causal_mask[None, None, input_pos]
174
+ mask = mask[..., input_pos]
175
+ freqs_cis = self.freqs_cis[input_pos]
176
+ if context is not None:
177
+ context_freqs_cis = self.freqs_cis[context_input_pos]
178
+ else:
179
+ context_freqs_cis = None
180
+ skip_in_x_list = []
181
+ for i, layer in enumerate(self.layers):
182
+ if self.uvit_skip_connection and i in self.layers_receive_skip:
183
+ skip_in_x = skip_in_x_list.pop(-1)
184
+ else:
185
+ skip_in_x = None
186
+ x = layer(x, c, input_pos, freqs_cis, mask, context, context_freqs_cis, cross_attention_mask, skip_in_x)
187
+ if self.uvit_skip_connection and i in self.layers_emit_skip:
188
+ skip_in_x_list.append(x)
189
+ x = self.norm(x, c)
190
+ return x
191
+
192
+ @classmethod
193
+ def from_name(cls, name: str):
194
+ return cls(ModelArgs.from_name(name))
195
+
196
+
197
+ class TransformerBlock(nn.Module):
198
+ def __init__(self, config: ModelArgs) -> None:
199
+ super().__init__()
200
+ self.attention = Attention(config)
201
+ self.feed_forward = FeedForward(config)
202
+ self.ffn_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
203
+ self.attention_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
204
+
205
+ if config.has_cross_attention:
206
+ self.has_cross_attention = True
207
+ self.cross_attention = Attention(config, is_cross_attention=True)
208
+ self.cross_attention_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
209
+ else:
210
+ self.has_cross_attention = False
211
+
212
+ if config.uvit_skip_connection:
213
+ self.skip_in_linear = nn.Linear(config.dim * 2, config.dim)
214
+ self.uvit_skip_connection = True
215
+ else:
216
+ self.uvit_skip_connection = False
217
+
218
+ def forward(self,
219
+ x: Tensor,
220
+ c: Tensor,
221
+ input_pos: Tensor,
222
+ freqs_cis: Tensor,
223
+ mask: Tensor,
224
+ context: Optional[Tensor] = None,
225
+ context_freqs_cis: Optional[Tensor] = None,
226
+ cross_attention_mask: Optional[Tensor] = None,
227
+ skip_in_x: Optional[Tensor] = None,
228
+ ) -> Tensor:
229
+ if self.uvit_skip_connection and skip_in_x is not None:
230
+ x = self.skip_in_linear(torch.cat([x, skip_in_x], dim=-1))
231
+ h = x + self.attention(self.attention_norm(x, c), freqs_cis, mask, input_pos)
232
+ if self.has_cross_attention:
233
+ h = h + self.cross_attention(self.cross_attention_norm(h, c), freqs_cis, cross_attention_mask, input_pos, context, context_freqs_cis)
234
+ out = h + self.feed_forward(self.ffn_norm(h, c))
235
+ return out
236
+
237
+
238
+ class Attention(nn.Module):
239
+ def __init__(self, config: ModelArgs, is_cross_attention: bool = False):
240
+ super().__init__()
241
+ assert config.dim % config.n_head == 0
242
+
243
+ total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
244
+ # key, query, value projections for all heads, but in a batch
245
+ if is_cross_attention:
246
+ self.wq = nn.Linear(config.dim, config.n_head * config.head_dim, bias=False)
247
+ self.wkv = nn.Linear(config.context_dim, 2 * config.n_local_heads * config.head_dim, bias=False)
248
+ else:
249
+ self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
250
+ self.wo = nn.Linear(config.head_dim * config.n_head, config.dim, bias=False)
251
+ self.kv_cache = None
252
+
253
+ self.n_head = config.n_head
254
+ self.head_dim = config.head_dim
255
+ self.n_local_heads = config.n_local_heads
256
+ self.dim = config.dim
257
+ # self._register_load_state_dict_pre_hook(self.load_hook)
258
+
259
+ # def load_hook(self, state_dict, prefix, *args):
260
+ # if prefix + "wq.weight" in state_dict:
261
+ # wq = state_dict.pop(prefix + "wq.weight")
262
+ # wk = state_dict.pop(prefix + "wk.weight")
263
+ # wv = state_dict.pop(prefix + "wv.weight")
264
+ # state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
265
+
266
+ def forward(self,
267
+ x: Tensor,
268
+ freqs_cis: Tensor,
269
+ mask: Tensor,
270
+ input_pos: Optional[Tensor] = None,
271
+ context: Optional[Tensor] = None,
272
+ context_freqs_cis: Optional[Tensor] = None,
273
+ ) -> Tensor:
274
+ bsz, seqlen, _ = x.shape
275
+
276
+ kv_size = self.n_local_heads * self.head_dim
277
+ if context is None:
278
+ q, k, v = self.wqkv(x).split([kv_size, kv_size, kv_size], dim=-1)
279
+ context_seqlen = seqlen
280
+ else:
281
+ q = self.wq(x)
282
+ k, v = self.wkv(context).split([kv_size, kv_size], dim=-1)
283
+ context_seqlen = context.shape[1]
284
+
285
+ q = q.view(bsz, seqlen, self.n_head, self.head_dim)
286
+ k = k.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
287
+ v = v.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
288
+
289
+ q = apply_rotary_emb(q, freqs_cis)
290
+ k = apply_rotary_emb(k, context_freqs_cis if context_freqs_cis is not None else freqs_cis)
291
+
292
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
293
+
294
+ if self.kv_cache is not None:
295
+ k, v = self.kv_cache.update(input_pos, k, v)
296
+
297
+ k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
298
+ v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
299
+ y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
300
+
301
+ y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.head_dim * self.n_head)
302
+
303
+ y = self.wo(y)
304
+ return y
305
+
306
+
307
+ class FeedForward(nn.Module):
308
+ def __init__(self, config: ModelArgs) -> None:
309
+ super().__init__()
310
+ self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
311
+ self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
312
+ self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
313
+
314
+ def forward(self, x: Tensor) -> Tensor:
315
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
316
+
317
+
318
+ class RMSNorm(nn.Module):
319
+ def __init__(self, dim: int, eps: float = 1e-5):
320
+ super().__init__()
321
+ self.eps = eps
322
+ self.weight = nn.Parameter(torch.ones(dim))
323
+
324
+ def _norm(self, x):
325
+ return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
326
+
327
+ def forward(self, x: Tensor) -> Tensor:
328
+ output = self._norm(x.float()).type_as(x)
329
+ return output * self.weight
330
+
331
+
332
+ def precompute_freqs_cis(
333
+ seq_len: int, n_elem: int, base: int = 10000,
334
+ dtype: torch.dtype = torch.bfloat16
335
+ ) -> Tensor:
336
+ freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
337
+ t = torch.arange(seq_len, device=freqs.device)
338
+ freqs = torch.outer(t, freqs)
339
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
340
+ cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
341
+ return cache.to(dtype=dtype)
342
+
343
+
344
+ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
345
+ xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
346
+ freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
347
+ x_out2 = torch.stack(
348
+ [
349
+ xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
350
+ xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
351
+ ],
352
+ -1,
353
+ )
354
+
355
+ x_out2 = x_out2.flatten(3)
356
+ return x_out2.type_as(x)
modules/gpt_fast/quantize.py ADDED
@@ -0,0 +1,622 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ import time
7
+ from pathlib import Path
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from tokenizer import get_tokenizer
13
+
14
+ try:
15
+ from GPTQ import GenericGPTQRunner, InputRecorder
16
+ from eval import get_task_dict, evaluate, lm_eval
17
+ except:
18
+ pass
19
+
20
+ from model import Transformer
21
+
22
+ ##### Quantization Primitives ######
23
+
24
+ def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):
25
+ # assumes symmetric quantization
26
+ # assumes axis == 0
27
+ # assumes dense memory format
28
+ # TODO(future): relax ^ as needed
29
+
30
+ # default setup for affine quantization of activations
31
+ eps = torch.finfo(torch.float32).eps
32
+
33
+ # get min and max
34
+ min_val, max_val = torch.aminmax(x, dim=1)
35
+
36
+ # calculate scales and zero_points based on min and max
37
+ # reference: https://fburl.com/code/srbiybme
38
+ min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
39
+ max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
40
+ device = min_val_neg.device
41
+
42
+ # reference: https://fburl.com/code/4wll53rk
43
+ max_val_pos = torch.max(-min_val_neg, max_val_pos)
44
+ scales = max_val_pos / (float(quant_max - quant_min) / 2)
45
+ # ensure scales is the same dtype as the original tensor
46
+ scales = torch.clamp(scales, min=eps).to(x.dtype)
47
+ zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
48
+
49
+ # quantize based on qmin/qmax/scales/zp
50
+ # reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63
51
+ x_div = x / scales.unsqueeze(-1)
52
+ x_round = torch.round(x_div)
53
+ x_zp = x_round + zero_points.unsqueeze(-1)
54
+ quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype)
55
+
56
+ return quant, scales, zero_points
57
+
58
+ def get_group_qparams(w, n_bit=4, groupsize=128):
59
+ # needed for GPTQ with padding
60
+ if groupsize > w.shape[-1]:
61
+ groupsize = w.shape[-1]
62
+ assert groupsize > 1
63
+ assert w.shape[-1] % groupsize == 0
64
+ assert w.dim() == 2
65
+
66
+ to_quant = w.reshape(-1, groupsize)
67
+ assert torch.isnan(to_quant).sum() == 0
68
+
69
+ max_val = to_quant.amax(dim=1, keepdim=True)
70
+ min_val = to_quant.amin(dim=1, keepdim=True)
71
+ max_int = 2**n_bit - 1
72
+ scales = (max_val - min_val).clamp(min=1e-6) / max_int
73
+ zeros = min_val + scales * (2 ** (n_bit - 1))
74
+ return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to(
75
+ torch.bfloat16
76
+ ).reshape(w.shape[0], -1)
77
+
78
+
79
+ def pack_scales_and_zeros(scales, zeros):
80
+ assert scales.shape == zeros.shape
81
+ assert scales.dtype == torch.bfloat16
82
+ assert zeros.dtype == torch.bfloat16
83
+ return (
84
+ torch.cat(
85
+ [
86
+ scales.reshape(scales.size(0), scales.size(1), 1),
87
+ zeros.reshape(zeros.size(0), zeros.size(1), 1),
88
+ ],
89
+ 2,
90
+ )
91
+ .transpose(0, 1)
92
+ .contiguous()
93
+ )
94
+
95
+
96
+ def unpack_scales_and_zeros(scales_and_zeros):
97
+ assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2
98
+ assert scales_and_zeros.dtype == torch.float
99
+ return torch.split(scales_and_zeros.transpose(0, 1), 1, 2)
100
+
101
+
102
+ def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128):
103
+ assert groupsize > 1
104
+ # needed for GPTQ single column quantize
105
+ if groupsize > w.shape[-1] and scales.shape[-1] == 1:
106
+ groupsize = w.shape[-1]
107
+
108
+ assert w.shape[-1] % groupsize == 0
109
+ assert w.dim() == 2
110
+
111
+ to_quant = w.reshape(-1, groupsize)
112
+ assert torch.isnan(to_quant).sum() == 0
113
+
114
+ scales = scales.reshape(-1, 1)
115
+ zeros = zeros.reshape(-1, 1)
116
+ min_val = zeros - scales * (2 ** (n_bit - 1))
117
+ max_int = 2**n_bit - 1
118
+ min_int = 0
119
+ w_int32 = (
120
+ to_quant.sub(min_val)
121
+ .div(scales)
122
+ .round()
123
+ .clamp_(min_int, max_int)
124
+ .to(torch.int32)
125
+ .reshape_as(w)
126
+ )
127
+
128
+ return w_int32
129
+
130
+
131
+ def group_quantize_tensor(w, n_bit=4, groupsize=128):
132
+ scales, zeros = get_group_qparams(w, n_bit, groupsize)
133
+ w_int32 = group_quantize_tensor_from_qparams(w, scales, zeros, n_bit, groupsize)
134
+ scales_and_zeros = pack_scales_and_zeros(scales, zeros)
135
+ return w_int32, scales_and_zeros
136
+
137
+
138
+ def group_dequantize_tensor_from_qparams(
139
+ w_int32, scales, zeros, n_bit=4, groupsize=128
140
+ ):
141
+ assert groupsize > 1
142
+ # needed for GPTQ single column dequantize
143
+ if groupsize > w_int32.shape[-1] and scales.shape[-1] == 1:
144
+ groupsize = w_int32.shape[-1]
145
+ assert w_int32.shape[-1] % groupsize == 0
146
+ assert w_int32.dim() == 2
147
+
148
+ w_int32_grouped = w_int32.reshape(-1, groupsize)
149
+ scales = scales.reshape(-1, 1)
150
+ zeros = zeros.reshape(-1, 1)
151
+
152
+ w_dq = (
153
+ w_int32_grouped.sub(2 ** (n_bit - 1)).mul(scales).add(zeros).reshape_as(w_int32)
154
+ )
155
+ return w_dq
156
+
157
+
158
+ def group_dequantize_tensor(w_int32, scales_and_zeros, n_bit=4, groupsize=128):
159
+ scales, zeros = unpack_scales_and_zeros(scales_and_zeros)
160
+ return group_dequantize_tensor_from_qparams(
161
+ w_int32, scales, zeros, n_bit, groupsize
162
+ )
163
+
164
+ class QuantHandler:
165
+ def __init__(self, mod):
166
+ self.mod = mod
167
+
168
+ def create_quantized_state_dict(self) -> "StateDict":
169
+ pass
170
+
171
+ def convert_for_runtime(self) -> "nn.Module":
172
+ pass
173
+
174
+ class GPTQQuantHandler(QuantHandler):
175
+ """
176
+ This class implements a GPTQ QuantHandler that can be used to apply GPTQ to a model in concert with the GenericGPTQRunner class.
177
+ Unlike the base QuantHandler class, the user does not need to implement the create_quantized_state_dict, instead they have to reimplement
178
+ __init__ such that it defines the functions for the quantization mode. User is expected to reimplement convert_for_runtime.
179
+
180
+ The following functions (which must be defined in __init__) are used to define the quantization mode for both GPTQ and
181
+ create_quantized_state_dict. Here is a description of each function.
182
+
183
+ get_qparams_func:
184
+ A function that calculates the quantization qparams for an input tensor.
185
+ Args:
186
+ weight: A 2d weight tensor with non-integer dtype.
187
+ Returns:
188
+ qparams: it can have any format but will need to be handled by the other defined functions below.
189
+
190
+ quantize_func:
191
+ A function that applies quantization to an input tensor. It should be noted
192
+ that this function needs to be able to handle quantizing the entire weight tensor, a single group,
193
+ or a single column.
194
+ Args:
195
+ weight: A 2d weight tensor with non-integer dtype.
196
+ qparams: the output from get_qparams_func
197
+ Returns:
198
+ quantized_weight: A 2d quantized weight tensor (generally with an integer dtype)
199
+
200
+
201
+ dequantize_func:
202
+ A function that dequantizes an input quantized weight tensor. It should be noted
203
+ that this function needs to be able to handle dequantizing the entire weight tensor, a single group,
204
+ or a single column.
205
+ Args:
206
+ quantized_weight: A 2d quantized weight tensor (generally with an integer dtype)
207
+ qparams: the output from get_qparams_func
208
+ Returns:
209
+ weight: A 2d weight tensor with non-integer dtype.
210
+
211
+ combine_qparams_list_func:
212
+ A function that combines several qparams into one qparam.
213
+ Args:
214
+ qparams_list: a list of qparams objects, each obtained by calling get_qparams_func
215
+ on a single group from a weight tensor
216
+ Returns:
217
+ qparams: an object of the same format as the qparams above.
218
+
219
+ skip_layer_func:
220
+ A function that determines which linear layers should be skipped during GPTQ
221
+ Args:
222
+ weight: A 2d weight tensor with non-integer dtype.
223
+ Returns:
224
+ skip: boolean indicating whether layer should be skipped
225
+
226
+ make_names_and_values_dict_func:
227
+ A function that prepares the qparams and quantized_weight and creates a dictionary indicating how they
228
+ should be inserted into the state_dict. Generally any packing of the weight and qparams should be done here.
229
+ Args:
230
+ quantized_weight: A 2d quantized weight tensor (generally with an integer dtype)
231
+ qparams: the output from get_qparams_func
232
+ Returns:
233
+ names_and_values_dict: a dictionary mapping the name of the parameters of the quantized module to the
234
+ corresponding quantized weights and qparams.
235
+ """
236
+ def __init__(self):
237
+ assert self.mod is not None
238
+ assert self.get_qparams_func is not None
239
+ assert self.quantize_func is not None
240
+ assert self.dequantize_func is not None
241
+ assert self.combine_qparams_list_func is not None
242
+ assert self.make_names_and_values_dict_func is not None
243
+
244
+ @staticmethod
245
+ def get_inputs(model, tokenizer, calibration_tasks, calibration_limit, calibration_seq_length, pad_calibration_inputs) -> "MultiInput":
246
+ input_recorder = InputRecorder(
247
+ model,
248
+ tokenizer,
249
+ calibration_seq_length,
250
+ pad_calibration_inputs,
251
+ )
252
+
253
+ try:
254
+ lm_eval.tasks.initialize_tasks()
255
+ except:
256
+ pass
257
+ task_dict = get_task_dict(calibration_tasks)
258
+ print("Obtaining GPTQ calibration inputs on: ", calibration_tasks)
259
+
260
+ evaluate(
261
+ input_recorder,
262
+ task_dict,
263
+ limit=calibration_limit,
264
+ )
265
+ inputs = input_recorder.get_recorded_inputs()
266
+ assert inputs is not None, (
267
+ f"No inputs were collected, use a task other than {calibration_tasks}, "+
268
+ f"use option pad_calibration_inputs, or decrease calibration_sequence_length (currently "+
269
+ f"{calibration_seq_length})"
270
+ )
271
+ print(f"Obtained {len(inputs[0].values)} calibration samples")
272
+ return inputs
273
+
274
+ @torch.no_grad()
275
+ def create_quantized_state_dict(
276
+ self,
277
+ tokenizer,
278
+ blocksize,
279
+ percdamp,
280
+ groupsize,
281
+ calibration_tasks,
282
+ calibration_limit,
283
+ calibration_seq_length,
284
+ pad_calibration_inputs,
285
+ ) -> "StateDict":
286
+ inputs = GPTQQuantHandler.get_inputs(self.mod, tokenizer, calibration_tasks, calibration_limit, calibration_seq_length, pad_calibration_inputs)
287
+ print("Tracing model for GPTQ")
288
+ GPTQ_runner = GenericGPTQRunner(
289
+ self.mod,
290
+ inputs,
291
+ blocksize,
292
+ percdamp,
293
+ groupsize,
294
+ ).configure_quantization_mode(
295
+ self.get_qparams_func,
296
+ self.quantize_func,
297
+ self.dequantize_func,
298
+ self.combine_qparams_list_func,
299
+ self.make_names_and_values_dict_func,
300
+ self.skip_layer_func
301
+ )
302
+
303
+ print("Applying GPTQ to weights")
304
+ GPTQ_runner.run()
305
+ return GPTQ_runner.get_quantized_state_dict()
306
+
307
+ def convert_for_runtime(self) -> "nn.Module":
308
+ pass
309
+
310
+ ##### Weight-only int8 per-channel quantized code ######
311
+
312
+ def replace_linear_weight_only_int8_per_channel(module):
313
+ for name, child in module.named_children():
314
+ if isinstance(child, nn.Linear):
315
+ setattr(module, name, WeightOnlyInt8Linear(child.in_features, child.out_features))
316
+ else:
317
+ replace_linear_weight_only_int8_per_channel(child)
318
+
319
+ class WeightOnlyInt8QuantHandler:
320
+ def __init__(self, mod):
321
+ self.mod = mod
322
+
323
+ @torch.no_grad()
324
+ def create_quantized_state_dict(self):
325
+ cur_state_dict = self.mod.state_dict()
326
+ for fqn, mod in self.mod.named_modules():
327
+ if isinstance(mod, torch.nn.Linear):
328
+ int8_weight, scales, _ = dynamically_quantize_per_channel(mod.weight.float(), -128, 127, torch.int8)
329
+ cur_state_dict[f"{fqn}.weight"] = int8_weight
330
+ cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype)
331
+
332
+ return cur_state_dict
333
+
334
+ def convert_for_runtime(self):
335
+ replace_linear_weight_only_int8_per_channel(self.mod)
336
+ return self.mod
337
+
338
+
339
+ class WeightOnlyInt8Linear(torch.nn.Module):
340
+ __constants__ = ['in_features', 'out_features']
341
+ in_features: int
342
+ out_features: int
343
+ weight: torch.Tensor
344
+
345
+ def __init__(self, in_features: int, out_features: int, bias: bool = True,
346
+ device=None, dtype=None) -> None:
347
+ factory_kwargs = {'device': device, 'dtype': dtype}
348
+ super().__init__()
349
+ self.in_features = in_features
350
+ self.out_features = out_features
351
+ self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.int8))
352
+ self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16))
353
+
354
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
355
+ return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales
356
+
357
+ ##### weight only int4 per channel groupwise quantized code ######
358
+
359
+ def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles):
360
+ weight_int32, scales_and_zeros = group_quantize_tensor(
361
+ weight_bf16, n_bit=4, groupsize=groupsize
362
+ )
363
+ weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight_int32, inner_k_tiles)
364
+ return weight_int4pack, scales_and_zeros
365
+
366
+
367
+ def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize):
368
+ origin_x_size = x.size()
369
+ x = x.reshape(-1, origin_x_size[-1])
370
+ c = torch.ops.aten._weight_int4pack_mm(x, weight_int4pack, groupsize, scales_and_zeros)
371
+ new_shape = origin_x_size[:-1] + (out_features,)
372
+ c = c.reshape(new_shape)
373
+ return c
374
+
375
+
376
+ def _check_linear_int4_k(k, groupsize = 1, inner_k_tiles = 1):
377
+ return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0
378
+
379
+ def replace_linear_int4(module, groupsize, inner_k_tiles, padding):
380
+ for name, child in module.named_children():
381
+ if isinstance(child, nn.Linear):
382
+ if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles):
383
+ setattr(module, name, WeightOnlyInt4Linear(
384
+ child.in_features, child.out_features, bias=False,
385
+ groupsize=groupsize, inner_k_tiles=inner_k_tiles, padding=False,
386
+ ))
387
+ elif padding:
388
+ setattr(module, name, WeightOnlyInt4Linear(
389
+ child.in_features, child.out_features, bias=False,
390
+ groupsize=groupsize, inner_k_tiles=inner_k_tiles, padding=True,
391
+ ))
392
+ else:
393
+ replace_linear_int4(child, groupsize, inner_k_tiles, padding)
394
+
395
+
396
+ class WeightOnlyInt4QuantHandler:
397
+ def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
398
+ self.mod = mod
399
+ self.groupsize = groupsize
400
+ self.inner_k_tiles = inner_k_tiles
401
+ self.padding = padding
402
+ assert groupsize in [32, 64, 128, 256]
403
+ assert inner_k_tiles in [2, 4, 8]
404
+
405
+ @torch.no_grad()
406
+ def create_quantized_state_dict(self, use_cuda = True):
407
+ if use_cuda:
408
+ device="cuda"
409
+ else:
410
+ device="cpu"
411
+
412
+ cur_state_dict = self.mod.state_dict()
413
+ for fqn, mod in self.mod.named_modules():
414
+ if isinstance(mod, torch.nn.Linear):
415
+ assert not mod.bias
416
+ out_features = mod.out_features
417
+ in_features = mod.in_features
418
+ assert out_features % 8 == 0, "require out_features % 8 == 0"
419
+ print(f"linear: {fqn}, in={in_features}, out={out_features}")
420
+
421
+ weight = mod.weight.data
422
+ if not _check_linear_int4_k(in_features, self.groupsize, self.inner_k_tiles):
423
+ if self.padding:
424
+ from model import find_multiple
425
+ import torch.nn.functional as F
426
+ print(f"warning: {fqn} is padded to satisfy in_features % 1024 == 0")
427
+ padded_in_features = find_multiple(in_features, 1024)
428
+ weight = F.pad(weight, pad=(0, padded_in_features - in_features))
429
+ else:
430
+ print(f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " +
431
+ "and that groupsize and inner_k_tiles*16 evenly divide into it")
432
+ continue
433
+ weight_int4pack, scales_and_zeros = prepare_int4_weight_and_scales_and_zeros(
434
+ weight.to(torch.bfloat16).to(device=device), self.groupsize, self.inner_k_tiles
435
+ )
436
+ cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to('cpu')
437
+ cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to('cpu')
438
+
439
+ return cur_state_dict
440
+
441
+ def convert_for_runtime(self):
442
+ replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding)
443
+ return self.mod
444
+
445
+ class WeightOnlyInt4GPTQQuantHandler(GPTQQuantHandler):
446
+ def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
447
+ from model import find_multiple
448
+ self.mod = mod
449
+ self.groupsize = groupsize
450
+ self.inner_k_tiles = inner_k_tiles
451
+ self.padding = padding
452
+ self.get_qparams_func = lambda w: get_group_qparams(w, 4, groupsize)
453
+ self.quantize_func = lambda w, qparams: \
454
+ group_quantize_tensor_from_qparams(w, qparams[0], qparams[1], 4, groupsize)
455
+ self.dequantize_func = lambda q, qparams: \
456
+ group_dequantize_tensor_from_qparams(q, qparams[0], qparams[1], 4, groupsize).float()
457
+ self.combine_qparams_list_func = lambda qparams_list: \
458
+ [torch.cat(x, dim=1) for x in zip(*qparams_list)]
459
+ # skip unless padding=True or its correctly sized
460
+ self.skip_layer_func = lambda linear_weight: not (
461
+ _check_linear_int4_k(linear_weight.shape[-1], groupsize, inner_k_tiles) or padding
462
+ )
463
+ # we need to do the padding here, both for q and the qparams if necessary
464
+ def make_names_and_values_dict_func(q, qparams):
465
+ k = q.shape[1]
466
+ new_k = find_multiple(k, 1024)
467
+ # how much we need to pad the weight
468
+ delta_k = new_k - q.shape[1]
469
+ final_q = torch.ops.aten._convert_weight_to_int4pack(F.pad(q, pad=(0, delta_k)), inner_k_tiles)
470
+ scales_and_zeros = pack_scales_and_zeros(*qparams)
471
+ # how many new groups we need for padded weight
472
+ delta_groups = new_k // groupsize - scales_and_zeros.shape[0]
473
+ final_s_and_z = F.pad(scales_and_zeros, pad=(0,0,0,0,0, delta_groups), value=1)
474
+ return {"weight": final_q, "scales_and_zeros": final_s_and_z}
475
+ self.make_names_and_values_dict_func = make_names_and_values_dict_func
476
+ super().__init__()
477
+
478
+
479
+ def convert_for_runtime(self):
480
+ replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding)
481
+ return self.mod
482
+
483
+ class WeightOnlyInt4Linear(torch.nn.Module):
484
+ __constants__ = ['in_features', 'out_features']
485
+ in_features: int
486
+ out_features: int
487
+ weight: torch.Tensor
488
+
489
+ def __init__(
490
+ self, in_features: int, out_features: int,
491
+ bias=True, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8, padding: bool = True,
492
+ ) -> None:
493
+ super().__init__()
494
+ self.padding = padding
495
+ if padding:
496
+ from model import find_multiple
497
+ self.origin_in_features = in_features
498
+ in_features = find_multiple(in_features, 1024)
499
+
500
+ self.in_features = in_features
501
+ self.out_features = out_features
502
+ assert not bias, "require bias=False"
503
+ self.groupsize = groupsize
504
+ self.inner_k_tiles = inner_k_tiles
505
+
506
+ assert out_features % 8 == 0, "require out_features % 8 == 0"
507
+ assert in_features % (inner_k_tiles * 16) == 0, "require in_features % (innerKTiles * 16) == 0"
508
+ self.register_buffer(
509
+ "weight",
510
+ torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32)
511
+ )
512
+ self.register_buffer(
513
+ "scales_and_zeros",
514
+ torch.empty((in_features // groupsize, out_features, 2), dtype=torch.bfloat16)
515
+ )
516
+
517
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
518
+ input = input.to(torch.bfloat16)
519
+ if self.padding:
520
+ import torch.nn.functional as F
521
+ input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
522
+ return linear_forward_int4(
523
+ input,
524
+ self.weight, self.scales_and_zeros, self.out_features, self.groupsize
525
+ )
526
+
527
+
528
+ def quantize(
529
+ checkpoint_path: Path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"),
530
+ mode: str = 'int8',
531
+ # following arguments only available when setting int4 quantization.
532
+ groupsize: int = 128,
533
+ # following arguments only used for GPTQ
534
+ calibration_tasks: list = ["hellaswag"],
535
+ calibration_limit: int = 1000,
536
+ calibration_seq_length: int = 100,
537
+ pad_calibration_inputs: bool = False,
538
+ percdamp: float = .01,
539
+ blocksize: int = 128,
540
+ label: str = '',
541
+ ) -> None:
542
+ assert checkpoint_path.is_file(), checkpoint_path
543
+
544
+ device = 'cpu'
545
+ precision = torch.bfloat16
546
+
547
+ print("Loading model ...")
548
+ t0 = time.time()
549
+
550
+ with torch.device('meta'):
551
+ model = Transformer.from_name(checkpoint_path.parent.name)
552
+
553
+ checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
554
+ model.load_state_dict(checkpoint, assign=True)
555
+ model = model.to(dtype=precision, device=device)
556
+
557
+ if mode == 'int8':
558
+ print("Quantizing model weights for int8 weight-only symmetric per-channel quantization")
559
+ quant_handler = WeightOnlyInt8QuantHandler(model)
560
+ quantized_state_dict = quant_handler.create_quantized_state_dict()
561
+
562
+ dir_name = checkpoint_path.parent
563
+ base_name = checkpoint_path.name
564
+ new_base_name = base_name.replace('.pth', f'{label}int8.pth')
565
+
566
+ elif mode == 'int4':
567
+ print("Quantizing model weights for int4 weight-only affine per-channel groupwise quantization")
568
+ quant_handler = WeightOnlyInt4QuantHandler(model, groupsize)
569
+ quantized_state_dict = quant_handler.create_quantized_state_dict()
570
+
571
+ dir_name = checkpoint_path.parent
572
+ base_name = checkpoint_path.name
573
+ new_base_name = base_name.replace('.pth', f"{label}int4.g{groupsize}.pth")
574
+
575
+ elif mode == 'int4-gptq':
576
+ print("Quantizing model weights for int4 weight-only affine per-channel groupwise quantization using GPTQ...")
577
+ quant_handler = WeightOnlyInt4GPTQQuantHandler(model, groupsize)
578
+
579
+ tokenizer_path = checkpoint_path.parent / "tokenizer.model"
580
+ assert tokenizer_path.is_file(), str(tokenizer_path)
581
+ tokenizer = get_tokenizer(tokenizer_path, checkpoint_path)
582
+
583
+ quantized_state_dict = quant_handler.create_quantized_state_dict(
584
+ tokenizer,
585
+ blocksize,
586
+ percdamp,
587
+ groupsize,
588
+ calibration_tasks,
589
+ calibration_limit,
590
+ calibration_seq_length,
591
+ pad_calibration_inputs
592
+ )
593
+
594
+ dir_name = checkpoint_path.parent
595
+ base_name = checkpoint_path.name
596
+ new_base_name = base_name.replace('.pth', f"{label}int4-gptq.g{groupsize}.pth")
597
+ else:
598
+ raise ValueError(f"Invalid quantization mode {mode} needs to be one of [int8, int4, int4-gpptq]")
599
+
600
+ quantize_path = dir_name / new_base_name
601
+ print(f"Writing quantized weights to {quantize_path}")
602
+ quantize_path.unlink(missing_ok=True) # remove existing file if one already there
603
+ torch.save(quantized_state_dict, quantize_path)
604
+ print(f"Quantization complete took {time.time() - t0:.02f} seconds")
605
+ return
606
+
607
+ if __name__ == '__main__':
608
+ import argparse
609
+ parser = argparse.ArgumentParser(description='Quantize a model.')
610
+ parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Path to the model checkpoint to be quantized.')
611
+ parser.add_argument('--mode', '-q', type=str, default='int8', choices=['int8', 'int4', 'int4-gptq'], help='type of quantization to perform')
612
+ parser.add_argument('--groupsize', type=int, default=32, help='Group size for int4 quantization.')
613
+ parser.add_argument('--calibration_tasks', type=str, nargs='+', default=['wikitext'], help='tasks to do gptq calibration on, if doing gptq')
614
+ parser.add_argument('--calibration_limit', type=int, default=1000, help='number of samples to use for gptq calibration')
615
+ parser.add_argument('--calibration_seq_length', type=int, default=100, help='length of sequences to use for gptq calibration')
616
+ parser.add_argument('--pad_calibration_inputs', type=bool, default=False, help='pads sequences shorter than calibration_seq_length to that length, yielding more calibration inputs but running much slower')
617
+ parser.add_argument('--percdamp', type=float, default=.01, help='gptq percentage dampening')
618
+ parser.add_argument('--blocksize', type=int, default=128, help='blocksize for gptq')
619
+ parser.add_argument('--label', type=str, default='_', help='label to add to output filename')
620
+
621
+ args = parser.parse_args()
622
+ quantize(args.checkpoint_path, args.mode, args.groupsize, args.calibration_tasks, args.calibration_limit, args.calibration_seq_length, args.pad_calibration_inputs, args.percdamp, args.blocksize, args.label)