llmixer commited on
Commit
e47221b
1 Parent(s): 4c7d441

Added test inference from exllamav2, added gguf-py from llama.cpp

Browse files
gguf/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .constants import *
2
+ from .lazy import *
3
+ from .gguf_reader import *
4
+ from .gguf_writer import *
5
+ from .quants import *
6
+ from .tensor_mapping import *
7
+ from .vocab import *
8
+ from .utility import *
9
+ from .metadata import *
gguf/constants.py ADDED
@@ -0,0 +1,1609 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from enum import Enum, IntEnum, auto
4
+ from typing import Any
5
+
6
+ #
7
+ # constants
8
+ #
9
+
10
+ GGUF_MAGIC = 0x46554747 # "GGUF"
11
+ GGUF_VERSION = 3
12
+ GGUF_DEFAULT_ALIGNMENT = 32
13
+ GGML_QUANT_VERSION = 2 # GGML_QNT_VERSION from ggml.h
14
+
15
+ #
16
+ # metadata keys
17
+ #
18
+
19
+
20
+ class Keys:
21
+ class General:
22
+ TYPE = "general.type"
23
+ ARCHITECTURE = "general.architecture"
24
+ QUANTIZATION_VERSION = "general.quantization_version"
25
+ ALIGNMENT = "general.alignment"
26
+ FILE_TYPE = "general.file_type"
27
+
28
+ # Authorship Metadata
29
+ NAME = "general.name"
30
+ AUTHOR = "general.author"
31
+ VERSION = "general.version"
32
+ ORGANIZATION = "general.organization"
33
+
34
+ FINETUNE = "general.finetune"
35
+ BASENAME = "general.basename"
36
+
37
+ DESCRIPTION = "general.description"
38
+ QUANTIZED_BY = "general.quantized_by"
39
+
40
+ SIZE_LABEL = "general.size_label"
41
+
42
+ # Licensing details
43
+ LICENSE = "general.license"
44
+ LICENSE_NAME = "general.license.name"
45
+ LICENSE_LINK = "general.license.link"
46
+
47
+ # Typically represents the converted GGUF repo (Unless native)
48
+ URL = "general.url" # Model Website/Paper
49
+ DOI = "general.doi"
50
+ UUID = "general.uuid"
51
+ REPO_URL = "general.repo_url" # Model Source Repository (git/svn/etc...)
52
+
53
+ # Model Source during conversion
54
+ SOURCE_URL = "general.source.url" # Model Website/Paper
55
+ SOURCE_DOI = "general.source.doi"
56
+ SOURCE_UUID = "general.source.uuid"
57
+ SOURCE_REPO_URL = "general.source.repo_url" # Model Source Repository (git/svn/etc...)
58
+
59
+ # Base Model Source. There can be more than one source if it's a merged
60
+ # model like with 'Mistral-7B-Merge-14-v0.1'. This will assist in
61
+ # tracing linage of models as it is finetuned or merged over time.
62
+ BASE_MODEL_COUNT = "general.base_model.count"
63
+ BASE_MODEL_NAME = "general.base_model.{id}.name"
64
+ BASE_MODEL_AUTHOR = "general.base_model.{id}.author"
65
+ BASE_MODEL_VERSION = "general.base_model.{id}.version"
66
+ BASE_MODEL_ORGANIZATION = "general.base_model.{id}.organization"
67
+ BASE_MODEL_URL = "general.base_model.{id}.url" # Model Website/Paper
68
+ BASE_MODEL_DOI = "general.base_model.{id}.doi"
69
+ BASE_MODEL_UUID = "general.base_model.{id}.uuid"
70
+ BASE_MODEL_REPO_URL = "general.base_model.{id}.repo_url" # Model Source Repository (git/svn/etc...)
71
+
72
+ # Array based KV stores
73
+ TAGS = "general.tags"
74
+ LANGUAGES = "general.languages"
75
+ DATASETS = "general.datasets"
76
+
77
+ class LLM:
78
+ VOCAB_SIZE = "{arch}.vocab_size"
79
+ CONTEXT_LENGTH = "{arch}.context_length"
80
+ EMBEDDING_LENGTH = "{arch}.embedding_length"
81
+ BLOCK_COUNT = "{arch}.block_count"
82
+ LEADING_DENSE_BLOCK_COUNT = "{arch}.leading_dense_block_count"
83
+ FEED_FORWARD_LENGTH = "{arch}.feed_forward_length"
84
+ EXPERT_FEED_FORWARD_LENGTH = "{arch}.expert_feed_forward_length"
85
+ EXPERT_SHARED_FEED_FORWARD_LENGTH = "{arch}.expert_shared_feed_forward_length"
86
+ USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual"
87
+ TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout"
88
+ EXPERT_COUNT = "{arch}.expert_count"
89
+ EXPERT_USED_COUNT = "{arch}.expert_used_count"
90
+ EXPERT_SHARED_COUNT = "{arch}.expert_shared_count"
91
+ EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale"
92
+ POOLING_TYPE = "{arch}.pooling_type"
93
+ LOGIT_SCALE = "{arch}.logit_scale"
94
+ DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id"
95
+ ATTN_LOGIT_SOFTCAPPING = "{arch}.attn_logit_softcapping"
96
+ FINAL_LOGIT_SOFTCAPPING = "{arch}.final_logit_softcapping"
97
+ SWIN_NORM = "{arch}.swin_norm"
98
+ RESCALE_EVERY_N_LAYERS = "{arch}.rescale_every_n_layers"
99
+ TIME_MIX_EXTRA_DIM = "{arch}.time_mix_extra_dim"
100
+ TIME_DECAY_EXTRA_DIM = "{arch}.time_decay_extra_dim"
101
+ RESIDUAL_SCALE = "{arch}.residual_scale"
102
+ EMBEDDING_SCALE = "{arch}.embedding_scale"
103
+
104
+ class Attention:
105
+ HEAD_COUNT = "{arch}.attention.head_count"
106
+ HEAD_COUNT_KV = "{arch}.attention.head_count_kv"
107
+ MAX_ALIBI_BIAS = "{arch}.attention.max_alibi_bias"
108
+ CLAMP_KQV = "{arch}.attention.clamp_kqv"
109
+ KEY_LENGTH = "{arch}.attention.key_length"
110
+ VALUE_LENGTH = "{arch}.attention.value_length"
111
+ LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon"
112
+ LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"
113
+ CAUSAL = "{arch}.attention.causal"
114
+ Q_LORA_RANK = "{arch}.attention.q_lora_rank"
115
+ KV_LORA_RANK = "{arch}.attention.kv_lora_rank"
116
+ REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count"
117
+ SLIDING_WINDOW = "{arch}.attention.sliding_window"
118
+ SCALE = "{arch}.attention.scale"
119
+
120
+ class Rope:
121
+ DIMENSION_COUNT = "{arch}.rope.dimension_count"
122
+ FREQ_BASE = "{arch}.rope.freq_base"
123
+ SCALING_TYPE = "{arch}.rope.scaling.type"
124
+ SCALING_FACTOR = "{arch}.rope.scaling.factor"
125
+ SCALING_ATTN_FACTOR = "{arch}.rope.scaling.attn_factor"
126
+ SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length"
127
+ SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"
128
+ SCALING_YARN_LOG_MUL = "{arch}.rope.scaling.yarn_log_multiplier"
129
+
130
+ class Split:
131
+ LLM_KV_SPLIT_NO = "split.no"
132
+ LLM_KV_SPLIT_COUNT = "split.count"
133
+ LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count"
134
+
135
+ class SSM:
136
+ CONV_KERNEL = "{arch}.ssm.conv_kernel"
137
+ INNER_SIZE = "{arch}.ssm.inner_size"
138
+ STATE_SIZE = "{arch}.ssm.state_size"
139
+ TIME_STEP_RANK = "{arch}.ssm.time_step_rank"
140
+ DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms"
141
+
142
+ class WKV:
143
+ HEAD_SIZE = "{arch}.wkv.head_size"
144
+
145
+ class Tokenizer:
146
+ MODEL = "tokenizer.ggml.model"
147
+ PRE = "tokenizer.ggml.pre"
148
+ LIST = "tokenizer.ggml.tokens"
149
+ TOKEN_TYPE = "tokenizer.ggml.token_type"
150
+ TOKEN_TYPE_COUNT = "tokenizer.ggml.token_type_count" # for BERT-style token types
151
+ SCORES = "tokenizer.ggml.scores"
152
+ MERGES = "tokenizer.ggml.merges"
153
+ BOS_ID = "tokenizer.ggml.bos_token_id"
154
+ EOS_ID = "tokenizer.ggml.eos_token_id"
155
+ EOT_ID = "tokenizer.ggml.eot_token_id"
156
+ EOM_ID = "tokenizer.ggml.eom_token_id"
157
+ UNK_ID = "tokenizer.ggml.unknown_token_id"
158
+ SEP_ID = "tokenizer.ggml.seperator_token_id"
159
+ PAD_ID = "tokenizer.ggml.padding_token_id"
160
+ CLS_ID = "tokenizer.ggml.cls_token_id"
161
+ MASK_ID = "tokenizer.ggml.mask_token_id"
162
+ ADD_BOS = "tokenizer.ggml.add_bos_token"
163
+ ADD_EOS = "tokenizer.ggml.add_eos_token"
164
+ ADD_PREFIX = "tokenizer.ggml.add_space_prefix"
165
+ REMOVE_EXTRA_WS = "tokenizer.ggml.remove_extra_whitespaces"
166
+ PRECOMPILED_CHARSMAP = "tokenizer.ggml.precompiled_charsmap"
167
+ HF_JSON = "tokenizer.huggingface.json"
168
+ RWKV = "tokenizer.rwkv.world"
169
+ CHAT_TEMPLATE = "tokenizer.chat_template"
170
+ CHAT_TEMPLATE_N = "tokenizer.chat_template.{name}"
171
+ CHAT_TEMPLATES = "tokenizer.chat_templates"
172
+ # FIM/Infill special tokens constants
173
+ FIM_PRE_ID = "tokenizer.ggml.fim_pre_token_id"
174
+ FIM_SUF_ID = "tokenizer.ggml.fim_suf_token_id"
175
+ FIM_MID_ID = "tokenizer.ggml.fim_mid_token_id"
176
+ FIM_PAD_ID = "tokenizer.ggml.fim_pad_token_id"
177
+ FIM_REP_ID = "tokenizer.ggml.fim_rep_token_id"
178
+ FIM_SEP_ID = "tokenizer.ggml.fim_sep_token_id"
179
+ # deprecated:
180
+ PREFIX_ID = "tokenizer.ggml.prefix_token_id"
181
+ SUFFIX_ID = "tokenizer.ggml.suffix_token_id"
182
+ MIDDLE_ID = "tokenizer.ggml.middle_token_id"
183
+
184
+ class Adapter:
185
+ TYPE = "adapter.type"
186
+ LORA_ALPHA = "adapter.lora.alpha"
187
+
188
+ #
189
+ # recommended mapping of model tensor names for storage in gguf
190
+ #
191
+
192
+
193
+ class GGUFType:
194
+ MODEL = "model"
195
+ ADAPTER = "adapter"
196
+
197
+
198
+ class MODEL_ARCH(IntEnum):
199
+ LLAMA = auto()
200
+ FALCON = auto()
201
+ BAICHUAN = auto()
202
+ GROK = auto()
203
+ GPT2 = auto()
204
+ GPTJ = auto()
205
+ GPTNEOX = auto()
206
+ MPT = auto()
207
+ STARCODER = auto()
208
+ REFACT = auto()
209
+ BERT = auto()
210
+ NOMIC_BERT = auto()
211
+ JINA_BERT_V2 = auto()
212
+ BLOOM = auto()
213
+ STABLELM = auto()
214
+ QWEN = auto()
215
+ QWEN2 = auto()
216
+ QWEN2MOE = auto()
217
+ PHI2 = auto()
218
+ PHI3 = auto()
219
+ PLAMO = auto()
220
+ CODESHELL = auto()
221
+ ORION = auto()
222
+ INTERNLM2 = auto()
223
+ MINICPM = auto()
224
+ MINICPM3 = auto()
225
+ GEMMA = auto()
226
+ GEMMA2 = auto()
227
+ STARCODER2 = auto()
228
+ RWKV6 = auto()
229
+ MAMBA = auto()
230
+ XVERSE = auto()
231
+ COMMAND_R = auto()
232
+ DBRX = auto()
233
+ OLMO = auto()
234
+ OLMOE = auto()
235
+ OPENELM = auto()
236
+ ARCTIC = auto()
237
+ DEEPSEEK2 = auto()
238
+ CHATGLM = auto()
239
+ BITNET = auto()
240
+ T5 = auto()
241
+ T5ENCODER = auto()
242
+ JAIS = auto()
243
+ NEMOTRON = auto()
244
+ EXAONE = auto()
245
+ GRANITE = auto()
246
+ GRANITE_MOE = auto()
247
+ CHAMELEON = auto()
248
+
249
+
250
+ class MODEL_TENSOR(IntEnum):
251
+ TOKEN_EMBD = auto()
252
+ TOKEN_EMBD_NORM = auto()
253
+ TOKEN_TYPES = auto()
254
+ POS_EMBD = auto()
255
+ OUTPUT = auto()
256
+ OUTPUT_NORM = auto()
257
+ ROPE_FREQS = auto()
258
+ ROPE_FACTORS_LONG = auto()
259
+ ROPE_FACTORS_SHORT = auto()
260
+ ATTN_Q = auto()
261
+ ATTN_K = auto()
262
+ ATTN_V = auto()
263
+ ATTN_QKV = auto()
264
+ ATTN_OUT = auto()
265
+ ATTN_NORM = auto()
266
+ ATTN_NORM_2 = auto()
267
+ ATTN_OUT_NORM = auto()
268
+ ATTN_POST_NORM = auto()
269
+ ATTN_ROT_EMBD = auto()
270
+ FFN_GATE_INP = auto()
271
+ FFN_GATE_INP_SHEXP = auto()
272
+ FFN_NORM = auto()
273
+ FFN_PRE_NORM = auto()
274
+ FFN_POST_NORM = auto()
275
+ FFN_GATE = auto()
276
+ FFN_DOWN = auto()
277
+ FFN_UP = auto()
278
+ FFN_ACT = auto()
279
+ FFN_NORM_EXP = auto()
280
+ FFN_GATE_EXP = auto()
281
+ FFN_DOWN_EXP = auto()
282
+ FFN_UP_EXP = auto()
283
+ FFN_GATE_SHEXP = auto()
284
+ FFN_DOWN_SHEXP = auto()
285
+ FFN_UP_SHEXP = auto()
286
+ ATTN_Q_NORM = auto()
287
+ ATTN_K_NORM = auto()
288
+ LAYER_OUT_NORM = auto()
289
+ SSM_IN = auto()
290
+ SSM_CONV1D = auto()
291
+ SSM_X = auto()
292
+ SSM_DT = auto()
293
+ SSM_A = auto()
294
+ SSM_D = auto()
295
+ SSM_OUT = auto()
296
+ TIME_MIX_W1 = auto()
297
+ TIME_MIX_W2 = auto()
298
+ TIME_MIX_LERP_X = auto()
299
+ TIME_MIX_LERP_K = auto()
300
+ TIME_MIX_LERP_V = auto()
301
+ TIME_MIX_LERP_R = auto()
302
+ TIME_MIX_LERP_G = auto()
303
+ TIME_MIX_LERP_W = auto()
304
+ TIME_MIX_FIRST = auto()
305
+ TIME_MIX_DECAY = auto()
306
+ TIME_MIX_DECAY_W1 = auto()
307
+ TIME_MIX_DECAY_W2 = auto()
308
+ TIME_MIX_KEY = auto()
309
+ TIME_MIX_VALUE = auto()
310
+ TIME_MIX_RECEPTANCE = auto()
311
+ TIME_MIX_GATE = auto()
312
+ TIME_MIX_LN = auto()
313
+ TIME_MIX_OUTPUT = auto()
314
+ CHANNEL_MIX_LERP_K = auto()
315
+ CHANNEL_MIX_LERP_R = auto()
316
+ CHANNEL_MIX_KEY = auto()
317
+ CHANNEL_MIX_RECEPTANCE = auto()
318
+ CHANNEL_MIX_VALUE = auto()
319
+ ATTN_Q_A = auto()
320
+ ATTN_Q_B = auto()
321
+ ATTN_KV_A_MQA = auto()
322
+ ATTN_KV_B = auto()
323
+ ATTN_Q_A_NORM = auto()
324
+ ATTN_KV_A_NORM = auto()
325
+ FFN_SUB_NORM = auto()
326
+ ATTN_SUB_NORM = auto()
327
+ DEC_ATTN_NORM = auto()
328
+ DEC_ATTN_Q = auto()
329
+ DEC_ATTN_K = auto()
330
+ DEC_ATTN_V = auto()
331
+ DEC_ATTN_OUT = auto()
332
+ DEC_ATTN_REL_B = auto()
333
+ DEC_CROSS_ATTN_NORM = auto()
334
+ DEC_CROSS_ATTN_Q = auto()
335
+ DEC_CROSS_ATTN_K = auto()
336
+ DEC_CROSS_ATTN_V = auto()
337
+ DEC_CROSS_ATTN_OUT = auto()
338
+ DEC_CROSS_ATTN_REL_B = auto()
339
+ DEC_FFN_NORM = auto()
340
+ DEC_FFN_GATE = auto()
341
+ DEC_FFN_DOWN = auto()
342
+ DEC_FFN_UP = auto()
343
+ DEC_OUTPUT_NORM = auto()
344
+ ENC_ATTN_NORM = auto()
345
+ ENC_ATTN_Q = auto()
346
+ ENC_ATTN_K = auto()
347
+ ENC_ATTN_V = auto()
348
+ ENC_ATTN_OUT = auto()
349
+ ENC_ATTN_REL_B = auto()
350
+ ENC_FFN_NORM = auto()
351
+ ENC_FFN_GATE = auto()
352
+ ENC_FFN_DOWN = auto()
353
+ ENC_FFN_UP = auto()
354
+ ENC_OUTPUT_NORM = auto()
355
+ CLS = auto() # classifier
356
+ CLS_OUT = auto() # classifier output projection
357
+
358
+
359
+ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
360
+ MODEL_ARCH.LLAMA: "llama",
361
+ MODEL_ARCH.FALCON: "falcon",
362
+ MODEL_ARCH.BAICHUAN: "baichuan",
363
+ MODEL_ARCH.GROK: "grok",
364
+ MODEL_ARCH.GPT2: "gpt2",
365
+ MODEL_ARCH.GPTJ: "gptj",
366
+ MODEL_ARCH.GPTNEOX: "gptneox",
367
+ MODEL_ARCH.MPT: "mpt",
368
+ MODEL_ARCH.STARCODER: "starcoder",
369
+ MODEL_ARCH.REFACT: "refact",
370
+ MODEL_ARCH.BERT: "bert",
371
+ MODEL_ARCH.NOMIC_BERT: "nomic-bert",
372
+ MODEL_ARCH.JINA_BERT_V2: "jina-bert-v2",
373
+ MODEL_ARCH.BLOOM: "bloom",
374
+ MODEL_ARCH.STABLELM: "stablelm",
375
+ MODEL_ARCH.QWEN: "qwen",
376
+ MODEL_ARCH.QWEN2: "qwen2",
377
+ MODEL_ARCH.QWEN2MOE: "qwen2moe",
378
+ MODEL_ARCH.PHI2: "phi2",
379
+ MODEL_ARCH.PHI3: "phi3",
380
+ MODEL_ARCH.PLAMO: "plamo",
381
+ MODEL_ARCH.CODESHELL: "codeshell",
382
+ MODEL_ARCH.ORION: "orion",
383
+ MODEL_ARCH.INTERNLM2: "internlm2",
384
+ MODEL_ARCH.MINICPM: "minicpm",
385
+ MODEL_ARCH.MINICPM3: "minicpm3",
386
+ MODEL_ARCH.GEMMA: "gemma",
387
+ MODEL_ARCH.GEMMA2: "gemma2",
388
+ MODEL_ARCH.STARCODER2: "starcoder2",
389
+ MODEL_ARCH.RWKV6: "rwkv6",
390
+ MODEL_ARCH.MAMBA: "mamba",
391
+ MODEL_ARCH.XVERSE: "xverse",
392
+ MODEL_ARCH.COMMAND_R: "command-r",
393
+ MODEL_ARCH.DBRX: "dbrx",
394
+ MODEL_ARCH.OLMO: "olmo",
395
+ MODEL_ARCH.OLMOE: "olmoe",
396
+ MODEL_ARCH.OPENELM: "openelm",
397
+ MODEL_ARCH.ARCTIC: "arctic",
398
+ MODEL_ARCH.DEEPSEEK2: "deepseek2",
399
+ MODEL_ARCH.CHATGLM: "chatglm",
400
+ MODEL_ARCH.BITNET: "bitnet",
401
+ MODEL_ARCH.T5: "t5",
402
+ MODEL_ARCH.T5ENCODER: "t5encoder",
403
+ MODEL_ARCH.JAIS: "jais",
404
+ MODEL_ARCH.NEMOTRON: "nemotron",
405
+ MODEL_ARCH.EXAONE: "exaone",
406
+ MODEL_ARCH.GRANITE: "granite",
407
+ MODEL_ARCH.GRANITE_MOE: "granitemoe",
408
+ MODEL_ARCH.CHAMELEON: "chameleon",
409
+ }
410
+
411
+ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
412
+ MODEL_TENSOR.TOKEN_EMBD: "token_embd",
413
+ MODEL_TENSOR.TOKEN_EMBD_NORM: "token_embd_norm",
414
+ MODEL_TENSOR.TOKEN_TYPES: "token_types",
415
+ MODEL_TENSOR.POS_EMBD: "position_embd",
416
+ MODEL_TENSOR.OUTPUT_NORM: "output_norm",
417
+ MODEL_TENSOR.OUTPUT: "output",
418
+ MODEL_TENSOR.ROPE_FREQS: "rope_freqs",
419
+ MODEL_TENSOR.ROPE_FACTORS_LONG: "rope_factors_long",
420
+ MODEL_TENSOR.ROPE_FACTORS_SHORT: "rope_factors_short",
421
+ MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm",
422
+ MODEL_TENSOR.ATTN_NORM_2: "blk.{bid}.attn_norm_2",
423
+ MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn_qkv",
424
+ MODEL_TENSOR.ATTN_Q: "blk.{bid}.attn_q",
425
+ MODEL_TENSOR.ATTN_K: "blk.{bid}.attn_k",
426
+ MODEL_TENSOR.ATTN_V: "blk.{bid}.attn_v",
427
+ MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output",
428
+ MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd",
429
+ MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm",
430
+ MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm",
431
+ MODEL_TENSOR.ATTN_OUT_NORM: "blk.{bid}.attn_output_norm",
432
+ MODEL_TENSOR.ATTN_POST_NORM: "blk.{bid}.post_attention_norm",
433
+ MODEL_TENSOR.FFN_GATE_INP: "blk.{bid}.ffn_gate_inp",
434
+ MODEL_TENSOR.FFN_GATE_INP_SHEXP: "blk.{bid}.ffn_gate_inp_shexp",
435
+ MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm",
436
+ MODEL_TENSOR.FFN_PRE_NORM: "blk.{bid}.ffn_norm",
437
+ MODEL_TENSOR.FFN_POST_NORM: "blk.{bid}.post_ffw_norm",
438
+ MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate",
439
+ MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
440
+ MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
441
+ MODEL_TENSOR.FFN_GATE_SHEXP: "blk.{bid}.ffn_gate_shexp",
442
+ MODEL_TENSOR.FFN_DOWN_SHEXP: "blk.{bid}.ffn_down_shexp",
443
+ MODEL_TENSOR.FFN_UP_SHEXP: "blk.{bid}.ffn_up_shexp",
444
+ MODEL_TENSOR.FFN_ACT: "blk.{bid}.ffn",
445
+ MODEL_TENSOR.FFN_NORM_EXP: "blk.{bid}.ffn_norm_exps",
446
+ MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps",
447
+ MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down_exps",
448
+ MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps",
449
+ MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm",
450
+ MODEL_TENSOR.SSM_IN: "blk.{bid}.ssm_in",
451
+ MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d",
452
+ MODEL_TENSOR.SSM_X: "blk.{bid}.ssm_x",
453
+ MODEL_TENSOR.SSM_DT: "blk.{bid}.ssm_dt",
454
+ MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a",
455
+ MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
456
+ MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
457
+ MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1",
458
+ MODEL_TENSOR.TIME_MIX_W2: "blk.{bid}.time_mix_w2",
459
+ MODEL_TENSOR.TIME_MIX_LERP_X: "blk.{bid}.time_mix_lerp_x",
460
+ MODEL_TENSOR.TIME_MIX_LERP_K: "blk.{bid}.time_mix_lerp_k",
461
+ MODEL_TENSOR.TIME_MIX_LERP_V: "blk.{bid}.time_mix_lerp_v",
462
+ MODEL_TENSOR.TIME_MIX_LERP_R: "blk.{bid}.time_mix_lerp_r",
463
+ MODEL_TENSOR.TIME_MIX_LERP_G: "blk.{bid}.time_mix_lerp_g",
464
+ MODEL_TENSOR.TIME_MIX_LERP_W: "blk.{bid}.time_mix_lerp_w",
465
+ MODEL_TENSOR.TIME_MIX_FIRST: "blk.{bid}.time_mix_first",
466
+ MODEL_TENSOR.TIME_MIX_DECAY: "blk.{bid}.time_mix_decay",
467
+ MODEL_TENSOR.TIME_MIX_DECAY_W1: "blk.{bid}.time_mix_decay_w1",
468
+ MODEL_TENSOR.TIME_MIX_DECAY_W2: "blk.{bid}.time_mix_decay_w2",
469
+ MODEL_TENSOR.TIME_MIX_KEY: "blk.{bid}.time_mix_key",
470
+ MODEL_TENSOR.TIME_MIX_VALUE: "blk.{bid}.time_mix_value",
471
+ MODEL_TENSOR.TIME_MIX_RECEPTANCE: "blk.{bid}.time_mix_receptance",
472
+ MODEL_TENSOR.TIME_MIX_GATE: "blk.{bid}.time_mix_gate",
473
+ MODEL_TENSOR.TIME_MIX_LN: "blk.{bid}.time_mix_ln",
474
+ MODEL_TENSOR.TIME_MIX_OUTPUT: "blk.{bid}.time_mix_output",
475
+ MODEL_TENSOR.CHANNEL_MIX_LERP_K: "blk.{bid}.channel_mix_lerp_k",
476
+ MODEL_TENSOR.CHANNEL_MIX_LERP_R: "blk.{bid}.channel_mix_lerp_r",
477
+ MODEL_TENSOR.CHANNEL_MIX_KEY: "blk.{bid}.channel_mix_key",
478
+ MODEL_TENSOR.CHANNEL_MIX_RECEPTANCE: "blk.{bid}.channel_mix_receptance",
479
+ MODEL_TENSOR.CHANNEL_MIX_VALUE: "blk.{bid}.channel_mix_value",
480
+ MODEL_TENSOR.ATTN_Q_A: "blk.{bid}.attn_q_a",
481
+ MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b",
482
+ MODEL_TENSOR.ATTN_KV_A_MQA: "blk.{bid}.attn_kv_a_mqa",
483
+ MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b",
484
+ MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm",
485
+ MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm",
486
+ MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm",
487
+ MODEL_TENSOR.FFN_SUB_NORM: "blk.{bid}.ffn_sub_norm",
488
+ MODEL_TENSOR.DEC_ATTN_NORM: "dec.blk.{bid}.attn_norm",
489
+ MODEL_TENSOR.DEC_ATTN_Q: "dec.blk.{bid}.attn_q",
490
+ MODEL_TENSOR.DEC_ATTN_K: "dec.blk.{bid}.attn_k",
491
+ MODEL_TENSOR.DEC_ATTN_V: "dec.blk.{bid}.attn_v",
492
+ MODEL_TENSOR.DEC_ATTN_OUT: "dec.blk.{bid}.attn_o",
493
+ MODEL_TENSOR.DEC_ATTN_REL_B: "dec.blk.{bid}.attn_rel_b",
494
+ MODEL_TENSOR.DEC_CROSS_ATTN_NORM: "dec.blk.{bid}.cross_attn_norm",
495
+ MODEL_TENSOR.DEC_CROSS_ATTN_Q: "dec.blk.{bid}.cross_attn_q",
496
+ MODEL_TENSOR.DEC_CROSS_ATTN_K: "dec.blk.{bid}.cross_attn_k",
497
+ MODEL_TENSOR.DEC_CROSS_ATTN_V: "dec.blk.{bid}.cross_attn_v",
498
+ MODEL_TENSOR.DEC_CROSS_ATTN_OUT: "dec.blk.{bid}.cross_attn_o",
499
+ MODEL_TENSOR.DEC_CROSS_ATTN_REL_B: "dec.blk.{bid}.cross_attn_rel_b",
500
+ MODEL_TENSOR.DEC_FFN_NORM: "dec.blk.{bid}.ffn_norm",
501
+ MODEL_TENSOR.DEC_FFN_GATE: "dec.blk.{bid}.ffn_gate",
502
+ MODEL_TENSOR.DEC_FFN_DOWN: "dec.blk.{bid}.ffn_down",
503
+ MODEL_TENSOR.DEC_FFN_UP: "dec.blk.{bid}.ffn_up",
504
+ MODEL_TENSOR.DEC_OUTPUT_NORM: "dec.output_norm",
505
+ MODEL_TENSOR.ENC_ATTN_NORM: "enc.blk.{bid}.attn_norm",
506
+ MODEL_TENSOR.ENC_ATTN_Q: "enc.blk.{bid}.attn_q",
507
+ MODEL_TENSOR.ENC_ATTN_K: "enc.blk.{bid}.attn_k",
508
+ MODEL_TENSOR.ENC_ATTN_V: "enc.blk.{bid}.attn_v",
509
+ MODEL_TENSOR.ENC_ATTN_OUT: "enc.blk.{bid}.attn_o",
510
+ MODEL_TENSOR.ENC_ATTN_REL_B: "enc.blk.{bid}.attn_rel_b",
511
+ MODEL_TENSOR.ENC_FFN_NORM: "enc.blk.{bid}.ffn_norm",
512
+ MODEL_TENSOR.ENC_FFN_GATE: "enc.blk.{bid}.ffn_gate",
513
+ MODEL_TENSOR.ENC_FFN_DOWN: "enc.blk.{bid}.ffn_down",
514
+ MODEL_TENSOR.ENC_FFN_UP: "enc.blk.{bid}.ffn_up",
515
+ MODEL_TENSOR.ENC_OUTPUT_NORM: "enc.output_norm",
516
+ MODEL_TENSOR.CLS: "cls",
517
+ MODEL_TENSOR.CLS_OUT: "cls.output",
518
+ }
519
+
520
+ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
521
+ MODEL_ARCH.LLAMA: [
522
+ MODEL_TENSOR.TOKEN_EMBD,
523
+ MODEL_TENSOR.OUTPUT_NORM,
524
+ MODEL_TENSOR.OUTPUT,
525
+ MODEL_TENSOR.ROPE_FREQS,
526
+ MODEL_TENSOR.ATTN_NORM,
527
+ MODEL_TENSOR.ATTN_Q,
528
+ MODEL_TENSOR.ATTN_K,
529
+ MODEL_TENSOR.ATTN_V,
530
+ MODEL_TENSOR.ATTN_OUT,
531
+ MODEL_TENSOR.ATTN_ROT_EMBD,
532
+ MODEL_TENSOR.FFN_GATE_INP,
533
+ MODEL_TENSOR.FFN_NORM,
534
+ MODEL_TENSOR.FFN_GATE,
535
+ MODEL_TENSOR.FFN_DOWN,
536
+ MODEL_TENSOR.FFN_UP,
537
+ MODEL_TENSOR.FFN_GATE_EXP,
538
+ MODEL_TENSOR.FFN_DOWN_EXP,
539
+ MODEL_TENSOR.FFN_UP_EXP,
540
+ ],
541
+ MODEL_ARCH.GROK: [
542
+ MODEL_TENSOR.TOKEN_EMBD,
543
+ MODEL_TENSOR.OUTPUT_NORM,
544
+ MODEL_TENSOR.OUTPUT,
545
+ MODEL_TENSOR.ROPE_FREQS,
546
+ MODEL_TENSOR.ATTN_NORM,
547
+ MODEL_TENSOR.ATTN_Q,
548
+ MODEL_TENSOR.ATTN_K,
549
+ MODEL_TENSOR.ATTN_V,
550
+ MODEL_TENSOR.ATTN_OUT,
551
+ MODEL_TENSOR.ATTN_ROT_EMBD,
552
+ MODEL_TENSOR.ATTN_OUT_NORM,
553
+ MODEL_TENSOR.FFN_GATE_INP,
554
+ MODEL_TENSOR.FFN_NORM,
555
+ MODEL_TENSOR.FFN_GATE,
556
+ MODEL_TENSOR.FFN_DOWN,
557
+ MODEL_TENSOR.FFN_UP,
558
+ MODEL_TENSOR.FFN_GATE_EXP,
559
+ MODEL_TENSOR.FFN_DOWN_EXP,
560
+ MODEL_TENSOR.FFN_UP_EXP,
561
+ MODEL_TENSOR.LAYER_OUT_NORM,
562
+ ],
563
+ MODEL_ARCH.GPTNEOX: [
564
+ MODEL_TENSOR.TOKEN_EMBD,
565
+ MODEL_TENSOR.OUTPUT_NORM,
566
+ MODEL_TENSOR.OUTPUT,
567
+ MODEL_TENSOR.ATTN_NORM,
568
+ MODEL_TENSOR.ATTN_QKV,
569
+ MODEL_TENSOR.ATTN_OUT,
570
+ MODEL_TENSOR.FFN_NORM,
571
+ MODEL_TENSOR.FFN_DOWN,
572
+ MODEL_TENSOR.FFN_UP,
573
+ ],
574
+ MODEL_ARCH.FALCON: [
575
+ MODEL_TENSOR.TOKEN_EMBD,
576
+ MODEL_TENSOR.OUTPUT_NORM,
577
+ MODEL_TENSOR.OUTPUT,
578
+ MODEL_TENSOR.ATTN_NORM,
579
+ MODEL_TENSOR.ATTN_NORM_2,
580
+ MODEL_TENSOR.ATTN_QKV,
581
+ MODEL_TENSOR.ATTN_OUT,
582
+ MODEL_TENSOR.FFN_DOWN,
583
+ MODEL_TENSOR.FFN_UP,
584
+ ],
585
+ MODEL_ARCH.BAICHUAN: [
586
+ MODEL_TENSOR.TOKEN_EMBD,
587
+ MODEL_TENSOR.OUTPUT_NORM,
588
+ MODEL_TENSOR.OUTPUT,
589
+ MODEL_TENSOR.ROPE_FREQS,
590
+ MODEL_TENSOR.ATTN_NORM,
591
+ MODEL_TENSOR.ATTN_Q,
592
+ MODEL_TENSOR.ATTN_K,
593
+ MODEL_TENSOR.ATTN_V,
594
+ MODEL_TENSOR.ATTN_OUT,
595
+ MODEL_TENSOR.ATTN_ROT_EMBD,
596
+ MODEL_TENSOR.FFN_NORM,
597
+ MODEL_TENSOR.FFN_GATE,
598
+ MODEL_TENSOR.FFN_DOWN,
599
+ MODEL_TENSOR.FFN_UP,
600
+ ],
601
+ MODEL_ARCH.STARCODER: [
602
+ MODEL_TENSOR.TOKEN_EMBD,
603
+ MODEL_TENSOR.POS_EMBD,
604
+ MODEL_TENSOR.OUTPUT_NORM,
605
+ MODEL_TENSOR.OUTPUT,
606
+ MODEL_TENSOR.ATTN_NORM,
607
+ MODEL_TENSOR.ATTN_QKV,
608
+ MODEL_TENSOR.ATTN_OUT,
609
+ MODEL_TENSOR.FFN_NORM,
610
+ MODEL_TENSOR.FFN_DOWN,
611
+ MODEL_TENSOR.FFN_UP,
612
+ ],
613
+ MODEL_ARCH.BERT: [
614
+ MODEL_TENSOR.TOKEN_EMBD,
615
+ MODEL_TENSOR.TOKEN_EMBD_NORM,
616
+ MODEL_TENSOR.TOKEN_TYPES,
617
+ MODEL_TENSOR.POS_EMBD,
618
+ MODEL_TENSOR.OUTPUT_NORM,
619
+ MODEL_TENSOR.ATTN_OUT_NORM,
620
+ MODEL_TENSOR.ATTN_Q,
621
+ MODEL_TENSOR.ATTN_K,
622
+ MODEL_TENSOR.ATTN_V,
623
+ MODEL_TENSOR.ATTN_OUT,
624
+ MODEL_TENSOR.FFN_DOWN,
625
+ MODEL_TENSOR.FFN_UP,
626
+ MODEL_TENSOR.LAYER_OUT_NORM,
627
+ MODEL_TENSOR.CLS,
628
+ MODEL_TENSOR.CLS_OUT,
629
+ ],
630
+ MODEL_ARCH.NOMIC_BERT: [
631
+ MODEL_TENSOR.TOKEN_EMBD,
632
+ MODEL_TENSOR.TOKEN_EMBD_NORM,
633
+ MODEL_TENSOR.TOKEN_TYPES,
634
+ MODEL_TENSOR.POS_EMBD,
635
+ MODEL_TENSOR.OUTPUT_NORM,
636
+ MODEL_TENSOR.ATTN_OUT_NORM,
637
+ MODEL_TENSOR.ATTN_QKV,
638
+ MODEL_TENSOR.ATTN_OUT,
639
+ MODEL_TENSOR.FFN_GATE,
640
+ MODEL_TENSOR.FFN_DOWN,
641
+ MODEL_TENSOR.FFN_UP,
642
+ MODEL_TENSOR.LAYER_OUT_NORM,
643
+ ],
644
+ MODEL_ARCH.JINA_BERT_V2: [
645
+ MODEL_TENSOR.TOKEN_EMBD,
646
+ MODEL_TENSOR.TOKEN_EMBD_NORM,
647
+ MODEL_TENSOR.TOKEN_TYPES,
648
+ MODEL_TENSOR.ATTN_NORM_2,
649
+ MODEL_TENSOR.ATTN_OUT_NORM,
650
+ MODEL_TENSOR.ATTN_Q,
651
+ MODEL_TENSOR.ATTN_Q_NORM,
652
+ MODEL_TENSOR.ATTN_K,
653
+ MODEL_TENSOR.ATTN_K_NORM,
654
+ MODEL_TENSOR.ATTN_V,
655
+ MODEL_TENSOR.ATTN_OUT,
656
+ MODEL_TENSOR.FFN_UP,
657
+ MODEL_TENSOR.FFN_GATE,
658
+ MODEL_TENSOR.FFN_DOWN,
659
+ MODEL_TENSOR.LAYER_OUT_NORM,
660
+ MODEL_TENSOR.CLS,
661
+ ],
662
+ MODEL_ARCH.MPT: [
663
+ MODEL_TENSOR.TOKEN_EMBD,
664
+ MODEL_TENSOR.OUTPUT_NORM,
665
+ MODEL_TENSOR.OUTPUT,
666
+ MODEL_TENSOR.ATTN_NORM,
667
+ MODEL_TENSOR.ATTN_QKV,
668
+ MODEL_TENSOR.ATTN_OUT,
669
+ MODEL_TENSOR.FFN_NORM,
670
+ MODEL_TENSOR.FFN_DOWN,
671
+ MODEL_TENSOR.FFN_UP,
672
+ MODEL_TENSOR.FFN_ACT,
673
+ MODEL_TENSOR.ATTN_Q_NORM,
674
+ MODEL_TENSOR.ATTN_K_NORM,
675
+ MODEL_TENSOR.POS_EMBD,
676
+ ],
677
+ MODEL_ARCH.GPTJ: [
678
+ MODEL_TENSOR.TOKEN_EMBD,
679
+ MODEL_TENSOR.OUTPUT_NORM,
680
+ MODEL_TENSOR.OUTPUT,
681
+ MODEL_TENSOR.ATTN_NORM,
682
+ MODEL_TENSOR.ATTN_Q,
683
+ MODEL_TENSOR.ATTN_K,
684
+ MODEL_TENSOR.ATTN_V,
685
+ MODEL_TENSOR.ATTN_OUT,
686
+ MODEL_TENSOR.FFN_DOWN,
687
+ MODEL_TENSOR.FFN_UP,
688
+ ],
689
+ MODEL_ARCH.REFACT: [
690
+ MODEL_TENSOR.TOKEN_EMBD,
691
+ MODEL_TENSOR.OUTPUT_NORM,
692
+ MODEL_TENSOR.OUTPUT,
693
+ MODEL_TENSOR.ATTN_NORM,
694
+ MODEL_TENSOR.ATTN_Q,
695
+ MODEL_TENSOR.ATTN_K,
696
+ MODEL_TENSOR.ATTN_V,
697
+ MODEL_TENSOR.ATTN_OUT,
698
+ MODEL_TENSOR.FFN_NORM,
699
+ MODEL_TENSOR.FFN_GATE,
700
+ MODEL_TENSOR.FFN_DOWN,
701
+ MODEL_TENSOR.FFN_UP,
702
+ ],
703
+ MODEL_ARCH.BLOOM: [
704
+ MODEL_TENSOR.TOKEN_EMBD,
705
+ MODEL_TENSOR.TOKEN_EMBD_NORM,
706
+ MODEL_TENSOR.OUTPUT_NORM,
707
+ MODEL_TENSOR.OUTPUT,
708
+ MODEL_TENSOR.ATTN_NORM,
709
+ MODEL_TENSOR.ATTN_QKV,
710
+ MODEL_TENSOR.ATTN_OUT,
711
+ MODEL_TENSOR.FFN_NORM,
712
+ MODEL_TENSOR.FFN_DOWN,
713
+ MODEL_TENSOR.FFN_UP,
714
+ ],
715
+ MODEL_ARCH.STABLELM: [
716
+ MODEL_TENSOR.TOKEN_EMBD,
717
+ MODEL_TENSOR.OUTPUT_NORM,
718
+ MODEL_TENSOR.OUTPUT,
719
+ MODEL_TENSOR.ROPE_FREQS,
720
+ MODEL_TENSOR.ATTN_NORM,
721
+ MODEL_TENSOR.ATTN_Q,
722
+ MODEL_TENSOR.ATTN_K,
723
+ MODEL_TENSOR.ATTN_V,
724
+ MODEL_TENSOR.ATTN_OUT,
725
+ MODEL_TENSOR.FFN_NORM,
726
+ MODEL_TENSOR.FFN_GATE,
727
+ MODEL_TENSOR.FFN_DOWN,
728
+ MODEL_TENSOR.FFN_UP,
729
+ MODEL_TENSOR.ATTN_Q_NORM,
730
+ MODEL_TENSOR.ATTN_K_NORM,
731
+ ],
732
+ MODEL_ARCH.QWEN: [
733
+ MODEL_TENSOR.TOKEN_EMBD,
734
+ MODEL_TENSOR.OUTPUT_NORM,
735
+ MODEL_TENSOR.OUTPUT,
736
+ MODEL_TENSOR.ROPE_FREQS,
737
+ MODEL_TENSOR.ATTN_NORM,
738
+ MODEL_TENSOR.ATTN_QKV,
739
+ MODEL_TENSOR.ATTN_OUT,
740
+ MODEL_TENSOR.ATTN_ROT_EMBD,
741
+ MODEL_TENSOR.FFN_NORM,
742
+ MODEL_TENSOR.FFN_GATE,
743
+ MODEL_TENSOR.FFN_DOWN,
744
+ MODEL_TENSOR.FFN_UP,
745
+ ],
746
+ MODEL_ARCH.QWEN2: [
747
+ MODEL_TENSOR.TOKEN_EMBD,
748
+ MODEL_TENSOR.OUTPUT_NORM,
749
+ MODEL_TENSOR.OUTPUT,
750
+ MODEL_TENSOR.ATTN_NORM,
751
+ MODEL_TENSOR.ATTN_Q,
752
+ MODEL_TENSOR.ATTN_K,
753
+ MODEL_TENSOR.ATTN_V,
754
+ MODEL_TENSOR.ATTN_OUT,
755
+ MODEL_TENSOR.FFN_NORM,
756
+ MODEL_TENSOR.FFN_GATE,
757
+ MODEL_TENSOR.FFN_DOWN,
758
+ MODEL_TENSOR.FFN_UP,
759
+ ],
760
+ MODEL_ARCH.QWEN2MOE: [
761
+ MODEL_TENSOR.TOKEN_EMBD,
762
+ MODEL_TENSOR.OUTPUT_NORM,
763
+ MODEL_TENSOR.OUTPUT,
764
+ MODEL_TENSOR.ATTN_NORM,
765
+ MODEL_TENSOR.ATTN_Q,
766
+ MODEL_TENSOR.ATTN_K,
767
+ MODEL_TENSOR.ATTN_V,
768
+ MODEL_TENSOR.ATTN_OUT,
769
+ MODEL_TENSOR.FFN_NORM,
770
+ MODEL_TENSOR.FFN_GATE_INP,
771
+ MODEL_TENSOR.FFN_GATE_EXP,
772
+ MODEL_TENSOR.FFN_DOWN_EXP,
773
+ MODEL_TENSOR.FFN_UP_EXP,
774
+ MODEL_TENSOR.FFN_GATE_INP_SHEXP,
775
+ MODEL_TENSOR.FFN_GATE_SHEXP,
776
+ MODEL_TENSOR.FFN_DOWN_SHEXP,
777
+ MODEL_TENSOR.FFN_UP_SHEXP,
778
+ ],
779
+ MODEL_ARCH.PLAMO: [
780
+ MODEL_TENSOR.TOKEN_EMBD,
781
+ MODEL_TENSOR.OUTPUT_NORM,
782
+ MODEL_TENSOR.OUTPUT,
783
+ MODEL_TENSOR.ROPE_FREQS,
784
+ MODEL_TENSOR.ATTN_NORM,
785
+ MODEL_TENSOR.ATTN_Q,
786
+ MODEL_TENSOR.ATTN_K,
787
+ MODEL_TENSOR.ATTN_V,
788
+ MODEL_TENSOR.ATTN_OUT,
789
+ MODEL_TENSOR.ATTN_ROT_EMBD,
790
+ MODEL_TENSOR.FFN_GATE,
791
+ MODEL_TENSOR.FFN_DOWN,
792
+ MODEL_TENSOR.FFN_UP,
793
+ ],
794
+ MODEL_ARCH.GPT2: [
795
+ MODEL_TENSOR.TOKEN_EMBD,
796
+ MODEL_TENSOR.POS_EMBD,
797
+ MODEL_TENSOR.OUTPUT_NORM,
798
+ MODEL_TENSOR.OUTPUT,
799
+ MODEL_TENSOR.ATTN_NORM,
800
+ MODEL_TENSOR.ATTN_QKV,
801
+ MODEL_TENSOR.ATTN_OUT,
802
+ MODEL_TENSOR.FFN_NORM,
803
+ MODEL_TENSOR.FFN_DOWN,
804
+ MODEL_TENSOR.FFN_UP,
805
+ ],
806
+ MODEL_ARCH.PHI2: [
807
+ MODEL_TENSOR.TOKEN_EMBD,
808
+ MODEL_TENSOR.OUTPUT_NORM,
809
+ MODEL_TENSOR.OUTPUT,
810
+ MODEL_TENSOR.ATTN_NORM,
811
+ MODEL_TENSOR.ATTN_QKV,
812
+ MODEL_TENSOR.ATTN_Q,
813
+ MODEL_TENSOR.ATTN_K,
814
+ MODEL_TENSOR.ATTN_V,
815
+ MODEL_TENSOR.ATTN_OUT,
816
+ MODEL_TENSOR.FFN_NORM,
817
+ MODEL_TENSOR.FFN_DOWN,
818
+ MODEL_TENSOR.FFN_UP,
819
+ ],
820
+ MODEL_ARCH.PHI3: [
821
+ MODEL_TENSOR.TOKEN_EMBD,
822
+ MODEL_TENSOR.OUTPUT_NORM,
823
+ MODEL_TENSOR.OUTPUT,
824
+ MODEL_TENSOR.ROPE_FACTORS_LONG,
825
+ MODEL_TENSOR.ROPE_FACTORS_SHORT,
826
+ MODEL_TENSOR.ATTN_NORM,
827
+ MODEL_TENSOR.ATTN_QKV,
828
+ MODEL_TENSOR.ATTN_Q,
829
+ MODEL_TENSOR.ATTN_K,
830
+ MODEL_TENSOR.ATTN_V,
831
+ MODEL_TENSOR.ATTN_OUT,
832
+ MODEL_TENSOR.FFN_NORM,
833
+ MODEL_TENSOR.FFN_DOWN,
834
+ MODEL_TENSOR.FFN_UP,
835
+ ],
836
+ MODEL_ARCH.CODESHELL: [
837
+ MODEL_TENSOR.TOKEN_EMBD,
838
+ MODEL_TENSOR.POS_EMBD,
839
+ MODEL_TENSOR.OUTPUT_NORM,
840
+ MODEL_TENSOR.OUTPUT,
841
+ MODEL_TENSOR.ATTN_NORM,
842
+ MODEL_TENSOR.ATTN_QKV,
843
+ MODEL_TENSOR.ATTN_OUT,
844
+ MODEL_TENSOR.ATTN_ROT_EMBD,
845
+ MODEL_TENSOR.FFN_NORM,
846
+ MODEL_TENSOR.FFN_DOWN,
847
+ MODEL_TENSOR.FFN_UP,
848
+ ],
849
+ MODEL_ARCH.ORION: [
850
+ MODEL_TENSOR.TOKEN_EMBD,
851
+ MODEL_TENSOR.OUTPUT_NORM,
852
+ MODEL_TENSOR.OUTPUT,
853
+ MODEL_TENSOR.ROPE_FREQS,
854
+ MODEL_TENSOR.ATTN_NORM,
855
+ MODEL_TENSOR.ATTN_Q,
856
+ MODEL_TENSOR.ATTN_K,
857
+ MODEL_TENSOR.ATTN_V,
858
+ MODEL_TENSOR.ATTN_OUT,
859
+ MODEL_TENSOR.ATTN_ROT_EMBD,
860
+ MODEL_TENSOR.FFN_NORM,
861
+ MODEL_TENSOR.FFN_GATE,
862
+ MODEL_TENSOR.FFN_DOWN,
863
+ MODEL_TENSOR.FFN_UP,
864
+ ],
865
+ MODEL_ARCH.INTERNLM2: [
866
+ MODEL_TENSOR.TOKEN_EMBD,
867
+ MODEL_TENSOR.OUTPUT_NORM,
868
+ MODEL_TENSOR.OUTPUT,
869
+ MODEL_TENSOR.ATTN_NORM,
870
+ MODEL_TENSOR.ATTN_Q,
871
+ MODEL_TENSOR.ATTN_K,
872
+ MODEL_TENSOR.ATTN_V,
873
+ MODEL_TENSOR.ATTN_OUT,
874
+ MODEL_TENSOR.ATTN_ROT_EMBD,
875
+ MODEL_TENSOR.FFN_NORM,
876
+ MODEL_TENSOR.FFN_GATE,
877
+ MODEL_TENSOR.FFN_DOWN,
878
+ MODEL_TENSOR.FFN_UP,
879
+ ],
880
+ MODEL_ARCH.MINICPM: [
881
+ MODEL_TENSOR.TOKEN_EMBD,
882
+ MODEL_TENSOR.OUTPUT,
883
+ MODEL_TENSOR.OUTPUT_NORM,
884
+ MODEL_TENSOR.ROPE_FREQS,
885
+ MODEL_TENSOR.ATTN_NORM,
886
+ MODEL_TENSOR.ATTN_Q,
887
+ MODEL_TENSOR.ATTN_K,
888
+ MODEL_TENSOR.ATTN_V,
889
+ MODEL_TENSOR.ATTN_OUT,
890
+ MODEL_TENSOR.ATTN_ROT_EMBD,
891
+ MODEL_TENSOR.FFN_GATE_INP,
892
+ MODEL_TENSOR.FFN_NORM,
893
+ MODEL_TENSOR.FFN_GATE,
894
+ MODEL_TENSOR.FFN_DOWN,
895
+ MODEL_TENSOR.FFN_UP,
896
+ MODEL_TENSOR.FFN_GATE_EXP,
897
+ MODEL_TENSOR.FFN_DOWN_EXP,
898
+ MODEL_TENSOR.FFN_UP_EXP,
899
+ ],
900
+ MODEL_ARCH.MINICPM3: [
901
+ MODEL_TENSOR.TOKEN_EMBD,
902
+ MODEL_TENSOR.OUTPUT_NORM,
903
+ MODEL_TENSOR.OUTPUT,
904
+ MODEL_TENSOR.ROPE_FACTORS_LONG,
905
+ MODEL_TENSOR.ROPE_FACTORS_SHORT,
906
+ MODEL_TENSOR.ATTN_NORM,
907
+ MODEL_TENSOR.ATTN_Q_A,
908
+ MODEL_TENSOR.ATTN_Q_B,
909
+ MODEL_TENSOR.ATTN_KV_A_MQA,
910
+ MODEL_TENSOR.ATTN_KV_B,
911
+ MODEL_TENSOR.ATTN_Q_A_NORM,
912
+ MODEL_TENSOR.ATTN_KV_A_NORM,
913
+ MODEL_TENSOR.ATTN_OUT,
914
+ MODEL_TENSOR.FFN_NORM,
915
+ MODEL_TENSOR.FFN_GATE,
916
+ MODEL_TENSOR.FFN_DOWN,
917
+ MODEL_TENSOR.FFN_UP,
918
+ ],
919
+ MODEL_ARCH.GEMMA: [
920
+ MODEL_TENSOR.TOKEN_EMBD,
921
+ MODEL_TENSOR.OUTPUT_NORM,
922
+ MODEL_TENSOR.ATTN_NORM,
923
+ MODEL_TENSOR.ATTN_Q,
924
+ MODEL_TENSOR.ATTN_K,
925
+ MODEL_TENSOR.ATTN_V,
926
+ MODEL_TENSOR.ATTN_OUT,
927
+ MODEL_TENSOR.FFN_GATE,
928
+ MODEL_TENSOR.FFN_DOWN,
929
+ MODEL_TENSOR.FFN_UP,
930
+ MODEL_TENSOR.FFN_NORM,
931
+ ],
932
+ MODEL_ARCH.GEMMA2: [
933
+ MODEL_TENSOR.TOKEN_EMBD,
934
+ MODEL_TENSOR.OUTPUT_NORM,
935
+ MODEL_TENSOR.ATTN_Q,
936
+ MODEL_TENSOR.ATTN_K,
937
+ MODEL_TENSOR.ATTN_V,
938
+ MODEL_TENSOR.ATTN_OUT,
939
+ MODEL_TENSOR.FFN_GATE,
940
+ MODEL_TENSOR.FFN_DOWN,
941
+ MODEL_TENSOR.FFN_UP,
942
+ MODEL_TENSOR.ATTN_NORM,
943
+ MODEL_TENSOR.ATTN_POST_NORM,
944
+ MODEL_TENSOR.FFN_PRE_NORM,
945
+ MODEL_TENSOR.FFN_POST_NORM,
946
+ ],
947
+ MODEL_ARCH.STARCODER2: [
948
+ MODEL_TENSOR.TOKEN_EMBD,
949
+ MODEL_TENSOR.OUTPUT_NORM,
950
+ MODEL_TENSOR.OUTPUT,
951
+ MODEL_TENSOR.ROPE_FREQS,
952
+ MODEL_TENSOR.ATTN_NORM,
953
+ MODEL_TENSOR.ATTN_Q,
954
+ MODEL_TENSOR.ATTN_K,
955
+ MODEL_TENSOR.ATTN_V,
956
+ MODEL_TENSOR.ATTN_OUT,
957
+ MODEL_TENSOR.ATTN_ROT_EMBD,
958
+ MODEL_TENSOR.FFN_NORM,
959
+ MODEL_TENSOR.FFN_DOWN,
960
+ MODEL_TENSOR.FFN_UP,
961
+ ],
962
+ MODEL_ARCH.RWKV6: [
963
+ MODEL_TENSOR.TOKEN_EMBD,
964
+ MODEL_TENSOR.TOKEN_EMBD_NORM,
965
+ MODEL_TENSOR.OUTPUT_NORM,
966
+ MODEL_TENSOR.OUTPUT,
967
+ MODEL_TENSOR.ATTN_NORM,
968
+ MODEL_TENSOR.ATTN_NORM_2,
969
+ MODEL_TENSOR.TIME_MIX_W1,
970
+ MODEL_TENSOR.TIME_MIX_W2,
971
+ MODEL_TENSOR.TIME_MIX_LERP_X,
972
+ MODEL_TENSOR.TIME_MIX_LERP_K,
973
+ MODEL_TENSOR.TIME_MIX_LERP_V,
974
+ MODEL_TENSOR.TIME_MIX_LERP_R,
975
+ MODEL_TENSOR.TIME_MIX_LERP_G,
976
+ MODEL_TENSOR.TIME_MIX_LERP_W,
977
+ MODEL_TENSOR.TIME_MIX_FIRST,
978
+ MODEL_TENSOR.TIME_MIX_DECAY,
979
+ MODEL_TENSOR.TIME_MIX_DECAY_W1,
980
+ MODEL_TENSOR.TIME_MIX_DECAY_W2,
981
+ MODEL_TENSOR.TIME_MIX_KEY,
982
+ MODEL_TENSOR.TIME_MIX_VALUE,
983
+ MODEL_TENSOR.TIME_MIX_RECEPTANCE,
984
+ MODEL_TENSOR.TIME_MIX_GATE,
985
+ MODEL_TENSOR.TIME_MIX_LN,
986
+ MODEL_TENSOR.TIME_MIX_OUTPUT,
987
+ MODEL_TENSOR.CHANNEL_MIX_LERP_K,
988
+ MODEL_TENSOR.CHANNEL_MIX_LERP_R,
989
+ MODEL_TENSOR.CHANNEL_MIX_KEY,
990
+ MODEL_TENSOR.CHANNEL_MIX_RECEPTANCE,
991
+ MODEL_TENSOR.CHANNEL_MIX_VALUE,
992
+ ],
993
+ MODEL_ARCH.MAMBA: [
994
+ MODEL_TENSOR.TOKEN_EMBD,
995
+ MODEL_TENSOR.OUTPUT_NORM,
996
+ MODEL_TENSOR.OUTPUT,
997
+ MODEL_TENSOR.ATTN_NORM,
998
+ MODEL_TENSOR.SSM_IN,
999
+ MODEL_TENSOR.SSM_CONV1D,
1000
+ MODEL_TENSOR.SSM_X,
1001
+ MODEL_TENSOR.SSM_DT,
1002
+ MODEL_TENSOR.SSM_A,
1003
+ MODEL_TENSOR.SSM_D,
1004
+ MODEL_TENSOR.SSM_OUT,
1005
+ ],
1006
+ MODEL_ARCH.XVERSE: [
1007
+ MODEL_TENSOR.TOKEN_EMBD,
1008
+ MODEL_TENSOR.OUTPUT_NORM,
1009
+ MODEL_TENSOR.OUTPUT,
1010
+ MODEL_TENSOR.ROPE_FREQS,
1011
+ MODEL_TENSOR.ATTN_NORM,
1012
+ MODEL_TENSOR.ATTN_Q,
1013
+ MODEL_TENSOR.ATTN_K,
1014
+ MODEL_TENSOR.ATTN_V,
1015
+ MODEL_TENSOR.ATTN_OUT,
1016
+ MODEL_TENSOR.ATTN_ROT_EMBD,
1017
+ MODEL_TENSOR.FFN_NORM,
1018
+ MODEL_TENSOR.FFN_GATE,
1019
+ MODEL_TENSOR.FFN_DOWN,
1020
+ MODEL_TENSOR.FFN_UP,
1021
+ ],
1022
+ MODEL_ARCH.COMMAND_R: [
1023
+ MODEL_TENSOR.TOKEN_EMBD,
1024
+ MODEL_TENSOR.OUTPUT_NORM,
1025
+ MODEL_TENSOR.ATTN_NORM,
1026
+ MODEL_TENSOR.ATTN_Q,
1027
+ MODEL_TENSOR.ATTN_K,
1028
+ MODEL_TENSOR.ATTN_V,
1029
+ MODEL_TENSOR.ATTN_OUT,
1030
+ MODEL_TENSOR.FFN_GATE,
1031
+ MODEL_TENSOR.FFN_DOWN,
1032
+ MODEL_TENSOR.FFN_UP,
1033
+ MODEL_TENSOR.ATTN_K_NORM,
1034
+ MODEL_TENSOR.ATTN_Q_NORM,
1035
+ ],
1036
+ MODEL_ARCH.DBRX: [
1037
+ MODEL_TENSOR.TOKEN_EMBD,
1038
+ MODEL_TENSOR.OUTPUT_NORM,
1039
+ MODEL_TENSOR.OUTPUT,
1040
+ MODEL_TENSOR.ATTN_NORM,
1041
+ MODEL_TENSOR.ATTN_QKV,
1042
+ MODEL_TENSOR.ATTN_OUT,
1043
+ MODEL_TENSOR.ATTN_OUT_NORM,
1044
+ MODEL_TENSOR.FFN_GATE_INP,
1045
+ MODEL_TENSOR.FFN_GATE_EXP,
1046
+ MODEL_TENSOR.FFN_DOWN_EXP,
1047
+ MODEL_TENSOR.FFN_UP_EXP,
1048
+ ],
1049
+ MODEL_ARCH.OLMO: [
1050
+ MODEL_TENSOR.TOKEN_EMBD,
1051
+ MODEL_TENSOR.OUTPUT,
1052
+ MODEL_TENSOR.ATTN_Q,
1053
+ MODEL_TENSOR.ATTN_K,
1054
+ MODEL_TENSOR.ATTN_V,
1055
+ MODEL_TENSOR.ATTN_OUT,
1056
+ MODEL_TENSOR.FFN_GATE,
1057
+ MODEL_TENSOR.FFN_DOWN,
1058
+ MODEL_TENSOR.FFN_UP,
1059
+ ],
1060
+ MODEL_ARCH.OLMOE: [
1061
+ MODEL_TENSOR.TOKEN_EMBD,
1062
+ MODEL_TENSOR.OUTPUT_NORM,
1063
+ MODEL_TENSOR.OUTPUT,
1064
+ MODEL_TENSOR.ATTN_OUT,
1065
+ MODEL_TENSOR.ATTN_Q,
1066
+ MODEL_TENSOR.ATTN_K,
1067
+ MODEL_TENSOR.ATTN_V,
1068
+ MODEL_TENSOR.ATTN_NORM,
1069
+ MODEL_TENSOR.ATTN_Q_NORM,
1070
+ MODEL_TENSOR.ATTN_K_NORM,
1071
+ MODEL_TENSOR.FFN_NORM,
1072
+ MODEL_TENSOR.FFN_GATE_INP,
1073
+ MODEL_TENSOR.FFN_GATE_EXP,
1074
+ MODEL_TENSOR.FFN_UP_EXP,
1075
+ MODEL_TENSOR.FFN_DOWN_EXP,
1076
+ ],
1077
+ MODEL_ARCH.OPENELM: [
1078
+ MODEL_TENSOR.TOKEN_EMBD,
1079
+ MODEL_TENSOR.OUTPUT_NORM,
1080
+ MODEL_TENSOR.ATTN_NORM,
1081
+ MODEL_TENSOR.ATTN_QKV,
1082
+ MODEL_TENSOR.ATTN_Q_NORM,
1083
+ MODEL_TENSOR.ATTN_K_NORM,
1084
+ MODEL_TENSOR.ATTN_OUT,
1085
+ MODEL_TENSOR.FFN_NORM,
1086
+ MODEL_TENSOR.FFN_GATE,
1087
+ MODEL_TENSOR.FFN_DOWN,
1088
+ MODEL_TENSOR.FFN_UP,
1089
+ ],
1090
+ MODEL_ARCH.ARCTIC: [
1091
+ MODEL_TENSOR.TOKEN_EMBD,
1092
+ MODEL_TENSOR.OUTPUT_NORM,
1093
+ MODEL_TENSOR.OUTPUT,
1094
+ MODEL_TENSOR.ROPE_FREQS,
1095
+ MODEL_TENSOR.ATTN_NORM,
1096
+ MODEL_TENSOR.ATTN_Q,
1097
+ MODEL_TENSOR.ATTN_K,
1098
+ MODEL_TENSOR.ATTN_V,
1099
+ MODEL_TENSOR.ATTN_OUT,
1100
+ MODEL_TENSOR.ATTN_ROT_EMBD,
1101
+ MODEL_TENSOR.FFN_GATE_INP,
1102
+ MODEL_TENSOR.FFN_NORM,
1103
+ MODEL_TENSOR.FFN_GATE,
1104
+ MODEL_TENSOR.FFN_DOWN,
1105
+ MODEL_TENSOR.FFN_UP,
1106
+ MODEL_TENSOR.FFN_NORM_EXP,
1107
+ MODEL_TENSOR.FFN_GATE_EXP,
1108
+ MODEL_TENSOR.FFN_DOWN_EXP,
1109
+ MODEL_TENSOR.FFN_UP_EXP,
1110
+ ],
1111
+ MODEL_ARCH.DEEPSEEK2: [
1112
+ MODEL_TENSOR.TOKEN_EMBD,
1113
+ MODEL_TENSOR.OUTPUT_NORM,
1114
+ MODEL_TENSOR.OUTPUT,
1115
+ MODEL_TENSOR.ROPE_FREQS,
1116
+ MODEL_TENSOR.ATTN_NORM,
1117
+ MODEL_TENSOR.ATTN_Q,
1118
+ MODEL_TENSOR.ATTN_Q_A,
1119
+ MODEL_TENSOR.ATTN_Q_B,
1120
+ MODEL_TENSOR.ATTN_KV_A_MQA,
1121
+ MODEL_TENSOR.ATTN_KV_B,
1122
+ MODEL_TENSOR.ATTN_Q_A_NORM,
1123
+ MODEL_TENSOR.ATTN_KV_A_NORM,
1124
+ MODEL_TENSOR.ATTN_OUT,
1125
+ MODEL_TENSOR.ATTN_ROT_EMBD,
1126
+ MODEL_TENSOR.FFN_GATE_INP,
1127
+ MODEL_TENSOR.FFN_NORM,
1128
+ MODEL_TENSOR.FFN_GATE,
1129
+ MODEL_TENSOR.FFN_DOWN,
1130
+ MODEL_TENSOR.FFN_UP,
1131
+ MODEL_TENSOR.FFN_GATE_EXP,
1132
+ MODEL_TENSOR.FFN_DOWN_EXP,
1133
+ MODEL_TENSOR.FFN_UP_EXP,
1134
+ MODEL_TENSOR.FFN_GATE_SHEXP,
1135
+ MODEL_TENSOR.FFN_DOWN_SHEXP,
1136
+ MODEL_TENSOR.FFN_UP_SHEXP,
1137
+ ],
1138
+ MODEL_ARCH.CHATGLM : [
1139
+ MODEL_TENSOR.TOKEN_EMBD,
1140
+ MODEL_TENSOR.ROPE_FREQS,
1141
+ MODEL_TENSOR.OUTPUT_NORM,
1142
+ MODEL_TENSOR.OUTPUT,
1143
+ MODEL_TENSOR.ATTN_NORM,
1144
+ MODEL_TENSOR.ATTN_QKV,
1145
+ MODEL_TENSOR.ATTN_OUT,
1146
+ MODEL_TENSOR.FFN_NORM,
1147
+ MODEL_TENSOR.FFN_DOWN,
1148
+ MODEL_TENSOR.FFN_UP,
1149
+ ],
1150
+ MODEL_ARCH.BITNET: [
1151
+ MODEL_TENSOR.ATTN_Q,
1152
+ MODEL_TENSOR.ATTN_K,
1153
+ MODEL_TENSOR.ATTN_V,
1154
+ MODEL_TENSOR.TOKEN_EMBD,
1155
+ MODEL_TENSOR.OUTPUT_NORM,
1156
+ MODEL_TENSOR.ATTN_NORM,
1157
+ MODEL_TENSOR.ATTN_OUT,
1158
+ MODEL_TENSOR.FFN_NORM,
1159
+ MODEL_TENSOR.FFN_GATE,
1160
+ MODEL_TENSOR.FFN_DOWN,
1161
+ MODEL_TENSOR.FFN_UP,
1162
+ MODEL_TENSOR.ATTN_SUB_NORM,
1163
+ MODEL_TENSOR.FFN_SUB_NORM,
1164
+ ],
1165
+ MODEL_ARCH.T5: [
1166
+ MODEL_TENSOR.TOKEN_EMBD,
1167
+ MODEL_TENSOR.OUTPUT,
1168
+ MODEL_TENSOR.DEC_ATTN_NORM,
1169
+ MODEL_TENSOR.DEC_ATTN_Q,
1170
+ MODEL_TENSOR.DEC_ATTN_K,
1171
+ MODEL_TENSOR.DEC_ATTN_V,
1172
+ MODEL_TENSOR.DEC_ATTN_OUT,
1173
+ MODEL_TENSOR.DEC_ATTN_REL_B,
1174
+ MODEL_TENSOR.DEC_CROSS_ATTN_NORM,
1175
+ MODEL_TENSOR.DEC_CROSS_ATTN_Q,
1176
+ MODEL_TENSOR.DEC_CROSS_ATTN_K,
1177
+ MODEL_TENSOR.DEC_CROSS_ATTN_V,
1178
+ MODEL_TENSOR.DEC_CROSS_ATTN_OUT,
1179
+ MODEL_TENSOR.DEC_CROSS_ATTN_REL_B,
1180
+ MODEL_TENSOR.DEC_FFN_NORM,
1181
+ MODEL_TENSOR.DEC_FFN_GATE,
1182
+ MODEL_TENSOR.DEC_FFN_DOWN,
1183
+ MODEL_TENSOR.DEC_FFN_UP,
1184
+ MODEL_TENSOR.DEC_OUTPUT_NORM,
1185
+ MODEL_TENSOR.ENC_ATTN_NORM,
1186
+ MODEL_TENSOR.ENC_ATTN_Q,
1187
+ MODEL_TENSOR.ENC_ATTN_K,
1188
+ MODEL_TENSOR.ENC_ATTN_V,
1189
+ MODEL_TENSOR.ENC_ATTN_OUT,
1190
+ MODEL_TENSOR.ENC_ATTN_REL_B,
1191
+ MODEL_TENSOR.ENC_FFN_NORM,
1192
+ MODEL_TENSOR.ENC_FFN_GATE,
1193
+ MODEL_TENSOR.ENC_FFN_DOWN,
1194
+ MODEL_TENSOR.ENC_FFN_UP,
1195
+ MODEL_TENSOR.ENC_OUTPUT_NORM,
1196
+ ],
1197
+ MODEL_ARCH.T5ENCODER: [
1198
+ MODEL_TENSOR.TOKEN_EMBD,
1199
+ MODEL_TENSOR.OUTPUT,
1200
+ MODEL_TENSOR.ENC_ATTN_NORM,
1201
+ MODEL_TENSOR.ENC_ATTN_Q,
1202
+ MODEL_TENSOR.ENC_ATTN_K,
1203
+ MODEL_TENSOR.ENC_ATTN_V,
1204
+ MODEL_TENSOR.ENC_ATTN_OUT,
1205
+ MODEL_TENSOR.ENC_ATTN_REL_B,
1206
+ MODEL_TENSOR.ENC_FFN_NORM,
1207
+ MODEL_TENSOR.ENC_FFN_GATE,
1208
+ MODEL_TENSOR.ENC_FFN_DOWN,
1209
+ MODEL_TENSOR.ENC_FFN_UP,
1210
+ MODEL_TENSOR.ENC_OUTPUT_NORM,
1211
+ ],
1212
+ MODEL_ARCH.JAIS: [
1213
+ MODEL_TENSOR.TOKEN_EMBD,
1214
+ MODEL_TENSOR.OUTPUT_NORM,
1215
+ MODEL_TENSOR.OUTPUT,
1216
+ MODEL_TENSOR.ATTN_NORM,
1217
+ MODEL_TENSOR.ATTN_QKV,
1218
+ MODEL_TENSOR.ATTN_OUT,
1219
+ MODEL_TENSOR.FFN_NORM,
1220
+ MODEL_TENSOR.FFN_DOWN,
1221
+ MODEL_TENSOR.FFN_GATE,
1222
+ MODEL_TENSOR.FFN_UP,
1223
+ ],
1224
+ MODEL_ARCH.NEMOTRON: [
1225
+ MODEL_TENSOR.TOKEN_EMBD,
1226
+ MODEL_TENSOR.OUTPUT_NORM,
1227
+ MODEL_TENSOR.OUTPUT,
1228
+ MODEL_TENSOR.ROPE_FREQS,
1229
+ MODEL_TENSOR.ATTN_NORM,
1230
+ MODEL_TENSOR.ATTN_Q,
1231
+ MODEL_TENSOR.ATTN_K,
1232
+ MODEL_TENSOR.ATTN_V,
1233
+ MODEL_TENSOR.ATTN_OUT,
1234
+ MODEL_TENSOR.ATTN_ROT_EMBD,
1235
+ MODEL_TENSOR.FFN_NORM,
1236
+ MODEL_TENSOR.FFN_DOWN,
1237
+ MODEL_TENSOR.FFN_UP,
1238
+ ],
1239
+ MODEL_ARCH.EXAONE: [
1240
+ MODEL_TENSOR.TOKEN_EMBD,
1241
+ MODEL_TENSOR.OUTPUT_NORM,
1242
+ MODEL_TENSOR.OUTPUT,
1243
+ MODEL_TENSOR.ROPE_FREQS,
1244
+ MODEL_TENSOR.ATTN_NORM,
1245
+ MODEL_TENSOR.ATTN_Q,
1246
+ MODEL_TENSOR.ATTN_K,
1247
+ MODEL_TENSOR.ATTN_V,
1248
+ MODEL_TENSOR.ATTN_OUT,
1249
+ MODEL_TENSOR.ATTN_ROT_EMBD,
1250
+ MODEL_TENSOR.FFN_NORM,
1251
+ MODEL_TENSOR.FFN_GATE,
1252
+ MODEL_TENSOR.FFN_DOWN,
1253
+ MODEL_TENSOR.FFN_UP,
1254
+ ],
1255
+ MODEL_ARCH.GRANITE: [
1256
+ MODEL_TENSOR.TOKEN_EMBD,
1257
+ MODEL_TENSOR.OUTPUT_NORM,
1258
+ MODEL_TENSOR.OUTPUT,
1259
+ MODEL_TENSOR.ATTN_NORM,
1260
+ MODEL_TENSOR.ATTN_Q,
1261
+ MODEL_TENSOR.ATTN_K,
1262
+ MODEL_TENSOR.ATTN_V,
1263
+ MODEL_TENSOR.ATTN_OUT,
1264
+ MODEL_TENSOR.FFN_NORM,
1265
+ MODEL_TENSOR.FFN_GATE,
1266
+ MODEL_TENSOR.FFN_DOWN,
1267
+ MODEL_TENSOR.FFN_UP,
1268
+ ],
1269
+ MODEL_ARCH.GRANITE_MOE: [
1270
+ MODEL_TENSOR.TOKEN_EMBD,
1271
+ MODEL_TENSOR.OUTPUT_NORM,
1272
+ MODEL_TENSOR.OUTPUT,
1273
+ MODEL_TENSOR.ATTN_NORM,
1274
+ MODEL_TENSOR.ATTN_Q,
1275
+ MODEL_TENSOR.ATTN_K,
1276
+ MODEL_TENSOR.ATTN_V,
1277
+ MODEL_TENSOR.ATTN_OUT,
1278
+ MODEL_TENSOR.FFN_NORM,
1279
+ MODEL_TENSOR.FFN_GATE_INP,
1280
+ MODEL_TENSOR.FFN_GATE_EXP,
1281
+ MODEL_TENSOR.FFN_DOWN_EXP,
1282
+ MODEL_TENSOR.FFN_UP_EXP,
1283
+ ],
1284
+ MODEL_ARCH.CHAMELEON: [
1285
+ MODEL_TENSOR.TOKEN_EMBD,
1286
+ MODEL_TENSOR.OUTPUT_NORM,
1287
+ MODEL_TENSOR.OUTPUT,
1288
+ MODEL_TENSOR.ATTN_NORM,
1289
+ MODEL_TENSOR.ATTN_Q,
1290
+ MODEL_TENSOR.ATTN_Q_NORM,
1291
+ MODEL_TENSOR.ATTN_K,
1292
+ MODEL_TENSOR.ATTN_K_NORM,
1293
+ MODEL_TENSOR.ATTN_V,
1294
+ MODEL_TENSOR.ATTN_OUT,
1295
+ MODEL_TENSOR.FFN_NORM,
1296
+ MODEL_TENSOR.FFN_GATE,
1297
+ MODEL_TENSOR.FFN_DOWN,
1298
+ MODEL_TENSOR.FFN_UP,
1299
+ ],
1300
+ # TODO
1301
+ }
1302
+
1303
+ # tensors that will not be serialized
1304
+ MODEL_TENSOR_SKIP: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
1305
+ MODEL_ARCH.LLAMA: [
1306
+ MODEL_TENSOR.ROPE_FREQS,
1307
+ MODEL_TENSOR.ATTN_ROT_EMBD,
1308
+ ],
1309
+ MODEL_ARCH.BAICHUAN: [
1310
+ MODEL_TENSOR.ROPE_FREQS,
1311
+ MODEL_TENSOR.ATTN_ROT_EMBD,
1312
+ ],
1313
+ MODEL_ARCH.QWEN: [
1314
+ MODEL_TENSOR.ROPE_FREQS,
1315
+ MODEL_TENSOR.ATTN_ROT_EMBD,
1316
+ ],
1317
+ MODEL_ARCH.CODESHELL: [
1318
+ MODEL_TENSOR.ROPE_FREQS,
1319
+ MODEL_TENSOR.ATTN_ROT_EMBD,
1320
+ ],
1321
+ MODEL_ARCH.ORION: [
1322
+ MODEL_TENSOR.ROPE_FREQS,
1323
+ MODEL_TENSOR.ATTN_ROT_EMBD,
1324
+ ],
1325
+ MODEL_ARCH.STARCODER2: [
1326
+ MODEL_TENSOR.ROPE_FREQS,
1327
+ MODEL_TENSOR.ATTN_ROT_EMBD,
1328
+ ],
1329
+ MODEL_ARCH.XVERSE: [
1330
+ MODEL_TENSOR.ROPE_FREQS,
1331
+ MODEL_TENSOR.ATTN_ROT_EMBD,
1332
+ ],
1333
+ MODEL_ARCH.DEEPSEEK2: [
1334
+ MODEL_TENSOR.ROPE_FREQS,
1335
+ MODEL_TENSOR.ATTN_ROT_EMBD,
1336
+ ],
1337
+ MODEL_ARCH.CHATGLM: [
1338
+ MODEL_TENSOR.ROPE_FREQS,
1339
+ ],
1340
+ MODEL_ARCH.NEMOTRON: [
1341
+ MODEL_TENSOR.ROPE_FREQS,
1342
+ MODEL_TENSOR.ATTN_ROT_EMBD,
1343
+ ],
1344
+ }
1345
+
1346
+ #
1347
+ # types
1348
+ #
1349
+
1350
+
1351
+ class TokenType(IntEnum):
1352
+ NORMAL = 1
1353
+ UNKNOWN = 2
1354
+ CONTROL = 3
1355
+ USER_DEFINED = 4
1356
+ UNUSED = 5
1357
+ BYTE = 6
1358
+
1359
+
1360
+ class RopeScalingType(Enum):
1361
+ NONE = 'none'
1362
+ LINEAR = 'linear'
1363
+ YARN = 'yarn'
1364
+
1365
+
1366
+ class PoolingType(IntEnum):
1367
+ NONE = 0
1368
+ MEAN = 1
1369
+ CLS = 2
1370
+
1371
+
1372
+ class GGMLQuantizationType(IntEnum):
1373
+ F32 = 0
1374
+ F16 = 1
1375
+ Q4_0 = 2
1376
+ Q4_1 = 3
1377
+ Q5_0 = 6
1378
+ Q5_1 = 7
1379
+ Q8_0 = 8
1380
+ Q8_1 = 9
1381
+ Q2_K = 10
1382
+ Q3_K = 11
1383
+ Q4_K = 12
1384
+ Q5_K = 13
1385
+ Q6_K = 14
1386
+ Q8_K = 15
1387
+ IQ2_XXS = 16
1388
+ IQ2_XS = 17
1389
+ IQ3_XXS = 18
1390
+ IQ1_S = 19
1391
+ IQ4_NL = 20
1392
+ IQ3_S = 21
1393
+ IQ2_S = 22
1394
+ IQ4_XS = 23
1395
+ I8 = 24
1396
+ I16 = 25
1397
+ I32 = 26
1398
+ I64 = 27
1399
+ F64 = 28
1400
+ IQ1_M = 29
1401
+ BF16 = 30
1402
+ Q4_0_4_4 = 31
1403
+ Q4_0_4_8 = 32
1404
+ Q4_0_8_8 = 33
1405
+ TQ1_0 = 34
1406
+ TQ2_0 = 35
1407
+
1408
+
1409
+ # TODO: add GGMLFileType from ggml_ftype in ggml.h
1410
+
1411
+
1412
+ # from llama_ftype in llama.h
1413
+ # ALL VALUES SHOULD BE THE SAME HERE AS THEY ARE OVER THERE.
1414
+ class LlamaFileType(IntEnum):
1415
+ ALL_F32 = 0
1416
+ MOSTLY_F16 = 1 # except 1d tensors
1417
+ MOSTLY_Q4_0 = 2 # except 1d tensors
1418
+ MOSTLY_Q4_1 = 3 # except 1d tensors
1419
+ # MOSTLY_Q4_1_SOME_F16 = 4 # tok_embeddings.weight and output.weight are F16
1420
+ # MOSTLY_Q4_2 = 5 # support has been removed
1421
+ # MOSTLY_Q4_3 = 6 # support has been removed
1422
+ MOSTLY_Q8_0 = 7 # except 1d tensors
1423
+ MOSTLY_Q5_0 = 8 # except 1d tensors
1424
+ MOSTLY_Q5_1 = 9 # except 1d tensors
1425
+ MOSTLY_Q2_K = 10 # except 1d tensors
1426
+ MOSTLY_Q3_K_S = 11 # except 1d tensors
1427
+ MOSTLY_Q3_K_M = 12 # except 1d tensors
1428
+ MOSTLY_Q3_K_L = 13 # except 1d tensors
1429
+ MOSTLY_Q4_K_S = 14 # except 1d tensors
1430
+ MOSTLY_Q4_K_M = 15 # except 1d tensors
1431
+ MOSTLY_Q5_K_S = 16 # except 1d tensors
1432
+ MOSTLY_Q5_K_M = 17 # except 1d tensors
1433
+ MOSTLY_Q6_K = 18 # except 1d tensors
1434
+ MOSTLY_IQ2_XXS = 19 # except 1d tensors
1435
+ MOSTLY_IQ2_XS = 20 # except 1d tensors
1436
+ MOSTLY_Q2_K_S = 21 # except 1d tensors
1437
+ MOSTLY_IQ3_XS = 22 # except 1d tensors
1438
+ MOSTLY_IQ3_XXS = 23 # except 1d tensors
1439
+ MOSTLY_IQ1_S = 24 # except 1d tensors
1440
+ MOSTLY_IQ4_NL = 25 # except 1d tensors
1441
+ MOSTLY_IQ3_S = 26 # except 1d tensors
1442
+ MOSTLY_IQ3_M = 27 # except 1d tensors
1443
+ MOSTLY_IQ2_S = 28 # except 1d tensors
1444
+ MOSTLY_IQ2_M = 29 # except 1d tensors
1445
+ MOSTLY_IQ4_XS = 30 # except 1d tensors
1446
+ MOSTLY_IQ1_M = 31 # except 1d tensors
1447
+ MOSTLY_BF16 = 32 # except 1d tensors
1448
+ MOSTLY_Q4_0_4_4 = 33 # except 1d tensors
1449
+ MOSTLY_Q4_0_4_8 = 34 # except 1d tensors
1450
+ MOSTLY_Q4_0_8_8 = 35 # except 1d tensors
1451
+ MOSTLY_TQ1_0 = 36 # except 1d tensors
1452
+ MOSTLY_TQ2_0 = 37 # except 1d tensors
1453
+
1454
+ GUESSED = 1024 # not specified in the model file
1455
+
1456
+
1457
+ class GGUFEndian(IntEnum):
1458
+ LITTLE = 0
1459
+ BIG = 1
1460
+
1461
+
1462
+ class GGUFValueType(IntEnum):
1463
+ UINT8 = 0
1464
+ INT8 = 1
1465
+ UINT16 = 2
1466
+ INT16 = 3
1467
+ UINT32 = 4
1468
+ INT32 = 5
1469
+ FLOAT32 = 6
1470
+ BOOL = 7
1471
+ STRING = 8
1472
+ ARRAY = 9
1473
+ UINT64 = 10
1474
+ INT64 = 11
1475
+ FLOAT64 = 12
1476
+
1477
+ @staticmethod
1478
+ def get_type(val: Any) -> GGUFValueType:
1479
+ if isinstance(val, (str, bytes, bytearray)):
1480
+ return GGUFValueType.STRING
1481
+ elif isinstance(val, list):
1482
+ return GGUFValueType.ARRAY
1483
+ elif isinstance(val, float):
1484
+ return GGUFValueType.FLOAT32
1485
+ elif isinstance(val, bool):
1486
+ return GGUFValueType.BOOL
1487
+ elif isinstance(val, int):
1488
+ return GGUFValueType.INT32
1489
+ # TODO: need help with 64-bit types in Python
1490
+ else:
1491
+ raise ValueError(f"Unknown type: {type(val)}")
1492
+
1493
+
1494
+ # Items here are (block size, type size)
1495
+ QK_K = 256
1496
+ GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = {
1497
+ GGMLQuantizationType.F32: (1, 4),
1498
+ GGMLQuantizationType.F16: (1, 2),
1499
+ GGMLQuantizationType.Q4_0: (32, 2 + 16),
1500
+ GGMLQuantizationType.Q4_1: (32, 2 + 2 + 16),
1501
+ GGMLQuantizationType.Q5_0: (32, 2 + 4 + 16),
1502
+ GGMLQuantizationType.Q5_1: (32, 2 + 2 + 4 + 16),
1503
+ GGMLQuantizationType.Q8_0: (32, 2 + 32),
1504
+ GGMLQuantizationType.Q8_1: (32, 4 + 4 + 32),
1505
+ GGMLQuantizationType.Q2_K: (256, 2 + 2 + QK_K // 16 + QK_K // 4),
1506
+ GGMLQuantizationType.Q3_K: (256, 2 + QK_K // 4 + QK_K // 8 + 12),
1507
+ GGMLQuantizationType.Q4_K: (256, 2 + 2 + QK_K // 2 + 12),
1508
+ GGMLQuantizationType.Q5_K: (256, 2 + 2 + QK_K // 2 + QK_K // 8 + 12),
1509
+ GGMLQuantizationType.Q6_K: (256, 2 + QK_K // 2 + QK_K // 4 + QK_K // 16),
1510
+ GGMLQuantizationType.Q8_K: (256, 4 + QK_K + QK_K // 8),
1511
+ GGMLQuantizationType.IQ2_XXS: (256, 2 + QK_K // 4),
1512
+ GGMLQuantizationType.IQ2_XS: (256, 2 + QK_K // 4 + QK_K // 32),
1513
+ GGMLQuantizationType.IQ3_XXS: (256, 2 + QK_K // 4 + QK_K // 8),
1514
+ GGMLQuantizationType.IQ1_S: (256, 2 + QK_K // 8 + QK_K // 16),
1515
+ GGMLQuantizationType.IQ4_NL: (32, 2 + 16),
1516
+ GGMLQuantizationType.IQ3_S: (256, 2 + QK_K // 4 + QK_K // 8 + QK_K // 32 + 4),
1517
+ GGMLQuantizationType.IQ2_S: (256, 2 + QK_K // 4 + QK_K // 16),
1518
+ GGMLQuantizationType.IQ4_XS: (256, 2 + 2 + QK_K // 2 + QK_K // 64),
1519
+ GGMLQuantizationType.I8: (1, 1),
1520
+ GGMLQuantizationType.I16: (1, 2),
1521
+ GGMLQuantizationType.I32: (1, 4),
1522
+ GGMLQuantizationType.I64: (1, 8),
1523
+ GGMLQuantizationType.F64: (1, 8),
1524
+ GGMLQuantizationType.IQ1_M: (256, QK_K // 8 + QK_K // 16 + QK_K // 32),
1525
+ GGMLQuantizationType.BF16: (1, 2),
1526
+ GGMLQuantizationType.Q4_0_4_4:(32, 2 + 16),
1527
+ GGMLQuantizationType.Q4_0_4_8:(32, 2 + 16),
1528
+ GGMLQuantizationType.Q4_0_8_8:(32, 2 + 16),
1529
+ GGMLQuantizationType.TQ1_0: (256, 2 + 4 * 13),
1530
+ GGMLQuantizationType.TQ2_0: (256, 2 + 64),
1531
+ }
1532
+
1533
+
1534
+ # Aliases for backward compatibility.
1535
+
1536
+ # general
1537
+ KEY_GENERAL_ARCHITECTURE = Keys.General.ARCHITECTURE
1538
+ KEY_GENERAL_QUANTIZATION_VERSION = Keys.General.QUANTIZATION_VERSION
1539
+ KEY_GENERAL_ALIGNMENT = Keys.General.ALIGNMENT
1540
+ KEY_GENERAL_NAME = Keys.General.NAME
1541
+ KEY_GENERAL_AUTHOR = Keys.General.AUTHOR
1542
+ KEY_GENERAL_URL = Keys.General.URL
1543
+ KEY_GENERAL_DESCRIPTION = Keys.General.DESCRIPTION
1544
+ KEY_GENERAL_LICENSE = Keys.General.LICENSE
1545
+ KEY_GENERAL_SOURCE_URL = Keys.General.SOURCE_URL
1546
+ KEY_GENERAL_FILE_TYPE = Keys.General.FILE_TYPE
1547
+
1548
+ # LLM
1549
+ KEY_VOCAB_SIZE = Keys.LLM.VOCAB_SIZE
1550
+ KEY_CONTEXT_LENGTH = Keys.LLM.CONTEXT_LENGTH
1551
+ KEY_EMBEDDING_LENGTH = Keys.LLM.EMBEDDING_LENGTH
1552
+ KEY_BLOCK_COUNT = Keys.LLM.BLOCK_COUNT
1553
+ KEY_FEED_FORWARD_LENGTH = Keys.LLM.FEED_FORWARD_LENGTH
1554
+ KEY_USE_PARALLEL_RESIDUAL = Keys.LLM.USE_PARALLEL_RESIDUAL
1555
+ KEY_TENSOR_DATA_LAYOUT = Keys.LLM.TENSOR_DATA_LAYOUT
1556
+
1557
+ # attention
1558
+ KEY_ATTENTION_HEAD_COUNT = Keys.Attention.HEAD_COUNT
1559
+ KEY_ATTENTION_HEAD_COUNT_KV = Keys.Attention.HEAD_COUNT_KV
1560
+ KEY_ATTENTION_MAX_ALIBI_BIAS = Keys.Attention.MAX_ALIBI_BIAS
1561
+ KEY_ATTENTION_CLAMP_KQV = Keys.Attention.CLAMP_KQV
1562
+ KEY_ATTENTION_LAYERNORM_EPS = Keys.Attention.LAYERNORM_EPS
1563
+ KEY_ATTENTION_LAYERNORM_RMS_EPS = Keys.Attention.LAYERNORM_RMS_EPS
1564
+
1565
+ # RoPE
1566
+ KEY_ROPE_DIMENSION_COUNT = Keys.Rope.DIMENSION_COUNT
1567
+ KEY_ROPE_FREQ_BASE = Keys.Rope.FREQ_BASE
1568
+ KEY_ROPE_SCALING_TYPE = Keys.Rope.SCALING_TYPE
1569
+ KEY_ROPE_SCALING_FACTOR = Keys.Rope.SCALING_FACTOR
1570
+ KEY_ROPE_SCALING_ORIG_CTX_LEN = Keys.Rope.SCALING_ORIG_CTX_LEN
1571
+ KEY_ROPE_SCALING_FINETUNED = Keys.Rope.SCALING_FINETUNED
1572
+
1573
+ # SSM
1574
+ KEY_SSM_CONV_KERNEL = Keys.SSM.CONV_KERNEL
1575
+ KEY_SSM_INNER_SIZE = Keys.SSM.INNER_SIZE
1576
+ KEY_SSM_STATE_SIZE = Keys.SSM.STATE_SIZE
1577
+ KEY_SSM_TIME_STEP_RANK = Keys.SSM.TIME_STEP_RANK
1578
+ KEY_SSM_DT_B_C_RMS = Keys.SSM.DT_B_C_RMS
1579
+
1580
+ # tokenization
1581
+ KEY_TOKENIZER_MODEL = Keys.Tokenizer.MODEL
1582
+ KEY_TOKENIZER_PRE = Keys.Tokenizer.PRE
1583
+ KEY_TOKENIZER_LIST = Keys.Tokenizer.LIST
1584
+ KEY_TOKENIZER_TOKEN_TYPE = Keys.Tokenizer.TOKEN_TYPE
1585
+ KEY_TOKENIZER_SCORES = Keys.Tokenizer.SCORES
1586
+ KEY_TOKENIZER_MERGES = Keys.Tokenizer.MERGES
1587
+ KEY_TOKENIZER_BOS_ID = Keys.Tokenizer.BOS_ID
1588
+ KEY_TOKENIZER_EOS_ID = Keys.Tokenizer.EOS_ID
1589
+ KEY_TOKENIZER_EOT_ID = Keys.Tokenizer.EOT_ID
1590
+ KEY_TOKENIZER_EOM_ID = Keys.Tokenizer.EOM_ID
1591
+ KEY_TOKENIZER_UNK_ID = Keys.Tokenizer.UNK_ID
1592
+ KEY_TOKENIZER_SEP_ID = Keys.Tokenizer.SEP_ID
1593
+ KEY_TOKENIZER_PAD_ID = Keys.Tokenizer.PAD_ID
1594
+ KEY_TOKENIZER_CLS_ID = Keys.Tokenizer.CLS_ID
1595
+ KEY_TOKENIZER_MASK_ID = Keys.Tokenizer.MASK_ID
1596
+ KEY_TOKENIZER_HF_JSON = Keys.Tokenizer.HF_JSON
1597
+ KEY_TOKENIZER_RWKV = Keys.Tokenizer.RWKV
1598
+
1599
+ KEY_TOKENIZER_FIM_PRE_ID = Keys.Tokenizer.FIM_PRE_ID
1600
+ KEY_TOKENIZER_FIM_SUF_ID = Keys.Tokenizer.FIM_SUF_ID
1601
+ KEY_TOKENIZER_FIM_MID_ID = Keys.Tokenizer.FIM_MID_ID
1602
+ KEY_TOKENIZER_FIM_PAD_ID = Keys.Tokenizer.FIM_PAD_ID
1603
+ KEY_TOKENIZER_FIM_REP_ID = Keys.Tokenizer.FIM_REP_ID
1604
+ KEY_TOKENIZER_FIM_SEP_ID = Keys.Tokenizer.FIM_SEP_ID
1605
+
1606
+ # deprecated
1607
+ KEY_TOKENIZER_PREFIX_ID = Keys.Tokenizer.PREFIX_ID
1608
+ KEY_TOKENIZER_SUFFIX_ID = Keys.Tokenizer.SUFFIX_ID
1609
+ KEY_TOKENIZER_MIDDLE_ID = Keys.Tokenizer.MIDDLE_ID
gguf/gguf.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file left for compatibility. If you want to use the GGUF API from Python
2
+ # then don't import gguf/gguf.py directly. If you're looking for examples, see the
3
+ # examples/ directory for gguf-py
4
+
5
+ import importlib
6
+ import sys
7
+ from pathlib import Path
8
+
9
+ sys.path.insert(0, str(Path(__file__).parent.parent))
10
+
11
+ # Compatibility for people trying to import gguf/gguf.py directly instead of as a package.
12
+ importlib.invalidate_caches()
13
+ import gguf # noqa: E402
14
+
15
+ importlib.reload(gguf)
gguf/gguf_reader.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # GGUF file reading/modification support. For API usage information,
3
+ # please see the files scripts/ for some fairly simple examples.
4
+ #
5
+ from __future__ import annotations
6
+
7
+ import logging
8
+ import os
9
+ from collections import OrderedDict
10
+ from typing import Any, Literal, NamedTuple, TypeVar, Union
11
+
12
+ import numpy as np
13
+ import numpy.typing as npt
14
+
15
+ from .quants import quant_shape_to_byte_shape
16
+
17
+ if __name__ == "__main__":
18
+ import sys
19
+ from pathlib import Path
20
+
21
+ # Allow running file in package as a script.
22
+ sys.path.insert(0, str(Path(__file__).parent.parent))
23
+
24
+ from gguf.constants import (
25
+ GGML_QUANT_SIZES,
26
+ GGUF_DEFAULT_ALIGNMENT,
27
+ GGUF_MAGIC,
28
+ GGUF_VERSION,
29
+ GGMLQuantizationType,
30
+ GGUFValueType,
31
+ )
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+ READER_SUPPORTED_VERSIONS = [2, GGUF_VERSION]
36
+
37
+
38
+ class ReaderField(NamedTuple):
39
+ # Offset to start of this field.
40
+ offset: int
41
+
42
+ # Name of the field (not necessarily from file data).
43
+ name: str
44
+
45
+ # Data parts. Some types have multiple components, such as strings
46
+ # that consist of a length followed by the string data.
47
+ parts: list[npt.NDArray[Any]] = []
48
+
49
+ # Indexes into parts that we can call the actual data. For example
50
+ # an array of strings will be populated with indexes to the actual
51
+ # string data.
52
+ data: list[int] = [-1]
53
+
54
+ types: list[GGUFValueType] = []
55
+
56
+
57
+ class ReaderTensor(NamedTuple):
58
+ name: str
59
+ tensor_type: GGMLQuantizationType
60
+ shape: npt.NDArray[np.uint32]
61
+ n_elements: int
62
+ n_bytes: int
63
+ data_offset: int
64
+ data: npt.NDArray[Any]
65
+ field: ReaderField
66
+
67
+
68
+ class GGUFReader:
69
+ # I - same as host, S - swapped
70
+ byte_order: Literal['I', 'S'] = 'I'
71
+ alignment: int = GGUF_DEFAULT_ALIGNMENT
72
+ data_offset: int
73
+
74
+ # Note: Internal helper, API may change.
75
+ gguf_scalar_to_np: dict[GGUFValueType, type[np.generic]] = {
76
+ GGUFValueType.UINT8: np.uint8,
77
+ GGUFValueType.INT8: np.int8,
78
+ GGUFValueType.UINT16: np.uint16,
79
+ GGUFValueType.INT16: np.int16,
80
+ GGUFValueType.UINT32: np.uint32,
81
+ GGUFValueType.INT32: np.int32,
82
+ GGUFValueType.FLOAT32: np.float32,
83
+ GGUFValueType.UINT64: np.uint64,
84
+ GGUFValueType.INT64: np.int64,
85
+ GGUFValueType.FLOAT64: np.float64,
86
+ GGUFValueType.BOOL: np.bool_,
87
+ }
88
+
89
+ def __init__(self, path: os.PathLike[str] | str, mode: Literal['r', 'r+', 'c'] = 'r'):
90
+ self.data = np.memmap(path, mode = mode)
91
+ offs = 0
92
+
93
+ # Check for GGUF magic
94
+ if self._get(offs, np.uint32, override_order = '<')[0] != GGUF_MAGIC:
95
+ raise ValueError('GGUF magic invalid')
96
+ offs += 4
97
+
98
+ # Check GGUF version
99
+ temp_version = self._get(offs, np.uint32)
100
+ if temp_version[0] & 65535 == 0:
101
+ # If we get 0 here that means it's (probably) a GGUF file created for
102
+ # the opposite byte order of the machine this script is running on.
103
+ self.byte_order = 'S'
104
+ temp_version = temp_version.newbyteorder(self.byte_order)
105
+ version = temp_version[0]
106
+ if version not in READER_SUPPORTED_VERSIONS:
107
+ raise ValueError(f'Sorry, file appears to be version {version} which we cannot handle')
108
+ self.fields: OrderedDict[str, ReaderField] = OrderedDict()
109
+ self.tensors: list[ReaderTensor] = []
110
+ offs += self._push_field(ReaderField(offs, 'GGUF.version', [temp_version], [0], [GGUFValueType.UINT32]))
111
+
112
+ # Check tensor count and kv count
113
+ temp_counts = self._get(offs, np.uint64, 2)
114
+ offs += self._push_field(ReaderField(offs, 'GGUF.tensor_count', [temp_counts[:1]], [0], [GGUFValueType.UINT64]))
115
+ offs += self._push_field(ReaderField(offs, 'GGUF.kv_count', [temp_counts[1:]], [0], [GGUFValueType.UINT64]))
116
+ tensor_count, kv_count = temp_counts
117
+ offs = self._build_fields(offs, kv_count)
118
+
119
+ # Build Tensor Info Fields
120
+ offs, tensors_fields = self._build_tensor_info(offs, tensor_count)
121
+ new_align = self.fields.get('general.alignment')
122
+ if new_align is not None:
123
+ if new_align.types != [GGUFValueType.UINT32]:
124
+ raise ValueError('Bad type for general.alignment field')
125
+ self.alignment = new_align.parts[-1][0]
126
+ padding = offs % self.alignment
127
+ if padding != 0:
128
+ offs += self.alignment - padding
129
+ self.data_offset = offs
130
+ self._build_tensors(offs, tensors_fields)
131
+
132
+ _DT = TypeVar('_DT', bound = npt.DTypeLike)
133
+
134
+ # Fetch a key/value metadata field by key.
135
+ def get_field(self, key: str) -> Union[ReaderField, None]:
136
+ return self.fields.get(key, None)
137
+
138
+ # Fetch a tensor from the list by index.
139
+ def get_tensor(self, idx: int) -> ReaderTensor:
140
+ return self.tensors[idx]
141
+
142
+ def _get(
143
+ self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I', 'S', '<'] = None,
144
+ ) -> npt.NDArray[Any]:
145
+ count = int(count)
146
+ itemsize = int(np.empty([], dtype = dtype).itemsize)
147
+ end_offs = offset + itemsize * count
148
+ return (
149
+ self.data[offset:end_offs]
150
+ .view(dtype = dtype)[:count]
151
+ .newbyteorder(override_order or self.byte_order)
152
+ )
153
+
154
+ def _push_field(self, field: ReaderField, skip_sum: bool = False) -> int:
155
+ if field.name in self.fields:
156
+ # TODO: add option to generate error on duplicate keys
157
+ # raise KeyError(f'Duplicate {field.name} already in list at offset {field.offset}')
158
+
159
+ logger.warning(f'Duplicate key {field.name} at offset {field.offset}')
160
+ self.fields[field.name + '_{}'.format(field.offset)] = field
161
+ else:
162
+ self.fields[field.name] = field
163
+ return 0 if skip_sum else sum(int(part.nbytes) for part in field.parts)
164
+
165
+ def _get_str(self, offset: int) -> tuple[npt.NDArray[np.uint64], npt.NDArray[np.uint8]]:
166
+ slen = self._get(offset, np.uint64)
167
+ return slen, self._get(offset + 8, np.uint8, slen[0])
168
+
169
+ def _get_field_parts(
170
+ self, orig_offs: int, raw_type: int,
171
+ ) -> tuple[int, list[npt.NDArray[Any]], list[int], list[GGUFValueType]]:
172
+ offs = orig_offs
173
+ types: list[GGUFValueType] = []
174
+ gtype = GGUFValueType(raw_type)
175
+ types.append(gtype)
176
+ # Handle strings.
177
+ if gtype == GGUFValueType.STRING:
178
+ sparts: list[npt.NDArray[Any]] = list(self._get_str(offs))
179
+ size = sum(int(part.nbytes) for part in sparts)
180
+ return size, sparts, [1], types
181
+ # Check if it's a simple scalar type.
182
+ nptype = self.gguf_scalar_to_np.get(gtype)
183
+ if nptype is not None:
184
+ val = self._get(offs, nptype)
185
+ return int(val.nbytes), [val], [0], types
186
+ # Handle arrays.
187
+ if gtype == GGUFValueType.ARRAY:
188
+ raw_itype = self._get(offs, np.uint32)
189
+ offs += int(raw_itype.nbytes)
190
+ alen = self._get(offs, np.uint64)
191
+ offs += int(alen.nbytes)
192
+ aparts: list[npt.NDArray[Any]] = [raw_itype, alen]
193
+ data_idxs: list[int] = []
194
+ for idx in range(alen[0]):
195
+ curr_size, curr_parts, curr_idxs, curr_types = self._get_field_parts(offs, raw_itype[0])
196
+ if idx == 0:
197
+ types += curr_types
198
+ idxs_offs = len(aparts)
199
+ aparts += curr_parts
200
+ data_idxs += (idx + idxs_offs for idx in curr_idxs)
201
+ offs += curr_size
202
+ return offs - orig_offs, aparts, data_idxs, types
203
+ # We can't deal with this one.
204
+ raise ValueError('Unknown/unhandled field type {gtype}')
205
+
206
+ def _get_tensor_info_field(self, orig_offs: int) -> ReaderField:
207
+ offs = orig_offs
208
+
209
+ # Get Tensor Name
210
+ name_len, name_data = self._get_str(offs)
211
+ offs += int(name_len.nbytes + name_data.nbytes)
212
+
213
+ # Get Tensor Dimensions Count
214
+ n_dims = self._get(offs, np.uint32)
215
+ offs += int(n_dims.nbytes)
216
+
217
+ # Get Tensor Dimension Array
218
+ dims = self._get(offs, np.uint64, n_dims[0])
219
+ offs += int(dims.nbytes)
220
+
221
+ # Get Tensor Encoding Scheme Type
222
+ raw_dtype = self._get(offs, np.uint32)
223
+ offs += int(raw_dtype.nbytes)
224
+
225
+ # Get Tensor Offset
226
+ offset_tensor = self._get(offs, np.uint64)
227
+ offs += int(offset_tensor.nbytes)
228
+
229
+ return ReaderField(
230
+ orig_offs,
231
+ str(bytes(name_data), encoding = 'utf-8'),
232
+ [name_len, name_data, n_dims, dims, raw_dtype, offset_tensor],
233
+ [1, 3, 4, 5],
234
+ )
235
+
236
+ def _build_fields(self, offs: int, count: int) -> int:
237
+ for _ in range(count):
238
+ orig_offs = offs
239
+ kv_klen, kv_kdata = self._get_str(offs)
240
+ offs += int(kv_klen.nbytes + kv_kdata.nbytes)
241
+ raw_kv_type = self._get(offs, np.uint32)
242
+ offs += int(raw_kv_type.nbytes)
243
+ parts: list[npt.NDArray[Any]] = [kv_klen, kv_kdata, raw_kv_type]
244
+ idxs_offs = len(parts)
245
+ field_size, field_parts, field_idxs, field_types = self._get_field_parts(offs, raw_kv_type[0])
246
+ parts += field_parts
247
+ self._push_field(ReaderField(
248
+ orig_offs,
249
+ str(bytes(kv_kdata), encoding = 'utf-8'),
250
+ parts,
251
+ [idx + idxs_offs for idx in field_idxs],
252
+ field_types,
253
+ ), skip_sum = True)
254
+ offs += field_size
255
+ return offs
256
+
257
+ def _build_tensor_info(self, offs: int, count: int) -> tuple[int, list[ReaderField]]:
258
+ tensor_fields = []
259
+ for _ in range(count):
260
+ field = self._get_tensor_info_field(offs)
261
+ offs += sum(int(part.nbytes) for part in field.parts)
262
+ tensor_fields.append(field)
263
+ return offs, tensor_fields
264
+
265
+ def _build_tensors(self, start_offs: int, fields: list[ReaderField]) -> None:
266
+ tensors = []
267
+ tensor_names = set() # keep track of name to prevent duplicated tensors
268
+ for field in fields:
269
+ _name_len, name_data, _n_dims, dims, raw_dtype, offset_tensor = field.parts
270
+ # check if there's any tensor having same name already in the list
271
+ tensor_name = str(bytes(name_data), encoding = 'utf-8')
272
+ if tensor_name in tensor_names:
273
+ raise ValueError(f'Found duplicated tensor with name {tensor_name}')
274
+ tensor_names.add(tensor_name)
275
+ ggml_type = GGMLQuantizationType(raw_dtype[0])
276
+ n_elems = int(np.prod(dims))
277
+ np_dims = tuple(reversed(dims.tolist()))
278
+ block_size, type_size = GGML_QUANT_SIZES[ggml_type]
279
+ n_bytes = n_elems * type_size // block_size
280
+ data_offs = int(start_offs + offset_tensor[0])
281
+ item_type: npt.DTypeLike
282
+ if ggml_type == GGMLQuantizationType.F16:
283
+ item_count = n_elems
284
+ item_type = np.float16
285
+ elif ggml_type == GGMLQuantizationType.F32:
286
+ item_count = n_elems
287
+ item_type = np.float32
288
+ elif ggml_type == GGMLQuantizationType.F64:
289
+ item_count = n_elems
290
+ item_type = np.float64
291
+ elif ggml_type == GGMLQuantizationType.I8:
292
+ item_count = n_elems
293
+ item_type = np.int8
294
+ elif ggml_type == GGMLQuantizationType.I16:
295
+ item_count = n_elems
296
+ item_type = np.int16
297
+ elif ggml_type == GGMLQuantizationType.I32:
298
+ item_count = n_elems
299
+ item_type = np.int32
300
+ elif ggml_type == GGMLQuantizationType.I64:
301
+ item_count = n_elems
302
+ item_type = np.int64
303
+ else:
304
+ item_count = n_bytes
305
+ item_type = np.uint8
306
+ np_dims = quant_shape_to_byte_shape(np_dims, ggml_type)
307
+ tensors.append(ReaderTensor(
308
+ name = tensor_name,
309
+ tensor_type = ggml_type,
310
+ shape = dims,
311
+ n_elements = n_elems,
312
+ n_bytes = n_bytes,
313
+ data_offset = data_offs,
314
+ data = self._get(data_offs, item_type, item_count).reshape(np_dims),
315
+ field = field,
316
+ ))
317
+ self.tensors = tensors
gguf/gguf_writer.py ADDED
@@ -0,0 +1,903 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import os
5
+ import shutil
6
+ import struct
7
+ import tempfile
8
+ from dataclasses import dataclass
9
+ from enum import Enum, auto
10
+ from math import prod
11
+ from pathlib import Path
12
+ from io import BufferedWriter
13
+ from typing import IO, Any, Sequence, Mapping
14
+ from string import ascii_letters, digits
15
+
16
+ import numpy as np
17
+
18
+ from .constants import (
19
+ GGUF_DEFAULT_ALIGNMENT,
20
+ GGUF_MAGIC,
21
+ GGUF_VERSION,
22
+ GGMLQuantizationType,
23
+ GGUFEndian,
24
+ GGUFValueType,
25
+ Keys,
26
+ RopeScalingType,
27
+ PoolingType,
28
+ TokenType,
29
+ )
30
+
31
+ from .quants import quant_shape_from_byte_shape
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ SHARD_NAME_FORMAT = "{:s}-{:05d}-of-{:05d}.gguf"
37
+
38
+
39
+ @dataclass
40
+ class TensorInfo:
41
+ shape: Sequence[int]
42
+ dtype: GGMLQuantizationType
43
+ nbytes: int
44
+ tensor: np.ndarray[Any, Any] | None = None
45
+
46
+
47
+ @dataclass
48
+ class GGUFValue:
49
+ value: Any
50
+ type: GGUFValueType
51
+
52
+
53
+ class WriterState(Enum):
54
+ NO_FILE = auto()
55
+ EMPTY = auto()
56
+ HEADER = auto()
57
+ KV_DATA = auto()
58
+ TI_DATA = auto()
59
+ WEIGHTS = auto()
60
+
61
+
62
+ class GGUFWriter:
63
+ fout: list[BufferedWriter] | None
64
+ path: Path | None
65
+ temp_file: tempfile.SpooledTemporaryFile[bytes] | None
66
+ tensors: list[dict[str, TensorInfo]]
67
+ kv_data: list[dict[str, GGUFValue]]
68
+ state: WriterState
69
+ _simple_value_packing = {
70
+ GGUFValueType.UINT8: "B",
71
+ GGUFValueType.INT8: "b",
72
+ GGUFValueType.UINT16: "H",
73
+ GGUFValueType.INT16: "h",
74
+ GGUFValueType.UINT32: "I",
75
+ GGUFValueType.INT32: "i",
76
+ GGUFValueType.FLOAT32: "f",
77
+ GGUFValueType.UINT64: "Q",
78
+ GGUFValueType.INT64: "q",
79
+ GGUFValueType.FLOAT64: "d",
80
+ GGUFValueType.BOOL: "?",
81
+ }
82
+
83
+ def __init__(
84
+ self, path: os.PathLike[str] | str | None, arch: str, use_temp_file: bool = False, endianess: GGUFEndian = GGUFEndian.LITTLE,
85
+ split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False
86
+ ):
87
+ self.fout = None
88
+ self.path = Path(path) if path else None
89
+ self.arch = arch
90
+ self.endianess = endianess
91
+ self.data_alignment = GGUF_DEFAULT_ALIGNMENT
92
+ self.use_temp_file = use_temp_file
93
+ self.temp_file = None
94
+ self.tensors = [{}]
95
+ self.kv_data = [{}]
96
+ self.split_max_tensors = split_max_tensors
97
+ self.split_max_size = split_max_size
98
+ self.dry_run = dry_run
99
+ self.small_first_shard = small_first_shard
100
+ logger.info("gguf: This GGUF file is for {0} Endian only".format(
101
+ "Big" if self.endianess == GGUFEndian.BIG else "Little",
102
+ ))
103
+ self.state = WriterState.NO_FILE
104
+
105
+ if self.small_first_shard:
106
+ self.tensors.append({})
107
+
108
+ self.add_architecture()
109
+
110
+ def get_total_parameter_count(self) -> tuple[int, int, int, int]:
111
+ total_params = 0
112
+ shared_params = 0
113
+ expert_params = 0
114
+
115
+ expert_sum = 0
116
+ n_expert_tensors = 0
117
+
118
+ last_lora_a: tuple[str, TensorInfo] | None = None
119
+
120
+ for tensors in self.tensors:
121
+ for name, info in tensors.items():
122
+
123
+ shape = info.shape
124
+
125
+ if name.endswith(".lora_a"):
126
+ last_lora_a = (name, info)
127
+ continue
128
+ elif name.endswith(".lora_b"):
129
+ if last_lora_a is None or last_lora_a[0] != name[:-1] + "a":
130
+ # Bail when the LoRA pair can't be found trivially
131
+ logger.warning("can't measure LoRA size correctly, tensor order is unusual")
132
+ return 0, 0, 0, 0
133
+ else:
134
+ shape = (*shape[:-1], last_lora_a[1].shape[-1])
135
+
136
+ size = prod(shape)
137
+
138
+ if "_exps." in name:
139
+ expert_params += (size // shape[-3])
140
+ expert_sum += shape[-3]
141
+ n_expert_tensors += 1
142
+ else:
143
+ shared_params += size
144
+
145
+ total_params += size
146
+
147
+ # Hopefully this should work even for variable-expert-count models
148
+ expert_count = (expert_sum // n_expert_tensors) if n_expert_tensors > 0 else 0
149
+
150
+ # Negate the total to signal it's likely not exact
151
+ if last_lora_a is not None:
152
+ total_params = -total_params
153
+
154
+ # NOTE: keep the output in the same order as accepted by 'size_label' in gguf-py/gguf/utility.py
155
+ return total_params, shared_params, expert_params, expert_count
156
+
157
+ def format_shard_names(self, path: Path) -> list[Path]:
158
+ if len(self.tensors) == 1:
159
+ return [path]
160
+ return [path.with_name(SHARD_NAME_FORMAT.format(path.stem, i + 1, len(self.tensors))) for i in range(len(self.tensors))]
161
+
162
+ def open_output_file(self, path: Path | None = None) -> None:
163
+ if self.state is WriterState.EMPTY and self.fout is not None and (path is None or path == self.path):
164
+ # allow calling this multiple times as long as the path is the same
165
+ return
166
+
167
+ if self.state is not WriterState.NO_FILE:
168
+ raise ValueError(f'Expected output file to be not yet opened, got {self.state}')
169
+
170
+ if path is not None:
171
+ self.path = path
172
+
173
+ if self.path is not None:
174
+ filenames = self.print_plan()
175
+ self.fout = [open(filename, "wb") for filename in filenames]
176
+ self.state = WriterState.EMPTY
177
+
178
+ def print_plan(self) -> list[Path]:
179
+ logger.info("Writing the following files:")
180
+ assert self.path is not None
181
+ filenames = self.format_shard_names(self.path)
182
+ assert len(filenames) == len(self.tensors)
183
+ for name, tensors in zip(filenames, self.tensors):
184
+ logger.info(f"{name}: n_tensors = {len(tensors)}, total_size = {GGUFWriter.format_n_bytes_to_str(sum(ti.nbytes for ti in tensors.values()))}")
185
+
186
+ if self.dry_run:
187
+ logger.info("Dry run, not writing files")
188
+ for name in filenames:
189
+ print(name) # noqa: NP100
190
+ exit()
191
+
192
+ return filenames
193
+
194
+ def add_shard_kv_data(self) -> None:
195
+ if len(self.tensors) == 1:
196
+ return
197
+
198
+ total_tensors = sum(len(t) for t in self.tensors)
199
+ assert self.fout is not None
200
+ total_splits = len(self.fout)
201
+ self.kv_data.extend({} for _ in range(len(self.kv_data), total_splits))
202
+ for i, kv_data in enumerate(self.kv_data):
203
+ kv_data[Keys.Split.LLM_KV_SPLIT_NO] = GGUFValue(i, GGUFValueType.UINT16)
204
+ kv_data[Keys.Split.LLM_KV_SPLIT_COUNT] = GGUFValue(total_splits, GGUFValueType.UINT16)
205
+ kv_data[Keys.Split.LLM_KV_SPLIT_TENSORS_COUNT] = GGUFValue(total_tensors, GGUFValueType.INT32)
206
+
207
+ def write_header_to_file(self, path: Path | None = None) -> None:
208
+ if len(self.tensors) == 1 and (self.split_max_tensors != 0 or self.split_max_size != 0):
209
+ logger.warning("Model fails split requirements, not splitting")
210
+
211
+ self.open_output_file(path)
212
+
213
+ if self.state is not WriterState.EMPTY:
214
+ raise ValueError(f'Expected output file to be empty, got {self.state}')
215
+
216
+ assert self.fout is not None
217
+ assert len(self.fout) == len(self.tensors)
218
+ assert len(self.kv_data) == 1
219
+
220
+ self.add_shard_kv_data()
221
+
222
+ for fout, tensors, kv_data in zip(self.fout, self.tensors, self.kv_data):
223
+ fout.write(self._pack("<I", GGUF_MAGIC, skip_pack_prefix = True))
224
+ fout.write(self._pack("I", GGUF_VERSION))
225
+ fout.write(self._pack("Q", len(tensors)))
226
+ fout.write(self._pack("Q", len(kv_data)))
227
+ fout.flush()
228
+ self.state = WriterState.HEADER
229
+
230
+ def write_kv_data_to_file(self) -> None:
231
+ if self.state is not WriterState.HEADER:
232
+ raise ValueError(f'Expected output file to contain the header, got {self.state}')
233
+ assert self.fout is not None
234
+
235
+ for fout, kv_data in zip(self.fout, self.kv_data):
236
+ kv_bytes = bytearray()
237
+
238
+ for key, val in kv_data.items():
239
+ kv_bytes += self._pack_val(key, GGUFValueType.STRING, add_vtype=False)
240
+ kv_bytes += self._pack_val(val.value, val.type, add_vtype=True)
241
+
242
+ fout.write(kv_bytes)
243
+
244
+ self.flush()
245
+ self.state = WriterState.KV_DATA
246
+
247
+ def write_ti_data_to_file(self) -> None:
248
+ if self.state is not WriterState.KV_DATA:
249
+ raise ValueError(f'Expected output file to contain KV data, got {self.state}')
250
+ assert self.fout is not None
251
+
252
+ for fout, tensors in zip(self.fout, self.tensors):
253
+ ti_data = bytearray()
254
+ offset_tensor = 0
255
+
256
+ for name, ti in tensors.items():
257
+ ti_data += self._pack_val(name, GGUFValueType.STRING, add_vtype=False)
258
+ n_dims = len(ti.shape)
259
+ ti_data += self._pack("I", n_dims)
260
+ for j in range(n_dims):
261
+ ti_data += self._pack("Q", ti.shape[n_dims - 1 - j])
262
+ ti_data += self._pack("I", ti.dtype)
263
+ ti_data += self._pack("Q", offset_tensor)
264
+ offset_tensor += GGUFWriter.ggml_pad(ti.nbytes, self.data_alignment)
265
+
266
+ fout.write(ti_data)
267
+ fout.flush()
268
+ self.state = WriterState.TI_DATA
269
+
270
+ def add_key_value(self, key: str, val: Any, vtype: GGUFValueType) -> None:
271
+ if any(key in kv_data for kv_data in self.kv_data):
272
+ raise ValueError(f'Duplicated key name {key!r}')
273
+
274
+ self.kv_data[0][key] = GGUFValue(value=val, type=vtype)
275
+
276
+ def add_uint8(self, key: str, val: int) -> None:
277
+ self.add_key_value(key,val, GGUFValueType.UINT8)
278
+
279
+ def add_int8(self, key: str, val: int) -> None:
280
+ self.add_key_value(key, val, GGUFValueType.INT8)
281
+
282
+ def add_uint16(self, key: str, val: int) -> None:
283
+ self.add_key_value(key, val, GGUFValueType.UINT16)
284
+
285
+ def add_int16(self, key: str, val: int) -> None:
286
+ self.add_key_value(key, val, GGUFValueType.INT16)
287
+
288
+ def add_uint32(self, key: str, val: int) -> None:
289
+ self.add_key_value(key, val, GGUFValueType.UINT32)
290
+
291
+ def add_int32(self, key: str, val: int) -> None:
292
+ self.add_key_value(key, val, GGUFValueType.INT32)
293
+
294
+ def add_float32(self, key: str, val: float) -> None:
295
+ self.add_key_value(key, val, GGUFValueType.FLOAT32)
296
+
297
+ def add_uint64(self, key: str, val: int) -> None:
298
+ self.add_key_value(key, val, GGUFValueType.UINT64)
299
+
300
+ def add_int64(self, key: str, val: int) -> None:
301
+ self.add_key_value(key, val, GGUFValueType.INT64)
302
+
303
+ def add_float64(self, key: str, val: float) -> None:
304
+ self.add_key_value(key, val, GGUFValueType.FLOAT64)
305
+
306
+ def add_bool(self, key: str, val: bool) -> None:
307
+ self.add_key_value(key, val, GGUFValueType.BOOL)
308
+
309
+ def add_string(self, key: str, val: str) -> None:
310
+ if not val:
311
+ return
312
+ self.add_key_value(key, val, GGUFValueType.STRING)
313
+
314
+ def add_array(self, key: str, val: Sequence[Any]) -> None:
315
+ if len(val) == 0:
316
+ return
317
+ self.add_key_value(key, val, GGUFValueType.ARRAY)
318
+
319
+ @staticmethod
320
+ def ggml_pad(x: int, n: int) -> int:
321
+ return ((x + n - 1) // n) * n
322
+
323
+ def add_tensor_info(
324
+ self, name: str, tensor_shape: Sequence[int], tensor_dtype: np.dtype,
325
+ tensor_nbytes: int, raw_dtype: GGMLQuantizationType | None = None,
326
+ ) -> None:
327
+ if self.state is not WriterState.NO_FILE:
328
+ raise ValueError(f'Expected output file to be not yet opened, got {self.state}')
329
+
330
+ if any(name in tensors for tensors in self.tensors):
331
+ raise ValueError(f'Duplicated tensor name {name!r}')
332
+
333
+ if raw_dtype is None:
334
+ if tensor_dtype == np.float16:
335
+ dtype = GGMLQuantizationType.F16
336
+ elif tensor_dtype == np.float32:
337
+ dtype = GGMLQuantizationType.F32
338
+ elif tensor_dtype == np.float64:
339
+ dtype = GGMLQuantizationType.F64
340
+ elif tensor_dtype == np.int8:
341
+ dtype = GGMLQuantizationType.I8
342
+ elif tensor_dtype == np.int16:
343
+ dtype = GGMLQuantizationType.I16
344
+ elif tensor_dtype == np.int32:
345
+ dtype = GGMLQuantizationType.I32
346
+ elif tensor_dtype == np.int64:
347
+ dtype = GGMLQuantizationType.I64
348
+ else:
349
+ raise ValueError("Only F16, F32, F64, I8, I16, I32, I64 tensors are supported for now")
350
+ else:
351
+ dtype = raw_dtype
352
+ if tensor_dtype == np.uint8:
353
+ tensor_shape = quant_shape_from_byte_shape(tensor_shape, raw_dtype)
354
+
355
+ # make sure there is at least one tensor before splitting
356
+ if len(self.tensors[-1]) > 0:
357
+ if ( # split when over tensor limit
358
+ self.split_max_tensors != 0
359
+ and len(self.tensors[-1]) >= self.split_max_tensors
360
+ ) or ( # split when over size limit
361
+ self.split_max_size != 0
362
+ and sum(ti.nbytes for ti in self.tensors[-1].values()) + tensor_nbytes > self.split_max_size
363
+ ):
364
+ self.tensors.append({})
365
+
366
+ self.tensors[-1][name] = TensorInfo(shape=tensor_shape, dtype=dtype, nbytes=tensor_nbytes)
367
+
368
+ def add_tensor(
369
+ self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None,
370
+ raw_dtype: GGMLQuantizationType | None = None,
371
+ ) -> None:
372
+ if self.endianess == GGUFEndian.BIG:
373
+ tensor.byteswap(inplace=True)
374
+ if self.use_temp_file and self.temp_file is None:
375
+ fp = tempfile.SpooledTemporaryFile(mode="w+b", max_size=256 * 1024 * 1024)
376
+ fp.seek(0)
377
+ self.temp_file = fp
378
+
379
+ shape: Sequence[int] = raw_shape if raw_shape is not None else tensor.shape
380
+ self.add_tensor_info(name, shape, tensor.dtype, tensor.nbytes, raw_dtype=raw_dtype)
381
+
382
+ if self.temp_file is None:
383
+ self.tensors[-1][name].tensor = tensor
384
+ return
385
+
386
+ tensor.tofile(self.temp_file)
387
+ self.write_padding(self.temp_file, tensor.nbytes)
388
+
389
+ def write_padding(self, fp: IO[bytes], n: int, align: int | None = None) -> None:
390
+ pad = GGUFWriter.ggml_pad(n, align if align is not None else self.data_alignment) - n
391
+ if pad != 0:
392
+ fp.write(bytes([0] * pad))
393
+
394
+ def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None:
395
+ if self.state is not WriterState.TI_DATA and self.state is not WriterState.WEIGHTS:
396
+ raise ValueError(f'Expected output file to contain tensor info or weights, got {self.state}')
397
+ assert self.fout is not None
398
+
399
+ if self.endianess == GGUFEndian.BIG:
400
+ tensor.byteswap(inplace=True)
401
+
402
+ file_id = -1
403
+ for i, tensors in enumerate(self.tensors):
404
+ if len(tensors) > 0:
405
+ file_id = i
406
+ break
407
+
408
+ fout = self.fout[file_id]
409
+
410
+ # pop the first tensor info
411
+ # TODO: cleaner way to get the first key
412
+ first_tensor_name = [name for name, _ in zip(self.tensors[file_id].keys(), range(1))][0]
413
+ ti = self.tensors[file_id].pop(first_tensor_name)
414
+ assert ti.nbytes == tensor.nbytes
415
+
416
+ self.write_padding(fout, fout.tell())
417
+ tensor.tofile(fout)
418
+ self.write_padding(fout, tensor.nbytes)
419
+
420
+ self.state = WriterState.WEIGHTS
421
+
422
+ def write_tensors_to_file(self, *, progress: bool = False) -> None:
423
+ self.write_ti_data_to_file()
424
+
425
+ assert self.fout is not None
426
+
427
+ for fout in self.fout:
428
+ self.write_padding(fout, fout.tell())
429
+
430
+ if self.temp_file is None:
431
+ shard_bar = None
432
+ bar = None
433
+
434
+ if progress:
435
+ from tqdm import tqdm
436
+
437
+ total_bytes = sum(ti.nbytes for t in self.tensors for ti in t.values())
438
+
439
+ if len(self.fout) > 1:
440
+ shard_bar = tqdm(desc=f"Shard (0/{len(self.fout)})", total=None, unit="byte", unit_scale=True)
441
+ bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True)
442
+
443
+ for i, (fout, tensors) in enumerate(zip(self.fout, self.tensors)):
444
+ if shard_bar is not None:
445
+ shard_bar.set_description(f"Shard ({i + 1}/{len(self.fout)})")
446
+ total = sum(ti.nbytes for ti in tensors.values())
447
+ shard_bar.reset(total=(total if total > 0 else None))
448
+
449
+ # relying on the fact that Python dicts preserve insertion order (since 3.7)
450
+ for ti in tensors.values():
451
+ assert ti.tensor is not None # can only iterate once over the tensors
452
+ assert ti.tensor.nbytes == ti.nbytes
453
+ ti.tensor.tofile(fout)
454
+ if shard_bar is not None:
455
+ shard_bar.update(ti.nbytes)
456
+ if bar is not None:
457
+ bar.update(ti.nbytes)
458
+ self.write_padding(fout, ti.nbytes)
459
+ ti.tensor = None
460
+ else:
461
+ self.temp_file.seek(0)
462
+
463
+ shutil.copyfileobj(self.temp_file, self.fout[0 if not self.small_first_shard else 1])
464
+ self.flush()
465
+ self.temp_file.close()
466
+
467
+ self.state = WriterState.WEIGHTS
468
+
469
+ def flush(self) -> None:
470
+ assert self.fout is not None
471
+ for fout in self.fout:
472
+ fout.flush()
473
+
474
+ def close(self) -> None:
475
+ if self.fout is not None:
476
+ for fout in self.fout:
477
+ fout.close()
478
+ self.fout = None
479
+
480
+ def add_type(self, type_name: str) -> None:
481
+ self.add_string(Keys.General.TYPE, type_name)
482
+
483
+ def add_architecture(self) -> None:
484
+ self.add_string(Keys.General.ARCHITECTURE, self.arch)
485
+
486
+ def add_quantization_version(self, quantization_version: int) -> None:
487
+ self.add_uint32(Keys.General.QUANTIZATION_VERSION, quantization_version)
488
+
489
+ def add_custom_alignment(self, alignment: int) -> None:
490
+ self.data_alignment = alignment
491
+ self.add_uint32(Keys.General.ALIGNMENT, alignment)
492
+
493
+ def add_file_type(self, ftype: int) -> None:
494
+ self.add_uint32(Keys.General.FILE_TYPE, ftype)
495
+
496
+ def add_name(self, name: str) -> None:
497
+ self.add_string(Keys.General.NAME, name)
498
+
499
+ def add_author(self, author: str) -> None:
500
+ self.add_string(Keys.General.AUTHOR, author)
501
+
502
+ def add_version(self, version: str) -> None:
503
+ self.add_string(Keys.General.VERSION, version)
504
+
505
+ def add_organization(self, organization: str) -> None:
506
+ self.add_string(Keys.General.ORGANIZATION, organization)
507
+
508
+ def add_finetune(self, finetune: str) -> None:
509
+ self.add_string(Keys.General.FINETUNE, finetune)
510
+
511
+ def add_basename(self, basename: str) -> None:
512
+ self.add_string(Keys.General.BASENAME, basename)
513
+
514
+ def add_description(self, description: str) -> None:
515
+ self.add_string(Keys.General.DESCRIPTION, description)
516
+
517
+ def add_quantized_by(self, quantized: str) -> None:
518
+ self.add_string(Keys.General.QUANTIZED_BY, quantized)
519
+
520
+ def add_size_label(self, size_label: str) -> None:
521
+ self.add_string(Keys.General.SIZE_LABEL, size_label)
522
+
523
+ def add_license(self, license: str) -> None:
524
+ self.add_string(Keys.General.LICENSE, license)
525
+
526
+ def add_license_name(self, license: str) -> None:
527
+ self.add_string(Keys.General.LICENSE_NAME, license)
528
+
529
+ def add_license_link(self, license: str) -> None:
530
+ self.add_string(Keys.General.LICENSE_LINK, license)
531
+
532
+ def add_url(self, url: str) -> None:
533
+ self.add_string(Keys.General.URL, url)
534
+
535
+ def add_doi(self, doi: str) -> None:
536
+ self.add_string(Keys.General.DOI, doi)
537
+
538
+ def add_uuid(self, uuid: str) -> None:
539
+ self.add_string(Keys.General.UUID, uuid)
540
+
541
+ def add_repo_url(self, repo_url: str) -> None:
542
+ self.add_string(Keys.General.REPO_URL, repo_url)
543
+
544
+ def add_source_url(self, url: str) -> None:
545
+ self.add_string(Keys.General.SOURCE_URL, url)
546
+
547
+ def add_source_doi(self, doi: str) -> None:
548
+ self.add_string(Keys.General.SOURCE_DOI, doi)
549
+
550
+ def add_source_uuid(self, uuid: str) -> None:
551
+ self.add_string(Keys.General.SOURCE_UUID, uuid)
552
+
553
+ def add_source_repo_url(self, repo_url: str) -> None:
554
+ self.add_string(Keys.General.SOURCE_REPO_URL, repo_url)
555
+
556
+ def add_base_model_count(self, source_count: int) -> None:
557
+ self.add_uint32(Keys.General.BASE_MODEL_COUNT, source_count)
558
+
559
+ def add_base_model_name(self, source_id: int, name: str) -> None:
560
+ self.add_string(Keys.General.BASE_MODEL_NAME.format(id=source_id), name)
561
+
562
+ def add_base_model_author(self, source_id: int, author: str) -> None:
563
+ self.add_string(Keys.General.BASE_MODEL_AUTHOR.format(id=source_id), author)
564
+
565
+ def add_base_model_version(self, source_id: int, version: str) -> None:
566
+ self.add_string(Keys.General.BASE_MODEL_VERSION.format(id=source_id), version)
567
+
568
+ def add_base_model_organization(self, source_id: int, organization: str) -> None:
569
+ self.add_string(Keys.General.BASE_MODEL_ORGANIZATION.format(id=source_id), organization)
570
+
571
+ def add_base_model_url(self, source_id: int, url: str) -> None:
572
+ self.add_string(Keys.General.BASE_MODEL_URL.format(id=source_id), url)
573
+
574
+ def add_base_model_doi(self, source_id: int, doi: str) -> None:
575
+ self.add_string(Keys.General.BASE_MODEL_DOI.format(id=source_id), doi)
576
+
577
+ def add_base_model_uuid(self, source_id: int, uuid: str) -> None:
578
+ self.add_string(Keys.General.BASE_MODEL_UUID.format(id=source_id), uuid)
579
+
580
+ def add_base_model_repo_url(self, source_id: int, repo_url: str) -> None:
581
+ self.add_string(Keys.General.BASE_MODEL_REPO_URL.format(id=source_id), repo_url)
582
+
583
+ def add_tags(self, tags: Sequence[str]) -> None:
584
+ self.add_array(Keys.General.TAGS, tags)
585
+
586
+ def add_languages(self, languages: Sequence[str]) -> None:
587
+ self.add_array(Keys.General.LANGUAGES, languages)
588
+
589
+ def add_datasets(self, datasets: Sequence[str]) -> None:
590
+ self.add_array(Keys.General.DATASETS, datasets)
591
+
592
+ def add_tensor_data_layout(self, layout: str) -> None:
593
+ self.add_string(Keys.LLM.TENSOR_DATA_LAYOUT.format(arch=self.arch), layout)
594
+
595
+ def add_vocab_size(self, size: int) -> None:
596
+ self.add_uint32(Keys.LLM.VOCAB_SIZE.format(arch=self.arch), size)
597
+
598
+ def add_context_length(self, length: int) -> None:
599
+ self.add_uint32(Keys.LLM.CONTEXT_LENGTH.format(arch=self.arch), length)
600
+
601
+ def add_embedding_length(self, length: int) -> None:
602
+ self.add_uint32(Keys.LLM.EMBEDDING_LENGTH.format(arch=self.arch), length)
603
+
604
+ def add_block_count(self, length: int) -> None:
605
+ self.add_uint32(Keys.LLM.BLOCK_COUNT.format(arch=self.arch), length)
606
+
607
+ def add_leading_dense_block_count(self, length: int) -> None:
608
+ self.add_uint32(Keys.LLM.LEADING_DENSE_BLOCK_COUNT.format(arch=self.arch), length)
609
+
610
+ def add_feed_forward_length(self, length: int | Sequence[int]) -> None:
611
+ if isinstance(length, int):
612
+ self.add_uint32(Keys.LLM.FEED_FORWARD_LENGTH.format(arch=self.arch), length)
613
+ else:
614
+ self.add_array(Keys.LLM.FEED_FORWARD_LENGTH.format(arch=self.arch), length)
615
+
616
+ def add_expert_feed_forward_length(self, length: int) -> None:
617
+ self.add_uint32(Keys.LLM.EXPERT_FEED_FORWARD_LENGTH.format(arch=self.arch), length)
618
+
619
+ def add_expert_shared_feed_forward_length(self, length: int) -> None:
620
+ self.add_uint32(Keys.LLM.EXPERT_SHARED_FEED_FORWARD_LENGTH.format(arch=self.arch), length)
621
+
622
+ def add_parallel_residual(self, use: bool) -> None:
623
+ self.add_bool(Keys.LLM.USE_PARALLEL_RESIDUAL.format(arch=self.arch), use)
624
+
625
+ def add_decoder_start_token_id(self, id: int) -> None:
626
+ self.add_uint32(Keys.LLM.DECODER_START_TOKEN_ID.format(arch=self.arch), id)
627
+
628
+ def add_head_count(self, count: int | Sequence[int]) -> None:
629
+ if isinstance(count, int):
630
+ self.add_uint32(Keys.Attention.HEAD_COUNT.format(arch=self.arch), count)
631
+ else:
632
+ self.add_array(Keys.Attention.HEAD_COUNT.format(arch=self.arch), count)
633
+
634
+ def add_head_count_kv(self, count: int | Sequence[int]) -> None:
635
+ if isinstance(count, int):
636
+ self.add_uint32(Keys.Attention.HEAD_COUNT_KV.format(arch=self.arch), count)
637
+ else:
638
+ self.add_array(Keys.Attention.HEAD_COUNT_KV.format(arch=self.arch), count)
639
+
640
+ def add_key_length(self, length: int) -> None:
641
+ self.add_uint32(Keys.Attention.KEY_LENGTH.format(arch=self.arch), length)
642
+
643
+ def add_value_length(self, length: int) -> None:
644
+ self.add_uint32(Keys.Attention.VALUE_LENGTH.format(arch=self.arch), length)
645
+
646
+ def add_max_alibi_bias(self, bias: float) -> None:
647
+ self.add_float32(Keys.Attention.MAX_ALIBI_BIAS.format(arch=self.arch), bias)
648
+
649
+ def add_clamp_kqv(self, value: float) -> None:
650
+ self.add_float32(Keys.Attention.CLAMP_KQV.format(arch=self.arch), value)
651
+
652
+ def add_logit_scale(self, value: float) -> None:
653
+ self.add_float32(Keys.LLM.LOGIT_SCALE.format(arch=self.arch), value)
654
+
655
+ def add_attn_logit_softcapping(self, value: float) -> None:
656
+ self.add_float32(Keys.LLM.ATTN_LOGIT_SOFTCAPPING.format(arch=self.arch), value)
657
+
658
+ def add_final_logit_softcapping(self, value: float) -> None:
659
+ self.add_float32(Keys.LLM.FINAL_LOGIT_SOFTCAPPING.format(arch=self.arch), value)
660
+
661
+ def add_expert_count(self, count: int) -> None:
662
+ self.add_uint32(Keys.LLM.EXPERT_COUNT.format(arch=self.arch), count)
663
+
664
+ def add_expert_used_count(self, count: int) -> None:
665
+ self.add_uint32(Keys.LLM.EXPERT_USED_COUNT.format(arch=self.arch), count)
666
+
667
+ def add_expert_shared_count(self, count: int) -> None:
668
+ self.add_uint32(Keys.LLM.EXPERT_SHARED_COUNT.format(arch=self.arch), count)
669
+
670
+ def add_expert_weights_scale(self, value: float) -> None:
671
+ self.add_float32(Keys.LLM.EXPERT_WEIGHTS_SCALE.format(arch=self.arch), value)
672
+
673
+ def add_swin_norm(self, value: bool) -> None:
674
+ self.add_bool(Keys.LLM.SWIN_NORM.format(arch=self.arch), value)
675
+
676
+ def add_rescale_every_n_layers(self, count: int) -> None:
677
+ self.add_uint32(Keys.LLM.RESCALE_EVERY_N_LAYERS.format(arch=self.arch), count)
678
+
679
+ def add_time_mix_extra_dim(self, dim: int) -> None:
680
+ self.add_uint32(Keys.LLM.TIME_MIX_EXTRA_DIM.format(arch=self.arch), dim)
681
+
682
+ def add_time_decay_extra_dim(self, dim: int) -> None:
683
+ self.add_uint32(Keys.LLM.TIME_DECAY_EXTRA_DIM.format(arch=self.arch), dim)
684
+
685
+ def add_residual_scale(self, value: float) -> None:
686
+ self.add_float32(Keys.LLM.RESIDUAL_SCALE.format(arch=self.arch), value)
687
+
688
+ def add_embedding_scale(self, value: float) -> None:
689
+ self.add_float32(Keys.LLM.EMBEDDING_SCALE.format(arch=self.arch), value)
690
+
691
+ def add_wkv_head_size(self, size: int) -> None:
692
+ self.add_uint32(Keys.WKV.HEAD_SIZE.format(arch=self.arch), size)
693
+
694
+ def add_layer_norm_eps(self, value: float) -> None:
695
+ self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value)
696
+
697
+ def add_layer_norm_rms_eps(self, value: float) -> None:
698
+ self.add_float32(Keys.Attention.LAYERNORM_RMS_EPS.format(arch=self.arch), value)
699
+
700
+ def add_causal_attention(self, value: bool) -> None:
701
+ self.add_bool(Keys.Attention.CAUSAL.format(arch=self.arch), value)
702
+
703
+ def add_q_lora_rank(self, length: int) -> None:
704
+ self.add_uint32(Keys.Attention.Q_LORA_RANK.format(arch=self.arch), length)
705
+
706
+ def add_kv_lora_rank(self, length: int) -> None:
707
+ self.add_uint32(Keys.Attention.KV_LORA_RANK.format(arch=self.arch), length)
708
+
709
+ def add_relative_attn_buckets_count(self, value: int) -> None:
710
+ self.add_uint32(Keys.Attention.REL_BUCKETS_COUNT.format(arch=self.arch), value)
711
+
712
+ def add_sliding_window(self, value: int) -> None:
713
+ self.add_uint32(Keys.Attention.SLIDING_WINDOW.format(arch=self.arch), value)
714
+
715
+ def add_attention_scale(self, value: float) -> None:
716
+ self.add_float32(Keys.Attention.SCALE.format(arch=self.arch), value)
717
+
718
+ def add_pooling_type(self, value: PoolingType) -> None:
719
+ self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value)
720
+
721
+ def add_rope_dimension_count(self, count: int) -> None:
722
+ self.add_uint32(Keys.Rope.DIMENSION_COUNT.format(arch=self.arch), count)
723
+
724
+ def add_rope_freq_base(self, value: float) -> None:
725
+ self.add_float32(Keys.Rope.FREQ_BASE.format(arch=self.arch), value)
726
+
727
+ def add_rope_scaling_type(self, value: RopeScalingType) -> None:
728
+ self.add_string(Keys.Rope.SCALING_TYPE.format(arch=self.arch), value.value)
729
+
730
+ def add_rope_scaling_factor(self, value: float) -> None:
731
+ self.add_float32(Keys.Rope.SCALING_FACTOR.format(arch=self.arch), value)
732
+
733
+ def add_rope_scaling_attn_factors(self, value: float) -> None:
734
+ self.add_float32(Keys.Rope.SCALING_ATTN_FACTOR.format(arch=self.arch), value)
735
+
736
+ def add_rope_scaling_orig_ctx_len(self, value: int) -> None:
737
+ self.add_uint32(Keys.Rope.SCALING_ORIG_CTX_LEN.format(arch=self.arch), value)
738
+
739
+ def add_rope_scaling_finetuned(self, value: bool) -> None:
740
+ self.add_bool(Keys.Rope.SCALING_FINETUNED.format(arch=self.arch), value)
741
+
742
+ def add_rope_scaling_yarn_log_mul(self, value: float) -> None:
743
+ self.add_float32(Keys.Rope.SCALING_YARN_LOG_MUL.format(arch=self.arch), value)
744
+
745
+ def add_ssm_conv_kernel(self, value: int) -> None:
746
+ self.add_uint32(Keys.SSM.CONV_KERNEL.format(arch=self.arch), value)
747
+
748
+ def add_ssm_inner_size(self, value: int) -> None:
749
+ self.add_uint32(Keys.SSM.INNER_SIZE.format(arch=self.arch), value)
750
+
751
+ def add_ssm_state_size(self, value: int) -> None:
752
+ self.add_uint32(Keys.SSM.STATE_SIZE.format(arch=self.arch), value)
753
+
754
+ def add_ssm_time_step_rank(self, value: int) -> None:
755
+ self.add_uint32(Keys.SSM.TIME_STEP_RANK.format(arch=self.arch), value)
756
+
757
+ def add_ssm_dt_b_c_rms(self, value: bool) -> None:
758
+ self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value)
759
+
760
+ def add_tokenizer_model(self, model: str) -> None:
761
+ self.add_string(Keys.Tokenizer.MODEL, model)
762
+
763
+ def add_tokenizer_pre(self, pre: str) -> None:
764
+ self.add_string(Keys.Tokenizer.PRE, pre)
765
+
766
+ def add_token_list(self, tokens: Sequence[str] | Sequence[bytes] | Sequence[bytearray]) -> None:
767
+ self.add_array(Keys.Tokenizer.LIST, tokens)
768
+
769
+ def add_token_merges(self, merges: Sequence[str] | Sequence[bytes] | Sequence[bytearray]) -> None:
770
+ self.add_array(Keys.Tokenizer.MERGES, merges)
771
+
772
+ def add_token_types(self, types: Sequence[TokenType] | Sequence[int]) -> None:
773
+ self.add_array(Keys.Tokenizer.TOKEN_TYPE, types)
774
+
775
+ def add_token_type_count(self, value: int) -> None:
776
+ self.add_uint32(Keys.Tokenizer.TOKEN_TYPE_COUNT, value)
777
+
778
+ def add_token_scores(self, scores: Sequence[float]) -> None:
779
+ self.add_array(Keys.Tokenizer.SCORES, scores)
780
+
781
+ def add_bos_token_id(self, id: int) -> None:
782
+ self.add_uint32(Keys.Tokenizer.BOS_ID, id)
783
+
784
+ def add_eos_token_id(self, id: int) -> None:
785
+ self.add_uint32(Keys.Tokenizer.EOS_ID, id)
786
+
787
+ def add_unk_token_id(self, id: int) -> None:
788
+ self.add_uint32(Keys.Tokenizer.UNK_ID, id)
789
+
790
+ def add_sep_token_id(self, id: int) -> None:
791
+ self.add_uint32(Keys.Tokenizer.SEP_ID, id)
792
+
793
+ def add_pad_token_id(self, id: int) -> None:
794
+ self.add_uint32(Keys.Tokenizer.PAD_ID, id)
795
+
796
+ def add_cls_token_id(self, id: int) -> None:
797
+ self.add_uint32(Keys.Tokenizer.CLS_ID, id)
798
+
799
+ def add_mask_token_id(self, id: int) -> None:
800
+ self.add_uint32(Keys.Tokenizer.MASK_ID, id)
801
+
802
+ def add_add_bos_token(self, value: bool) -> None:
803
+ self.add_bool(Keys.Tokenizer.ADD_BOS, value)
804
+
805
+ def add_add_eos_token(self, value: bool) -> None:
806
+ self.add_bool(Keys.Tokenizer.ADD_EOS, value)
807
+
808
+ def add_add_space_prefix(self, value: bool) -> None:
809
+ self.add_bool(Keys.Tokenizer.ADD_PREFIX, value)
810
+
811
+ def add_remove_extra_whitespaces(self, value: bool) -> None:
812
+ self.add_bool(Keys.Tokenizer.REMOVE_EXTRA_WS, value)
813
+
814
+ def add_precompiled_charsmap(self, charsmap: Sequence[bytes]) -> None:
815
+ self.add_array(Keys.Tokenizer.PRECOMPILED_CHARSMAP, charsmap)
816
+
817
+ def add_chat_template(self, value: str | Sequence[Mapping[str, str]]) -> None:
818
+ if not isinstance(value, str):
819
+ template_default = None
820
+ template_names = set()
821
+
822
+ for choice in value:
823
+ name = choice.get('name', '')
824
+ template = choice.get('template')
825
+
826
+ # Allowing non-alphanumerical characters in template name is probably not a good idea, so filter it
827
+ name = ''.join((c if c in ascii_letters + digits else '_' for c in name))
828
+
829
+ if name and template is not None:
830
+ if name == 'default':
831
+ template_default = template
832
+ else:
833
+ template_names.add(name)
834
+ self.add_string(Keys.Tokenizer.CHAT_TEMPLATE_N.format(name=name), template)
835
+
836
+ if template_names:
837
+ self.add_array(Keys.Tokenizer.CHAT_TEMPLATES, list(template_names))
838
+
839
+ if template_default is None:
840
+ return
841
+
842
+ value = template_default
843
+
844
+ self.add_string(Keys.Tokenizer.CHAT_TEMPLATE, value)
845
+
846
+ def add_eot_token_id(self, id: int) -> None:
847
+ self.add_uint32(Keys.Tokenizer.EOT_ID, id)
848
+
849
+ def add_eom_token_id(self, id: int) -> None:
850
+ self.add_uint32(Keys.Tokenizer.EOM_ID, id)
851
+
852
+ def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes:
853
+ pack_prefix = ''
854
+ if not skip_pack_prefix:
855
+ pack_prefix = '<' if self.endianess == GGUFEndian.LITTLE else '>'
856
+ return struct.pack(f'{pack_prefix}{fmt}', value)
857
+
858
+ def _pack_val(self, val: Any, vtype: GGUFValueType, add_vtype: bool) -> bytes:
859
+ kv_data = bytearray()
860
+
861
+ if add_vtype:
862
+ kv_data += self._pack("I", vtype)
863
+
864
+ pack_fmt = self._simple_value_packing.get(vtype)
865
+ if pack_fmt is not None:
866
+ kv_data += self._pack(pack_fmt, val, skip_pack_prefix = vtype == GGUFValueType.BOOL)
867
+ elif vtype == GGUFValueType.STRING:
868
+ encoded_val = val.encode("utf-8") if isinstance(val, str) else val
869
+ kv_data += self._pack("Q", len(encoded_val))
870
+ kv_data += encoded_val
871
+ elif vtype == GGUFValueType.ARRAY:
872
+
873
+ if not isinstance(val, Sequence):
874
+ raise ValueError("Invalid GGUF metadata array, expecting sequence")
875
+
876
+ if len(val) == 0:
877
+ raise ValueError("Invalid GGUF metadata array. Empty array")
878
+
879
+ if isinstance(val, bytes):
880
+ ltype = GGUFValueType.UINT8
881
+ else:
882
+ ltype = GGUFValueType.get_type(val[0])
883
+ if not all(GGUFValueType.get_type(i) is ltype for i in val[1:]):
884
+ raise ValueError("All items in a GGUF array should be of the same type")
885
+ kv_data += self._pack("I", ltype)
886
+ kv_data += self._pack("Q", len(val))
887
+ for item in val:
888
+ kv_data += self._pack_val(item, ltype, add_vtype=False)
889
+ else:
890
+ raise ValueError("Invalid GGUF metadata value type or value")
891
+
892
+ return kv_data
893
+
894
+ @staticmethod
895
+ def format_n_bytes_to_str(num: int) -> str:
896
+ if num == 0:
897
+ return "negligible - metadata only"
898
+ fnum = float(num)
899
+ for unit in ("", "K", "M", "G"):
900
+ if abs(fnum) < 1000.0:
901
+ return f"{fnum:3.1f}{unit}"
902
+ fnum /= 1000.0
903
+ return f"{fnum:.1f}T - over 1TB, split recommended"
gguf/lazy.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from abc import ABC, ABCMeta, abstractmethod
3
+
4
+ import logging
5
+ from typing import Any, Callable
6
+
7
+ import numpy as np
8
+ from numpy.typing import DTypeLike
9
+
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class LazyMeta(ABCMeta):
15
+
16
+ def __new__(cls, name: str, bases: tuple[type, ...], namespace: dict[str, Any], **kwargs):
17
+ def __getattr__(self, name: str) -> Any:
18
+ meta_attr = getattr(self._meta, name)
19
+ if callable(meta_attr):
20
+ return type(self)._wrap_fn(
21
+ (lambda s, *args, **kwargs: getattr(s, name)(*args, **kwargs)),
22
+ use_self=self,
23
+ )
24
+ elif isinstance(meta_attr, self._tensor_type):
25
+ # e.g. self.T with torch.Tensor should still be wrapped
26
+ return type(self)._wrap_fn(lambda s: getattr(s, name))(self)
27
+ else:
28
+ # no need to wrap non-tensor properties,
29
+ # and they likely don't depend on the actual contents of the tensor
30
+ return meta_attr
31
+
32
+ namespace["__getattr__"] = __getattr__
33
+
34
+ # need to make a builder for the wrapped wrapper to copy the name,
35
+ # or else it fails with very cryptic error messages,
36
+ # because somehow the same string would end up in every closures
37
+ def mk_wrap(op_name: str, *, meta_noop: bool = False):
38
+ # need to wrap the wrapper to get self
39
+ def wrapped_special_op(self, *args, **kwargs):
40
+ return type(self)._wrap_fn(
41
+ getattr(type(self)._tensor_type, op_name),
42
+ meta_noop=meta_noop,
43
+ )(self, *args, **kwargs)
44
+ return wrapped_special_op
45
+
46
+ # special methods bypass __getattr__, so they need to be added manually
47
+ # ref: https://docs.python.org/3/reference/datamodel.html#special-lookup
48
+ # NOTE: doing this from a metaclass is very convenient
49
+ # TODO: make this even more comprehensive
50
+ for binary_op in (
51
+ "lt", "le", "eq", "ne", "ge", "gt", "not"
52
+ "abs", "add", "and", "floordiv", "invert", "lshift", "mod", "mul", "matmul",
53
+ "neg", "or", "pos", "pow", "rshift", "sub", "truediv", "xor",
54
+ "iadd", "iand", "ifloordiv", "ilshift", "imod", "imul", "ior", "irshift", "isub", "ixor",
55
+ "radd", "rand", "rfloordiv", "rmul", "ror", "rpow", "rsub", "rtruediv", "rxor",
56
+ ):
57
+ attr_name = f"__{binary_op}__"
58
+ # the result of these operators usually has the same shape and dtype as the input,
59
+ # so evaluation on the meta tensor can be skipped.
60
+ namespace[attr_name] = mk_wrap(attr_name, meta_noop=True)
61
+
62
+ for special_op in (
63
+ "getitem", "setitem", "len",
64
+ ):
65
+ attr_name = f"__{special_op}__"
66
+ namespace[attr_name] = mk_wrap(attr_name, meta_noop=False)
67
+
68
+ return super().__new__(cls, name, bases, namespace, **kwargs)
69
+
70
+
71
+ # Tree of lazy tensors
72
+ class LazyBase(ABC, metaclass=LazyMeta):
73
+ _tensor_type: type
74
+ _meta: Any
75
+ _data: Any | None
76
+ _args: tuple
77
+ _kwargs: dict[str, Any]
78
+ _func: Callable[[Any], Any] | None
79
+
80
+ def __init__(self, *, meta: Any, data: Any | None = None, args: tuple = (), kwargs: dict[str, Any] | None = None, func: Callable[[Any], Any] | None = None):
81
+ super().__init__()
82
+ self._meta = meta
83
+ self._data = data
84
+ self._args = args
85
+ self._kwargs = kwargs if kwargs is not None else {}
86
+ self._func = func
87
+ assert self._func is not None or self._data is not None
88
+
89
+ def __init_subclass__(cls) -> None:
90
+ if "_tensor_type" not in cls.__dict__:
91
+ raise TypeError(f"property '_tensor_type' must be defined for {cls!r}")
92
+ return super().__init_subclass__()
93
+
94
+ @staticmethod
95
+ def _recurse_apply(o: Any, fn: Callable[[Any], Any]) -> Any:
96
+ # TODO: dict and set
97
+ if isinstance(o, (list, tuple)):
98
+ L = []
99
+ for item in o:
100
+ L.append(LazyBase._recurse_apply(item, fn))
101
+ if isinstance(o, tuple):
102
+ L = tuple(L)
103
+ return L
104
+ elif isinstance(o, LazyBase):
105
+ return fn(o)
106
+ else:
107
+ return o
108
+
109
+ @classmethod
110
+ def _wrap_fn(cls, fn: Callable, *, use_self: LazyBase | None = None, meta_noop: bool | DTypeLike | tuple[DTypeLike, Callable[[tuple[int, ...]], tuple[int, ...]]] = False) -> Callable[[Any], Any]:
111
+ def wrapped_fn(*args, **kwargs):
112
+ if kwargs is None:
113
+ kwargs = {}
114
+ args = ((use_self,) if use_self is not None else ()) + args
115
+
116
+ meta_args = LazyBase._recurse_apply(args, lambda t: t._meta)
117
+ # TODO: maybe handle tensors in kwargs too
118
+
119
+ if isinstance(meta_noop, bool) and not meta_noop:
120
+ try:
121
+ res = fn(*meta_args, **kwargs)
122
+ except NotImplementedError:
123
+ # running some operations on PyTorch's Meta tensors can cause this exception
124
+ res = None
125
+ else:
126
+ # some operators don't need to actually run on the meta tensors
127
+ assert len(args) > 0
128
+ res = args[0]
129
+ assert isinstance(res, cls)
130
+ res = res._meta
131
+ # allow operations to override the dtype and shape
132
+ if meta_noop is not True:
133
+ if isinstance(meta_noop, tuple):
134
+ dtype, shape = meta_noop
135
+ assert callable(shape)
136
+ res = cls.meta_with_dtype_and_shape(dtype, shape(res.shape))
137
+ else:
138
+ res = cls.meta_with_dtype_and_shape(meta_noop, res.shape)
139
+
140
+ if isinstance(res, cls._tensor_type):
141
+ return cls(meta=cls.eager_to_meta(res), args=args, kwargs=kwargs, func=fn)
142
+ else:
143
+ del res # not needed
144
+ # non-tensor return likely relies on the contents of the args
145
+ # (e.g. the result of torch.equal)
146
+ eager_args = cls.to_eager(args)
147
+ return fn(*eager_args, **kwargs)
148
+ return wrapped_fn
149
+
150
+ @classmethod
151
+ def to_eager(cls, t: Any) -> Any:
152
+ def simple_to_eager(_t: LazyBase) -> Any:
153
+ if _t._data is not None:
154
+ return _t._data
155
+
156
+ # NOTE: there's a recursion limit in Python (usually 1000)
157
+
158
+ assert _t._func is not None
159
+ _t._args = cls._recurse_apply(_t._args, simple_to_eager)
160
+ _t._data = _t._func(*_t._args, **_t._kwargs)
161
+ # sanity check
162
+ assert _t._data is not None
163
+ assert _t._data.dtype == _t._meta.dtype
164
+ assert _t._data.shape == _t._meta.shape
165
+
166
+ return _t._data
167
+
168
+ # recurse into lists and/or tuples, keeping their structure
169
+ return cls._recurse_apply(t, simple_to_eager)
170
+
171
+ @classmethod
172
+ def eager_to_meta(cls, t: Any) -> Any:
173
+ return cls.meta_with_dtype_and_shape(t.dtype, t.shape)
174
+
175
+ # must be overridden, meta tensor init is backend-specific
176
+ @classmethod
177
+ @abstractmethod
178
+ def meta_with_dtype_and_shape(cls, dtype: Any, shape: Any) -> Any: pass
179
+
180
+ @classmethod
181
+ def from_eager(cls, t: Any) -> Any:
182
+ if type(t) is cls:
183
+ # already lazy
184
+ return t
185
+ elif isinstance(t, cls._tensor_type):
186
+ return cls(meta=cls.eager_to_meta(t), data=t)
187
+ else:
188
+ return TypeError(f"{type(t)!r} is not compatible with {cls._tensor_type!r}")
189
+
190
+
191
+ class LazyNumpyTensor(LazyBase):
192
+ _tensor_type = np.ndarray
193
+
194
+ shape: tuple[int, ...] # Makes the type checker happy in quants.py
195
+
196
+ @classmethod
197
+ def meta_with_dtype_and_shape(cls, dtype: DTypeLike, shape: tuple[int, ...]) -> np.ndarray[Any, Any]:
198
+ # The initial idea was to use np.nan as the fill value,
199
+ # but non-float types like np.int16 can't use that.
200
+ # So zero it is.
201
+ cheat = np.zeros(1, dtype)
202
+ return np.lib.stride_tricks.as_strided(cheat, shape, (0 for _ in shape))
203
+
204
+ def astype(self, dtype, *args, **kwargs):
205
+ meta = type(self).meta_with_dtype_and_shape(dtype, self._meta.shape)
206
+ full_args = (self, dtype,) + args
207
+ return type(self)(meta=meta, args=full_args, kwargs=kwargs, func=(lambda a, *args, **kwargs: a.astype(*args, **kwargs)))
208
+
209
+ def tofile(self, *args, **kwargs):
210
+ eager = LazyNumpyTensor.to_eager(self)
211
+ return eager.tofile(*args, **kwargs)
212
+
213
+ # TODO: __array_function__
gguf/metadata.py ADDED
@@ -0,0 +1,510 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import re
4
+ import json
5
+ import yaml
6
+ import logging
7
+ from pathlib import Path
8
+ from typing import Any, Literal, Optional
9
+ from dataclasses import dataclass
10
+
11
+ from .constants import Keys
12
+
13
+ import gguf
14
+
15
+ logger = logging.getLogger("metadata")
16
+
17
+
18
+ @dataclass
19
+ class Metadata:
20
+ # Authorship Metadata to be written to GGUF KV Store
21
+ name: Optional[str] = None
22
+ author: Optional[str] = None
23
+ version: Optional[str] = None
24
+ organization: Optional[str] = None
25
+ finetune: Optional[str] = None
26
+ basename: Optional[str] = None
27
+ description: Optional[str] = None
28
+ quantized_by: Optional[str] = None
29
+ size_label: Optional[str] = None
30
+ url: Optional[str] = None
31
+ doi: Optional[str] = None
32
+ uuid: Optional[str] = None
33
+ repo_url: Optional[str] = None
34
+ source_url: Optional[str] = None
35
+ source_doi: Optional[str] = None
36
+ source_uuid: Optional[str] = None
37
+ source_repo_url: Optional[str] = None
38
+ license: Optional[str] = None
39
+ license_name: Optional[str] = None
40
+ license_link: Optional[str] = None
41
+ base_models: Optional[list[dict]] = None
42
+ tags: Optional[list[str]] = None
43
+ languages: Optional[list[str]] = None
44
+ datasets: Optional[list[str]] = None
45
+
46
+ @staticmethod
47
+ def load(metadata_override_path: Optional[Path] = None, model_path: Optional[Path] = None, model_name: Optional[str] = None, total_params: int = 0) -> Metadata:
48
+ # This grabs as many contextual authorship metadata as possible from the model repository
49
+ # making any conversion as required to match the gguf kv store metadata format
50
+ # as well as giving users the ability to override any authorship metadata that may be incorrect
51
+
52
+ # Create a new Metadata instance
53
+ metadata = Metadata()
54
+
55
+ model_card = Metadata.load_model_card(model_path)
56
+ hf_params = Metadata.load_hf_parameters(model_path)
57
+ # TODO: load adapter_config.json when possible, it usually contains the base model of the LoRA adapter
58
+
59
+ # heuristics
60
+ metadata = Metadata.apply_metadata_heuristic(metadata, model_card, hf_params, model_path, total_params)
61
+
62
+ # Metadata Override File Provided
63
+ # This is based on LLM_KV_NAMES mapping in llama.cpp
64
+ metadata_override = Metadata.load_metadata_override(metadata_override_path)
65
+
66
+ metadata.name = metadata_override.get(Keys.General.NAME, metadata.name)
67
+ metadata.author = metadata_override.get(Keys.General.AUTHOR, metadata.author)
68
+ metadata.version = metadata_override.get(Keys.General.VERSION, metadata.version)
69
+ metadata.organization = metadata_override.get(Keys.General.ORGANIZATION, metadata.organization)
70
+
71
+ metadata.finetune = metadata_override.get(Keys.General.FINETUNE, metadata.finetune)
72
+ metadata.basename = metadata_override.get(Keys.General.BASENAME, metadata.basename)
73
+
74
+ metadata.description = metadata_override.get(Keys.General.DESCRIPTION, metadata.description)
75
+ metadata.quantized_by = metadata_override.get(Keys.General.QUANTIZED_BY, metadata.quantized_by)
76
+
77
+ metadata.size_label = metadata_override.get(Keys.General.SIZE_LABEL, metadata.size_label)
78
+ metadata.license_name = metadata_override.get(Keys.General.LICENSE_NAME, metadata.license_name)
79
+ metadata.license_link = metadata_override.get(Keys.General.LICENSE_LINK, metadata.license_link)
80
+
81
+ metadata.url = metadata_override.get(Keys.General.URL, metadata.url)
82
+ metadata.doi = metadata_override.get(Keys.General.DOI, metadata.doi)
83
+ metadata.uuid = metadata_override.get(Keys.General.UUID, metadata.uuid)
84
+ metadata.repo_url = metadata_override.get(Keys.General.REPO_URL, metadata.repo_url)
85
+
86
+ metadata.source_url = metadata_override.get(Keys.General.SOURCE_URL, metadata.source_url)
87
+ metadata.source_doi = metadata_override.get(Keys.General.SOURCE_DOI, metadata.source_doi)
88
+ metadata.source_uuid = metadata_override.get(Keys.General.SOURCE_UUID, metadata.source_uuid)
89
+ metadata.source_repo_url = metadata_override.get(Keys.General.SOURCE_REPO_URL, metadata.source_repo_url)
90
+
91
+ # Base Models is received here as an array of models
92
+ metadata.base_models = metadata_override.get("general.base_models", metadata.base_models)
93
+
94
+ metadata.tags = metadata_override.get(Keys.General.TAGS, metadata.tags)
95
+ metadata.languages = metadata_override.get(Keys.General.LANGUAGES, metadata.languages)
96
+ metadata.datasets = metadata_override.get(Keys.General.DATASETS, metadata.datasets)
97
+
98
+ # Direct Metadata Override (via direct cli argument)
99
+ if model_name is not None:
100
+ metadata.name = model_name
101
+
102
+ return metadata
103
+
104
+ @staticmethod
105
+ def load_metadata_override(metadata_override_path: Optional[Path] = None) -> dict[str, Any]:
106
+ if metadata_override_path is None or not metadata_override_path.is_file():
107
+ return {}
108
+
109
+ with open(metadata_override_path, "r", encoding="utf-8") as f:
110
+ return json.load(f)
111
+
112
+ @staticmethod
113
+ def load_model_card(model_path: Optional[Path] = None) -> dict[str, Any]:
114
+ if model_path is None or not model_path.is_dir():
115
+ return {}
116
+
117
+ model_card_path = model_path / "README.md"
118
+
119
+ if not model_card_path.is_file():
120
+ return {}
121
+
122
+ # The model card metadata is assumed to always be in YAML
123
+ # ref: https://github.com/huggingface/transformers/blob/a5c642fe7a1f25d3bdcd76991443ba6ff7ee34b2/src/transformers/modelcard.py#L468-L473
124
+ with open(model_card_path, "r", encoding="utf-8") as f:
125
+ if f.readline() == "---\n":
126
+ raw = f.read().partition("---\n")[0]
127
+ data = yaml.safe_load(raw)
128
+ if isinstance(data, dict):
129
+ return data
130
+ else:
131
+ logger.error(f"while reading YAML model card frontmatter, data is {type(data)} instead of dict")
132
+ return {}
133
+ else:
134
+ return {}
135
+
136
+ @staticmethod
137
+ def load_hf_parameters(model_path: Optional[Path] = None) -> dict[str, Any]:
138
+ if model_path is None or not model_path.is_dir():
139
+ return {}
140
+
141
+ config_path = model_path / "config.json"
142
+
143
+ if not config_path.is_file():
144
+ return {}
145
+
146
+ with open(config_path, "r", encoding="utf-8") as f:
147
+ return json.load(f)
148
+
149
+ @staticmethod
150
+ def id_to_title(string):
151
+ # Convert capitalization into title form unless acronym or version number
152
+ return ' '.join([w.title() if w.islower() and not re.match(r'^(v\d+(?:\.\d+)*|\d.*)$', w) else w for w in string.strip().replace('-', ' ').split()])
153
+
154
+ @staticmethod
155
+ def get_model_id_components(model_id: Optional[str] = None, total_params: int = 0) -> tuple[str | None, str | None, str | None, str | None, str | None, str | None]:
156
+ # Huggingface often store model id as '<org>/<model name>'
157
+ # so let's parse it and apply some heuristics if possible for model name components
158
+
159
+ if model_id is None:
160
+ # model ID missing
161
+ return None, None, None, None, None, None
162
+
163
+ if ' ' in model_id:
164
+ # model ID is actually a normal human sentence
165
+ # which means its most likely a normal model name only
166
+ # not part of the hugging face naming standard, but whatever
167
+ return model_id, None, None, None, None, None
168
+
169
+ if '/' in model_id:
170
+ # model ID (huggingface style)
171
+ org_component, model_full_name_component = model_id.split('/', 1)
172
+ else:
173
+ # model ID but missing org components
174
+ org_component, model_full_name_component = None, model_id
175
+
176
+ # Check if we erroneously matched against './' or '../' etc...
177
+ if org_component is not None and len(org_component) > 0 and org_component[0] == '.':
178
+ org_component = None
179
+
180
+ name_parts: list[str] = model_full_name_component.split('-')
181
+
182
+ # Remove empty parts
183
+ for i in reversed(range(len(name_parts))):
184
+ if len(name_parts[i]) == 0:
185
+ del name_parts[i]
186
+
187
+ name_types: list[
188
+ set[Literal["basename", "size_label", "finetune", "version", "type"]]
189
+ ] = [set() for _ in name_parts]
190
+
191
+ # Annotate the name
192
+ for i, part in enumerate(name_parts):
193
+ # Version
194
+ if re.fullmatch(r'(v|iter)?\d+([.]\d+)*', part, re.IGNORECASE):
195
+ name_types[i].add("version")
196
+ # Quant type (should not be there for base models, but still annotated)
197
+ elif re.fullmatch(r'i?q\d(_\w)*|b?fp?(16|32)', part, re.IGNORECASE):
198
+ name_types[i].add("type")
199
+ name_parts[i] = part.upper()
200
+ # Model size
201
+ elif i > 0 and re.fullmatch(r'(([A]|\d+[x])?\d+([._]\d+)?[KMBT][\d]?|small|mini|medium|large|x?xl)', part, re.IGNORECASE):
202
+ part = part.replace("_", ".")
203
+ # Handle weird bloom-7b1 notation
204
+ if part[-1].isdecimal():
205
+ part = part[:-2] + "." + part[-1] + part[-2]
206
+ # Normalize the size suffixes
207
+ if len(part) > 1 and part[-2].isdecimal():
208
+ if part[-1] in "kmbt":
209
+ part = part[:-1] + part[-1].upper()
210
+ if total_params != 0:
211
+ try:
212
+ label_params = float(part[:-1]) * pow(1000, " KMBT".find(part[-1]))
213
+ # Only use it as a size label if it's close or bigger than the model size
214
+ # Note that LoRA adapters don't necessarily include all layers,
215
+ # so this is why bigger label sizes are accepted.
216
+ # Do not use the size label when it's smaller than 1/8 of the model size
217
+ if (total_params < 0 and label_params < abs(total_params) // 8) or (
218
+ # Check both directions when the current model isn't a LoRA adapter
219
+ total_params > 0 and abs(label_params - total_params) > 7 * total_params // 8
220
+ ):
221
+ # Likely a context length
222
+ name_types[i].add("finetune")
223
+ # Lowercase the size when it's a context length
224
+ part = part[:-1] + part[-1].lower()
225
+ except ValueError:
226
+ # Failed to convert the size label to float, use it anyway
227
+ pass
228
+ if len(name_types[i]) == 0:
229
+ name_types[i].add("size_label")
230
+ name_parts[i] = part
231
+ # Some easy to recognize finetune names
232
+ elif i > 0 and re.fullmatch(r'chat|instruct|vision|lora', part, re.IGNORECASE):
233
+ if total_params < 0 and part.lower() == "lora":
234
+ # ignore redundant "lora" in the finetune part when the output is a lora adapter
235
+ name_types[i].add("type")
236
+ else:
237
+ name_types[i].add("finetune")
238
+
239
+ # Ignore word-based size labels when there is at least a number-based one present
240
+ # TODO: should word-based size labels always be removed instead?
241
+ if any(c.isdecimal() for n, t in zip(name_parts, name_types) if "size_label" in t for c in n):
242
+ for n, t in zip(name_parts, name_types):
243
+ if "size_label" in t:
244
+ if all(c.isalpha() for c in n):
245
+ t.remove("size_label")
246
+
247
+ at_start = True
248
+ # Find the basename through the annotated name
249
+ for part, t in zip(name_parts, name_types):
250
+ if at_start and ((len(t) == 0 and part[0].isalpha()) or "version" in t):
251
+ t.add("basename")
252
+ else:
253
+ if at_start:
254
+ at_start = False
255
+ if len(t) == 0:
256
+ t.add("finetune")
257
+
258
+ # Remove the basename annotation from trailing version
259
+ for part, t in zip(reversed(name_parts), reversed(name_types)):
260
+ if "basename" in t and len(t) > 1:
261
+ t.remove("basename")
262
+ else:
263
+ break
264
+
265
+ basename = "-".join(n for n, t in zip(name_parts, name_types) if "basename" in t) or None
266
+ # Deduplicate size labels using order-preserving 'dict' ('set' seems to sort the keys)
267
+ size_label = "-".join(dict.fromkeys(s for s, t in zip(name_parts, name_types) if "size_label" in t).keys()) or None
268
+ finetune = "-".join(f for f, t in zip(name_parts, name_types) if "finetune" in t) or None
269
+ # TODO: should the basename version always be excluded?
270
+ # NOTE: multiple finetune versions are joined together
271
+ version = "-".join(v for v, t, in zip(name_parts, name_types) if "version" in t and "basename" not in t) or None
272
+
273
+ if size_label is None and finetune is None and version is None:
274
+ # Too ambiguous, output nothing
275
+ basename = None
276
+
277
+ return model_full_name_component, org_component, basename, finetune, version, size_label
278
+
279
+ @staticmethod
280
+ def apply_metadata_heuristic(metadata: Metadata, model_card: Optional[dict] = None, hf_params: Optional[dict] = None, model_path: Optional[Path] = None, total_params: int = 0) -> Metadata:
281
+ # Reference Model Card Metadata: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1
282
+
283
+ # Model Card Heuristics
284
+ ########################
285
+ if model_card is not None:
286
+
287
+ def use_model_card_metadata(metadata_key: str, model_card_key: str):
288
+ if model_card_key in model_card and getattr(metadata, metadata_key, None) is None:
289
+ setattr(metadata, metadata_key, model_card.get(model_card_key))
290
+
291
+ def use_array_model_card_metadata(metadata_key: str, model_card_key: str):
292
+ # Note: Will append rather than replace if already exist
293
+ tags_value = model_card.get(model_card_key, None)
294
+ if tags_value is None:
295
+ return
296
+
297
+ current_value = getattr(metadata, metadata_key, None)
298
+ if current_value is None:
299
+ current_value = []
300
+
301
+ if isinstance(tags_value, str):
302
+ current_value.append(tags_value)
303
+ elif isinstance(tags_value, list):
304
+ current_value.extend(tags_value)
305
+
306
+ setattr(metadata, metadata_key, current_value)
307
+
308
+ # LLAMA.cpp's direct internal convention
309
+ # (Definitely not part of hugging face formal/informal standard)
310
+ #########################################
311
+ use_model_card_metadata("name", "name")
312
+ use_model_card_metadata("author", "author")
313
+ use_model_card_metadata("version", "version")
314
+ use_model_card_metadata("organization", "organization")
315
+ use_model_card_metadata("description", "description")
316
+ use_model_card_metadata("finetune", "finetune")
317
+ use_model_card_metadata("basename", "basename")
318
+ use_model_card_metadata("size_label", "size_label")
319
+ use_model_card_metadata("source_url", "url")
320
+ use_model_card_metadata("source_doi", "doi")
321
+ use_model_card_metadata("source_uuid", "uuid")
322
+ use_model_card_metadata("source_repo_url", "repo_url")
323
+
324
+ # LLAMA.cpp's huggingface style convention
325
+ # (Definitely not part of hugging face formal/informal standard... but with model_ appended to match their style)
326
+ ###########################################
327
+ use_model_card_metadata("name", "model_name")
328
+ use_model_card_metadata("author", "model_author")
329
+ use_model_card_metadata("version", "model_version")
330
+ use_model_card_metadata("organization", "model_organization")
331
+ use_model_card_metadata("description", "model_description")
332
+ use_model_card_metadata("finetune", "model_finetune")
333
+ use_model_card_metadata("basename", "model_basename")
334
+ use_model_card_metadata("size_label", "model_size_label")
335
+ use_model_card_metadata("source_url", "model_url")
336
+ use_model_card_metadata("source_doi", "model_doi")
337
+ use_model_card_metadata("source_uuid", "model_uuid")
338
+ use_model_card_metadata("source_repo_url", "model_repo_url")
339
+
340
+ # Hugging Face Direct Convention
341
+ #################################
342
+
343
+ # Not part of huggingface model card standard but notice some model creator using it
344
+ # such as TheBloke in 'TheBloke/Mistral-7B-Instruct-v0.2-GGUF'
345
+ use_model_card_metadata("name", "model_name")
346
+ use_model_card_metadata("author", "model_creator")
347
+ use_model_card_metadata("basename", "model_type")
348
+
349
+ if "base_model" in model_card:
350
+ # This represents the parent models that this is based on
351
+ # Example: stabilityai/stable-diffusion-xl-base-1.0. Can also be a list (for merges)
352
+ # Example of merges: https://huggingface.co/EmbeddedLLM/Mistral-7B-Merge-14-v0.1/blob/main/README.md
353
+ metadata_base_models = []
354
+ base_model_value = model_card.get("base_model", None)
355
+
356
+ if base_model_value is not None:
357
+ if isinstance(base_model_value, str):
358
+ metadata_base_models.append(base_model_value)
359
+ elif isinstance(base_model_value, list):
360
+ metadata_base_models.extend(base_model_value)
361
+
362
+ if metadata.base_models is None:
363
+ metadata.base_models = []
364
+
365
+ for model_id in metadata_base_models:
366
+ # NOTE: model size of base model is assumed to be similar to the size of the current model
367
+ model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params)
368
+ base_model = {}
369
+ if model_full_name_component is not None:
370
+ base_model["name"] = Metadata.id_to_title(model_full_name_component)
371
+ if org_component is not None:
372
+ base_model["organization"] = Metadata.id_to_title(org_component)
373
+ if version is not None:
374
+ base_model["version"] = version
375
+ if org_component is not None and model_full_name_component is not None:
376
+ base_model["repo_url"] = f"https://huggingface.co/{org_component}/{model_full_name_component}"
377
+ metadata.base_models.append(base_model)
378
+
379
+ use_model_card_metadata("license", "license")
380
+ use_model_card_metadata("license_name", "license_name")
381
+ use_model_card_metadata("license_link", "license_link")
382
+
383
+ use_array_model_card_metadata("tags", "tags")
384
+ use_array_model_card_metadata("tags", "pipeline_tag")
385
+
386
+ use_array_model_card_metadata("languages", "languages")
387
+ use_array_model_card_metadata("languages", "language")
388
+
389
+ use_array_model_card_metadata("datasets", "datasets")
390
+ use_array_model_card_metadata("datasets", "dataset")
391
+
392
+ # Hugging Face Parameter Heuristics
393
+ ####################################
394
+
395
+ if hf_params is not None:
396
+
397
+ hf_name_or_path = hf_params.get("_name_or_path")
398
+ if hf_name_or_path is not None and hf_name_or_path.count('/') <= 1:
399
+ # Use _name_or_path only if its actually a model name and not some computer path
400
+ # e.g. 'meta-llama/Llama-2-7b-hf'
401
+ model_id = hf_name_or_path
402
+ model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params)
403
+ if metadata.name is None and model_full_name_component is not None:
404
+ metadata.name = Metadata.id_to_title(model_full_name_component)
405
+ if metadata.organization is None and org_component is not None:
406
+ metadata.organization = Metadata.id_to_title(org_component)
407
+ if metadata.basename is None and basename is not None:
408
+ metadata.basename = basename
409
+ if metadata.finetune is None and finetune is not None:
410
+ metadata.finetune = finetune
411
+ if metadata.version is None and version is not None:
412
+ metadata.version = version
413
+ if metadata.size_label is None and size_label is not None:
414
+ metadata.size_label = size_label
415
+
416
+ # Directory Folder Name Fallback Heuristics
417
+ ############################################
418
+ if model_path is not None:
419
+ model_id = model_path.name
420
+ model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params)
421
+ if metadata.name is None and model_full_name_component is not None:
422
+ metadata.name = Metadata.id_to_title(model_full_name_component)
423
+ if metadata.organization is None and org_component is not None:
424
+ metadata.organization = Metadata.id_to_title(org_component)
425
+ if metadata.basename is None and basename is not None:
426
+ metadata.basename = basename
427
+ if metadata.finetune is None and finetune is not None:
428
+ metadata.finetune = finetune
429
+ if metadata.version is None and version is not None:
430
+ metadata.version = version
431
+ if metadata.size_label is None and size_label is not None:
432
+ metadata.size_label = size_label
433
+
434
+ return metadata
435
+
436
+ def set_gguf_meta_model(self, gguf_writer: gguf.GGUFWriter):
437
+ assert self.name is not None
438
+ gguf_writer.add_name(self.name)
439
+
440
+ if self.author is not None:
441
+ gguf_writer.add_author(self.author)
442
+ if self.version is not None:
443
+ gguf_writer.add_version(self.version)
444
+ if self.organization is not None:
445
+ gguf_writer.add_organization(self.organization)
446
+
447
+ if self.finetune is not None:
448
+ gguf_writer.add_finetune(self.finetune)
449
+ if self.basename is not None:
450
+ gguf_writer.add_basename(self.basename)
451
+
452
+ if self.description is not None:
453
+ gguf_writer.add_description(self.description)
454
+ if self.quantized_by is not None:
455
+ gguf_writer.add_quantized_by(self.quantized_by)
456
+
457
+ if self.size_label is not None:
458
+ gguf_writer.add_size_label(self.size_label)
459
+
460
+ if self.license is not None:
461
+ gguf_writer.add_license(self.license)
462
+ if self.license_name is not None:
463
+ gguf_writer.add_license_name(self.license_name)
464
+ if self.license_link is not None:
465
+ gguf_writer.add_license_link(self.license_link)
466
+
467
+ if self.url is not None:
468
+ gguf_writer.add_url(self.url)
469
+ if self.doi is not None:
470
+ gguf_writer.add_doi(self.doi)
471
+ if self.uuid is not None:
472
+ gguf_writer.add_uuid(self.uuid)
473
+ if self.repo_url is not None:
474
+ gguf_writer.add_repo_url(self.repo_url)
475
+
476
+ if self.source_url is not None:
477
+ gguf_writer.add_source_url(self.source_url)
478
+ if self.source_doi is not None:
479
+ gguf_writer.add_source_doi(self.source_doi)
480
+ if self.source_uuid is not None:
481
+ gguf_writer.add_source_uuid(self.source_uuid)
482
+ if self.source_repo_url is not None:
483
+ gguf_writer.add_source_repo_url(self.source_repo_url)
484
+
485
+ if self.base_models is not None:
486
+ gguf_writer.add_base_model_count(len(self.base_models))
487
+ for key, base_model_entry in enumerate(self.base_models):
488
+ if "name" in base_model_entry:
489
+ gguf_writer.add_base_model_name(key, base_model_entry["name"])
490
+ if "author" in base_model_entry:
491
+ gguf_writer.add_base_model_author(key, base_model_entry["author"])
492
+ if "version" in base_model_entry:
493
+ gguf_writer.add_base_model_version(key, base_model_entry["version"])
494
+ if "organization" in base_model_entry:
495
+ gguf_writer.add_base_model_organization(key, base_model_entry["organization"])
496
+ if "url" in base_model_entry:
497
+ gguf_writer.add_base_model_url(key, base_model_entry["url"])
498
+ if "doi" in base_model_entry:
499
+ gguf_writer.add_base_model_doi(key, base_model_entry["doi"])
500
+ if "uuid" in base_model_entry:
501
+ gguf_writer.add_base_model_uuid(key, base_model_entry["uuid"])
502
+ if "repo_url" in base_model_entry:
503
+ gguf_writer.add_base_model_repo_url(key, base_model_entry["repo_url"])
504
+
505
+ if self.tags is not None:
506
+ gguf_writer.add_tags(self.tags)
507
+ if self.languages is not None:
508
+ gguf_writer.add_languages(self.languages)
509
+ if self.datasets is not None:
510
+ gguf_writer.add_datasets(self.datasets)
gguf/py.typed ADDED
File without changes
gguf/quants.py ADDED
@@ -0,0 +1,1269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from abc import ABC, abstractmethod
3
+ from typing import Any, Callable, Sequence
4
+ from math import log2, ceil
5
+
6
+ from numpy.typing import DTypeLike
7
+
8
+ from .constants import GGML_QUANT_SIZES, GGMLQuantizationType, QK_K
9
+ from .lazy import LazyNumpyTensor
10
+
11
+ import numpy as np
12
+
13
+
14
+ def quant_shape_to_byte_shape(shape: Sequence[int], quant_type: GGMLQuantizationType) -> tuple[int, ...]:
15
+ block_size, type_size = GGML_QUANT_SIZES[quant_type]
16
+ if shape[-1] % block_size != 0:
17
+ raise ValueError(f"Quantized tensor row size ({shape[-1]}) is not a multiple of {quant_type.name} block size ({block_size})")
18
+ return (*shape[:-1], shape[-1] // block_size * type_size)
19
+
20
+
21
+ def quant_shape_from_byte_shape(shape: Sequence[int], quant_type: GGMLQuantizationType) -> tuple[int, ...]:
22
+ block_size, type_size = GGML_QUANT_SIZES[quant_type]
23
+ if shape[-1] % type_size != 0:
24
+ raise ValueError(f"Quantized tensor bytes per row ({shape[-1]}) is not a multiple of {quant_type.name} type size ({type_size})")
25
+ return (*shape[:-1], shape[-1] // type_size * block_size)
26
+
27
+
28
+ # This is faster than np.vectorize and np.apply_along_axis because it works on more than one row at a time
29
+ def _apply_over_grouped_rows(func: Callable[[np.ndarray], np.ndarray], arr: np.ndarray, otype: DTypeLike, oshape: tuple[int, ...]) -> np.ndarray:
30
+ rows = arr.reshape((-1, arr.shape[-1]))
31
+ osize = 1
32
+ for dim in oshape:
33
+ osize *= dim
34
+ out = np.empty(shape=osize, dtype=otype)
35
+ # compute over groups of 16 rows (arbitrary, but seems good for performance)
36
+ n_groups = (rows.shape[0] // 16) or 1
37
+ np.concatenate([func(group).ravel() for group in np.array_split(rows, n_groups)], axis=0, out=out)
38
+ return out.reshape(oshape)
39
+
40
+
41
+ # round away from zero
42
+ # ref: https://stackoverflow.com/a/59143326/22827863
43
+ def np_roundf(n: np.ndarray) -> np.ndarray:
44
+ a = abs(n)
45
+ floored = np.floor(a)
46
+ b = floored + np.floor(2 * (a - floored))
47
+ return np.sign(n) * b
48
+
49
+
50
+ class QuantError(Exception): ...
51
+
52
+
53
+ _type_traits: dict[GGMLQuantizationType, type[__Quant]] = {}
54
+
55
+
56
+ def quantize(data: np.ndarray, qtype: GGMLQuantizationType) -> np.ndarray:
57
+ if qtype == GGMLQuantizationType.F32:
58
+ return data.astype(np.float32, copy=False)
59
+ elif qtype == GGMLQuantizationType.F16:
60
+ return data.astype(np.float16, copy=False)
61
+ elif (q := _type_traits.get(qtype)) is not None:
62
+ return q.quantize(data)
63
+ else:
64
+ raise NotImplementedError(f"Quantization for {qtype.name} is not yet implemented")
65
+
66
+
67
+ def dequantize(data: np.ndarray, qtype: GGMLQuantizationType) -> np.ndarray:
68
+ if qtype == GGMLQuantizationType.F32:
69
+ return data.view(np.float32)
70
+ elif qtype == GGMLQuantizationType.F16:
71
+ return data.view(np.float16).astype(np.float32)
72
+ elif (q := _type_traits.get(qtype)) is not None:
73
+ return q.dequantize(data)
74
+ else:
75
+ raise NotImplementedError(f"Dequantization for {qtype.name} is not yet implemented")
76
+
77
+
78
+ class __Quant(ABC):
79
+ qtype: GGMLQuantizationType
80
+ block_size: int
81
+ type_size: int
82
+
83
+ grid: np.ndarray[Any, np.dtype[np.float32]] | None = None
84
+ grid_shape: tuple[int, int] = (0, 0)
85
+ grid_map: tuple[int | float, ...] = ()
86
+ grid_hex: bytes | None = None
87
+
88
+ def __init__(self):
89
+ return TypeError("Quant conversion classes can't have instances")
90
+
91
+ def __init_subclass__(cls, qtype: GGMLQuantizationType) -> None:
92
+ cls.qtype = qtype
93
+ cls.block_size, cls.type_size = GGML_QUANT_SIZES[qtype]
94
+ cls.__quantize_lazy = LazyNumpyTensor._wrap_fn(
95
+ cls.__quantize_array,
96
+ meta_noop=(np.uint8, cls.__shape_to_bytes)
97
+ )
98
+ cls.__dequantize_lazy = LazyNumpyTensor._wrap_fn(
99
+ cls.__dequantize_array,
100
+ meta_noop=(np.float32, cls.__shape_from_bytes)
101
+ )
102
+ assert qtype not in _type_traits
103
+ _type_traits[qtype] = cls
104
+
105
+ @classmethod
106
+ def init_grid(cls):
107
+ if cls.grid is not None or cls.grid_hex is None:
108
+ return
109
+
110
+ bits_per_elem = ceil(log2(len(cls.grid_map)))
111
+ assert bits_per_elem != 0, cls.qtype.name
112
+ elems_per_byte = 8 // bits_per_elem
113
+
114
+ grid = np.frombuffer(cls.grid_hex, dtype=np.uint8)
115
+ # decode hexadecimal chars from grid
116
+ grid = grid.reshape((-1, 2))
117
+ grid = (np.where(grid > 0x40, grid + 9, grid) & 0x0F) << np.array([4, 0], dtype=np.uint8).reshape((1, 2))
118
+ grid = grid[..., 0] | grid[..., 1]
119
+ # unpack the grid values
120
+ grid = grid.reshape((-1, 1)) >> np.array([i for i in range(0, 8, 8 // elems_per_byte)], dtype=np.uint8).reshape((1, elems_per_byte))
121
+ grid = (grid & ((1 << bits_per_elem) - 1)).reshape((-1, 1))
122
+ grid_map = np.array(cls.grid_map, dtype=np.float32).reshape((1, -1))
123
+ grid = np.take_along_axis(grid_map, grid, axis=-1)
124
+ cls.grid = grid.reshape((1, 1, *cls.grid_shape))
125
+
126
+ @classmethod
127
+ @abstractmethod
128
+ def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
129
+ raise NotImplementedError
130
+
131
+ @classmethod
132
+ @abstractmethod
133
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
134
+ raise NotImplementedError
135
+
136
+ @classmethod
137
+ def quantize_rows(cls, rows: np.ndarray) -> np.ndarray:
138
+ rows = rows.astype(np.float32, copy=False)
139
+ shape = rows.shape
140
+ n_blocks = rows.size // cls.block_size
141
+ blocks = rows.reshape((n_blocks, cls.block_size))
142
+ blocks = cls.quantize_blocks(blocks)
143
+ assert blocks.dtype == np.uint8
144
+ assert blocks.shape[-1] == cls.type_size
145
+ return blocks.reshape(cls.__shape_to_bytes(shape))
146
+
147
+ @classmethod
148
+ def dequantize_rows(cls, rows: np.ndarray) -> np.ndarray:
149
+ rows = rows.view(np.uint8)
150
+ shape = rows.shape
151
+ n_blocks = rows.size // cls.type_size
152
+ blocks = rows.reshape((n_blocks, cls.type_size))
153
+ blocks = cls.dequantize_blocks(blocks)
154
+ assert blocks.dtype == np.float32
155
+ assert blocks.shape[-1] == cls.block_size
156
+ return blocks.reshape(cls.__shape_from_bytes(shape))
157
+
158
+ @classmethod
159
+ def __shape_to_bytes(cls, shape: Sequence[int]):
160
+ return quant_shape_to_byte_shape(shape, cls.qtype)
161
+
162
+ @classmethod
163
+ def __shape_from_bytes(cls, shape: Sequence[int]):
164
+ return quant_shape_from_byte_shape(shape, cls.qtype)
165
+
166
+ @classmethod
167
+ def __quantize_array(cls, array: np.ndarray) -> np.ndarray:
168
+ return _apply_over_grouped_rows(cls.quantize_rows, arr=array, otype=np.uint8, oshape=cls.__shape_to_bytes(array.shape))
169
+
170
+ @classmethod
171
+ def __dequantize_array(cls, array: np.ndarray) -> np.ndarray:
172
+ cls.init_grid()
173
+ return _apply_over_grouped_rows(cls.dequantize_rows, arr=array, otype=np.float32, oshape=cls.__shape_from_bytes(array.shape))
174
+
175
+ @classmethod
176
+ def __quantize_lazy(cls, lazy_tensor: LazyNumpyTensor, /) -> Any:
177
+ pass
178
+
179
+ @classmethod
180
+ def __dequantize_lazy(cls, lazy_tensor: LazyNumpyTensor, /) -> Any:
181
+ pass
182
+
183
+ @classmethod
184
+ def can_quantize(cls, tensor: np.ndarray | LazyNumpyTensor) -> bool:
185
+ return tensor.shape[-1] % cls.block_size == 0
186
+
187
+ @classmethod
188
+ def quantize(cls, tensor: np.ndarray | LazyNumpyTensor) -> np.ndarray:
189
+ if not cls.can_quantize(tensor):
190
+ raise QuantError(f"Can't quantize tensor with shape {tensor.shape} to {cls.qtype.name}")
191
+ if isinstance(tensor, LazyNumpyTensor):
192
+ return cls.__quantize_lazy(tensor)
193
+ else:
194
+ return cls.__quantize_array(tensor)
195
+
196
+ @classmethod
197
+ def dequantize(cls, tensor: np.ndarray | LazyNumpyTensor) -> np.ndarray:
198
+ if isinstance(tensor, LazyNumpyTensor):
199
+ return cls.__dequantize_lazy(tensor)
200
+ else:
201
+ return cls.__dequantize_array(tensor)
202
+
203
+
204
+ class BF16(__Quant, qtype=GGMLQuantizationType.BF16):
205
+ @classmethod
206
+ # same as ggml_compute_fp32_to_bf16 in ggml-impl.h
207
+ def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
208
+ n = blocks.view(np.uint32)
209
+ # force nan to quiet
210
+ n = np.where((n & 0x7fffffff) > 0x7f800000, (n & np.uint32(0xffff0000)) | np.uint32(64 << 16), n)
211
+ # round to nearest even
212
+ n = (np.uint64(n) + (0x7fff + ((n >> 16) & 1))) >> 16
213
+ return n.astype(np.uint16).view(np.uint8)
214
+
215
+ @classmethod
216
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
217
+ return (blocks.view(np.int16).astype(np.int32) << 16).view(np.float32)
218
+
219
+
220
+ class Q4_0(__Quant, qtype=GGMLQuantizationType.Q4_0):
221
+ @classmethod
222
+ def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
223
+ n_blocks = blocks.shape[0]
224
+
225
+ imax = abs(blocks).argmax(axis=-1, keepdims=True)
226
+ max = np.take_along_axis(blocks, imax, axis=-1)
227
+
228
+ d = max / -8
229
+ with np.errstate(divide="ignore"):
230
+ id = np.where(d == 0, 0, 1 / d)
231
+ # FIXME: Q4_0's reference rounding is cursed and depends on FMA
232
+ qs = np.trunc((np.float64(blocks) * np.float64(id)) + np.float64(8.5), dtype=np.float32).astype(np.uint8).clip(0, 15)
233
+
234
+ qs = qs.reshape((n_blocks, 2, cls.block_size // 2))
235
+ qs = qs[..., 0, :] | (qs[..., 1, :] << np.uint8(4))
236
+
237
+ d = d.astype(np.float16).view(np.uint8)
238
+
239
+ return np.concatenate([d, qs], axis=-1)
240
+
241
+ @classmethod
242
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
243
+ n_blocks = blocks.shape[0]
244
+
245
+ d, qs = np.hsplit(blocks, [2])
246
+
247
+ d = d.view(np.float16).astype(np.float32)
248
+
249
+ qs = qs.reshape((n_blocks, -1, 1, cls.block_size // 2)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1))
250
+ qs = (qs & np.uint8(0x0F)).reshape((n_blocks, -1)).astype(np.int8) - np.int8(8)
251
+
252
+ return (d * qs.astype(np.float32))
253
+
254
+
255
+ class Q4_1(__Quant, qtype=GGMLQuantizationType.Q4_1):
256
+ @classmethod
257
+ def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
258
+ n_blocks = blocks.shape[0]
259
+
260
+ max = blocks.max(axis=-1, keepdims=True)
261
+ min = blocks.min(axis=-1, keepdims=True)
262
+
263
+ d = (max - min) / 15
264
+ with np.errstate(divide="ignore"):
265
+ id = np.where(d == 0, 0, 1 / d)
266
+ qs = np.trunc((blocks - min) * id + np.float32(0.5), dtype=np.float32).astype(np.uint8).clip(0, 15)
267
+
268
+ qs = qs.reshape((n_blocks, 2, cls.block_size // 2))
269
+ qs = qs[..., 0, :] | (qs[..., 1, :] << np.uint8(4))
270
+
271
+ d = d.astype(np.float16).view(np.uint8)
272
+ m = min.astype(np.float16).view(np.uint8)
273
+
274
+ return np.concatenate([d, m, qs], axis=-1)
275
+
276
+ @classmethod
277
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
278
+ n_blocks = blocks.shape[0]
279
+
280
+ d, rest = np.hsplit(blocks, [2])
281
+ m, qs = np.hsplit(rest, [2])
282
+
283
+ d = d.view(np.float16).astype(np.float32)
284
+ m = m.view(np.float16).astype(np.float32)
285
+
286
+ qs = qs.reshape((n_blocks, -1, 1, cls.block_size // 2)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1))
287
+ qs = (qs & np.uint8(0x0F)).reshape((n_blocks, -1)).astype(np.float32)
288
+
289
+ return (d * qs) + m
290
+
291
+
292
+ class Q5_0(__Quant, qtype=GGMLQuantizationType.Q5_0):
293
+ @classmethod
294
+ def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
295
+ n_blocks = blocks.shape[0]
296
+
297
+ imax = abs(blocks).argmax(axis=-1, keepdims=True)
298
+ max = np.take_along_axis(blocks, imax, axis=-1)
299
+
300
+ d = max / -16
301
+ with np.errstate(divide="ignore"):
302
+ id = np.where(d == 0, 0, 1 / d)
303
+ # FIXME: Q5_0's reference rounding is cursed and depends on FMA
304
+ q = np.trunc((np.float64(blocks) * np.float64(id)) + np.float64(16.5), dtype=np.float32).astype(np.uint8).clip(0, 31)
305
+
306
+ qs = q.reshape((n_blocks, 2, cls.block_size // 2))
307
+ qs = (qs[..., 0, :] & np.uint8(0x0F)) | (qs[..., 1, :] << np.uint8(4))
308
+
309
+ qh = np.packbits(q.reshape((n_blocks, 1, 32)) >> np.uint8(4), axis=-1, bitorder="little").reshape(n_blocks, 4)
310
+
311
+ d = d.astype(np.float16).view(np.uint8)
312
+
313
+ return np.concatenate([d, qh, qs], axis=-1)
314
+
315
+ @classmethod
316
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
317
+ n_blocks = blocks.shape[0]
318
+
319
+ d, rest = np.hsplit(blocks, [2])
320
+ qh, qs = np.hsplit(rest, [4])
321
+
322
+ d = d.view(np.float16).astype(np.float32)
323
+ qh = qh.view(np.uint32)
324
+
325
+ qh = qh.reshape((n_blocks, 1)) >> np.array([i for i in range(32)], dtype=np.uint32).reshape((1, 32))
326
+ ql = qs.reshape((n_blocks, -1, 1, cls.block_size // 2)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1))
327
+ qh = (qh & np.uint32(0x01)).astype(np.uint8)
328
+ ql = (ql & np.uint8(0x0F)).reshape((n_blocks, -1))
329
+
330
+ qs = (ql | (qh << np.uint8(4))).astype(np.int8) - np.int8(16)
331
+
332
+ return (d * qs.astype(np.float32))
333
+
334
+
335
+ class Q5_1(__Quant, qtype=GGMLQuantizationType.Q5_1):
336
+ @classmethod
337
+ def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
338
+ n_blocks = blocks.shape[0]
339
+
340
+ max = blocks.max(axis=-1, keepdims=True)
341
+ min = blocks.min(axis=-1, keepdims=True)
342
+
343
+ d = (max - min) / 31
344
+ with np.errstate(divide="ignore"):
345
+ id = np.where(d == 0, 0, 1 / d)
346
+ q = np.trunc((blocks - min) * id + np.float32(0.5), dtype=np.float32).astype(np.uint8).clip(0, 31)
347
+
348
+ qs = q.reshape((n_blocks, 2, cls.block_size // 2))
349
+ qs = (qs[..., 0, :] & np.uint8(0x0F)) | (qs[..., 1, :] << np.uint8(4))
350
+
351
+ qh = np.packbits(q.reshape((n_blocks, 1, 32)) >> np.uint8(4), axis=-1, bitorder="little").reshape(n_blocks, 4)
352
+
353
+ d = d.astype(np.float16).view(np.uint8)
354
+ m = min.astype(np.float16).view(np.uint8)
355
+
356
+ return np.concatenate([d, m, qh, qs], axis=-1)
357
+
358
+ @classmethod
359
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
360
+ n_blocks = blocks.shape[0]
361
+
362
+ d, rest = np.hsplit(blocks, [2])
363
+ m, rest = np.hsplit(rest, [2])
364
+ qh, qs = np.hsplit(rest, [4])
365
+
366
+ d = d.view(np.float16).astype(np.float32)
367
+ m = m.view(np.float16).astype(np.float32)
368
+ qh = qh.view(np.uint32)
369
+
370
+ qh = qh.reshape((n_blocks, 1)) >> np.array([i for i in range(32)], dtype=np.uint32).reshape((1, 32))
371
+ ql = qs.reshape((n_blocks, -1, 1, cls.block_size // 2)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1))
372
+ qh = (qh & np.uint32(0x01)).astype(np.uint8)
373
+ ql = (ql & np.uint8(0x0F)).reshape((n_blocks, -1))
374
+
375
+ qs = (ql | (qh << np.uint8(4))).astype(np.float32)
376
+
377
+ return (d * qs) + m
378
+
379
+
380
+ class Q8_0(__Quant, qtype=GGMLQuantizationType.Q8_0):
381
+ @classmethod
382
+ # Implementation of Q8_0 with bit-exact same results as reference implementation in ggml-quants.c
383
+ def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
384
+
385
+ d = abs(blocks).max(axis=1, keepdims=True) / 127
386
+ with np.errstate(divide="ignore"):
387
+ id = np.where(d == 0, 0, 1 / d)
388
+ qs = np_roundf(blocks * id)
389
+
390
+ # (n_blocks, 2)
391
+ d = d.astype(np.float16).view(np.uint8)
392
+ # (n_blocks, block_size)
393
+ qs = qs.astype(np.int8).view(np.uint8)
394
+
395
+ return np.concatenate([d, qs], axis=1)
396
+
397
+ @classmethod
398
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
399
+ d, x = np.split(blocks, [2], axis=1)
400
+ d = d.view(np.float16).astype(np.float32)
401
+ x = x.view(np.int8).astype(np.float32)
402
+
403
+ return (x * d)
404
+
405
+
406
+ class Q2_K(__Quant, qtype=GGMLQuantizationType.Q2_K):
407
+ @classmethod
408
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
409
+ n_blocks = blocks.shape[0]
410
+
411
+ scales, rest = np.hsplit(blocks, [QK_K // 16])
412
+ qs, rest = np.hsplit(rest, [QK_K // 4])
413
+ d, dmin = np.hsplit(rest, [2])
414
+
415
+ d = d.view(np.float16).astype(np.float32)
416
+ dmin = dmin.view(np.float16).astype(np.float32)
417
+
418
+ # (n_blocks, 16, 1)
419
+ dl = (d * (scales & 0xF).astype(np.float32)).reshape((n_blocks, QK_K // 16, 1))
420
+ ml = (dmin * (scales >> 4).astype(np.float32)).reshape((n_blocks, QK_K // 16, 1))
421
+
422
+ shift = np.array([0, 2, 4, 6], dtype=np.uint8).reshape((1, 1, 4, 1))
423
+
424
+ qs = (qs.reshape((n_blocks, -1, 1, 32)) >> shift) & np.uint8(3)
425
+
426
+ qs = qs.reshape((n_blocks, QK_K // 16, 16)).astype(np.float32)
427
+
428
+ qs = dl * qs - ml
429
+
430
+ return qs.reshape((n_blocks, -1))
431
+
432
+
433
+ class Q3_K(__Quant, qtype=GGMLQuantizationType.Q3_K):
434
+ @classmethod
435
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
436
+ n_blocks = blocks.shape[0]
437
+
438
+ hmask, rest = np.hsplit(blocks, [QK_K // 8])
439
+ qs, rest = np.hsplit(rest, [QK_K // 4])
440
+ scales, d = np.hsplit(rest, [12])
441
+
442
+ d = d.view(np.float16).astype(np.float32)
443
+
444
+ # The scales are packed at 6-bit each in this pattern:
445
+ # 0: IIIIAAAA
446
+ # 1: JJJJBBBB
447
+ # 2: KKKKCCCC
448
+ # 3: LLLLDDDD
449
+ # 4: MMMMEEEE
450
+ # 5: NNNNFFFF
451
+ # 6: OOOOGGGG
452
+ # 7: PPPPHHHH
453
+ # 8: MMIIEEAA
454
+ # 9: NNJJFFBB
455
+ # 10: OOKKGGCC
456
+ # 11: PPLLHHDD
457
+ lscales, hscales = np.hsplit(scales, [8])
458
+ lscales = lscales.reshape((n_blocks, 1, 8)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 2, 1))
459
+ lscales = lscales.reshape((n_blocks, 16))
460
+ hscales = hscales.reshape((n_blocks, 1, 4)) >> np.array([0, 2, 4, 6], dtype=np.uint8).reshape((1, 4, 1))
461
+ hscales = hscales.reshape((n_blocks, 16))
462
+ scales = (lscales & np.uint8(0x0F)) | ((hscales & np.uint8(0x03)) << np.uint8(4))
463
+ scales = (scales.astype(np.int8) - np.int8(32)).astype(np.float32)
464
+
465
+ dl = (d * scales).reshape((n_blocks, 16, 1))
466
+
467
+ ql = qs.reshape((n_blocks, -1, 1, 32)) >> np.array([0, 2, 4, 6], dtype=np.uint8).reshape((1, 1, 4, 1))
468
+ qh = hmask.reshape(n_blocks, -1, 1, 32) >> np.array([i for i in range(8)], dtype=np.uint8).reshape((1, 1, 8, 1))
469
+ ql = ql.reshape((n_blocks, 16, QK_K // 16)) & np.uint8(3)
470
+ qh = (qh.reshape((n_blocks, 16, QK_K // 16)) & np.uint8(1))
471
+ qh = qh ^ np.uint8(1) # strangely, the offset is zero when the bitmask is 1
472
+ q = (ql.astype(np.int8) - (qh << np.uint8(2)).astype(np.int8)).astype(np.float32)
473
+
474
+ return (dl * q).reshape((n_blocks, QK_K))
475
+
476
+
477
+ class Q4_K(__Quant, qtype=GGMLQuantizationType.Q4_K):
478
+ K_SCALE_SIZE = 12
479
+
480
+ @staticmethod
481
+ def get_scale_min(scales: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
482
+ n_blocks = scales.shape[0]
483
+ scales = scales.view(np.uint8)
484
+ ### Unpacking the following: ###
485
+ # 0 EEAAAAAA
486
+ # 1 FFBBBBBB
487
+ # 2 GGCCCCCC
488
+ # 3 HHDDDDDD
489
+ # 4 eeaaaaaa
490
+ # 5 ffbbbbbb
491
+ # 6 ggcccccc
492
+ # 7 hhdddddd
493
+ # 8 eeeeEEEE
494
+ # 9 ffffFFFF
495
+ # 10 ggggGGGG
496
+ # 11 hhhhHHHH
497
+ scales = scales.reshape((n_blocks, 3, 4))
498
+ d, m, m_d = np.split(scales, 3, axis=-2)
499
+
500
+ sc = np.concatenate([d & 0x3F, (m_d & 0x0F) | ((d >> 2) & 0x30)], axis=-1)
501
+ min = np.concatenate([m & 0x3F, (m_d >> 4) | ((m >> 2) & 0x30)], axis=-1)
502
+
503
+ return (sc.reshape((n_blocks, 8)), min.reshape((n_blocks, 8)))
504
+
505
+ @classmethod
506
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
507
+ n_blocks = blocks.shape[0]
508
+
509
+ d, rest = np.hsplit(blocks, [2])
510
+ dmin, rest = np.hsplit(rest, [2])
511
+ scales, qs = np.hsplit(rest, [cls.K_SCALE_SIZE])
512
+
513
+ d = d.view(np.float16).astype(np.float32)
514
+ dmin = dmin.view(np.float16).astype(np.float32)
515
+
516
+ sc, m = Q4_K.get_scale_min(scales)
517
+
518
+ d = (d * sc.astype(np.float32)).reshape((n_blocks, -1, 1))
519
+ dm = (dmin * m.astype(np.float32)).reshape((n_blocks, -1, 1))
520
+
521
+ qs = qs.reshape((n_blocks, -1, 1, 32)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1))
522
+ qs = (qs & np.uint8(0x0F)).reshape((n_blocks, -1, 32)).astype(np.float32)
523
+
524
+ return (d * qs - dm).reshape((n_blocks, QK_K))
525
+
526
+
527
+ class Q5_K(__Quant, qtype=GGMLQuantizationType.Q5_K):
528
+ @classmethod
529
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
530
+ n_blocks = blocks.shape[0]
531
+
532
+ d, rest = np.hsplit(blocks, [2])
533
+ dmin, rest = np.hsplit(rest, [2])
534
+ scales, rest = np.hsplit(rest, [Q4_K.K_SCALE_SIZE])
535
+ qh, qs = np.hsplit(rest, [QK_K // 8])
536
+
537
+ d = d.view(np.float16).astype(np.float32)
538
+ dmin = dmin.view(np.float16).astype(np.float32)
539
+
540
+ sc, m = Q4_K.get_scale_min(scales)
541
+
542
+ d = (d * sc.astype(np.float32)).reshape((n_blocks, -1, 1))
543
+ dm = (dmin * m.astype(np.float32)).reshape((n_blocks, -1, 1))
544
+
545
+ ql = qs.reshape((n_blocks, -1, 1, 32)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1))
546
+ qh = qh.reshape((n_blocks, -1, 1, 32)) >> np.array([i for i in range(8)], dtype=np.uint8).reshape((1, 1, 8, 1))
547
+ ql = (ql & np.uint8(0x0F)).reshape((n_blocks, -1, 32))
548
+ qh = (qh & np.uint8(0x01)).reshape((n_blocks, -1, 32))
549
+ q = (ql | (qh << np.uint8(4))).astype(np.float32)
550
+
551
+ return (d * q - dm).reshape((n_blocks, QK_K))
552
+
553
+
554
+ class Q6_K(__Quant, qtype=GGMLQuantizationType.Q6_K):
555
+ @classmethod
556
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
557
+ n_blocks = blocks.shape[0]
558
+
559
+ ql, rest = np.hsplit(blocks, [QK_K // 2])
560
+ qh, rest = np.hsplit(rest, [QK_K // 4])
561
+ scales, d = np.hsplit(rest, [QK_K // 16])
562
+
563
+ scales = scales.view(np.int8).astype(np.float32)
564
+ d = d.view(np.float16).astype(np.float32)
565
+ d = (d * scales).reshape((n_blocks, QK_K // 16, 1))
566
+
567
+ ql = ql.reshape((n_blocks, -1, 1, 64)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1))
568
+ ql = (ql & np.uint8(0x0F)).reshape((n_blocks, -1, 32))
569
+ qh = qh.reshape((n_blocks, -1, 1, 32)) >> np.array([0, 2, 4, 6], dtype=np.uint8).reshape((1, 1, 4, 1))
570
+ qh = (qh & np.uint8(0x03)).reshape((n_blocks, -1, 32))
571
+ q = (ql | (qh << np.uint8(4))).astype(np.int8) - np.int8(32)
572
+ q = q.reshape((n_blocks, QK_K // 16, -1)).astype(np.float32)
573
+
574
+ return (d * q).reshape((n_blocks, QK_K))
575
+
576
+
577
+ class TQ1_0(__Quant, qtype=GGMLQuantizationType.TQ1_0):
578
+ @classmethod
579
+ def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
580
+ n_blocks = blocks.shape[0]
581
+
582
+ d = abs(blocks).max(axis=-1, keepdims=True)
583
+ with np.errstate(divide="ignore"):
584
+ id = np.where(d == 0, 0, 1 / d)
585
+ qs = np_roundf(blocks * id)
586
+ qs = (qs.astype(np.int8) + np.int8(1)).astype(np.uint8)
587
+
588
+ qs0, qs1, qh = qs[..., :(32 * 5)], qs[..., (32 * 5):(48 * 5)], qs[..., (48 * 5):]
589
+ qs0 = qs0.reshape((n_blocks, -1, 5, 32)) * np.array([81, 27, 9, 3, 1], dtype=np.uint8).reshape((1, 1, 5, 1))
590
+ qs0 = np.sum(qs0, axis=-2).reshape((n_blocks, -1))
591
+ qs1 = qs1.reshape((n_blocks, -1, 5, 16)) * np.array([81, 27, 9, 3, 1], dtype=np.uint8).reshape((1, 1, 5, 1))
592
+ qs1 = np.sum(qs1, axis=-2).reshape((n_blocks, -1))
593
+ qh = qh.reshape((n_blocks, -1, 4, 4)) * np.array([81, 27, 9, 3], dtype=np.uint8).reshape((1, 1, 4, 1))
594
+ qh = np.sum(qh, axis=-2).reshape((n_blocks, -1))
595
+ qs = np.concatenate([qs0, qs1, qh], axis=-1)
596
+ qs = (qs.astype(np.uint16) * 256 + (243 - 1)) // 243
597
+
598
+ qs = qs.astype(np.uint8)
599
+ d = d.astype(np.float16).view(np.uint8)
600
+
601
+ return np.concatenate([qs, d], axis=-1)
602
+
603
+ @classmethod
604
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
605
+ n_blocks = blocks.shape[0]
606
+
607
+ qs, rest = np.hsplit(blocks, [(QK_K - 4 * QK_K // 64) // 5])
608
+ qh, d = np.hsplit(rest, [QK_K // 64])
609
+
610
+ d = d.view(np.float16).astype(np.float32)
611
+
612
+ qs0, qs1 = qs[..., :32], qs[..., 32:]
613
+ qs0 = qs0.reshape((n_blocks, -1, 1, 32)) * np.array([1, 3, 9, 27, 81], dtype=np.uint8).reshape((1, 1, 5, 1))
614
+ qs0 = qs0.reshape((n_blocks, -1))
615
+ qs1 = qs1.reshape((n_blocks, -1, 1, 16)) * np.array([1, 3, 9, 27, 81], dtype=np.uint8).reshape((1, 1, 5, 1))
616
+ qs1 = qs1.reshape((n_blocks, -1))
617
+ qh = qh.reshape((n_blocks, -1, 1, 4)) * np.array([1, 3, 9, 27], dtype=np.uint8).reshape((1, 1, 4, 1))
618
+ qh = qh.reshape((n_blocks, -1))
619
+ qs = np.concatenate([qs0, qs1, qh], axis=-1)
620
+ qs = ((qs.astype(np.uint16) * 3) >> 8).astype(np.int8) - np.int8(1)
621
+
622
+ return (d * qs.astype(np.float32))
623
+
624
+
625
+ class TQ2_0(__Quant, qtype=GGMLQuantizationType.TQ2_0):
626
+ @classmethod
627
+ def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
628
+ n_blocks = blocks.shape[0]
629
+
630
+ d = abs(blocks).max(axis=-1, keepdims=True)
631
+ with np.errstate(divide="ignore"):
632
+ id = np.where(d == 0, 0, 1 / d)
633
+ qs = np_roundf(blocks * id)
634
+ qs = (qs.astype(np.int8) + np.int8(1)).astype(np.uint8)
635
+
636
+ qs = qs.reshape((n_blocks, -1, 4, 32)) << np.array([0, 2, 4, 6], dtype=np.uint8).reshape((1, 1, 4, 1))
637
+ qs = qs[..., 0, :] | qs[..., 1, :] | qs[..., 2, :] | qs[..., 3, :]
638
+ qs = qs.reshape((n_blocks, -1))
639
+
640
+ d = d.astype(np.float16).view(np.uint8)
641
+
642
+ return np.concatenate([qs, d], axis=-1)
643
+
644
+ @classmethod
645
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
646
+ n_blocks = blocks.shape[0]
647
+
648
+ qs, d = np.hsplit(blocks, [QK_K // 4])
649
+
650
+ d = d.view(np.float16).astype(np.float32)
651
+
652
+ qs = qs.reshape((n_blocks, -1, 1, 32)) >> np.array([0, 2, 4, 6], dtype=np.uint8).reshape((1, 1, 4, 1))
653
+ qs = (qs & 0x03).reshape((n_blocks, -1)).astype(np.int8) - np.int8(1)
654
+
655
+ return (d * qs.astype(np.float32))
656
+
657
+
658
+ class IQ2_XXS(__Quant, qtype=GGMLQuantizationType.IQ2_XXS):
659
+ ksigns: bytes = (
660
+ b"\x00\x81\x82\x03\x84\x05\x06\x87\x88\x09\x0a\x8b\x0c\x8d\x8e\x0f"
661
+ b"\x90\x11\x12\x93\x14\x95\x96\x17\x18\x99\x9a\x1b\x9c\x1d\x1e\x9f"
662
+ b"\xa0\x21\x22\xa3\x24\xa5\xa6\x27\x28\xa9\xaa\x2b\xac\x2d\x2e\xaf"
663
+ b"\x30\xb1\xb2\x33\xb4\x35\x36\xb7\xb8\x39\x3a\xbb\x3c\xbd\xbe\x3f"
664
+ b"\xc0\x41\x42\xc3\x44\xc5\xc6\x47\x48\xc9\xca\x4b\xcc\x4d\x4e\xcf"
665
+ b"\x50\xd1\xd2\x53\xd4\x55\x56\xd7\xd8\x59\x5a\xdb\x5c\xdd\xde\x5f"
666
+ b"\x60\xe1\xe2\x63\xe4\x65\x66\xe7\xe8\x69\x6a\xeb\x6c\xed\xee\x6f"
667
+ b"\xf0\x71\x72\xf3\x74\xf5\xf6\x77\x78\xf9\xfa\x7b\xfc\x7d\x7e\xff"
668
+ )
669
+
670
+ # iq2xxs_grid, but with each byte of the original packed in 2 bits,
671
+ # by mapping 0x08 to 0, 0x19 to 1, and 0x2b to 2.
672
+ grid_shape = (256, 8)
673
+ grid_map = (0x08, 0x19, 0x2b)
674
+ grid_hex = (
675
+ b"00000200050008000a00110014002000220028002a0041004400500058006100"
676
+ b"6400800082008a00a20001010401100115014001840198010002020222028202"
677
+ b"010404041004210424044004420448046004810484049004a404000502050805"
678
+ b"200546056905800591050906100640068406a406000805080808140828084108"
679
+ b"440850085208880804094009020a140a01100410101021104010601084109010"
680
+ b"951000110811201150115a118011241245120014081420142514491480141815"
681
+ b"6215001616160118041810184018811800190519a019511a002002200a204420"
682
+ b"6120802082202921482100220222012404241024402456240025412564259026"
683
+ b"082820289428442a014004401040184021402440404048405640604081408440"
684
+ b"9040004120416141804185410142104248425642684200440844204480449944"
685
+ b"124524450046014804481048404845480049584961498249454a904a00500850"
686
+ b"1150195020508050885004514251a4519152905492540a550156545600581158"
687
+ b"195864584059085a046010604060686000615561186260620064056410651265"
688
+ b"84654268008002800a8041808280048118814081118201840484108415844084"
689
+ b"608400854685948509864086608602880489118a0490109024904090a1901691"
690
+ b"8091459200942294449451958198209902a050a085a009a100a218a450a804a9"
691
+ )
692
+
693
+ @classmethod
694
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
695
+ n_blocks = blocks.shape[0]
696
+
697
+ d, qs = np.hsplit(blocks, [2])
698
+
699
+ d = d.view(np.float16).astype(np.float32)
700
+
701
+ qs = qs.view(np.uint32).reshape(n_blocks, -1, 2)
702
+
703
+ db = d * (np.float32(0.5) + (qs[..., 1] >> 28).astype(np.float32)) * np.float32(0.25)
704
+ db = db.reshape((n_blocks, -1, 1, 1))
705
+
706
+ # get the sign indices and unpack the bits
707
+ signs = qs[..., 1].reshape((n_blocks, -1, 1)) >> np.array([0, 7, 14, 21], dtype=np.uint32).reshape((1, 1, 4))
708
+ ksigns = np.frombuffer(cls.ksigns, dtype=np.uint8).reshape((1, 1, 1, 128))
709
+ signs = (signs & np.uint32(0x7F)).reshape((n_blocks, -1, 4, 1))
710
+ signs = np.take_along_axis(ksigns, signs, axis=-1)
711
+ signs = signs.reshape((n_blocks, -1, 4, 1)) >> np.array([i for i in range(8)], dtype=np.uint8).reshape((1, 1, 1, 8))
712
+ signs = signs & np.uint8(0x01)
713
+ signs = np.where(signs == 0, np.float32(1), np.float32(-1))
714
+ signs = signs.reshape((n_blocks, -1, 4, 8))
715
+
716
+ assert cls.grid is not None
717
+ grid = np.take_along_axis(cls.grid, qs[..., 0].copy().view(np.uint8).reshape((n_blocks, -1, 1, 1)), axis=-2)
718
+ grid = grid.reshape((n_blocks, -1, 4, 8))
719
+
720
+ return (db * grid * signs).reshape((n_blocks, -1))
721
+
722
+
723
+ class IQ2_XS(__Quant, qtype=GGMLQuantizationType.IQ2_XS):
724
+ # iq2xs_grid, but with each byte of the original packed in 2 bits,
725
+ # by mapping 0x08 to 0, 0x19 to 1, and 0x2b to 2.
726
+ grid_shape = (512, 8)
727
+ grid_map = (0x08, 0x19, 0x2b)
728
+ grid_hex = (
729
+ b"00000200050008000a0011001400160019002000220025002800410044004600"
730
+ b"49005000520055005800610064008000820085008800910094009900a0000101"
731
+ b"04010601090110011201150118011a0121012401400142014501480151015401"
732
+ b"6001680181018401900100020202050208021102140220024102440250025502"
733
+ b"80028a0201040404060409041004120415041804210424044004420445044804"
734
+ b"5104540456046004810484049004000502050505080511051405200541054405"
735
+ b"500561058005010604061006260640064206840600080208050808080a081108"
736
+ b"14082008250841084408500858088008a008aa08010904091009400981098909"
737
+ b"000a200a280a960aa00a01100410061009101010121015101810211024104010"
738
+ b"4210451048105110541060106a10811084109010001102110511081111111411"
739
+ b"2011411144115011801194119611011204120612101240126012001402140514"
740
+ b"0814111414142014411444144914501464148014011504151015401500161416"
741
+ b"49160118041810181218401854188618001905196619511aa91a002002200520"
742
+ b"08200a201120142020204120442050208020a020012104211021402148216521"
743
+ b"002222228022a82201240424102429244024002541255225992501261a26a626"
744
+ b"002808280a28202855288828a22868299029082a202a822a882a8a2a01400440"
745
+ b"0640094010401240154018402140244040404240454048404a40514054406040"
746
+ b"6540814084409040004102410541084111411441204141414441504180418541"
747
+ b"a241014204421042124229424042004402440544084411441444194420444144"
748
+ b"4444504480449444014504451045244540459a4500460a464446504601480448"
749
+ b"1048404845485448624800491149444950496949044a00500250055008501150"
750
+ b"145020502850415044505050805001510451105115514051425100524452aa52"
751
+ b"0154045410542154405460548154a154005508558055885521566856a1560058"
752
+ b"14584158505899581a5940594259855a0160046010604060546062608660a960"
753
+ b"006124624a62926200641664106540654565a46501686a682569066a546a626a"
754
+ b"00800280058008801180148020802a8041804480508080808280a880aa800181"
755
+ b"0481068110814081518159810082208280828282a082a8820184048410841284"
756
+ b"158440846084898400854485a58518866a860088088825885a8880888288a888"
757
+ b"0689228a808a888a968aa88a0190049010904090569084900091229164915692"
758
+ b"89920094059444945094589429959095929541965198a6984999159a609a00a0"
759
+ b"02a008a00aa020a02aa0a0a051a159a1a6a100a202a208a22aa280a2a0a240a4"
760
+ b"95a465a698a60aa820a822a828a8a0a8a8a804a984a986a928aa2aaa91aaaaaa"
761
+ )
762
+
763
+ @classmethod
764
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
765
+ n_blocks = blocks.shape[0]
766
+
767
+ d, rest = np.hsplit(blocks, [2])
768
+ qs, scales = np.hsplit(rest, [2 * QK_K // 8])
769
+
770
+ d = d.view(np.float16).astype(np.float32)
771
+ qs = qs.view(np.uint16)
772
+
773
+ scales = scales.reshape((n_blocks, -1, 1)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2))
774
+ scales = (scales & 0x0F).reshape((n_blocks, -1))
775
+ db = d * (np.float32(0.5) + scales) * np.float32(0.25)
776
+ db = db.reshape((n_blocks, -1, 1, 1))
777
+
778
+ # get the sign indices and unpack the bits
779
+ signs = np.frombuffer(IQ2_XXS.ksigns, dtype=np.uint8).reshape(1, 1, 128)
780
+ signs = np.take_along_axis(signs, (qs >> 9).reshape((n_blocks, -1, 1)), axis=-1)
781
+ signs = signs.reshape((n_blocks, -1, 1)) >> np.array([i for i in range(8)], dtype=np.uint8).reshape((1, 1, 8))
782
+ signs = signs & np.uint8(0x01)
783
+ signs = np.where(signs == 0, np.float32(1), np.float32(-1))
784
+ signs = signs.reshape((n_blocks, -1, 2, 8))
785
+
786
+ assert cls.grid is not None
787
+ grid = np.take_along_axis(cls.grid, (qs & np.uint16(511)).reshape((n_blocks, -1, 1, 1)), axis=-2)
788
+ grid = grid.reshape((n_blocks, -1, 2, 8))
789
+
790
+ return (db * grid * signs).reshape((n_blocks, -1))
791
+
792
+
793
+ class IQ2_S(__Quant, qtype=GGMLQuantizationType.IQ2_S):
794
+ # iq2s_grid, but with each byte of the original packed in 2 bits,
795
+ # by mapping 0x08 to 0, 0x19 to 1, and 0x2b to 2.
796
+ grid_shape = (1024, 8)
797
+ grid_map = (0x08, 0x19, 0x2b)
798
+ grid_hex = (
799
+ b"00000200050008000a0011001400160019002000220025002800410044004600"
800
+ b"490050005200550058006100640066006900800082008500880091009400a000"
801
+ b"a500aa0001010401060109011001120115011801210124014001420145014801"
802
+ b"510154015601590160016501680181018401900192019501a101a40100020202"
803
+ b"050208021102140220022a02410244024602490250025502800285028a029402"
804
+ b"a202010404040604090410041204150418042104240426042904400442044504"
805
+ b"48044a0451045404560459046004620465048104840486048904900495049804"
806
+ b"a104a40400050205050508050a05110514051605190520052505280541054405"
807
+ b"46054905500552055505580561056405800582058505880591059405a0050106"
808
+ b"0406060609061006150640064506480651065406600681068406900600080208"
809
+ b"050808081108140816081908200825082a084108440846084908500852085508"
810
+ b"580861086408800885089408aa08010904091009120915091809210940094509"
811
+ b"480951095409600981099009000a110a140a220a280a2a0a500a990a01100410"
812
+ b"0610091010101210151018102110241026104010421045104810511054105610"
813
+ b"59106010621065106810811084108610901095109810a110a410001102110511"
814
+ b"08110a1111111411161119112011221125112811411144114611491150115211"
815
+ b"5511581161116411801182118511881191119411011204120912101215122112"
816
+ b"2412401245125112541281128412901200140214051408141114141416141914"
817
+ b"2014251428144114441446144914501452145514581461146414801482148514"
818
+ b"881491149414a014011504150615091510151215151518152115241540154215"
819
+ b"4515481551155415601581158415901500160516081611161416201641164416"
820
+ b"50168016aa160118041806180918101815181818211840184218451848185118"
821
+ b"541860188118841800190219051908191119141920194119441950196919a219"
822
+ b"041a101a401a561a00200220052008201120142016201920202025202a204120"
823
+ b"4420502052205520642080208a209420aa200121042110211221152121214021"
824
+ b"4221452151215421602181218421902100220a22222228222a22442250228822"
825
+ b"8a22a82201240424062409241024152418242124242440244224452448245124"
826
+ b"5424602481248424902400250525082511251425202541254425502566258025"
827
+ b"0126042610264026592600280528112814284128442850288a28aa2801290429"
828
+ b"102995290a2a222a642a882a8a2a014004400640094010401240154018401a40"
829
+ b"21402440264040404240454048404a4051405440564059406040624065408140"
830
+ b"8440904095409840a140a4400041024105410841114114411641194120412241"
831
+ b"2541414144414641494150415241554158416141644180418241854188419141"
832
+ b"9441a04101420442104212421542184224424042454248425142544260428142"
833
+ b"844200440244054408440a441144144416441944204422442544284441444444"
834
+ b"46444944504452445544584461446444804482448544884491449444a0440145"
835
+ b"0445064509451045124515451845214524454045424545454845514554456045"
836
+ b"6a4581458445904500460246054608461146144620464146444650468046a546"
837
+ b"0148044809481048124815481848214824484048424845484848514854486048"
838
+ b"84489048004902490549084911491449204941494449504980499649014a044a"
839
+ b"104a404a00500250055008501150145016501950205022502550285041504450"
840
+ b"4650495050505250555058506150645080508250855088509150945001510451"
841
+ b"0651095110511251155118512151245140514251455148515151545160518151"
842
+ b"8451905100520552085211521452205241524452505269528052015404540654"
843
+ b"0954105412541554185421542454405442544554485451545454605481548454"
844
+ b"9054005502550555085511551455205541554455505580550156045610562656"
845
+ b"405600580258055808581158145820584158445850585a588058015904591059"
846
+ b"4059005a195a855aa85a01600460066010601260156018602160246040604560"
847
+ b"4860516054606060846090600061026105610861116114612061416144615061"
848
+ b"806199610462106240625662a162006405640864116414642064416444645064"
849
+ b"806401650465106540654a656865926500669466016804681068656898680069"
850
+ b"2a69426aa16a0080028005800880118014801980208025804180448050805280"
851
+ b"5580588061808080858091809480018104810981108112811581188121812481"
852
+ b"408142814581488151815481818184819081a981008205820a82118214824182"
853
+ b"4482508201840484068409841084128415841884218440844284458448845184"
854
+ b"5484608481848484908400850285058508851185148520854185448550858085"
855
+ b"8a85018604861086298640860088058811881488418844885088a28801890489"
856
+ b"40896589228a588a5a8a828aa28a019004900990109012901590189024904090"
857
+ b"4290459048905190549060908190849090900091059111911491419144915091"
858
+ b"5a910192049210924092a6920094029405940894119414942094419444945094"
859
+ b"8094969401950495109540959895a19500964696649601980498109826984098"
860
+ b"a998009949995299909a00a005a00aa014a022a02aa041a044a050a0a2a0aaa0"
861
+ b"40a165a102a20aa222a228a22aa282a288a28aa2a8a201a404a410a440a489a4"
862
+ b"a4a400a519a551a60aa828a8a2a854a986a908aa0aaa20aa22aa28aa88aaaaaa"
863
+ )
864
+
865
+ @classmethod
866
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
867
+ n_blocks = blocks.shape[0]
868
+
869
+ d, rest = np.hsplit(blocks, [2])
870
+ qs, rest = np.hsplit(rest, [QK_K // 8])
871
+ signs, rest = np.hsplit(rest, [QK_K // 8])
872
+ qh, scales = np.hsplit(rest, [QK_K // 32])
873
+
874
+ d = d.view(np.float16).astype(np.float32)
875
+
876
+ scales = scales.reshape((n_blocks, -1, 1)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2))
877
+ scales = (scales & 0x0F).reshape((n_blocks, -1))
878
+ db = d * (np.float32(0.5) + scales) * np.float32(0.25)
879
+ db = db.reshape((n_blocks, -1, 1, 1))
880
+
881
+ # unpack the sign bits
882
+ signs = signs.reshape((n_blocks, -1, 1)) >> np.array([i for i in range(8)], dtype=np.uint8).reshape((1, 1, 8))
883
+ signs = signs & np.uint8(0x01)
884
+ signs = np.where(signs == 0, np.float32(1), np.float32(-1))
885
+ signs = signs.reshape((n_blocks, -1, 2, 8))
886
+
887
+ qh = qh.reshape((n_blocks, -1, 1)) >> np.array([0, 2, 4, 6], dtype=np.uint8).reshape((1, 1, 4))
888
+ qs = qs.astype(np.uint16) | ((qh & 0x03).astype(np.uint16) << 8).reshape((n_blocks, -1))
889
+
890
+ assert cls.grid is not None
891
+ grid = np.take_along_axis(cls.grid, qs.reshape((n_blocks, -1, 1, 1)), axis=-2)
892
+ grid = grid.reshape((n_blocks, -1, 2, 8))
893
+
894
+ return (db * grid * signs).reshape((n_blocks, -1))
895
+
896
+
897
+ class IQ3_XXS(__Quant, qtype=GGMLQuantizationType.IQ3_XXS):
898
+ grid_shape = (256, 4)
899
+ grid_map = (0x04, 0x0c, 0x14, 0x1c, 0x24, 0x2c, 0x34, 0x3e)
900
+ grid_hex = (
901
+ b"0000020004001100130017002000220031004200730075000101030110011201"
902
+ b"2101250130013201410154017001000202020402110220022202310233023702"
903
+ b"5102570275020103070310031203250370031304370444045704730475040105"
904
+ b"0705320552053506640610071407160743076107011003101010121021102310"
905
+ b"3010321034104710501000110211111120112211011203121012121221123012"
906
+ b"7212001302132013311346136613011405145014201524154615711505162217"
907
+ b"4017002002201120132020202220262031204220012103210521102112212121"
908
+ b"3021632167217021002202221122172220222222372240225522012310231423"
909
+ b"7023742335245324032527254125742501270327162745270130103012302130"
910
+ b"2330503065307230003102312031313144314631013203321032253252327232"
911
+ b"1133333330344734723400350635223555351436363663363337603704401740"
912
+ b"3540374053405740744120423742404260426642074345430444514464442545"
913
+ b"4345704505471047124730471250415070500051065126515551145232527252"
914
+ b"0253535310542354275472540255315550562457425724604460466064602161"
915
+ b"6161176264623063366344640565526533660367216703700570077010703270"
916
+ b"5270267140711272457252720073157333736073217441740075027524753076"
917
+ )
918
+
919
+ @classmethod
920
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
921
+ n_blocks = blocks.shape[0]
922
+
923
+ d, rest = np.hsplit(blocks, [2])
924
+ qs, scales = np.hsplit(rest, [QK_K // 4])
925
+
926
+ d = d.view(np.float16).astype(np.float32)
927
+ scales = scales.view(np.uint32)
928
+
929
+ db = d * (np.float32(0.5) + (scales >> 28).astype(np.float32)) * np.float32(0.5)
930
+ db = db.reshape((n_blocks, -1, 1, 1))
931
+
932
+ # get the sign indices and unpack the bits
933
+ signs = scales.reshape((n_blocks, -1, 1)) >> np.array([0, 7, 14, 21], dtype=np.uint32).reshape((1, 1, 4))
934
+ ksigns = np.frombuffer(IQ2_XXS.ksigns, dtype=np.uint8).reshape((1, 1, 1, 128))
935
+ signs = (signs & np.uint32(0x7F)).reshape((n_blocks, -1, 4, 1))
936
+ signs = np.take_along_axis(ksigns, signs, axis=-1)
937
+ signs = signs.reshape((n_blocks, -1, 4, 1)) >> np.array([i for i in range(8)], dtype=np.uint8).reshape((1, 1, 1, 8))
938
+ signs = signs & np.uint8(0x01)
939
+ signs = np.where(signs == 0, np.float32(1), np.float32(-1))
940
+ signs = signs.reshape((n_blocks, -1, 4, 8))
941
+
942
+ assert cls.grid is not None
943
+ grid = np.take_along_axis(cls.grid, qs.reshape((n_blocks, -1, 1, 1)), axis=-2)
944
+ grid = grid.reshape((n_blocks, -1, 4, 8))
945
+
946
+ return (db * grid * signs).reshape((n_blocks, -1))
947
+
948
+
949
+ class IQ3_S(__Quant, qtype=GGMLQuantizationType.IQ3_S):
950
+ grid_shape = (512, 4)
951
+ grid_map = (0x01, 0x03, 0x05, 0x07, 0x09, 0x0b, 0x0d, 0x0f)
952
+ grid_hex = (
953
+ b"0000010002000500070010001100120014001600200021002500330040004200"
954
+ b"4500470051005300600062007100740077000001010102010401100111011501"
955
+ b"2001230127013101350144016101650172010002010205020702100213021602"
956
+ b"2102250230023402420245024702510253027002730203031103150320032203"
957
+ b"3103330336034403500352036703710375030004130417042104240432044004"
958
+ b"4304510470040205040520052205260533054105450547056605730506061106"
959
+ b"1306310652067106000702070407200722072607330750075407001001100210"
960
+ b"0410101011101310151017102010221031103410361054105610611072100011"
961
+ b"0111031106111011141121113011331141115011521170117611001212121512"
962
+ b"1712201224123212401243125512601272120113041307131013131321132713"
963
+ b"3013341341136213701303140514121414143114331442144614501454140115"
964
+ b"1015131521153015321551152016241627164416461601170317101712172117"
965
+ b"3517411762177017002001200320052007201020122014201620212023202720"
966
+ b"3020322041204320452050205220672070207320752000210221102113211721"
967
+ b"2221252131213421422151210122042207222122232230223722412253225722"
968
+ b"7122742200230223052311232223242331233323422350236623012407242024"
969
+ b"2324322435244124722475240425112522253725402553257025002602260726"
970
+ b"2126552661260527112726273027432750270230113013301530173022303130"
971
+ b"3330353042304430473051306330713001310331053114312131233140316031"
972
+ b"7231763100321232203232323432503201331033143321332333273330334133"
973
+ b"4333473355337333033411341634223431345234603464340135103512352535"
974
+ b"3235443556357335163641360137033720372237353700400440124020402440"
975
+ b"2740324041405040704002410741114113412241304135414341514155410142"
976
+ b"0342104215422142334240425742624270420443114313432043224331433543"
977
+ b"0044024424443744404471440545074521456245134634466046104715473047"
978
+ b"4347514702501050145022504050445047505250665074500151035105511251"
979
+ b"2151325172510052115223523052365253520253075310532753445351536553"
980
+ b"7353015404542054325446541255265551555355425602570457225711601360"
981
+ b"1560316033606060006120612761646112623462426255626262706200631463"
982
+ b"2163406325644364626400650365346560650566406611671367007004700770"
983
+ b"2070227036704070547062700271117124714371457101720472107216722172"
984
+ b"3072517202733273357353730174057413742074507422754275027631760077"
985
+ )
986
+
987
+ @classmethod
988
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
989
+ n_blocks = blocks.shape[0]
990
+
991
+ d, rest = np.hsplit(blocks, [2])
992
+ qs, rest = np.hsplit(rest, [QK_K // 4])
993
+ qh, rest = np.hsplit(rest, [QK_K // 32])
994
+ signs, scales = np.hsplit(rest, [QK_K // 8])
995
+
996
+ d = d.view(np.float16).astype(np.float32)
997
+
998
+ scales = scales.reshape((n_blocks, -1, 1)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2))
999
+ scales = (scales & 0x0F).reshape((n_blocks, -1))
1000
+ db = d * (1 + 2 * scales)
1001
+ db = db.reshape((n_blocks, -1, 1, 1))
1002
+
1003
+ # unpack the sign bits
1004
+ signs = signs.reshape((n_blocks, -1, 1)) >> np.array([i for i in range(8)], dtype=np.uint8).reshape((1, 1, 8))
1005
+ signs = signs & np.uint8(0x01)
1006
+ signs = np.where(signs == 0, np.float32(1), np.float32(-1))
1007
+ signs = signs.reshape((n_blocks, -1, 4, 8))
1008
+
1009
+ qh = qh.reshape((n_blocks, -1, 1)) >> np.array([i for i in range(8)], dtype=np.uint8)
1010
+ qh = (qh & 0x01).astype(np.uint16).reshape((n_blocks, -1))
1011
+ qs = qs.astype(np.uint16) | (qh << 8)
1012
+
1013
+ assert cls.grid is not None
1014
+ grid = np.take_along_axis(cls.grid, qs.reshape((n_blocks, -1, 1, 1)), axis=-2)
1015
+ grid = grid.reshape((n_blocks, -1, 4, 8))
1016
+
1017
+ return (db * grid * signs).reshape((n_blocks, -1))
1018
+
1019
+
1020
+ class IQ1_S(__Quant, qtype=GGMLQuantizationType.IQ1_S):
1021
+ # iq1s_grid, with each byte packed into 2 bits
1022
+ # -1, 0, 1 <=> 0, 1, 2
1023
+ grid_shape = (2048, 8)
1024
+ grid_map = (-1, 0, 1)
1025
+ grid_hex = (
1026
+ b"00000200050008000a00110015002000220028002a0045005100540056006500"
1027
+ b"8000820088008a009500a000a200a800aa000401050111011401160119011a01"
1028
+ b"2501410146014901520155015a0161016401660168018501910194019601a501"
1029
+ b"0002020208020a0215022002220228022a024502510259026402690280028202"
1030
+ b"88028a02910295029902a002a202a802aa021104140416042504410449045504"
1031
+ b"5a046404650491049904a5040105040505050605150518051a05290540054505"
1032
+ b"4a0550055105540555055605590560056205650568056a058105910595059805"
1033
+ b"9a05a105a405a505a605a9051406190641064406500652065506580660066106"
1034
+ b"6606690685069106940699060008020808080a0815082008220828082a084508"
1035
+ b"5108560865088008820888088a089508a008a208a808aa080509110914091909"
1036
+ b"2409250941095009510955096109640969099109940996099909a509000a020a"
1037
+ b"080a0a0a150a200a220a280a2a0a450a510a590a610a650a800a820a850a880a"
1038
+ b"8a0a950aa00aa20aa80aaa0a1010111014101910241025104110441050105510"
1039
+ b"58106110641065106910911094109610a110a510011104110611091110111211"
1040
+ b"1511181121112411291145114a11501151115211541155115611591160116511"
1041
+ b"841192119511a111a41111121412161225124012461249125212551258125a12"
1042
+ b"641266128512911294129612a512011406140914141415141814191421142614"
1043
+ b"41144514461448144a1451145414551456145914621465146814841489149014"
1044
+ b"94149514981499149a14a114a414a514a914021505150a151115141515151615"
1045
+ b"191520152215251528152a154115441545154615511552155415551556155915"
1046
+ b"5a1561156415651566156915801582158415851588158a159015911594159515"
1047
+ b"961599159a15a015a215a51501160416051606161516161618161a1621162616"
1048
+ b"401642164416451648164a165116551656165816591661166416651668166916"
1049
+ b"6a1686168a1692169516a416a916111816182518411844184618491850185518"
1050
+ b"58185a1860186118641866186918851891189418a5181019121915191a192119"
1051
+ b"25194219441945194819511954195519561959195a19601965196a1989199119"
1052
+ b"921995199819a119a619a919091a161a241a261a441a461a491a501a521a551a"
1053
+ b"581a611a661a691a851a911a961a9a1a0020022008200a201520202022202520"
1054
+ b"28202a20452051205920612065208020822088208a209520a020a220a520a820"
1055
+ b"aa2005211121142119212521422144214921552158215a216121642165216621"
1056
+ b"8521902196219921a521012208220a22112215222022222228222a2245225122"
1057
+ b"562259226522812288228a2291229522a022a222a822aa220524142416241924"
1058
+ b"252444244524462449245224552458245a2466248524912494249924a124a524"
1059
+ b"0925152521252925402545254825512554255525592562256525682589259025"
1060
+ b"9425952598259a25a125a425a625a92505261026122619262526412649265526"
1061
+ b"6026612669268426862690269a260028022808280a2815282028222828282a28"
1062
+ b"45285128542865288028822888288a28a028a228a828aa280929112914291929"
1063
+ b"2529462949295229552961296429662969298529902996299929a429a529002a"
1064
+ b"022a082a0a2a202a222a282a2a2a452a512a562a592a652a802a822a882a8a2a"
1065
+ b"952aa02aa22aa82aaa2a054011401640254049405240554058405a4061406440"
1066
+ b"664094409940a140a6400041014104410641094112411541164118411a412141"
1067
+ b"26412941454148414a41514154415541564159415a41654168416a4181418441"
1068
+ b"8641904192419541a041a141a241054211421442164225424142524255425a42"
1069
+ b"6442694289429442a5420144154419442944454448444a445144544455445644"
1070
+ b"61446244654468446a44814486448944904492449544a044a144a94401450245"
1071
+ b"05450a4511451445154516451945204525452a45414544454545464549455045"
1072
+ b"5145544555455645584559456145644565456645694582458445854588459145"
1073
+ b"94459545964599459a45a545a845aa450146054609461446154618461a462146"
1074
+ b"2446294640464246454648465046514652465546564659466246654668468146"
1075
+ b"85468a4694469546a146a446a6460548114815481a4825484248494850485548"
1076
+ b"5848614864486648694885489148944896489948a5480149054906490a491049"
1077
+ b"144915491849214924492649404945494a495149524954495549564959496049"
1078
+ b"6249654966496a49864989499249954996499849a149a449a649a949164a444a"
1079
+ b"464a494a554a584a5a4a644a694a944aa54a0150045005500650095012501550"
1080
+ b"1a50215024502950405045504850515054505550565059506550685086508950"
1081
+ b"95509850a050a150a650a9500551085109510a51115114511551165118511951"
1082
+ b"20512551265128512a5141514451455146514951505151515251545155515651"
1083
+ b"585159515a51615164516551665169518251855191519451955196519951a051"
1084
+ b"a551aa5101520652125215521a5221522452425245524a525152545255525652"
1085
+ b"595262526552855290529252955299529a52a452045405541154145415541654"
1086
+ b"185419542154255428542a54415444544554465449544a545054515454545554"
1087
+ b"5654585459545a54615462546454655466546954805488548a54915494549554"
1088
+ b"96549954a154a454a554aa540155025504550555065509551055115512551455"
1089
+ b"1555165519551a55215524552555265529554055415542554455455546554855"
1090
+ b"4955505551555255545555555655585559555a55605561556455655566556855"
1091
+ b"69556a5581558455855589558a559055915594559555965598559955a155a455"
1092
+ b"a555a655a9550056015602560456065608560956115614561556185619562056"
1093
+ b"2156225624562556265628562956415645564656485649564a56505651565256"
1094
+ b"545655565656585659565a566156645665566956825685568656885689568a56"
1095
+ b"915695569a56a256a556a656a856a95604580558065809581058155818582158"
1096
+ b"2a58455848584a58515854585558565858585958605862586458655882588958"
1097
+ b"9058925895589858a158a9580159025905590a59115914591559165919592559"
1098
+ b"41594459455946594959505951595259545955595659585959595a5961596459"
1099
+ b"655966596959815985598959915994599559965998599959a559045a085a155a"
1100
+ b"1a5a205a255a265a295a455a485a495a515a555a565a585a595a625a655a685a"
1101
+ b"6a5a815a8a5a925a955a965a985a9a5aa15a0560146016601960256044605060"
1102
+ b"5560566058605a60616064606660696081609660a56001610461066109611261"
1103
+ b"15612161226126612961456149615161556156615961656166616a6184618a61"
1104
+ b"92619561a161a661a96111621662196240624162466255625662586260628562"
1105
+ b"91629662a56211641264156416641a6421642664296440644264456448644a64"
1106
+ b"516454645564566459645a646064626465648464856489649064926494649564"
1107
+ b"966498649a64a164a464a964056508650a651165156516651965446545654665"
1108
+ b"496550655165546555655665596561656465656566656965866589658a659165"
1109
+ b"9565966599659a65a265a565a665a86502660966156620662666286629664066"
1110
+ b"456648664a66516654665566566658665a666066656668668066826685668a66"
1111
+ b"9466966698669966a066a466a666aa661668196825684168526855685a686168"
1112
+ b"6968856891689868a66801690469106915692169246926692969406941694569"
1113
+ b"4669486951695469556956695969606965696a69826984698a699569a169a469"
1114
+ b"a569a969116a166a186a416a446a496a506a556a586a5a6a646a656a696a866a"
1115
+ b"946a986a9a6aa66a0080028008800a802080228028802a804580508051805480"
1116
+ b"5680598065808080828088808a809580a080a280a880aa800581118114811681"
1117
+ b"1981258141814481498150815281558156815881598164816681698185818981"
1118
+ b"948196819981a5810082028208820a8215822082228228822a82518254825982"
1119
+ b"65828082828288828a829582a082a282a882aa82148419844184448451845584"
1120
+ b"5a846184648469849484998401850985128515851a8526852985408541854585"
1121
+ b"4885518554855585568559855a856585668568856a8581858485868589859085"
1122
+ b"928595859885a68511861686198625864186448649864a865086558659865a86"
1123
+ b"618666866a86858691869a86a4860088028808880a8815882088228828882a88"
1124
+ b"41884588518854885988658869888088828888888a889588a088a288a888aa88"
1125
+ b"05890689118914891689258941894489468949895089528955895a8961896489"
1126
+ b"858996899989a589008a028a088a0a8a158a208a228a288a2a8a458a518a548a"
1127
+ b"568a808a828a888a8a8a958aa08aa28aa88aaa8a059011901690189019902590"
1128
+ b"419046904990559058905a9069906a9085909190949096909990a59001910491"
1129
+ b"069109911091159118911a912191249126912991409145915091519154915591"
1130
+ b"569159916291659184918691929195919891a191a491a691a991059211921492"
1131
+ b"19922592449246924992509252925592589266926992859294929692a9920194"
1132
+ b"04940694109415941894269440944a9451945494559456945894599460946194"
1133
+ b"62946594849486949294949495949894a194a9940095059508950a9510951195"
1134
+ b"14951595169519952195259529952a9541954495459546954995509551955295"
1135
+ b"549555955695589559955a956195649565956695699581958595889591959295"
1136
+ b"94959595969599959a95a095a295a595a895aa95019604961096159619962096"
1137
+ b"2696299645964896499651965296559656965996659668968296849689968a96"
1138
+ b"929694969596a496a696a9960598169819982598419846985098529855985698"
1139
+ b"5a98649865988598919896989998a59804990699099910991299159918991a99"
1140
+ b"209921992499269940994299459948994a995199549955995699599962996599"
1141
+ b"66996a99819984999099929995999a99a199a699059a159a259a449a469a499a"
1142
+ b"509a559a589a619a859a919a949a959a969a00a002a008a00aa015a020a022a0"
1143
+ b"28a02aa045a051a054a056a059a080a082a088a08aa095a0a0a0a2a0a8a0aaa0"
1144
+ b"05a109a111a114a116a119a11aa146a149a151a155a158a15aa161a164a185a1"
1145
+ b"90a192a196a199a102a208a20aa210a219a222a228a22aa245a251a256a259a2"
1146
+ b"65a280a282a288a28aa295a2a0a2a2a2a8a2aaa219a425a441a444a450a454a4"
1147
+ b"55a458a45aa461a465a466a468a469a485a406a509a510a512a515a518a526a5"
1148
+ b"29a542a545a551a554a555a556a559a565a56aa581a584a585a586a589a592a5"
1149
+ b"95a598a505a611a616a61aa621a625a644a646a64aa652a655a656a658a660a6"
1150
+ b"62a686a690a695a696a699a6a1a6a4a6a6a600a802a808a80aa820a822a828a8"
1151
+ b"2aa851a854a856a859a880a882a888a88aa895a8a0a8a2a8a8a8aaa805a914a9"
1152
+ b"19a921a925a941a950a955a95aa961a966a969a990a996a900aa02aa08aa0aaa"
1153
+ b"20aa22aa28aa2aaa51aa54aa56aa80aa82aa88aa8aaa95aaa0aaa2aaa8aaaaaa"
1154
+ )
1155
+
1156
+ delta = np.float32(0.125)
1157
+
1158
+ @classmethod
1159
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
1160
+ n_blocks = blocks.shape[0]
1161
+
1162
+ d, rest = np.hsplit(blocks, [2])
1163
+ qs, qh = np.hsplit(rest, [QK_K // 8])
1164
+
1165
+ d = d.view(np.float16).astype(np.float32)
1166
+ qh = qh.view(np.uint16)
1167
+
1168
+ dl = d * (2 * ((qh >> 12) & 7) + 1)
1169
+ dl = dl.reshape((n_blocks, -1, 1, 1))
1170
+ delta = np.where((qh & np.uint16(0x8000)) == 0, cls.delta, -cls.delta)
1171
+ delta = delta.reshape((n_blocks, -1, 1, 1))
1172
+
1173
+ qh = qh.reshape((n_blocks, -1, 1)) >> np.array([0, 3, 6, 9], dtype=np.uint16).reshape((1, 1, 4))
1174
+ qs = qs.astype(np.uint16) | ((qh & 7) << 8).reshape((n_blocks, -1))
1175
+
1176
+ assert cls.grid is not None
1177
+ grid = np.take_along_axis(cls.grid, qs.reshape((n_blocks, -1, 1, 1)), axis=-2)
1178
+ grid = grid.reshape((n_blocks, -1, 4, 8))
1179
+
1180
+ return (dl * (grid + delta)).reshape((n_blocks, -1))
1181
+
1182
+
1183
+ class IQ1_M(__Quant, qtype=GGMLQuantizationType.IQ1_M):
1184
+ grid_shape = IQ1_S.grid_shape
1185
+ grid_map = IQ1_S.grid_map
1186
+ grid_hex = IQ1_S.grid_hex
1187
+
1188
+ delta = IQ1_S.delta
1189
+
1190
+ # Okay *this* type is weird. It's the only one which stores the f16 scales in multiple parts.
1191
+ @classmethod
1192
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
1193
+ n_blocks = blocks.shape[0]
1194
+
1195
+ qs, rest = np.hsplit(blocks, [QK_K // 8])
1196
+ qh, scales = np.hsplit(rest, [QK_K // 16])
1197
+
1198
+ # The f16 scale is packed across multiple bytes
1199
+ scales = scales.view(np.uint16)
1200
+ d = (scales.reshape((n_blocks, 4)) & np.uint16(0xF000)) >> np.array([12, 8, 4, 0], dtype=np.uint16).reshape((1, 4))
1201
+ d = d[..., 0] | d[..., 1] | d[..., 2] | d[..., 3]
1202
+ d = d.view(np.float16).astype(np.float32).reshape((n_blocks, 1))
1203
+
1204
+ scales = scales.reshape(n_blocks, -1, 1) >> np.array([0, 3, 6, 9], dtype=np.uint16).reshape((1, 1, 4))
1205
+ scales = (scales & 0x07).reshape((n_blocks, -1))
1206
+ dl = d * (2 * scales + 1)
1207
+ dl = dl.reshape((n_blocks, -1, 2, 1, 1))
1208
+
1209
+ qh = qh.reshape((n_blocks, -1, 1)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2))
1210
+ qs = qs.astype(np.uint16) | ((qh & 0x07).astype(np.uint16) << 8).reshape((n_blocks, -1))
1211
+
1212
+ delta = np.where(qh & 0x08 == 0, cls.delta, -cls.delta)
1213
+ delta = delta.reshape((n_blocks, -1, 2, 2, 1))
1214
+
1215
+ assert cls.grid is not None
1216
+ grid = np.take_along_axis(cls.grid, qs.reshape((n_blocks, -1, 1, 1)), axis=-2)
1217
+ grid = grid.reshape((n_blocks, -1, 2, 2, 8))
1218
+
1219
+ return (dl * (grid + delta)).reshape((n_blocks, -1))
1220
+
1221
+
1222
+ class IQ4_NL(__Quant, qtype=GGMLQuantizationType.IQ4_NL):
1223
+ kvalues = (-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113)
1224
+
1225
+ @classmethod
1226
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
1227
+ n_blocks = blocks.shape[0]
1228
+
1229
+ d, qs = np.hsplit(blocks, [2])
1230
+
1231
+ d = d.view(np.float16).astype(np.float32)
1232
+
1233
+ qs = qs.reshape((n_blocks, -1, 1, cls.block_size // 2)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1))
1234
+
1235
+ qs = (qs & np.uint8(0x0F)).reshape((n_blocks, -1, 1))
1236
+
1237
+ kvalues = np.array(cls.kvalues, dtype=np.int8).reshape(1, 1, 16)
1238
+ qs = np.take_along_axis(kvalues, qs, axis=-1).astype(np.float32).reshape((n_blocks, -1))
1239
+
1240
+ return (d * qs)
1241
+
1242
+
1243
+ class IQ4_XS(__Quant, qtype=GGMLQuantizationType.IQ4_XS):
1244
+ @classmethod
1245
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
1246
+ n_blocks = blocks.shape[0]
1247
+
1248
+ d, rest = np.hsplit(blocks, [2])
1249
+ scales_h, rest = np.hsplit(rest, [2])
1250
+ scales_l, qs = np.hsplit(rest, [QK_K // 64])
1251
+
1252
+ d = d.view(np.float16).astype(np.float32)
1253
+ scales_h = scales_h.view(np.uint16)
1254
+
1255
+ scales_l = scales_l.reshape((n_blocks, -1, 1)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2))
1256
+ scales_h = scales_h.reshape((n_blocks, 1, -1)) >> np.array([2 * i for i in range(QK_K // 32)], dtype=np.uint16).reshape((1, -1, 1))
1257
+ scales_l = scales_l.reshape((n_blocks, -1)) & np.uint8(0x0F)
1258
+ scales_h = scales_h.reshape((n_blocks, -1)).astype(np.uint8) & np.uint8(0x03)
1259
+
1260
+ scales = (scales_l | (scales_h << np.uint8(4))).astype(np.int8) - np.int8(32)
1261
+ dl = (d * scales.astype(np.float32)).reshape((n_blocks, -1, 1))
1262
+
1263
+ qs = qs.reshape((n_blocks, -1, 1, 16)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1))
1264
+ qs = qs.reshape((n_blocks, -1, 32, 1)) & np.uint8(0x0F)
1265
+
1266
+ kvalues = np.array(IQ4_NL.kvalues, dtype=np.int8).reshape((1, 1, 1, -1))
1267
+ qs = np.take_along_axis(kvalues, qs, axis=-1).astype(np.float32).reshape((n_blocks, -1, 32))
1268
+
1269
+ return (dl * qs).reshape((n_blocks, -1))
gguf/tensor_mapping.py ADDED
@@ -0,0 +1,769 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Sequence
4
+
5
+ from .constants import MODEL_ARCH, MODEL_TENSOR, MODEL_TENSORS, TENSOR_NAMES
6
+
7
+
8
+ class TensorNameMap:
9
+ mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = {
10
+ # Token embeddings
11
+ MODEL_TENSOR.TOKEN_EMBD: (
12
+ "gpt_neox.embed_in", # gptneox
13
+ "transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais exaone
14
+ "transformer.word_embeddings", # falcon
15
+ "word_embeddings", # bloom
16
+ "model.embed_tokens", # llama-hf nemotron olmoe
17
+ "tok_embeddings", # llama-pth
18
+ "embeddings.word_embeddings", # bert nomic-bert
19
+ "language_model.embedding.word_embeddings", # persimmon
20
+ "wte", # gpt2
21
+ "transformer.embd.wte", # phi2
22
+ "model.tok_embeddings", # internlm2
23
+ "model.embedding", # mamba-qbert
24
+ "backbone.embedding", # mamba
25
+ "backbone.embeddings", # mamba-hf
26
+ "transformer.in_out_embed", # Grok
27
+ "embedding.word_embeddings", # chatglm
28
+ "transformer.token_embeddings", # openelm
29
+ "shared", # t5
30
+ "rwkv.embeddings", # rwkv
31
+ ),
32
+
33
+ # Token type embeddings
34
+ MODEL_TENSOR.TOKEN_TYPES: (
35
+ "embeddings.token_type_embeddings", # bert nomic-bert
36
+ ),
37
+
38
+ # Normalization of token embeddings
39
+ MODEL_TENSOR.TOKEN_EMBD_NORM: (
40
+ "word_embeddings_layernorm", # bloom
41
+ "embeddings.LayerNorm", # bert
42
+ "emb_ln", # nomic-bert
43
+ "transformer.norm", # openelm
44
+ "rwkv.blocks.0.pre_ln", # rwkv
45
+ ),
46
+
47
+ # Position embeddings
48
+ MODEL_TENSOR.POS_EMBD: (
49
+ "transformer.wpe", # gpt2
50
+ "embeddings.position_embeddings", # bert
51
+ "wpe", # gpt2
52
+ ),
53
+
54
+ # Output
55
+ MODEL_TENSOR.OUTPUT: (
56
+ "embed_out", # gptneox
57
+ "lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone olmoe
58
+ "output", # llama-pth bloom internlm2
59
+ "word_embeddings_for_head", # persimmon
60
+ "lm_head.linear", # phi2
61
+ "output_layer", # chatglm
62
+ "head", # rwkv
63
+ ),
64
+
65
+ # Output norm
66
+ MODEL_TENSOR.OUTPUT_NORM: (
67
+ "gpt_neox.final_layer_norm", # gptneox
68
+ "transformer.ln_f", # gpt2 gpt-j falcon jais exaone
69
+ "model.norm", # llama-hf baichuan internlm2 olmoe
70
+ "norm", # llama-pth
71
+ "transformer.norm_f", # mpt dbrx
72
+ "ln_f", # refact bloom qwen gpt2
73
+ "language_model.encoder.final_layernorm", # persimmon
74
+ "model.final_layernorm", # persimmon
75
+ "lm_head.ln", # phi2
76
+ "model.norm_f", # mamba-qbert
77
+ "backbone.norm_f", # mamba
78
+ "transformer.rms_norm", # Grok
79
+ "encoder.final_layernorm", # chatglm
80
+ "transformer.norm", # openelm
81
+ "model.norm", # nemotron
82
+ "rwkv.ln_out", # rwkv
83
+ ),
84
+
85
+ # Rope frequencies
86
+ MODEL_TENSOR.ROPE_FREQS: (
87
+ "rope.freqs", # llama-pth
88
+ "rotary_pos_emb.inv_freq", # chatglm
89
+ ),
90
+
91
+ MODEL_TENSOR.ROPE_FACTORS_LONG: (),
92
+ MODEL_TENSOR.ROPE_FACTORS_SHORT: (),
93
+ }
94
+
95
+ block_mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = {
96
+ # Attention norm
97
+ MODEL_TENSOR.ATTN_NORM: (
98
+ "gpt_neox.layers.{bid}.input_layernorm", # gptneox
99
+ "transformer.h.{bid}.ln_1", # gpt2 gpt-j refact qwen jais exaone
100
+ "transformer.blocks.{bid}.norm_1", # mpt
101
+ "transformer.h.{bid}.input_layernorm", # falcon7b
102
+ "h.{bid}.input_layernorm", # bloom
103
+ "transformer.h.{bid}.ln_mlp", # falcon40b
104
+ "model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe
105
+ "layers.{bid}.attention_norm", # llama-pth
106
+ "language_model.encoder.layers.{bid}.input_layernorm", # persimmon
107
+ "model.layers.{bid}.ln1", # yi
108
+ "h.{bid}.ln_1", # gpt2
109
+ "transformer.h.{bid}.ln", # phi2
110
+ "model.layers.layers.{bid}.norm", # plamo
111
+ "model.layers.{bid}.attention_norm", # internlm2
112
+ "model.layers.{bid}.norm", # mamba-qbert
113
+ "backbone.layers.{bid}.norm", # mamba
114
+ "transformer.decoder_layer.{bid}.rms_norm", # Grok
115
+ "transformer.blocks.{bid}.norm_attn_norm.norm_1", # dbrx
116
+ "encoder.layers.{bid}.input_layernorm", # chatglm
117
+ "transformer.layers.{bid}.attn_norm", # openelm
118
+ "rwkv.blocks.{bid}.ln1", # rwkv
119
+ ),
120
+
121
+ # Attention norm 2
122
+ MODEL_TENSOR.ATTN_NORM_2: (
123
+ "transformer.h.{bid}.ln_attn", # falcon40b
124
+ "encoder.layer.{bid}.layer_norm_1", # jina-v2-code
125
+ "rwkv.blocks.{bid}.ln2", # rwkv
126
+ ),
127
+
128
+ # Attention query-key-value
129
+ MODEL_TENSOR.ATTN_QKV: (
130
+ "gpt_neox.layers.{bid}.attention.query_key_value", # gptneox
131
+ "transformer.h.{bid}.attn.c_attn", # gpt2 qwen jais
132
+ "transformer.blocks.{bid}.attn.Wqkv", # mpt
133
+ "transformer.blocks.{bid}.norm_attn_norm.attn.Wqkv", # dbrx
134
+ "transformer.h.{bid}.self_attention.query_key_value", # falcon
135
+ "h.{bid}.self_attention.query_key_value", # bloom
136
+ "language_model.encoder.layers.{bid}.self_attention.query_key_value", # persimmon
137
+ "model.layers.{bid}.self_attn.query_key_value", # persimmon
138
+ "h.{bid}.attn.c_attn", # gpt2
139
+ "transformer.h.{bid}.mixer.Wqkv", # phi2
140
+ "encoder.layers.{bid}.attn.Wqkv", # nomic-bert
141
+ "model.layers.{bid}.self_attn.qkv_proj", # phi3
142
+ "encoder.layers.{bid}.self_attention.query_key_value", # chatglm
143
+ "transformer.layers.{bid}.attn.qkv_proj", # openelm
144
+ ),
145
+
146
+ # Attention query
147
+ MODEL_TENSOR.ATTN_Q: (
148
+ "model.layers.{bid}.self_attn.q_proj", # llama-hf nemotron olmoe
149
+ "layers.{bid}.attention.wq", # llama-pth
150
+ "encoder.layer.{bid}.attention.self.query", # bert
151
+ "transformer.h.{bid}.attn.q_proj", # gpt-j
152
+ "model.layers.layers.{bid}.self_attn.q_proj", # plamo
153
+ "model.layers.{bid}.attention.wq", # internlm2
154
+ "transformer.decoder_layer.{bid}.multi_head_attention.query",# Grok
155
+ "transformer.h.{bid}.attn.attention.q_proj", # exaone
156
+ ),
157
+
158
+ # Attention key
159
+ MODEL_TENSOR.ATTN_K: (
160
+ "model.layers.{bid}.self_attn.k_proj", # llama-hf nemotron olmoe
161
+ "layers.{bid}.attention.wk", # llama-pth
162
+ "encoder.layer.{bid}.attention.self.key", # bert
163
+ "transformer.h.{bid}.attn.k_proj", # gpt-j
164
+ "transformer.h.{bid}.attn.k", # refact
165
+ "model.layers.layers.{bid}.self_attn.k_proj", # plamo
166
+ "model.layers.{bid}.attention.wk", # internlm2
167
+ "transformer.decoder_layer.{bid}.multi_head_attention.key",# Grok
168
+ "transformer.h.{bid}.attn.attention.k_proj", # exaone
169
+ ),
170
+
171
+ # Attention value
172
+ MODEL_TENSOR.ATTN_V: (
173
+ "model.layers.{bid}.self_attn.v_proj", # llama-hf nemotron olmoe
174
+ "layers.{bid}.attention.wv", # llama-pth
175
+ "encoder.layer.{bid}.attention.self.value", # bert
176
+ "transformer.h.{bid}.attn.v_proj", # gpt-j
177
+ "transformer.h.{bid}.attn.v", # refact
178
+ "model.layers.layers.{bid}.self_attn.v_proj", # plamo
179
+ "model.layers.{bid}.attention.wv", # internlm2
180
+ "transformer.decoder_layer.{bid}.multi_head_attention.value",# Grok
181
+ "transformer.h.{bid}.attn.attention.v_proj", # exaone
182
+ ),
183
+
184
+ # Attention output
185
+ MODEL_TENSOR.ATTN_OUT: (
186
+ "gpt_neox.layers.{bid}.attention.dense", # gptneox
187
+ "transformer.h.{bid}.attn.c_proj", # gpt2 refact qwen jais
188
+ "transformer.blocks.{bid}.attn.out_proj", # mpt
189
+ "transformer.h.{bid}.self_attention.dense", # falcon
190
+ "h.{bid}.self_attention.dense", # bloom
191
+ "model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron olmoe
192
+ "layers.{bid}.attention.wo", # llama-pth
193
+ "encoder.layer.{bid}.attention.output.dense", # bert
194
+ "transformer.h.{bid}.attn.out_proj", # gpt-j
195
+ "language_model.encoder.layers.{bid}.self_attention.dense", # persimmon
196
+ "model.layers.{bid}.self_attn.dense", # persimmon
197
+ "h.{bid}.attn.c_proj", # gpt2
198
+ "transformer.h.{bid}.mixer.out_proj", # phi2
199
+ "model.layers.layers.{bid}.self_attn.o_proj", # plamo
200
+ "model.layers.{bid}.attention.wo", # internlm2
201
+ "encoder.layers.{bid}.attn.out_proj", # nomic-bert
202
+ "transformer.decoder_layer.{bid}.multi_head_attention.linear", # Grok
203
+ "transformer.blocks.{bid}.norm_attn_norm.attn.out_proj", # dbrx
204
+ "encoder.layers.{bid}.self_attention.dense", # chatglm
205
+ "transformer.layers.{bid}.attn.out_proj", # openelm
206
+ "transformer.h.{bid}.attn.attention.out_proj", # exaone
207
+ ),
208
+
209
+ # Attention output norm
210
+ MODEL_TENSOR.ATTN_OUT_NORM: (
211
+ "encoder.layer.{bid}.attention.output.LayerNorm", # bert
212
+ "encoder.layers.{bid}.norm1", # nomic-bert
213
+ "transformer.decoder_layer.{bid}.rms_norm_1", # Grok
214
+ "transformer.blocks.{bid}.norm_attn_norm.norm_2", # dbrx
215
+ ),
216
+
217
+ MODEL_TENSOR.ATTN_POST_NORM: (
218
+ "model.layers.{bid}.post_attention_layernorm", # gemma2
219
+ ),
220
+
221
+ # Rotary embeddings
222
+ MODEL_TENSOR.ATTN_ROT_EMBD: (
223
+ "model.layers.{bid}.self_attn.rotary_emb.inv_freq", # llama-hf
224
+ "layers.{bid}.attention.inner_attention.rope.freqs", # llama-pth
225
+ "model.layers.layers.{bid}.self_attn.rotary_emb.inv_freq", # plamo
226
+ "transformer.h.{bid}.attn.rotary_emb.inv_freq", # codeshell
227
+ ),
228
+
229
+ # Feed-forward norm
230
+ MODEL_TENSOR.FFN_NORM: (
231
+ "gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox
232
+ "transformer.h.{bid}.ln_2", # gpt2 refact qwen jais exaone
233
+ "h.{bid}.post_attention_layernorm", # bloom
234
+ "transformer.blocks.{bid}.norm_2", # mpt
235
+ "model.layers.{bid}.post_attention_layernorm", # llama-hf nemotron olmoe
236
+ "layers.{bid}.ffn_norm", # llama-pth
237
+ "language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon
238
+ "model.layers.{bid}.ln2", # yi
239
+ "h.{bid}.ln_2", # gpt2
240
+ "model.layers.{bid}.ffn_norm", # internlm2
241
+ "transformer.decoder_layer.{bid}.rms_norm_2", # Grok
242
+ "encoder.layers.{bid}.post_attention_layernorm", # chatglm
243
+ "transformer.layers.{bid}.ffn_norm", # openelm
244
+ ),
245
+
246
+ # Post feed-forward norm
247
+ MODEL_TENSOR.FFN_PRE_NORM: (
248
+ "model.layers.{bid}.pre_feedforward_layernorm", # gemma2
249
+ ),
250
+
251
+ # Post feed-forward norm
252
+ MODEL_TENSOR.FFN_POST_NORM: (
253
+ "model.layers.{bid}.post_feedforward_layernorm", # gemma2
254
+ ),
255
+
256
+ MODEL_TENSOR.FFN_GATE_INP: (
257
+ "layers.{bid}.feed_forward.gate", # mixtral
258
+ "model.layers.{bid}.block_sparse_moe.gate", # mixtral
259
+ "model.layers.{bid}.mlp.gate", # qwen2moe olmoe
260
+ "transformer.decoder_layer.{bid}.router", # Grok
261
+ "transformer.blocks.{bid}.ffn.router.layer", # dbrx
262
+ "model.layers.{bid}.block_sparse_moe.router.layer", # granitemoe
263
+ ),
264
+
265
+ MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
266
+ "model.layers.{bid}.mlp.shared_expert_gate", # qwen2moe
267
+ ),
268
+
269
+ # Feed-forward up
270
+ MODEL_TENSOR.FFN_UP: (
271
+ "gpt_neox.layers.{bid}.mlp.dense_h_to_4h", # gptneox
272
+ "transformer.h.{bid}.mlp.c_fc", # gpt2 jais
273
+ "transformer.blocks.{bid}.ffn.up_proj", # mpt
274
+ "transformer.h.{bid}.mlp.dense_h_to_4h", # falcon
275
+ "h.{bid}.mlp.dense_h_to_4h", # bloom
276
+ "model.layers.{bid}.mlp.up_proj", # llama-hf refact nemotron
277
+ "layers.{bid}.feed_forward.w3", # llama-pth
278
+ "encoder.layer.{bid}.intermediate.dense", # bert
279
+ "transformer.h.{bid}.mlp.fc_in", # gpt-j
280
+ "transformer.h.{bid}.mlp.linear_3", # refact
281
+ "language_model.encoder.layers.{bid}.mlp.dense_h_to_4h", # persimmon
282
+ "model.layers.{bid}.mlp.dense_h_to_4h", # persimmon
283
+ "transformer.h.{bid}.mlp.w1", # qwen
284
+ "h.{bid}.mlp.c_fc", # gpt2
285
+ "transformer.h.{bid}.mlp.fc1", # phi2
286
+ "model.layers.{bid}.mlp.fc1", # phi2
287
+ "model.layers.{bid}.mlp.gate_up_proj", # phi3
288
+ "model.layers.layers.{bid}.mlp.up_proj", # plamo
289
+ "model.layers.{bid}.feed_forward.w3", # internlm2
290
+ "encoder.layers.{bid}.mlp.fc11", # nomic-bert
291
+ "model.layers.{bid}.mlp.c_fc", # starcoder2
292
+ "encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2
293
+ "model.layers.{bid}.residual_mlp.w3", # arctic
294
+ "encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm
295
+ "transformer.h.{bid}.mlp.c_fc_1", # exaone
296
+ ),
297
+
298
+ MODEL_TENSOR.FFN_UP_EXP: (
299
+ "layers.{bid}.feed_forward.experts.w3", # mixtral (merged)
300
+ "transformer.decoder_layer.{bid}.moe.linear_v", # Grok (merged)
301
+ "transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx
302
+ "model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged)
303
+ ),
304
+
305
+ MODEL_TENSOR.FFN_UP_SHEXP: (
306
+ "model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe
307
+ "model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek2
308
+ ),
309
+
310
+ # AWQ-activation gate
311
+ MODEL_TENSOR.FFN_ACT: (
312
+ "transformer.blocks.{bid}.ffn.act", # mpt
313
+ ),
314
+
315
+ # Feed-forward gate
316
+ MODEL_TENSOR.FFN_GATE: (
317
+ "model.layers.{bid}.mlp.gate_proj", # llama-hf refact
318
+ "layers.{bid}.feed_forward.w1", # llama-pth
319
+ "transformer.h.{bid}.mlp.w2", # qwen
320
+ "transformer.h.{bid}.mlp.c_fc2", # jais
321
+ "model.layers.layers.{bid}.mlp.gate_proj", # plamo
322
+ "model.layers.{bid}.feed_forward.w1", # internlm2
323
+ "encoder.layers.{bid}.mlp.fc12", # nomic-bert
324
+ "encoder.layer.{bid}.mlp.gated_layers_w", # jina-bert-v2
325
+ "transformer.h.{bid}.mlp.linear_1", # refact
326
+ "model.layers.{bid}.residual_mlp.w1", # arctic
327
+ "transformer.h.{bid}.mlp.c_fc_0", # exaone
328
+ ),
329
+
330
+ MODEL_TENSOR.FFN_GATE_EXP: (
331
+ "layers.{bid}.feed_forward.experts.w1", # mixtral (merged)
332
+ "transformer.decoder_layer.{bid}.moe.linear", # Grok (merged)
333
+ "transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx
334
+ "model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged)
335
+ ),
336
+
337
+ MODEL_TENSOR.FFN_GATE_SHEXP: (
338
+ "model.layers.{bid}.mlp.shared_expert.gate_proj", # qwen2moe
339
+ "model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek2
340
+ ),
341
+
342
+ # Feed-forward down
343
+ MODEL_TENSOR.FFN_DOWN: (
344
+ "gpt_neox.layers.{bid}.mlp.dense_4h_to_h", # gptneox
345
+ "transformer.h.{bid}.mlp.c_proj", # gpt2 refact qwen jais
346
+ "transformer.blocks.{bid}.ffn.down_proj", # mpt
347
+ "transformer.h.{bid}.mlp.dense_4h_to_h", # falcon
348
+ "h.{bid}.mlp.dense_4h_to_h", # bloom
349
+ "model.layers.{bid}.mlp.down_proj", # llama-hf nemotron
350
+ "layers.{bid}.feed_forward.w2", # llama-pth
351
+ "encoder.layer.{bid}.output.dense", # bert
352
+ "transformer.h.{bid}.mlp.fc_out", # gpt-j
353
+ "language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon
354
+ "model.layers.{bid}.mlp.dense_4h_to_h", # persimmon
355
+ "h.{bid}.mlp.c_proj", # gpt2
356
+ "transformer.h.{bid}.mlp.fc2", # phi2
357
+ "model.layers.{bid}.mlp.fc2", # phi2
358
+ "model.layers.layers.{bid}.mlp.down_proj", # plamo
359
+ "model.layers.{bid}.feed_forward.w2", # internlm2
360
+ "encoder.layers.{bid}.mlp.fc2", # nomic-bert
361
+ "model.layers.{bid}.mlp.c_proj", # starcoder2
362
+ "encoder.layer.{bid}.mlp.wo", # jina-bert-v2
363
+ "transformer.layers.{bid}.ffn.proj_2", # openelm
364
+ "model.layers.{bid}.residual_mlp.w2", # arctic
365
+ "encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2
366
+ "encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm
367
+ "model.layers.h.{bid}.mlp.c_proj", # exaone
368
+ ),
369
+
370
+ MODEL_TENSOR.FFN_DOWN_EXP: (
371
+ "layers.{bid}.feed_forward.experts.w2", # mixtral (merged)
372
+ "transformer.decoder_layer.{bid}.moe.linear_1", # Grok (merged)
373
+ "transformer.blocks.{bid}.ffn.experts.mlp.w2", # dbrx
374
+ "model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged)
375
+ "model.layers.{bid}.block_sparse_moe.output_linear", # granitemoe
376
+ ),
377
+
378
+ MODEL_TENSOR.FFN_DOWN_SHEXP: (
379
+ "model.layers.{bid}.mlp.shared_expert.down_proj", # qwen2moe
380
+ "model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek2
381
+ ),
382
+
383
+ MODEL_TENSOR.ATTN_Q_NORM: (
384
+ "language_model.encoder.layers.{bid}.self_attention.q_layernorm",
385
+ "model.layers.{bid}.self_attn.q_layernorm", # persimmon
386
+ "model.layers.{bid}.self_attn.q_norm", # cohere olmoe chameleon
387
+ "transformer.blocks.{bid}.attn.q_ln", # sea-lion
388
+ "encoder.layer.{bid}.attention.self.layer_norm_q", # jina-bert-v2
389
+ "transformer.layers.{bid}.attn.q_norm", # openelm
390
+ ),
391
+
392
+ MODEL_TENSOR.ATTN_K_NORM: (
393
+ "language_model.encoder.layers.{bid}.self_attention.k_layernorm",
394
+ "model.layers.{bid}.self_attn.k_layernorm", # persimmon
395
+ "model.layers.{bid}.self_attn.k_norm", # cohere olmoe chameleon
396
+ "transformer.blocks.{bid}.attn.k_ln", # sea-lion
397
+ "encoder.layer.{bid}.attention.self.layer_norm_k", # jina-bert-v2
398
+ "transformer.layers.{bid}.attn.k_norm", # openelm
399
+ ),
400
+
401
+ MODEL_TENSOR.ROPE_FREQS: (
402
+ "language_model.encoder.layers.{bid}.self_attention.rotary_emb.inv_freq", # persimmon
403
+ ),
404
+
405
+ MODEL_TENSOR.LAYER_OUT_NORM: (
406
+ "encoder.layer.{bid}.output.LayerNorm", # bert
407
+ "encoder.layers.{bid}.norm2", # nomic-bert
408
+ "transformer.decoder_layer.{bid}.rms_norm_3", # Grok
409
+ "encoder.layer.{bid}.mlp.layernorm", # jina-bert-v2
410
+ "encoder.layer.{bid}.layer_norm_2" # jina-v2-code
411
+ ),
412
+
413
+ MODEL_TENSOR.SSM_IN: (
414
+ "model.layers.{bid}.in_proj",
415
+ "backbone.layers.{bid}.mixer.in_proj",
416
+ ),
417
+
418
+ MODEL_TENSOR.SSM_CONV1D: (
419
+ "model.layers.{bid}.conv1d",
420
+ "backbone.layers.{bid}.mixer.conv1d",
421
+ ),
422
+
423
+ MODEL_TENSOR.SSM_X: (
424
+ "model.layers.{bid}.x_proj",
425
+ "backbone.layers.{bid}.mixer.x_proj",
426
+ ),
427
+
428
+ MODEL_TENSOR.SSM_DT: (
429
+ "model.layers.{bid}.dt_proj",
430
+ "backbone.layers.{bid}.mixer.dt_proj",
431
+ ),
432
+
433
+ MODEL_TENSOR.SSM_A: (
434
+ "model.layers.{bid}.A_log",
435
+ "backbone.layers.{bid}.mixer.A_log",
436
+ ),
437
+
438
+ MODEL_TENSOR.SSM_D: (
439
+ "model.layers.{bid}.D",
440
+ "backbone.layers.{bid}.mixer.D",
441
+ ),
442
+
443
+ MODEL_TENSOR.SSM_OUT: (
444
+ "model.layers.{bid}.out_proj",
445
+ "backbone.layers.{bid}.mixer.out_proj",
446
+ ),
447
+
448
+ MODEL_TENSOR.TIME_MIX_W1: (
449
+ "rwkv.blocks.{bid}.attention.time_maa_w1", # rwkv v6
450
+ ),
451
+
452
+ MODEL_TENSOR.TIME_MIX_W2: (
453
+ "rwkv.blocks.{bid}.attention.time_maa_w2", # rwkv v6
454
+ ),
455
+
456
+ MODEL_TENSOR.TIME_MIX_LERP_X: (
457
+ "rwkv.blocks.{bid}.attention.time_maa_x", # rwkv v6
458
+ ),
459
+
460
+ MODEL_TENSOR.TIME_MIX_LERP_K: (
461
+ "rwkv.blocks.{bid}.attention.time_maa_k", # rwkv v6
462
+ ),
463
+
464
+ MODEL_TENSOR.TIME_MIX_LERP_V: (
465
+ "rwkv.blocks.{bid}.attention.time_maa_v", # rwkv v6
466
+ ),
467
+
468
+ MODEL_TENSOR.TIME_MIX_LERP_R: (
469
+ "rwkv.blocks.{bid}.attention.time_maa_r", # rwkv v6
470
+ ),
471
+
472
+ MODEL_TENSOR.TIME_MIX_LERP_G: (
473
+ "rwkv.blocks.{bid}.attention.time_maa_g", # rwkv v6
474
+ ),
475
+
476
+ MODEL_TENSOR.TIME_MIX_LERP_W: (
477
+ "rwkv.blocks.{bid}.attention.time_maa_w", # rwkv v6
478
+ ),
479
+
480
+ MODEL_TENSOR.TIME_MIX_FIRST: (
481
+ "rwkv.blocks.{bid}.attention.time_faaaa", # rwkv v6
482
+ ),
483
+
484
+ MODEL_TENSOR.TIME_MIX_DECAY: (
485
+ "rwkv.blocks.{bid}.attention.time_decay", # rwkv v6
486
+ ),
487
+
488
+ MODEL_TENSOR.TIME_MIX_DECAY_W1: (
489
+ "rwkv.blocks.{bid}.attention.time_decay_w1", # rwkv v6
490
+ ),
491
+
492
+ MODEL_TENSOR.TIME_MIX_DECAY_W2: (
493
+ "rwkv.blocks.{bid}.attention.time_decay_w2", # rwkv v6
494
+ ),
495
+
496
+ MODEL_TENSOR.TIME_MIX_KEY: (
497
+ "rwkv.blocks.{bid}.attention.key", # rwkv
498
+ ),
499
+
500
+ MODEL_TENSOR.TIME_MIX_VALUE: (
501
+ "rwkv.blocks.{bid}.attention.value", # rwkv
502
+ ),
503
+
504
+ MODEL_TENSOR.TIME_MIX_RECEPTANCE: (
505
+ "rwkv.blocks.{bid}.attention.receptance", # rwkv
506
+ ),
507
+
508
+ MODEL_TENSOR.TIME_MIX_GATE: (
509
+ "rwkv.blocks.{bid}.attention.gate", # rwkv
510
+ ),
511
+
512
+ MODEL_TENSOR.TIME_MIX_LN: (
513
+ "rwkv.blocks.{bid}.attention.ln_x", # rwkv
514
+ ),
515
+
516
+ MODEL_TENSOR.TIME_MIX_OUTPUT: (
517
+ "rwkv.blocks.{bid}.attention.output", # rwkv
518
+ ),
519
+
520
+ MODEL_TENSOR.CHANNEL_MIX_LERP_K: (
521
+ "rwkv.blocks.{bid}.feed_forward.time_maa_k", # rwkv v6
522
+ ),
523
+
524
+ MODEL_TENSOR.CHANNEL_MIX_LERP_R: (
525
+ "rwkv.blocks.{bid}.feed_forward.time_maa_r", # rwkv v6
526
+ ),
527
+
528
+ MODEL_TENSOR.CHANNEL_MIX_KEY: (
529
+ "rwkv.blocks.{bid}.feed_forward.key", # rwkv
530
+ ),
531
+
532
+ MODEL_TENSOR.CHANNEL_MIX_RECEPTANCE: (
533
+ "rwkv.blocks.{bid}.feed_forward.receptance", # rwkv
534
+ ),
535
+
536
+ MODEL_TENSOR.CHANNEL_MIX_VALUE: (
537
+ "rwkv.blocks.{bid}.feed_forward.value", # rwkv
538
+ ),
539
+
540
+ MODEL_TENSOR.ATTN_Q_A: (
541
+ "model.layers.{bid}.self_attn.q_a_proj", # deepseek2
542
+ ),
543
+
544
+ MODEL_TENSOR.ATTN_Q_B: (
545
+ "model.layers.{bid}.self_attn.q_b_proj", # deepseek2
546
+ ),
547
+
548
+ MODEL_TENSOR.ATTN_KV_A_MQA: (
549
+ "model.layers.{bid}.self_attn.kv_a_proj_with_mqa", # deepseek2
550
+ ),
551
+
552
+ MODEL_TENSOR.ATTN_KV_B: (
553
+ "model.layers.{bid}.self_attn.kv_b_proj", # deepseek2
554
+ ),
555
+
556
+ MODEL_TENSOR.ATTN_Q_A_NORM: (
557
+ "model.layers.{bid}.self_attn.q_a_layernorm", # deepseek2
558
+ ),
559
+
560
+ MODEL_TENSOR.ATTN_KV_A_NORM: (
561
+ "model.layers.{bid}.self_attn.kv_a_layernorm", # deepseek2
562
+ ),
563
+
564
+ MODEL_TENSOR.ATTN_SUB_NORM: (
565
+ "model.layers.{bid}.self_attn.inner_attn_ln", # bitnet
566
+ ),
567
+
568
+ MODEL_TENSOR.FFN_SUB_NORM: (
569
+ "model.layers.{bid}.mlp.ffn_layernorm", # bitnet
570
+ ),
571
+
572
+ MODEL_TENSOR.DEC_ATTN_NORM: (
573
+ "decoder.block.{bid}.layer.0.layer_norm", # t5
574
+ ),
575
+
576
+ MODEL_TENSOR.DEC_ATTN_Q: (
577
+ "decoder.block.{bid}.layer.0.SelfAttention.q", # t5
578
+ ),
579
+
580
+ MODEL_TENSOR.DEC_ATTN_K: (
581
+ "decoder.block.{bid}.layer.0.SelfAttention.k", # t5
582
+ ),
583
+
584
+ MODEL_TENSOR.DEC_ATTN_V: (
585
+ "decoder.block.{bid}.layer.0.SelfAttention.v", # t5
586
+ ),
587
+
588
+ MODEL_TENSOR.DEC_ATTN_OUT: (
589
+ "decoder.block.{bid}.layer.0.SelfAttention.o", # t5
590
+ ),
591
+
592
+ MODEL_TENSOR.DEC_ATTN_REL_B: (
593
+ "decoder.block.{bid}.layer.0.SelfAttention.relative_attention_bias", # t5
594
+ ),
595
+
596
+ MODEL_TENSOR.DEC_CROSS_ATTN_NORM: (
597
+ "decoder.block.{bid}.layer.1.layer_norm", # t5
598
+ ),
599
+
600
+ MODEL_TENSOR.DEC_CROSS_ATTN_Q: (
601
+ "decoder.block.{bid}.layer.1.EncDecAttention.q", # t5
602
+ ),
603
+
604
+ MODEL_TENSOR.DEC_CROSS_ATTN_K: (
605
+ "decoder.block.{bid}.layer.1.EncDecAttention.k", # t5
606
+ ),
607
+
608
+ MODEL_TENSOR.DEC_CROSS_ATTN_V: (
609
+ "decoder.block.{bid}.layer.1.EncDecAttention.v", # t5
610
+ ),
611
+
612
+ MODEL_TENSOR.DEC_CROSS_ATTN_OUT: (
613
+ "decoder.block.{bid}.layer.1.EncDecAttention.o", # t5
614
+ ),
615
+
616
+ MODEL_TENSOR.DEC_CROSS_ATTN_REL_B: (
617
+ "decoder.block.{bid}.layer.1.EncDecAttention.relative_attention_bias", # t5
618
+ ),
619
+
620
+ MODEL_TENSOR.DEC_FFN_NORM: (
621
+ "decoder.block.{bid}.layer.2.layer_norm", # t5
622
+ ),
623
+
624
+ MODEL_TENSOR.DEC_FFN_GATE: (
625
+ "decoder.block.{bid}.layer.2.DenseReluDense.wi_0", # flan-t5
626
+ ),
627
+
628
+ MODEL_TENSOR.DEC_FFN_UP: (
629
+ "decoder.block.{bid}.layer.2.DenseReluDense.wi", # t5
630
+ "decoder.block.{bid}.layer.2.DenseReluDense.wi_1", # flan-t5
631
+ ),
632
+
633
+ MODEL_TENSOR.DEC_FFN_DOWN: (
634
+ "decoder.block.{bid}.layer.2.DenseReluDense.wo", # t5
635
+ ),
636
+
637
+ MODEL_TENSOR.DEC_OUTPUT_NORM: (
638
+ "decoder.final_layer_norm", # t5
639
+ ),
640
+
641
+ MODEL_TENSOR.ENC_ATTN_NORM: (
642
+ "encoder.block.{bid}.layer.0.layer_norm", # t5
643
+ ),
644
+
645
+ MODEL_TENSOR.ENC_ATTN_Q: (
646
+ "encoder.block.{bid}.layer.0.SelfAttention.q", # t5
647
+ ),
648
+
649
+ MODEL_TENSOR.ENC_ATTN_K: (
650
+ "encoder.block.{bid}.layer.0.SelfAttention.k", # t5
651
+ ),
652
+
653
+ MODEL_TENSOR.ENC_ATTN_V: (
654
+ "encoder.block.{bid}.layer.0.SelfAttention.v", # t5
655
+ ),
656
+
657
+ MODEL_TENSOR.ENC_ATTN_OUT: (
658
+ "encoder.block.{bid}.layer.0.SelfAttention.o", # t5
659
+ ),
660
+
661
+ MODEL_TENSOR.ENC_ATTN_REL_B: (
662
+ "encoder.block.{bid}.layer.0.SelfAttention.relative_attention_bias", # t5
663
+ ),
664
+
665
+ MODEL_TENSOR.ENC_FFN_NORM: (
666
+ "encoder.block.{bid}.layer.1.layer_norm", # t5
667
+ ),
668
+
669
+ MODEL_TENSOR.ENC_FFN_GATE: (
670
+ "encoder.block.{bid}.layer.1.DenseReluDense.wi_0", # flan-t5
671
+ ),
672
+
673
+ MODEL_TENSOR.ENC_FFN_UP: (
674
+ "encoder.block.{bid}.layer.1.DenseReluDense.wi", # t5
675
+ "encoder.block.{bid}.layer.1.DenseReluDense.wi_1", # flan-t5
676
+ ),
677
+
678
+ MODEL_TENSOR.ENC_FFN_DOWN: (
679
+ "encoder.block.{bid}.layer.1.DenseReluDense.wo", # t5
680
+ ),
681
+
682
+ MODEL_TENSOR.ENC_OUTPUT_NORM: (
683
+ "encoder.final_layer_norm", # t5
684
+ ),
685
+
686
+ MODEL_TENSOR.CLS: (
687
+ "classifier", # jina
688
+ "classifier.dense", # roberta
689
+ ),
690
+
691
+ MODEL_TENSOR.CLS_OUT: (
692
+ "classifier.out_proj", # roberta
693
+ ),
694
+ }
695
+
696
+ # architecture-specific block mappings
697
+ arch_block_mappings_cfg: dict[MODEL_ARCH, dict[MODEL_TENSOR, tuple[str, ...]]] = {
698
+ MODEL_ARCH.ARCTIC: {
699
+ MODEL_TENSOR.FFN_NORM: (
700
+ "model.layers.{bid}.residual_layernorm",
701
+ ),
702
+ MODEL_TENSOR.FFN_NORM_EXP: (
703
+ "model.layers.{bid}.post_attention_layernorm",
704
+ ),
705
+ },
706
+ }
707
+
708
+ mapping: dict[str, tuple[MODEL_TENSOR, str]]
709
+
710
+ def __init__(self, arch: MODEL_ARCH, n_blocks: int):
711
+ self.mapping = {}
712
+ for tensor, keys in self.mappings_cfg.items():
713
+ if tensor not in MODEL_TENSORS[arch]:
714
+ continue
715
+ tensor_name = TENSOR_NAMES[tensor]
716
+ self.mapping[tensor_name] = (tensor, tensor_name)
717
+ for key in keys:
718
+ self.mapping[key] = (tensor, tensor_name)
719
+ if arch in self.arch_block_mappings_cfg:
720
+ self.block_mappings_cfg.update(self.arch_block_mappings_cfg[arch])
721
+ for bid in range(n_blocks):
722
+ for tensor, keys in self.block_mappings_cfg.items():
723
+ if tensor not in MODEL_TENSORS[arch]:
724
+ continue
725
+
726
+ tensor_name = TENSOR_NAMES[tensor].format(bid = bid)
727
+ self.mapping[tensor_name] = (tensor, tensor_name)
728
+ for key in keys:
729
+ key = key.format(bid = bid)
730
+ self.mapping[key] = (tensor, tensor_name)
731
+
732
+ def get_type_and_name(self, key: str, try_suffixes: Sequence[str] = ()) -> tuple[MODEL_TENSOR, str] | None:
733
+ result = self.mapping.get(key)
734
+ if result is not None:
735
+ return result
736
+ for suffix in try_suffixes:
737
+ if key.endswith(suffix):
738
+ result = self.mapping.get(key[:-len(suffix)])
739
+ if result is not None:
740
+ return result[0], result[1] + suffix
741
+ return None
742
+
743
+ def get_name(self, key: str, try_suffixes: Sequence[str] = ()) -> str | None:
744
+ result = self.get_type_and_name(key, try_suffixes = try_suffixes)
745
+ if result is None:
746
+ return None
747
+ return result[1]
748
+
749
+ def get_type(self, key: str, try_suffixes: Sequence[str] = ()) -> MODEL_TENSOR | None:
750
+ result = self.get_type_and_name(key, try_suffixes = try_suffixes)
751
+ if result is None:
752
+ return None
753
+ return result[0]
754
+
755
+ def __getitem__(self, key: str) -> str:
756
+ try:
757
+ return self.mapping[key][1]
758
+ except KeyError:
759
+ raise KeyError(key)
760
+
761
+ def __contains__(self, key: str) -> bool:
762
+ return key in self.mapping
763
+
764
+ def __repr__(self) -> str:
765
+ return repr(self.mapping)
766
+
767
+
768
+ def get_tensor_name_map(arch: MODEL_ARCH, n_blocks: int) -> TensorNameMap:
769
+ return TensorNameMap(arch, n_blocks)
gguf/utility.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Literal
4
+
5
+
6
+ def fill_templated_filename(filename: str, output_type: str | None) -> str:
7
+ # Given a file name fill in any type templates e.g. 'some-model-name.{ftype}.gguf'
8
+ ftype_lowercase: str = output_type.lower() if output_type is not None else ""
9
+ ftype_uppercase: str = output_type.upper() if output_type is not None else ""
10
+ return filename.format(ftype_lowercase,
11
+ outtype=ftype_lowercase, ftype=ftype_lowercase,
12
+ OUTTYPE=ftype_uppercase, FTYPE=ftype_uppercase)
13
+
14
+
15
+ def model_weight_count_rounded_notation(model_params_count: int, min_digits: int = 2) -> str:
16
+ if model_params_count > 1e12 :
17
+ # Trillions Of Parameters
18
+ scaled_model_params = model_params_count * 1e-12
19
+ scale_suffix = "T"
20
+ elif model_params_count > 1e9 :
21
+ # Billions Of Parameters
22
+ scaled_model_params = model_params_count * 1e-9
23
+ scale_suffix = "B"
24
+ elif model_params_count > 1e6 :
25
+ # Millions Of Parameters
26
+ scaled_model_params = model_params_count * 1e-6
27
+ scale_suffix = "M"
28
+ else:
29
+ # Thousands Of Parameters
30
+ scaled_model_params = model_params_count * 1e-3
31
+ scale_suffix = "K"
32
+
33
+ fix = max(min_digits - len(str(round(scaled_model_params)).lstrip('0')), 0)
34
+
35
+ return f"{scaled_model_params:.{fix}f}{scale_suffix}"
36
+
37
+
38
+ def size_label(total_params: int, shared_params: int, expert_params: int, expert_count: int) -> str:
39
+
40
+ if expert_count > 0:
41
+ pretty_size = model_weight_count_rounded_notation(abs(shared_params) + abs(expert_params), min_digits=2)
42
+ size_class = f"{expert_count}x{pretty_size}"
43
+ else:
44
+ size_class = model_weight_count_rounded_notation(abs(total_params), min_digits=2)
45
+
46
+ return size_class
47
+
48
+
49
+ def naming_convention(model_name: str | None, base_name: str | None, finetune_string: str | None, version_string: str | None, size_label: str | None, output_type: str | None, model_type: Literal['vocab', 'LoRA'] | None = None) -> str:
50
+ # Reference: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#gguf-naming-convention
51
+
52
+ if base_name is not None:
53
+ name = base_name.strip().replace(' ', '-').replace('/', '-')
54
+ elif model_name is not None:
55
+ name = model_name.strip().replace(' ', '-').replace('/', '-')
56
+ else:
57
+ name = "ggml-model"
58
+
59
+ parameters = f"-{size_label}" if size_label is not None else ""
60
+
61
+ finetune = f"-{finetune_string.strip().replace(' ', '-')}" if finetune_string is not None else ""
62
+
63
+ version = f"-{version_string.strip().replace(' ', '-')}" if version_string is not None else ""
64
+
65
+ encoding = f"-{output_type.strip().replace(' ', '-').upper()}" if output_type is not None else ""
66
+
67
+ kind = f"-{model_type.strip().replace(' ', '-')}" if model_type is not None else ""
68
+
69
+ return f"{name}{parameters}{finetune}{version}{encoding}{kind}"
gguf/vocab.py ADDED
@@ -0,0 +1,487 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import re
4
+ import logging
5
+ import json
6
+ import os
7
+ from pathlib import Path
8
+ from typing import Any, Callable, Sequence, Mapping, Iterable, Protocol, ClassVar, runtime_checkable
9
+
10
+ from sentencepiece import SentencePieceProcessor
11
+
12
+ import gguf
13
+
14
+ from .gguf_writer import GGUFWriter
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class SpecialVocab:
20
+ merges: list[str]
21
+ add_special_token: dict[str, bool]
22
+ special_token_ids: dict[str, int]
23
+ chat_template: str | Sequence[Mapping[str, str]] | None
24
+
25
+ def __init__(
26
+ self, path: str | os.PathLike[str], load_merges: bool = False,
27
+ special_token_types: Iterable[str] | None = None,
28
+ n_vocab: int | None = None,
29
+ ):
30
+ self.special_token_ids = {}
31
+ self.add_special_token = {}
32
+ self.n_vocab = n_vocab
33
+ self.load_merges = load_merges
34
+ self.merges = []
35
+ self.chat_template = None
36
+ if special_token_types is not None:
37
+ self.special_token_types = special_token_types
38
+ else:
39
+ self.special_token_types = ('bos', 'eos', 'unk', 'sep', 'pad', 'cls', 'mask')
40
+ self._load(Path(path))
41
+
42
+ def __repr__(self) -> str:
43
+ return '<SpecialVocab with {} merges, special tokens {}, add special tokens {}>'.format(
44
+ len(self.merges), self.special_token_ids or "unset", self.add_special_token or "unset",
45
+ )
46
+
47
+ def add_to_gguf(self, gw: GGUFWriter, quiet: bool = False) -> None:
48
+ if self.merges:
49
+ if not quiet:
50
+ logger.info(f'Adding {len(self.merges)} merge(s).')
51
+ gw.add_token_merges(self.merges)
52
+ elif self.load_merges:
53
+ logger.warning('Adding merges requested but no merges found, output may be non-functional.')
54
+ for typ, tokid in self.special_token_ids.items():
55
+ id_handler: Callable[[int], None] | None = getattr(gw, f'add_{typ}_token_id', None)
56
+ if id_handler is None:
57
+ logger.warning(f'No handler for special token type {typ} with id {tokid} - skipping')
58
+ continue
59
+ if not quiet:
60
+ logger.info(f'Setting special token type {typ} to {tokid}')
61
+ id_handler(tokid)
62
+ for typ, value in self.add_special_token.items():
63
+ add_handler: Callable[[bool], None] | None = getattr(gw, f'add_add_{typ}_token', None)
64
+ if add_handler is None:
65
+ logger.warning(f'No handler for add_{typ}_token with value {value} - skipping')
66
+ continue
67
+ if not quiet:
68
+ logger.info(f'Setting add_{typ}_token to {value}')
69
+ add_handler(value)
70
+ if self.chat_template is not None:
71
+ if not quiet:
72
+ logger.info(f'Setting chat_template to {self.chat_template}')
73
+ gw.add_chat_template(self.chat_template)
74
+
75
+ def _load(self, path: Path) -> None:
76
+ self._try_load_from_tokenizer_json(path)
77
+ self._try_load_from_config_json(path)
78
+ if self.load_merges and not self.merges:
79
+ self._try_load_merges_txt(path)
80
+
81
+ def _try_load_merges_txt(self, path: Path) -> bool:
82
+ merges_file = path / 'merges.txt'
83
+ if not merges_file.is_file():
84
+ return False
85
+ with open(merges_file, 'r', encoding = 'utf-8') as fp:
86
+ first_line = next(fp, '').strip()
87
+ if not first_line.startswith('#'):
88
+ fp.seek(0)
89
+ line_num = 0
90
+ else:
91
+ line_num = 1
92
+ merges = []
93
+ for line in fp:
94
+ line_num += 1
95
+ line = line.strip()
96
+ if not line:
97
+ continue
98
+ parts = line.split(None, 3)
99
+ if len(parts) != 2:
100
+ logger.warning(f'{merges_file.name}: Line {line_num}: Entry malformed, ignoring')
101
+ continue
102
+ merges.append(f'{parts[0]} {parts[1]}')
103
+ self.merges = merges
104
+ return True
105
+
106
+ def _set_special_token(self, typ: str, tid: Any) -> None:
107
+ if not isinstance(tid, int):
108
+ return
109
+ if tid < 0:
110
+ raise ValueError(f'invalid value for special token type {typ}: {tid}')
111
+ if self.n_vocab is None or tid < self.n_vocab:
112
+ if typ in self.special_token_ids:
113
+ return
114
+ self.special_token_ids[typ] = tid
115
+ return
116
+ logger.warning(f'Special token type {typ}, id {tid} out of range, must be under {self.n_vocab} - skipping')
117
+
118
+ def _try_load_from_tokenizer_json(self, path: Path) -> bool:
119
+ tokenizer_file = path / 'tokenizer.json'
120
+ if tokenizer_file.is_file():
121
+ with open(tokenizer_file, encoding = 'utf-8') as f:
122
+ tokenizer = json.load(f)
123
+ if self.load_merges:
124
+ merges = tokenizer.get('model', {}).get('merges')
125
+ if isinstance(merges, list) and merges:
126
+ if isinstance(merges[0], str):
127
+ self.merges = merges
128
+ elif isinstance(merges[0], list) and len(merges[0]) == 2 and isinstance(merges[0][0], str):
129
+ # New format since transformers 4.45 to support spaces in merges
130
+ # ref: https://github.com/ggerganov/llama.cpp/issues/9692
131
+ # TODO: internally store as the new format instead of converting to old
132
+ if any(' ' in s for pair in merges for s in pair):
133
+ logger.warning(f'Spaces in merges detected, encoding as {chr(ord(" ") + 256)!r}')
134
+ self.merges = [
135
+ ' '.join(
136
+ [
137
+ # ensure the spaces are properly encoded
138
+ ''.join(
139
+ chr(ord(c) + 256) if c == ' ' else c
140
+ for c in part
141
+ )
142
+ for part in pair
143
+ ]
144
+ )
145
+ for pair in merges
146
+ ]
147
+ else:
148
+ raise ValueError("Unknown tokenizer merges format")
149
+ added_tokens = tokenizer.get('added_tokens', {})
150
+ else:
151
+ added_tokens = {}
152
+ tokenizer_config_file = path / 'tokenizer_config.json'
153
+ if not tokenizer_config_file.is_file():
154
+ return True
155
+ with open(tokenizer_config_file, encoding = 'utf-8') as f:
156
+ tokenizer_config = json.load(f)
157
+ chat_template = tokenizer_config.get('chat_template')
158
+ if chat_template is None or isinstance(chat_template, (str, list)):
159
+ self.chat_template = chat_template
160
+ else:
161
+ logger.warning(f'Bad type for chat_template field in {tokenizer_config_file!r} - ignoring')
162
+ for typ in self.special_token_types:
163
+ add_entry = tokenizer_config.get(f'add_{typ}_token')
164
+ if isinstance(add_entry, bool):
165
+ self.add_special_token[typ] = add_entry
166
+ entry = tokenizer_config.get(f'{typ}_token')
167
+ if isinstance(entry, str):
168
+ tc_content = entry
169
+ elif isinstance(entry, dict):
170
+ entry_content = entry.get('content')
171
+ if not isinstance(entry_content, str):
172
+ continue
173
+ tc_content = entry_content
174
+ else:
175
+ continue
176
+ # We only need the first match here.
177
+ maybe_token_id = next(
178
+ (atok.get('id') for atok in added_tokens if atok.get('content') == tc_content),
179
+ None,
180
+ )
181
+ self._set_special_token(typ, maybe_token_id)
182
+ return True
183
+
184
+ def _try_load_from_config_json(self, path: Path) -> bool:
185
+ config_file = path / 'config.json'
186
+ if not config_file.is_file():
187
+ return False
188
+ with open(config_file, encoding = 'utf-8') as f:
189
+ config = json.load(f)
190
+ for typ in self.special_token_types:
191
+ self._set_special_token(typ, config.get(f'{typ}_token_id'))
192
+ return True
193
+
194
+
195
+ @runtime_checkable
196
+ class BaseVocab(Protocol):
197
+ tokenizer_model: ClassVar[str]
198
+ name: ClassVar[str]
199
+
200
+
201
+ @runtime_checkable
202
+ class Vocab(BaseVocab, Protocol):
203
+ vocab_size: int
204
+ added_tokens_dict: dict[str, int]
205
+ added_tokens_list: list[str]
206
+ fname_tokenizer: Path
207
+
208
+ def __init__(self, base_path: Path): ...
209
+ def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: ...
210
+
211
+
212
+ class NoVocab(BaseVocab):
213
+ tokenizer_model = "no_vocab"
214
+ name = "no_vocab"
215
+
216
+ def __repr__(self) -> str:
217
+ return "<NoVocab for a model without integrated vocabulary>"
218
+
219
+
220
+ class BpeVocab(Vocab):
221
+ tokenizer_model = "gpt2"
222
+ name = "bpe"
223
+
224
+ def __init__(self, base_path: Path):
225
+ added_tokens: dict[str, int] = {}
226
+
227
+ if (fname_tokenizer := base_path / 'vocab.json').exists():
228
+ # "slow" tokenizer
229
+ with open(fname_tokenizer, encoding="utf-8") as f:
230
+ self.vocab = json.load(f)
231
+
232
+ try:
233
+ # FIXME: Verify that added tokens here _cannot_ overlap with the main vocab.
234
+ with open(base_path / 'added_tokens.json', encoding="utf-8") as f:
235
+ added_tokens = json.load(f)
236
+ except FileNotFoundError:
237
+ pass
238
+ else:
239
+ # "fast" tokenizer
240
+ fname_tokenizer = base_path / 'tokenizer.json'
241
+
242
+ # if this fails, FileNotFoundError propagates to caller
243
+ with open(fname_tokenizer, encoding="utf-8") as f:
244
+ tokenizer_json = json.load(f)
245
+
246
+ tokenizer_model: dict[str, Any] = tokenizer_json['model']
247
+ if (
248
+ tokenizer_model['type'] != 'BPE' or tokenizer_model.get('byte_fallback', False)
249
+ or tokenizer_json['decoder']['type'] != 'ByteLevel'
250
+ ):
251
+ raise FileNotFoundError('Cannot find GPT-2 BPE tokenizer')
252
+
253
+ self.vocab = tokenizer_model["vocab"]
254
+
255
+ if (added := tokenizer_json.get('added_tokens')) is not None:
256
+ # Added tokens here can be duplicates of the main vocabulary.
257
+ added_tokens = {item['content']: item['id']
258
+ for item in added
259
+ if item['content'] not in self.vocab}
260
+
261
+ vocab_size = len(self.vocab)
262
+ expected_ids = list(range(vocab_size, vocab_size + len(added_tokens)))
263
+ actual_ids = sorted(added_tokens.values())
264
+ if expected_ids != actual_ids:
265
+ expected_end_id = vocab_size + len(actual_ids) - 1
266
+ raise ValueError(f"Expected the {len(actual_ids)} added token ID(s) to be sequential in the range "
267
+ f"{vocab_size} - {expected_end_id}; got {actual_ids}")
268
+
269
+ items = sorted(added_tokens.items(), key=lambda text_idx: text_idx[1])
270
+ self.added_tokens_dict = added_tokens
271
+ self.added_tokens_list = [text for (text, idx) in items]
272
+ self.vocab_size_base = vocab_size
273
+ self.vocab_size = self.vocab_size_base + len(self.added_tokens_list)
274
+ self.fname_tokenizer = fname_tokenizer
275
+
276
+ def bpe_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
277
+ reverse_vocab = {id: encoded_tok for encoded_tok, id in self.vocab.items()}
278
+
279
+ for i, _ in enumerate(self.vocab):
280
+ yield reverse_vocab[i], 0.0, gguf.TokenType.NORMAL
281
+
282
+ def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
283
+ for text in self.added_tokens_list:
284
+ score = -1000.0
285
+ yield text.encode("utf-8"), score, gguf.TokenType.CONTROL
286
+
287
+ def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
288
+ yield from self.bpe_tokens()
289
+ yield from self.added_tokens()
290
+
291
+ def __repr__(self) -> str:
292
+ return f"<BpeVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
293
+
294
+
295
+ class SentencePieceVocab(Vocab):
296
+ tokenizer_model = "llama"
297
+ name = "spm"
298
+
299
+ def __init__(self, base_path: Path):
300
+ added_tokens: dict[str, int] = {}
301
+ if (fname_tokenizer := base_path / 'tokenizer.model').exists():
302
+ # normal location
303
+ try:
304
+ with open(base_path / 'added_tokens.json', encoding="utf-8") as f:
305
+ added_tokens = json.load(f)
306
+ except FileNotFoundError:
307
+ pass
308
+ elif not (fname_tokenizer := base_path.parent / 'tokenizer.model').exists():
309
+ # not found in alternate location either
310
+ raise FileNotFoundError('Cannot find tokenizer.model')
311
+
312
+ self.sentencepiece_tokenizer = SentencePieceProcessor()
313
+ self.sentencepiece_tokenizer.LoadFromFile(str(fname_tokenizer))
314
+ vocab_size = self.sentencepiece_tokenizer.vocab_size()
315
+
316
+ new_tokens = {id: piece for piece, id in added_tokens.items() if id >= vocab_size}
317
+ expected_new_ids = list(range(vocab_size, vocab_size + len(new_tokens)))
318
+ actual_new_ids = sorted(new_tokens.keys())
319
+
320
+ if expected_new_ids != actual_new_ids:
321
+ raise ValueError(f"Expected new token IDs {expected_new_ids} to be sequential; got {actual_new_ids}")
322
+
323
+ # Token pieces that were added to the base vocabulary.
324
+ self.added_tokens_dict = added_tokens
325
+ self.added_tokens_list = [new_tokens[id] for id in actual_new_ids]
326
+ self.vocab_size_base = vocab_size
327
+ self.vocab_size = self.vocab_size_base + len(self.added_tokens_list)
328
+ self.fname_tokenizer = fname_tokenizer
329
+
330
+ def sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
331
+ tokenizer = self.sentencepiece_tokenizer
332
+ for i in range(tokenizer.vocab_size()):
333
+ piece = tokenizer.IdToPiece(i)
334
+ text = piece.encode("utf-8")
335
+ score: float = tokenizer.GetScore(i)
336
+
337
+ toktype = gguf.TokenType.NORMAL
338
+ if tokenizer.IsUnknown(i):
339
+ toktype = gguf.TokenType.UNKNOWN
340
+ if tokenizer.IsControl(i):
341
+ toktype = gguf.TokenType.CONTROL
342
+
343
+ # NOTE: I think added_tokens are user defined.
344
+ # ref: https://github.com/google/sentencepiece/blob/master/src/sentencepiece_model.proto
345
+ # if tokenizer.is_user_defined(i): toktype = gguf.TokenType.USER_DEFINED
346
+
347
+ if tokenizer.IsUnused(i):
348
+ toktype = gguf.TokenType.UNUSED
349
+ if tokenizer.IsByte(i):
350
+ toktype = gguf.TokenType.BYTE
351
+
352
+ yield text, score, toktype
353
+
354
+ def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
355
+ for text in self.added_tokens_list:
356
+ score = -1000.0
357
+ yield text.encode("utf-8"), score, gguf.TokenType.USER_DEFINED
358
+
359
+ def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
360
+ yield from self.sentencepiece_tokens()
361
+ yield from self.added_tokens()
362
+
363
+ def __repr__(self) -> str:
364
+ return f"<SentencePieceVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
365
+
366
+
367
+ class LlamaHfVocab(Vocab):
368
+ tokenizer_model = "llama"
369
+ name = "hfft"
370
+
371
+ def __init__(self, base_path: Path):
372
+ fname_tokenizer = base_path / 'tokenizer.json'
373
+ # if this fails, FileNotFoundError propagates to caller
374
+ with open(fname_tokenizer, encoding='utf-8') as f:
375
+ tokenizer_json = json.load(f)
376
+
377
+ # pre-check so we know if we need transformers
378
+ tokenizer_model: dict[str, Any] = tokenizer_json['model']
379
+ is_llama3 = (
380
+ tokenizer_model['type'] == 'BPE' and tokenizer_model.get('ignore_merges', False)
381
+ and not tokenizer_model.get('byte_fallback', True)
382
+ )
383
+ if is_llama3:
384
+ raise TypeError('Llama 3 must be converted with BpeVocab')
385
+
386
+ if not is_llama3 and (
387
+ tokenizer_model['type'] != 'BPE' or not tokenizer_model.get('byte_fallback', False)
388
+ or tokenizer_json['decoder']['type'] != 'Sequence'
389
+ ):
390
+ raise FileNotFoundError('Cannot find Llama BPE tokenizer')
391
+
392
+ try:
393
+ from transformers import AutoTokenizer
394
+ except ImportError as e:
395
+ raise ImportError(
396
+ "To use LlamaHfVocab, please install the `transformers` package. "
397
+ "You can install it with `pip install transformers`."
398
+ ) from e
399
+
400
+ # Allow the tokenizer to default to slow or fast versions.
401
+ # Explicitly set tokenizer to use local paths.
402
+ self.tokenizer = AutoTokenizer.from_pretrained(
403
+ base_path,
404
+ cache_dir=base_path,
405
+ local_files_only=True,
406
+ )
407
+ assert self.tokenizer.is_fast # assume tokenizer.json is used
408
+
409
+ # Initialize lists and dictionaries for added tokens
410
+ self.added_tokens_list = []
411
+ self.added_tokens_dict = dict()
412
+ self.added_tokens_ids = set()
413
+
414
+ # Process added tokens
415
+ for tok, tokidx in sorted(
416
+ self.tokenizer.get_added_vocab().items(), key=lambda x: x[1]
417
+ ):
418
+ # Only consider added tokens that are not in the base vocabulary
419
+ if tokidx >= self.tokenizer.vocab_size:
420
+ self.added_tokens_list.append(tok)
421
+ self.added_tokens_dict[tok] = tokidx
422
+ self.added_tokens_ids.add(tokidx)
423
+
424
+ # Store special tokens and their IDs
425
+ self.specials = {
426
+ tok: self.tokenizer.get_vocab()[tok]
427
+ for tok in self.tokenizer.all_special_tokens
428
+ }
429
+ self.special_ids = set(self.tokenizer.all_special_ids)
430
+
431
+ # Set vocabulary sizes
432
+ self.vocab_size_base = self.tokenizer.vocab_size
433
+ self.vocab_size = self.vocab_size_base + len(self.added_tokens_list)
434
+
435
+ self.fname_tokenizer = fname_tokenizer
436
+
437
+ def hf_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
438
+ reverse_vocab = {
439
+ id: encoded_tok for encoded_tok, id in self.tokenizer.get_vocab().items()
440
+ }
441
+
442
+ for token_id in range(self.vocab_size_base):
443
+ # Skip processing added tokens here
444
+ if token_id in self.added_tokens_ids:
445
+ continue
446
+
447
+ # Convert token text to bytes
448
+ token_text = reverse_vocab[token_id].encode("utf-8")
449
+
450
+ # Yield token text, score, and type
451
+ yield token_text, self.get_token_score(token_id), self.get_token_type(
452
+ token_id, token_text, self.special_ids # Reuse already stored special IDs
453
+ )
454
+
455
+ def get_token_type(self, token_id: int, token_text: bytes, special_ids: set[int]) -> gguf.TokenType:
456
+ # Special case for byte tokens
457
+ if re.fullmatch(br"<0x[0-9A-Fa-f]{2}>", token_text):
458
+ return gguf.TokenType.BYTE
459
+
460
+ # Determine token type based on whether it's a special token
461
+ return gguf.TokenType.CONTROL if token_id in special_ids else gguf.TokenType.NORMAL
462
+
463
+ def get_token_score(self, token_id: int) -> float:
464
+ # Placeholder for actual logic to determine the token's score
465
+ # This needs to be implemented based on specific requirements
466
+ return -1000.0 # Default score
467
+
468
+ def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
469
+ for text in self.added_tokens_list:
470
+ if text in self.specials:
471
+ toktype = self.get_token_type(self.specials[text], b'', self.special_ids)
472
+ score = self.get_token_score(self.specials[text])
473
+ else:
474
+ toktype = gguf.TokenType.USER_DEFINED
475
+ score = -1000.0
476
+
477
+ yield text.encode("utf-8"), score, toktype
478
+
479
+ def has_newline_token(self):
480
+ return "<0x0A>" in self.tokenizer.vocab or "\n" in self.tokenizer.vocab
481
+
482
+ def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
483
+ yield from self.hf_tokens()
484
+ yield from self.added_tokens()
485
+
486
+ def __repr__(self) -> str:
487
+ return f"<LlamaHfVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
test_inference.py ADDED
@@ -0,0 +1,617 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from exllamav2 import(
3
+ ExLlamaV2,
4
+ ExLlamaV2Config,
5
+ ExLlamaV2Cache,
6
+ ExLlamaV2Cache_8bit,
7
+ ExLlamaV2Cache_Q4,
8
+ ExLlamaV2Cache_Q6,
9
+ ExLlamaV2Cache_Q8,
10
+ ExLlamaV2Cache_TP,
11
+ ExLlamaV2Tokenizer,
12
+ model_init,
13
+ )
14
+
15
+ from exllamav2.generator import (
16
+ ExLlamaV2BaseGenerator,
17
+ ExLlamaV2Sampler
18
+ )
19
+
20
+ from exllamav2.attn import ExLlamaV2Attention
21
+ from exllamav2.mlp import ExLlamaV2MLP
22
+ from exllamav2.moe_mlp import ExLlamaV2MoEMLP
23
+ from exllamav2.parallel_decoder import ExLlamaV2ParallelDecoder
24
+
25
+ import argparse, os, math, time
26
+ import torch
27
+ import torch.nn.functional as F
28
+ from exllamav2.conversion.tokenize import get_tokens
29
+ from exllamav2.conversion.quantize import list_live_tensors
30
+ import gc
31
+
32
+ # from exllamav2.mlp import set_catch
33
+
34
+ import sys
35
+ import json
36
+
37
+ torch.cuda._lazy_init()
38
+ torch.set_printoptions(precision = 5, sci_mode = False, linewidth = 150)
39
+
40
+ # torch.backends.cuda.matmul.allow_tf32 = True
41
+ # torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
42
+ # torch.set_float32_matmul_precision("medium")
43
+
44
+ # (!!!) NOTE: These go on top of the engine arguments that can be found in `model_init.py` (!!!)
45
+ parser = argparse.ArgumentParser(description = "Test inference on ExLlamaV2 model")
46
+ parser.add_argument("-ed", "--eval_dataset", type = str, help = "Perplexity evaluation dataset (.parquet file)")
47
+ parser.add_argument("-er", "--eval_rows", type = int, default = 128, help = "Number of rows to apply from dataset")
48
+ parser.add_argument("-el", "--eval_length", type = int, default = 2048, help = "Max no. tokens per sample")
49
+ parser.add_argument("-et", "--eval_token", action = "store_true", help = "Evaluate perplexity on token-by-token inference using cache")
50
+ parser.add_argument("-e8", "--eval_token_8bit", action = "store_true", help = "Evaluate perplexity on token-by-token inference using 8-bit (FP8) cache")
51
+ parser.add_argument("-eq4", "--eval_token_q4", action = "store_true", help = "Evaluate perplexity on token-by-token inference using Q4 cache")
52
+ parser.add_argument("-eq6", "--eval_token_q6", action = "store_true", help = "Evaluate perplexity on token-by-token inference using Q6 cache")
53
+ parser.add_argument("-eq8", "--eval_token_q8", action = "store_true", help = "Evaluate perplexity on token-by-token inference using Q8 cache")
54
+ parser.add_argument("-ecl", "--eval_context_lens", action = "store_true", help = "Evaluate perplexity at range of context lengths")
55
+ # parser.add_argument("-eb", "--eval_bos", action = "store_true", help = "Add BOS token to every row in perplexity test (required by Gemma and maybe other models.)")
56
+ parser.add_argument("-p", "--prompt", type = str, help = "Generate from prompt (basic sampling settings)")
57
+ parser.add_argument("-pnb", "--prompt_no_bos", action = "store_true", help = "Don't add BOS token to prompt")
58
+ parser.add_argument("-t", "--tokens", type = int, default = 128, help = "Max no. tokens")
59
+ parser.add_argument("-ps", "--prompt_speed", action = "store_true", help = "Test prompt processing (batch) speed over context length")
60
+ parser.add_argument("-s", "--speed", action = "store_true", help = "Test raw generation speed over context length")
61
+ parser.add_argument("-mix", "--mix_layers", type = str, help = "Load replacement layers from secondary model. Example: --mix_layers 1,6-7:/mnt/models/other_model")
62
+ parser.add_argument("-nwu", "--no_warmup", action = "store_true", help = "Skip warmup before testing model")
63
+ parser.add_argument("-sl", "--stream_layers", action = "store_true", help = "Load model layer by layer (perplexity evaluation only)")
64
+ parser.add_argument("-sp", "--standard_perplexity", choices = ["wiki2"], help = "Run standard (HF) perplexity test, stride 512 (experimental)")
65
+ parser.add_argument("-rr", "--rank_reduce", type = str, help = "Rank-reduction for MLP layers of model, in reverse order (for experimentation)")
66
+ parser.add_argument("-mol", "--max_output_len", type = int, help = "Set max output chunk size (incompatible with ppl tests)")
67
+
68
+ # Initialize model and tokenizer
69
+
70
+ model_init.add_args(parser)
71
+ args = parser.parse_args()
72
+
73
+ # Check conflicting settings
74
+
75
+ if args.stream_layers:
76
+ if args.eval_token or args.eval_token_8bit or args.eval_token_q4 or args.eval_token_q6 or args.eval_token_q8:
77
+ print(" ## Can't test token ppl while streaming layers")
78
+ sys.exit()
79
+ if args.prompt:
80
+ print(" ## Can't generate while streaming layers")
81
+ sys.exit()
82
+ if args.speed or args.prompt_speed:
83
+ print(" ## Can't test speed while streaming layers")
84
+ sys.exit()
85
+ if args.gpu_split:
86
+ print(" ## Can only use one GPU when streaming layers")
87
+ sys.exit()
88
+ if args.eval_context_lens and args.stream_layers:
89
+ print(" ## eval_context_lens not compatible with stream_layers")
90
+ sys.exit()
91
+ if args.eval_dataset:
92
+ if args.length and args.eval_length != args.length:
93
+ print(" !! Overriding model context length to match eval row length")
94
+ args.length = args.eval_length
95
+
96
+ # Init
97
+
98
+ model_init.check_args(args)
99
+ model_init.print_options(args)
100
+ model, tokenizer = model_init.init(
101
+ args,
102
+ allow_auto_split = True,
103
+ skip_load = args.stream_layers,
104
+ benchmark = True,
105
+ max_output_len = args.max_output_len,
106
+ progress = True
107
+ )
108
+ cache = None
109
+
110
+ # Auto split
111
+
112
+ if not model.loaded and not args.stream_layers:
113
+
114
+ if args.mix_layers:
115
+ print(" !! Warning, auto split does not account for VRAM requirement of replacement layers")
116
+
117
+ print(" -- Loading model...")
118
+ cache = ExLlamaV2Cache(model, lazy = True)
119
+ t = time.time()
120
+ model.load_autosplit(cache, progress = True)
121
+ t = time.time() - t
122
+ print(f" -- Loaded model in {t:.4f} seconds")
123
+
124
+ if args.stream_layers:
125
+
126
+ stream_batch_size = 2
127
+ model.config.max_batch_size = stream_batch_size
128
+ model.load(lazy = True)
129
+
130
+ # Rank reduction
131
+
132
+ if args.rank_reduce:
133
+
134
+ if args.stream_layers:
135
+ print(" ## --rank_reduce can not be combined with --stream_layers")
136
+ sys.exit()
137
+
138
+ rr = args.rank_reduce.split(",")
139
+ idx = len(model.modules) - 1
140
+ for r in rr:
141
+ k = float(r)
142
+
143
+ while True:
144
+ idx -= 1
145
+ module = model.modules[idx]
146
+ if isinstance(module, ExLlamaV2ParallelDecoder): break
147
+ if isinstance(module, ExLlamaV2MLP): break
148
+ if isinstance(module, ExLlamaV2MoEMLP): break
149
+ if idx < 0:
150
+ print(" ## Not enough layers")
151
+ sys.exit()
152
+
153
+ print(f" -- Reducing {module.key} ({module.name}) to {k * 100:.2f}%")
154
+ module.rank_reduce(k)
155
+
156
+ # Replacement
157
+
158
+ if args.mix_layers:
159
+ intervals_, extra_dir = args.mix_layers.split(":")
160
+
161
+ print(f" -- Loading replacement layers from: {extra_dir}")
162
+
163
+ extra_config = ExLlamaV2Config()
164
+ extra_config.model_dir = extra_dir
165
+ extra_config.prepare()
166
+ intervals = intervals_.split(",")
167
+ for interval in intervals:
168
+ ab = interval.split("-")
169
+ a, b = int(ab[0]), int(ab[-1])
170
+ for idx in range(a, b + 1):
171
+ print(f" -- Layer {idx}...")
172
+ layerkey = "model.layers." + str(idx) + "."
173
+ remove = [k for k in model.config.tensor_file_map.keys() if k.startswith(layerkey)]
174
+ replace = [k for k in extra_config.tensor_file_map.keys() if k.startswith(layerkey)]
175
+ # reload = [k for k in model.modules_dict.keys() if k.startswith(layerkey)]
176
+ for k in remove: del model.config.tensor_file_map[k]
177
+ for k in replace: model.config.tensor_file_map[k] = extra_config.tensor_file_map[k]
178
+ # for k in reload:
179
+ # model.modules_dict[k].unload()
180
+ # model.modules_dict[k].load()
181
+ if not args.stream_layers:
182
+ model.modules[idx * 2 + 1].reload()
183
+ model.modules[idx * 2 + 2].reload()
184
+
185
+ # Test generation
186
+
187
+ if args.prompt:
188
+
189
+ with torch.inference_mode():
190
+
191
+ if cache is None:
192
+ cache = ExLlamaV2Cache(model) if not model.tp_context else ExLlamaV2Cache_TP(model)
193
+
194
+ ids = tokenizer.encode(args.prompt)
195
+ tokens_prompt = ids.shape[-1]
196
+
197
+ print(f" -- Warmup...")
198
+
199
+ generator = ExLlamaV2BaseGenerator(model, cache, tokenizer)
200
+ if not args.no_warmup: generator.warmup()
201
+
202
+ print(f" -- Generating...")
203
+ print()
204
+
205
+ settings = ExLlamaV2Sampler.Settings()
206
+ settings.temperature = 1.0
207
+ settings.top_k = 0
208
+ settings.top_p = 0.8
209
+ settings.token_repetition_penalty = 1.02
210
+ settings.disallow_tokens(tokenizer, [tokenizer.eos_token_id])
211
+
212
+ time_begin = time.time()
213
+
214
+ output = generator.generate_simple(args.prompt, settings, args.tokens, token_healing = True, add_bos = not args.prompt_no_bos)
215
+
216
+ torch.cuda.synchronize()
217
+ time_prompt = time.time()
218
+
219
+ time_end = time.time()
220
+
221
+ print(output)
222
+ print()
223
+
224
+ total_gen = time_end - time_begin
225
+ print(f" -- Response generated in {total_gen:.2f} seconds, {args.tokens} tokens, {args.tokens / total_gen:.2f} tokens/second (includes prompt eval.)")
226
+
227
+
228
+ # Test perplexity
229
+
230
+ if args.eval_dataset or args.standard_perplexity:
231
+
232
+ with torch.inference_mode():
233
+
234
+ print(f" -- Running perplexity test")
235
+
236
+ if args.standard_perplexity:
237
+
238
+ eval_length = args.eval_length
239
+ if args.eval_dataset:
240
+ print(f" !! Note, overriding specified --eval_dataset with {args.standard_perplexity}")
241
+
242
+ from datasets import load_dataset
243
+
244
+ if args.standard_perplexity == "wiki2":
245
+ ds = "wikitext"
246
+ part = "wikitext-2-raw-v1"
247
+ split = "test"
248
+ # if args.standard_perplexity == "c4":
249
+ # ds = "allenai/c4"
250
+ # part = "allenai--c4"
251
+ # split = "train"
252
+
253
+ print(f" -- Loading dataset {ds}, {part}, {split}...")
254
+ test = load_dataset(ds, part, split = split)
255
+
256
+ print(f" -- Tokenizing samples...")
257
+ text = "\n\n".join(test["text"])
258
+ eval_tokens = tokenizer.encode(text)
259
+
260
+ stride = 512
261
+ seqs = []
262
+ eval_len = []
263
+ a = 0
264
+ while True:
265
+ b = a + model.config.max_seq_len
266
+ if b > eval_tokens.shape[-1]: break
267
+ seqs.append(eval_tokens[:, a:b])
268
+ eval_len.append(b if a == 0 else stride)
269
+ a += stride
270
+
271
+ eval_tokens = torch.cat(seqs, dim = 0)
272
+
273
+ else:
274
+
275
+ eval_dataset = args.eval_dataset
276
+ eval_rows = args.eval_rows
277
+ eval_length = args.eval_length
278
+
279
+ print(f" -- Dataset: {eval_dataset}")
280
+ print(f" -- Tokenizing eval data, {eval_rows} rows x {eval_length} tokens...")
281
+
282
+ eval_tokens = get_tokens(eval_rows, eval_length, eval_dataset, tokenizer)
283
+ eval_len = [eval_tokens.shape[1]] * eval_tokens.shape[0]
284
+
285
+ # if args.eval_bos:
286
+ if model.config.arch.requires_bos:
287
+ boss = torch.full((eval_tokens.shape[0], 1), tokenizer.bos_token_id, dtype = torch.long)
288
+ eval_tokens = torch.cat((boss, eval_tokens[:, :-1]), dim = 1)
289
+
290
+ if args.eval_context_lens:
291
+ logprob_sum = []
292
+ logprob_count = []
293
+ else:
294
+ logprob_sum = 0.0
295
+ logprob_count = 0
296
+
297
+ def ppl(input_ids__, logits__, lengths__, bins = False):
298
+
299
+ logits_device = model.modules[-1].device() if not model.tp_context else \
300
+ torch.device(model.tp_context.device)
301
+
302
+ if bins:
303
+ num_bins = (max(lengths__) + 255) // 256
304
+ logprob_sum_ = [0.0] * num_bins
305
+ logprob_count_ = [0] * num_bins
306
+ else:
307
+ logprob_sum_ = 0.0
308
+ logprob_count_ = 0
309
+
310
+ assert logits__.shape[0] == input_ids__.shape[0]
311
+ ll = logits__.shape[1]
312
+
313
+ for bi in range(logits__.shape[0]):
314
+ cl = max(ll - lengths__[bi], 0)
315
+ logits_ = logits__[bi:bi+1, cl:, :]
316
+ input_ids_ = input_ids__[bi:bi+1, cl:]
317
+
318
+ if bins:
319
+ chunksize = 256
320
+ else:
321
+ chunksize = logits_.shape[1] * 4000 // logits_.shape[2] + 1
322
+ b_ = 0
323
+ while b_ < logits_.shape[1]:
324
+ a_ = b_
325
+ b_ = min(b_ + chunksize, logits_.shape[1])
326
+
327
+ logits_f = logits_[:, a_:b_, :].to(logits_device).float() + 1e-10
328
+ target_ids = input_ids_[:, a_ + 1:b_ + 1].to(logits_f.device)
329
+
330
+ log_probs = F.log_softmax(logits_f, dim=-1)
331
+ token_log_probs = log_probs.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1)
332
+ if bins:
333
+ # for cbin in range(a_ // 256 + 1):
334
+ cbin = a_ // 256
335
+ logprob_sum_[cbin] += token_log_probs.sum().item()
336
+ logprob_count_[cbin] += target_ids.numel()
337
+ else:
338
+ logprob_sum_ += token_log_probs.sum().item()
339
+ logprob_count_ += target_ids.numel()
340
+
341
+ return logprob_sum_, logprob_count_
342
+
343
+ if args.stream_layers:
344
+
345
+ print(f" -- Inference (streamed)", end = "")
346
+ sys.stdout.flush()
347
+
348
+ batch_size, seq_len = eval_tokens.shape
349
+ attn_params = ExLlamaV2Attention.Params(stream_batch_size, seq_len, 0, None, None)
350
+ # attn_mask = model.build_attn_mask(stream_batch_size, seq_len, 0, None, "cuda:0")
351
+
352
+ for idx, module in enumerate(model.modules):
353
+ module.set_device_idx(-1 if idx == 0 else 0)
354
+
355
+ model.modules[0].load()
356
+ hidden_state = model.modules[0].forward(eval_tokens)
357
+ model.modules[0].unload()
358
+
359
+ for idx, module in enumerate(model.modules):
360
+ if idx == 0: continue
361
+
362
+ print(".", end = "")
363
+ sys.stdout.flush()
364
+ module.load()
365
+
366
+ b = 0
367
+ while b < eval_tokens.shape[0]:
368
+ a = b
369
+ b = min(b + stream_batch_size, eval_tokens.shape[0])
370
+ x = hidden_state[a:b, :, :].to("cuda:0")
371
+ x = module.forward(x, cache = None, attn_params = attn_params, past_len = 0, loras = None)
372
+
373
+ if idx < len(model.modules) - 1:
374
+ hidden_state[a:b, :, :] = x.to("cpu")
375
+
376
+ else:
377
+ input_ids = eval_tokens[a:b, :]
378
+ logits = x[:, :-1, :]
379
+
380
+ # if model.config.logit_scale != 1:
381
+ # logits.mul_(model.config.logit_scale)
382
+
383
+ logprob_sum__, logprob_count__ = ppl(input_ids, logits, eval_len[a:b])
384
+ logprob_sum += logprob_sum__
385
+ logprob_count += logprob_count__
386
+
387
+ module.unload()
388
+
389
+ print()
390
+
391
+ else:
392
+
393
+ print(f" -- Inference", end = "")
394
+ sys.stdout.flush()
395
+
396
+ if cache is None:
397
+ if eval_length > model.config.max_input_len:
398
+ cache = ExLlamaV2Cache(model, max_seq_len = eval_length) if not model.tp_context else ExLlamaV2Cache_TP(model, max_seq_len = eval_length)
399
+ else:
400
+ cache = None
401
+
402
+ for i in range(eval_tokens.shape[0]):
403
+
404
+ if i % 10 == 0: print(".", end = "")
405
+ sys.stdout.flush()
406
+
407
+ input_ids = eval_tokens[i:i+1, :]
408
+
409
+ input_ids = input_ids[:, :]
410
+ if cache is not None: cache.current_seq_len = 0
411
+ logits = model.forward(input_ids, cache, cpu_logits = input_ids.numel() > 2048)
412
+ logits = logits[:, :-1, :]
413
+
414
+ logprob_sum__, logprob_count__ = ppl(input_ids, logits, eval_len[i:i+1], args.eval_context_lens)
415
+ if args.eval_context_lens:
416
+ while len(logprob_sum) < len(logprob_sum__):
417
+ logprob_sum.append(0.0)
418
+ logprob_count.append(0)
419
+ for j in range(len(logprob_sum__)):
420
+ logprob_sum[j] += logprob_sum__[j]
421
+ logprob_count[j] += logprob_count__[j]
422
+ else:
423
+ logprob_sum += logprob_sum__
424
+ logprob_count += logprob_count__
425
+
426
+ if not args.eval_context_lens:
427
+ print()
428
+ mean_log_prob = logprob_sum / logprob_count
429
+ perplexity = math.exp(-mean_log_prob)
430
+ print(f" -- Evaluation perplexity: {perplexity:.4f}")
431
+ else:
432
+ print()
433
+ for j in range(len(logprob_sum__)):
434
+ mean_log_prob = logprob_sum[j] / logprob_count[j]
435
+ perplexity = math.exp(-mean_log_prob)
436
+ dl = min((j + 1) * 256, eval_length)
437
+ print(f" -- Evaluation perplexity: {dl} {perplexity:.4f}")
438
+
439
+ def test_ppl_token():
440
+ global logprob_sum, logprob_count, i, input_ids
441
+ global logits, target_ids, log_probs, token_log_probs
442
+ global mean_log_prob, perplexity
443
+
444
+ # set_catch("model.layers.3")
445
+
446
+ logprob_sum = 0
447
+ logprob_count = 0
448
+
449
+ for i in range(eval_tokens.shape[0]):
450
+
451
+ cache.current_seq_len = 0
452
+
453
+ for j in range(eval_tokens.shape[1] - 1):
454
+ if j % 256 == 0: print(".", end = "")
455
+ sys.stdout.flush()
456
+
457
+ input_ids = eval_tokens[i:i + 1, j:j + 1]
458
+ logits = model.forward(input_ids, cache)
459
+ logits = logits.float() + 1e-10
460
+
461
+ log_probs = F.log_softmax(logits, dim = -1)
462
+ logprob_sum += log_probs[0, 0, eval_tokens[i, j+1]]
463
+ logprob_count += 1
464
+
465
+ # mean_log_prob = logprob_sum / logprob_count
466
+ # perplexity = math.exp(-mean_log_prob)
467
+ # print(f" -- Token {j}: {perplexity:.4f}")
468
+
469
+ print()
470
+
471
+ mean_log_prob = logprob_sum / logprob_count
472
+ perplexity = math.exp(-mean_log_prob)
473
+ print(f" -- Evaluation perplexity: {perplexity:.4f}")
474
+
475
+ if args.eval_token:
476
+ if args.standard_perplexity:
477
+ print(f" !! Note, can't evalutate token perplexity on standard test")
478
+ else:
479
+ print(f" -- Inference (token)", end = "")
480
+ sys.stdout.flush()
481
+ cache = ExLlamaV2Cache(model, max_seq_len = eval_length) if not model.tp_context else \
482
+ ExLlamaV2Cache_TP(model, max_seq_len = eval_length)
483
+ test_ppl_token()
484
+
485
+ if args.eval_token_8bit:
486
+ if args.standard_perplexity:
487
+ print(f" !! Note, can't evalutate token perplexity on standard test")
488
+ else:
489
+ print(f" -- Inference (token, 8-bit cache)", end = "")
490
+ sys.stdout.flush()
491
+ cache = ExLlamaV2Cache_8bit(model, max_seq_len = eval_length) if not model.tp_context else \
492
+ ExLlamaV2Cache_TP(model, max_seq_len = eval_length, base = ExLlamaV2Cache_8bit)
493
+ test_ppl_token()
494
+
495
+ if args.eval_token_q4:
496
+ if args.standard_perplexity:
497
+ print(f" !! Note, can't evalutate token perplexity on standard test")
498
+ else:
499
+ print(f" -- Inference (token, Q4 cache)", end = "")
500
+ sys.stdout.flush()
501
+ cache = ExLlamaV2Cache_Q4(model, max_seq_len = eval_length) if not model.tp_context else \
502
+ ExLlamaV2Cache_TP(model, max_seq_len = eval_length, base = ExLlamaV2Cache_Q4)
503
+ # cache.calibrate(tokenizer)
504
+ test_ppl_token()
505
+
506
+ if args.eval_token_q6:
507
+ if args.standard_perplexity:
508
+ print(f" !! Note, can't evalutate token perplexity on standard test")
509
+ else:
510
+ print(f" -- Inference (token, Q6 cache)", end = "")
511
+ sys.stdout.flush()
512
+ cache = ExLlamaV2Cache_Q6(model, max_seq_len = eval_length) if not model.tp_context else \
513
+ ExLlamaV2Cache_TP(model, max_seq_len = eval_length, base = ExLlamaV2Cache_Q6)
514
+ # cache.calibrate(tokenizer)
515
+ test_ppl_token()
516
+
517
+ if args.eval_token_q8:
518
+ if args.standard_perplexity:
519
+ print(f" !! Note, can't evalutate token perplexity on standard test")
520
+ else:
521
+ print(f" -- Inference (token, Q8 cache)", end = "")
522
+ sys.stdout.flush()
523
+ cache = ExLlamaV2Cache_Q8(model, max_seq_len = eval_length) if not model.tp_context else \
524
+ ExLlamaV2Cache_TP(model, max_seq_len = eval_length, base = ExLlamaV2Cache_Q8)
525
+ # cache.calibrate(tokenizer)
526
+ test_ppl_token()
527
+
528
+
529
+ # Test prompt speed
530
+
531
+ if args.prompt_speed:
532
+
533
+ with torch.inference_mode():
534
+
535
+ if cache is None:
536
+ cache = ExLlamaV2Cache(model) if not model.tp_context else ExLlamaV2Cache_TP(model)
537
+
538
+ ids = torch.randint(0, model.config.vocab_size - 1, (1, model.config.max_seq_len))
539
+
540
+ print(f" -- Warmup...")
541
+
542
+ if not args.no_warmup:
543
+ model.forward(ids[:, -1:])
544
+
545
+ print(f" -- Measuring prompt speed...")
546
+
547
+ torch.cuda.synchronize()
548
+
549
+ current_len = 128
550
+ step = 128
551
+ prompt_iters = 3
552
+ while True:
553
+
554
+ total_time = 0
555
+ for i in range(prompt_iters):
556
+
557
+ torch.cuda.synchronize()
558
+ time_begin = time.time()
559
+
560
+ cache.current_seq_len = 0
561
+ model.forward(ids[:, :current_len], cache, preprocess_only = True)
562
+
563
+ torch.cuda.synchronize()
564
+ time_end = time.time()
565
+ total_time += time_end - time_begin
566
+
567
+ tps = current_len / (total_time / prompt_iters)
568
+
569
+ print(f" ** Length {current_len:>5} tokens: {tps:>11.4f} t/s")
570
+
571
+ if current_len >= 1024: step = 1024
572
+ if current_len >= 4096: step = 4096
573
+ if current_len >= 16384: step = 8192
574
+
575
+ current_len_ = current_len
576
+ current_len = min(current_len + step, model.config.max_seq_len)
577
+ if current_len == current_len_: break
578
+
579
+
580
+ # Test token speed
581
+
582
+ if args.speed:
583
+
584
+ with torch.inference_mode():
585
+
586
+ if cache is None:
587
+ cache = ExLlamaV2Cache(model) if not model.tp_context else ExLlamaV2Cache_TP(model)
588
+ cache.current_seq_len = 0
589
+
590
+ print(f" -- Measuring token speed...")
591
+ ids = tokenizer.encode("X")
592
+ model.forward(ids[:, :])
593
+
594
+ current_idx = ids.shape[-1]
595
+ next_stop = 128
596
+
597
+ while True:
598
+
599
+ time_begin = time.time()
600
+
601
+ tokens = next_stop - current_idx
602
+ for i in range(tokens):
603
+
604
+ logits = model.forward(ids[:, -1:], cache)
605
+ sample = torch.argmax(logits[0, -1]).cpu().unsqueeze(0).unsqueeze(0)
606
+ sample.clamp_(0, tokenizer.get_vocab_size() - 1)
607
+ ids = torch.cat((ids, sample), dim=-1)
608
+
609
+ time_end = time.time()
610
+ tps = tokens / (time_end - time_begin)
611
+
612
+ print(f" ** Position {current_idx:>5} + {tokens:>3} tokens: {tps:>9.4f} t/s")
613
+
614
+ current_idx = next_stop
615
+ next_stop = min(next_stop + 128, model.config.max_seq_len)
616
+ if next_stop == current_idx: break
617
+