tinywell commited on
Commit
d17b9a9
·
verified ·
1 Parent(s): 906705e

Upload chatglm.py

Browse files
Files changed (1) hide show
  1. chatglm.py +608 -0
chatglm.py ADDED
@@ -0,0 +1,608 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Dict, Optional, Tuple, Union
3
+ import math
4
+
5
+ import mlx.core as mx
6
+ import mlx.nn as nn
7
+
8
+ from .base import BaseModelArgs
9
+
10
+
11
+ @dataclass
12
+ class ModelArgs(BaseModelArgs):
13
+ model_type: str
14
+ add_bias_linear: bool = False
15
+ add_qkv_bias: bool = True
16
+ apply_query_key_layer_scaling: bool = True
17
+ apply_residual_connection_post_layernorm: bool = False
18
+ attention_dropout: float = 0.0
19
+ attention_softmax_in_fp32: bool = True
20
+ bias_dropout_fusion: bool = True
21
+ ffn_hidden_size: int = 13696
22
+ fp32_residual_connection: bool = False
23
+ hidden_dropout: float = 0.0
24
+ hidden_size: int = 4096
25
+ kv_channels: int = 128
26
+ layernorm_epsilon: float = 1.5625e-07
27
+ multi_query_attention: bool = True
28
+ multi_query_group_num: int = 2
29
+ num_attention_heads: int = 32
30
+ num_hidden_layers: int = 40
31
+ num_layers: int = 40
32
+ rope_ratio: int = 500
33
+ original_rope: bool = True
34
+ padded_vocab_size: int = 151552
35
+ post_layer_norm: bool = True
36
+ rmsnorm: bool = True
37
+ seq_length: int = 131072
38
+ use_cache: bool = True
39
+ torch_dtype: str = "bfloat16"
40
+ tie_word_embeddings: bool = False
41
+
42
+ def __post_init__(self):
43
+ pass
44
+
45
+ class RotaryEmbedding(nn.Module):
46
+ def __init__(self, dim, rope_ratio=1, original_impl=False, dtype=None):
47
+ super().__init__()
48
+ # inv_freq = 1.0 / (10000 ** (mx.arange(0, dim, 2, dtype=dtype) / dim))
49
+ # self.register_buffer("inv_freq", inv_freq)
50
+ # self.inv_freq = mx.array(inv_freq, dtype=dtype)
51
+ self.inv_freq_type = dtype
52
+ self.dim = dim
53
+ self.original_impl = original_impl
54
+ self.rope_ratio = rope_ratio
55
+
56
+ def forward_impl(
57
+ self, seq_len: int, n_elem: int, dtype: mx.Dtype, base: int = 10000
58
+ ):
59
+ """Enhanced Transformer with Rotary Position Embedding.
60
+ Derived from:https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
61
+ transformers/rope/__init__.py. MIT License:
62
+ https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
63
+ """
64
+ # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
65
+ base = base * self.rope_ratio
66
+ theta = 1.0 / (base ** (mx.arange(0, n_elem, 2, dtype=mx.float16) / n_elem))
67
+
68
+ # Create position indexes `[0, 1, ..., seq_len - 1]`
69
+ seq_idx = mx.arange(seq_len, dtype=mx.float16)
70
+
71
+ # Calculate the product of position index and $\theta_i$
72
+ idx_theta = mx.outer(seq_idx, theta).astype(mx.float16)
73
+
74
+ cache = mx.stack([mx.cos(idx_theta), mx.sin(idx_theta)], axis=-1)
75
+
76
+ # this is to mimic the behaviour of complex32, else we will get different results
77
+ if dtype in (mx.float16, mx.bfloat16, mx.int8):
78
+ cache = cache.astype(mx.bfloat16) if dtype == mx.bfloat16 else cache.astype(mx.float16)
79
+ return cache
80
+
81
+ def __call__(self, max_seq_len, offset=0):
82
+ return self.forward_impl(
83
+ max_seq_len, self.dim, dtype=self.inv_freq_type,
84
+ )
85
+
86
+ def apply_rotary_pos_emb(x: mx.array, rope_cache: mx.array) -> mx.array:
87
+ # x: [b, np, sq, hn]
88
+ b, np, sq, hn = x.shape[0], x.shape[1], x.shape[2], x.shape[3]
89
+ rot_dim = rope_cache.shape[-2] * 2
90
+ x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
91
+ # truncate to support variable sizes
92
+ rope_cache = rope_cache[:, :sq]
93
+ xshaped = x.reshape(b, np, sq, rot_dim // 2, 2)
94
+ rope_cache = rope_cache.reshape(-1, 1, sq, xshaped.shape[3], 2)
95
+ x_out2 = mx.stack(
96
+ [
97
+ xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
98
+ xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
99
+ ],
100
+ -1,
101
+ )
102
+ x_out2 = x_out2.flatten(3)
103
+ return mx.concatenate((x_out2, x_pass), axis=-1)
104
+
105
+ # class RMSNorm(nn.Module):
106
+ # def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
107
+ # super().__init__()
108
+ # self.weight = nn.empty(normalized_shape, device=device, dtype=dtype)
109
+ # self.eps = eps
110
+
111
+ # def __call__(self, hidden_states: mx.array):
112
+ # input_dtype = hidden_states.dtype
113
+ # variance = hidden_states.astype("float32").power(2).mean(-1, keepdims=True)
114
+ # hidden_states = hidden_states * variance.rsqrt()
115
+
116
+ # return (self.weight * hidden_states).astype(input_dtype)
117
+
118
+
119
+ class CoreAttention(nn.Module):
120
+ def __init__(self, args: ModelArgs, layer_number):
121
+ super().__init__()
122
+
123
+ self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
124
+ self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
125
+ if self.apply_query_key_layer_scaling:
126
+ self.attention_softmax_in_fp32 = True
127
+ self.layer_number = max(1, layer_number)
128
+
129
+ projection_size = args.kv_channels * args.num_attention_heads
130
+
131
+ # Per attention head and per partition values.
132
+ self.hidden_size_per_partition = projection_size
133
+ self.hidden_size_per_attention_head = projection_size // args.num_attention_heads
134
+ self.num_attention_heads_per_partition = args.num_attention_heads
135
+
136
+ coeff = None
137
+ self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
138
+ if self.apply_query_key_layer_scaling:
139
+ coeff = self.layer_number
140
+ self.norm_factor *= coeff
141
+ self.coeff = coeff
142
+
143
+ self.attention_dropout = nn.Dropout(args.attention_dropout)
144
+
145
+ def __call__(self, query_layer, key_layer, value_layer, attention_mask):
146
+ # scale_factor = 1 / math.sqrt(query_layer.shape[-1])
147
+ scale_factor = query_layer.shape[-1] ** -0.5
148
+ # if self.layer_number == 1:
149
+ # print(f"== |{self.layer_number}| query_layer:{query_layer.shape} key_layer:{key_layer.shape} value_layer:{value_layer.shape}")
150
+ if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
151
+ attention_mask = nn.MultiHeadAttention.create_additive_causal_mask(query_layer.shape[2]).astype(query_layer.dtype)
152
+ context_layer = mx.fast.scaled_dot_product_attention(query_layer, key_layer, value_layer, scale=scale_factor,mask=attention_mask)
153
+ else:
154
+ if attention_mask is not None:
155
+ attention_mask = ~attention_mask
156
+ context_layer = mx.fast.scaled_dot_product_attention(query_layer, key_layer, value_layer, scale=scale_factor, mask=attention_mask)
157
+ context_layer = context_layer.transpose((0,2,1,3))
158
+ new_context_layer_shape = context_layer.shape[:-2] + (self.hidden_size_per_partition,)
159
+ context_layer = context_layer.reshape(*new_context_layer_shape)
160
+
161
+ return context_layer
162
+
163
+ class SelfAttention(nn.Module):
164
+ def __init__(self, args: ModelArgs, layer_number):
165
+ super(SelfAttention, self).__init__()
166
+ self.layer_number = max(1, layer_number)
167
+
168
+ self.projection_size = args.kv_channels * args.num_attention_heads
169
+
170
+ # Per attention head and per partition values.
171
+ self.hidden_size_per_attention_head = self.projection_size // args.num_attention_heads
172
+ self.num_attention_heads_per_partition = args.num_attention_heads
173
+ self.multi_query_attention = args.multi_query_attention
174
+ self.qkv_hidden_size = 3 * self.projection_size
175
+ if self.multi_query_attention:
176
+ self.num_multi_query_groups_per_partition = args.multi_query_group_num
177
+ self.qkv_hidden_size = (
178
+ self.projection_size + 2 * self.hidden_size_per_attention_head * args.multi_query_group_num
179
+ )
180
+ self.query_key_value = nn.Linear(args.hidden_size, self.qkv_hidden_size,
181
+ bias=args.add_bias_linear or args.add_qkv_bias)
182
+
183
+ self.core_attention = CoreAttention(args, self.layer_number)
184
+
185
+ # Output.
186
+ self.dense = nn.Linear(self.projection_size, args.hidden_size, bias=args.add_bias_linear)
187
+
188
+ def __call__(self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True):
189
+ # hidden_states: [b, sq, h]
190
+
191
+ # =================================================
192
+ # Pre-allocate memory for key-values for inference.
193
+ # =================================================
194
+ # =====================
195
+ # Query, Key, and Value
196
+ # =====================
197
+
198
+ # Attention heads [b, sq, h] --> [b, sq, (np * 3 * hn)]
199
+ mixed_x_layer = self.query_key_value(hidden_states)
200
+
201
+ if self.multi_query_attention:
202
+ q_k_v_len = [
203
+ self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
204
+ self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
205
+ self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
206
+ ]
207
+ mixs = mixed_x_layer.split([
208
+ q_k_v_len[0],
209
+ q_k_v_len[0]+q_k_v_len[1],
210
+ q_k_v_len[0]+q_k_v_len[1]+q_k_v_len[2],
211
+ ],
212
+ axis=-1,
213
+ )
214
+
215
+ query_layer, key_layer, value_layer = mixs[0], mixs[1], mixs[2]
216
+ query_layer = query_layer.reshape(
217
+ query_layer.shape[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
218
+ )
219
+ key_layer = key_layer.reshape( key_layer.shape[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head))
220
+ value_layer = value_layer.reshape(
221
+ value_layer.shape[:-1]
222
+ + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
223
+ )
224
+ else:
225
+ new_tensor_shape = mixed_x_layer.shape[:-1] + \
226
+ (self.num_attention_heads_per_partition,
227
+ 3 * self.hidden_size_per_attention_head)
228
+ mixed_x_layer = mixed_x_layer.reshape(*new_tensor_shape)
229
+
230
+ # [b, sq, np, 3 * hn] --> 3 [b, sq, np, hn]
231
+ (query_layer, key_layer, value_layer) = mx.split_along_last_dim(mixed_x_layer, 3)
232
+
233
+ # [b, sq, np, hn] -> [b, np, sq, hn]
234
+ query_layer, key_layer, value_layer = [k.transpose((0,2,1,3)) for k in [query_layer, key_layer, value_layer]]
235
+
236
+ # apply relative positional encoding (rotary embedding)
237
+ if rotary_pos_emb is not None:
238
+ query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
239
+ key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
240
+
241
+
242
+ # adjust key and value for inference
243
+ if use_cache:
244
+ key_layer, value_layer = kv_cache.update_and_fetch(key_layer, value_layer)
245
+ else:
246
+ kv_cache = None
247
+
248
+ # if self.multi_query_attention:
249
+ # # key_layer = key_layer.unsqueeze(2)
250
+ # key_layer = mx.expand_dims(key_layer,2)
251
+ # key_layer_shape = key_layer.shape
252
+ # key_layer = mx.broadcast_to(key_layer,[
253
+ # key_layer_shape[0], key_layer_shape[1], self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, key_layer_shape[3], key_layer_shape[4]]
254
+ # )
255
+ # key_layer = key_layer.reshape(
256
+ # key_layer.shape[:1] + (self.num_attention_heads_per_partition,) + key_layer.shape[3:]
257
+ # )
258
+
259
+ # # value_layer = value_layer.unsqueeze(2)
260
+ # value_layer = mx.expand_dims(value_layer,2)
261
+ # value_layer_shape = value_layer.shape
262
+ # value_layer = mx.broadcast_to(value_layer,[
263
+ # value_layer_shape[0], value_layer_shape[1], self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, value_layer_shape[3], value_layer_shape[4]]
264
+ # )
265
+ # value_layer = value_layer.reshape(
266
+ # value_layer.shape[:1] + (self.num_attention_heads_per_partition,) + value_layer.shape[3:]
267
+ # )
268
+
269
+ # ==================================
270
+ # core attention computation
271
+ # ==================================
272
+ context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)
273
+
274
+ # =================
275
+ # Output. [sq, b, h]
276
+ # =================
277
+
278
+ output = self.dense(context_layer)
279
+
280
+ return output
281
+
282
+ class MLP(nn.Module):
283
+ def __init__(self, args: ModelArgs):
284
+ super().__init__()
285
+
286
+ self.add_bias = args.add_bias_linear
287
+
288
+ # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
289
+ self.dense_h_to_4h = nn.Linear(
290
+ args.hidden_size,
291
+ args.ffn_hidden_size * 2,
292
+ bias=self.add_bias,
293
+ )
294
+
295
+ def swiglu(x):
296
+ x = mx.split(x, 2, axis=-1)
297
+ return nn.silu(x[0]) * x[1]
298
+
299
+ self.activation_func = swiglu
300
+
301
+ # Project back to h.
302
+ self.dense_4h_to_h = nn.Linear(
303
+ args.ffn_hidden_size,
304
+ args.hidden_size,
305
+ bias=self.add_bias,
306
+ )
307
+
308
+ def __call__(self, hidden_states):
309
+ # [s, b, 4hp]
310
+ intermediate_parallel = self.dense_h_to_4h(hidden_states)
311
+ intermediate_parallel = self.activation_func(intermediate_parallel)
312
+ # [s, b, h]
313
+ output = self.dense_4h_to_h(intermediate_parallel)
314
+ return output
315
+
316
+
317
+ class GLMBlock(nn.Module):
318
+ def __init__(self, args: ModelArgs, layer_number):
319
+ super(GLMBlock, self).__init__()
320
+ self.layer_number = layer_number
321
+
322
+ self.apply_residual_connection_post_layernorm = args.apply_residual_connection_post_layernorm
323
+
324
+ self.fp32_residual_connection = args.fp32_residual_connection
325
+
326
+ LayerNormFunc = nn.RMSNorm if args.rmsnorm else nn.LayerNorm
327
+ # Layernorm on the input data.
328
+ self.input_layernorm = LayerNormFunc(args.hidden_size, eps=args.layernorm_epsilon)
329
+
330
+ # Self attention.
331
+ self.self_attention = SelfAttention(args, layer_number)
332
+ self.hidden_dropout = args.hidden_dropout
333
+
334
+ self.dropout = nn.Dropout(self.hidden_dropout)
335
+
336
+ # Layernorm on the attention output
337
+ self.post_attention_layernorm = LayerNormFunc(args.hidden_size, eps=args.layernorm_epsilon)
338
+
339
+ # MLP
340
+ self.mlp = MLP(args)
341
+
342
+ def __call__(
343
+ self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True,
344
+ ):
345
+ # hidden_states: [s, b, h]
346
+
347
+ # Layer norm at the beginning of the transformer layer.
348
+ layernorm_output = self.input_layernorm(hidden_states)
349
+ # Self attention.
350
+ attention_output = self.self_attention(
351
+ layernorm_output,
352
+ attention_mask,
353
+ rotary_pos_emb,
354
+ kv_cache=kv_cache,
355
+ use_cache=use_cache
356
+ )
357
+
358
+ # Residual connection.
359
+ if self.apply_residual_connection_post_layernorm:
360
+ residual = layernorm_output
361
+ else:
362
+ residual = hidden_states
363
+
364
+ layernorm_input = self.dropout(attention_output)
365
+ layernorm_input = residual + layernorm_input
366
+
367
+ # Layer norm post the self attention.
368
+ layernorm_output = self.post_attention_layernorm(layernorm_input)
369
+
370
+ # MLP.
371
+ mlp_output = self.mlp(layernorm_output)
372
+
373
+ # Second residual connection.
374
+ if self.apply_residual_connection_post_layernorm:
375
+ residual = layernorm_output
376
+ else:
377
+ residual = layernorm_input
378
+
379
+ output = self.dropout(mlp_output)
380
+ output = residual + output
381
+
382
+ return output
383
+
384
+ class GLMTransformer(nn.Module):
385
+ def __init__(self, args: ModelArgs):
386
+ super().__init__()
387
+
388
+ self.fp32_residual_connection = args.fp32_residual_connection
389
+ self.post_layer_norm = args.post_layer_norm
390
+
391
+ # Number of layers.
392
+ self.num_layers = args.num_layers
393
+
394
+ # Transformer layers.
395
+ def build_layer(layer_number):
396
+ return GLMBlock(args, layer_number)
397
+
398
+ self.layers = [build_layer(i + 1) for i in range(self.num_layers)]
399
+
400
+ if self.post_layer_norm:
401
+ LayerNormFunc = nn.RMSNorm if args.rmsnorm else nn.LayerNorm
402
+ # Final layer norm before output.
403
+ self.final_layernorm = LayerNormFunc(args.hidden_size, eps=args.layernorm_epsilon)
404
+
405
+ self.gradient_checkpointing = False
406
+
407
+ def _get_layer(self, layer_number):
408
+ return self.layers[layer_number]
409
+
410
+ def __call__(
411
+ self, hidden_states, attention_mask, rotary_pos_emb, kv_caches=None,
412
+ use_cache: Optional[bool] = True,
413
+ ):
414
+ if not kv_caches:
415
+ kv_caches = [None for _ in range(self.num_layers)]
416
+
417
+ for index in range(self.num_layers):
418
+ layer = self._get_layer(index)
419
+ layer_ret = layer(
420
+ hidden_states,
421
+ attention_mask,
422
+ rotary_pos_emb,
423
+ kv_cache=kv_caches[index],
424
+ use_cache=use_cache
425
+ )
426
+ hidden_states = layer_ret
427
+
428
+ # Final layer norm.
429
+ if self.post_layer_norm:
430
+ hidden_states = self.final_layernorm(hidden_states)
431
+
432
+ return hidden_states
433
+
434
+ class Embedding(nn.Module):
435
+ def __init__(self, args: ModelArgs):
436
+ super().__init__()
437
+
438
+ self.hidden_size = args.hidden_size
439
+ # Word embeddings (parallel).
440
+ self.word_embeddings = nn.Embedding(
441
+ args.padded_vocab_size,
442
+ self.hidden_size,
443
+ )
444
+ self.fp32_residual_connection = args.fp32_residual_connection
445
+
446
+ def __call__(self, input_ids):
447
+ # Embeddings.
448
+ words_embeddings = self.word_embeddings(input_ids)
449
+ embeddings = words_embeddings
450
+ # If the input flag for fp32 residual connection is set, convert for float.
451
+ if self.fp32_residual_connection:
452
+ embeddings = embeddings.float()
453
+ return embeddings
454
+
455
+
456
+ class ChatGLMModel(nn.Module):
457
+ def __init__(self, args: ModelArgs):
458
+ super().__init__()
459
+
460
+ self.embedding = Embedding(args)
461
+ self.num_layers = args.num_layers
462
+ self.multi_query_group_num = args.multi_query_group_num
463
+
464
+ self.kv_channels = args.kv_channels
465
+ self.use_cache = args.use_cache
466
+ self.use_return_dict = False
467
+ self.output_hidden_states = False
468
+
469
+ # Rotary positional embeddings
470
+ self.seq_length = args.seq_length
471
+ rotary_dim = (
472
+ args.hidden_size // args.num_attention_heads if args.kv_channels is None else args.kv_channels
473
+ )
474
+
475
+ self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, rope_ratio=args.rope_ratio, original_impl=args.original_rope,dtype=args.torch_dtype)
476
+ self.encoder = GLMTransformer(args)
477
+ self.output_layer = nn.Linear(args.hidden_size, args.padded_vocab_size, bias=False)
478
+
479
+ self.new_position_id = None
480
+ self.is_first_forward = True
481
+
482
+ def get_input_embeddings(self):
483
+ return self.embedding.word_embeddings
484
+
485
+ def set_input_embeddings(self, value):
486
+ self.embedding.word_embeddings = value
487
+
488
+ def get_masks(self, input_ids, past_key_values, padding_mask=None):
489
+ batch_size, seq_length = input_ids.shape
490
+ full_attention_mask = mx.ones((batch_size, seq_length, seq_length), dtype=input_ids.dtype)
491
+ full_attention_mask = mx.tril(full_attention_mask)
492
+ past_length = 0
493
+ if past_key_values and past_key_values[0].keys is not None:
494
+ past_length = past_key_values[0].offset
495
+ if past_length:
496
+ full_attention_mask = mx.concatenate((mx.ones((batch_size, seq_length, past_length), dtype=input_ids.dtype),
497
+ full_attention_mask), axis=-1)
498
+ if padding_mask is not None:
499
+ full_attention_mask = full_attention_mask * mx.expand_dims(padding_mask,1)
500
+ if not past_length and padding_mask is not None:
501
+ full_attention_mask -= mx.expand_dims(padding_mask,-1) - 1
502
+ full_attention_mask = (full_attention_mask < 0.5)
503
+ full_attention_mask = mx.expand_dims(full_attention_mask,1)
504
+ return full_attention_mask
505
+
506
+ def get_position_ids(self, input_ids):
507
+ batch_size, seq_length = input_ids.shape
508
+ position_ids = mx.arange(seq_length, dtype=mx.int32)
509
+ position_ids = mx.broadcast_to(position_ids, (batch_size, seq_length))
510
+ return position_ids
511
+
512
+ def __call__(
513
+ self,
514
+ input_ids,
515
+ position_ids: Optional[mx.array] = None,
516
+ attention_mask: Optional[mx.array] = None,
517
+ full_attention_mask: Optional[mx.array] = None,
518
+ past_key_values: Optional[Tuple[Tuple[mx.array, mx.array], ...]] = None,
519
+ inputs_embeds: Optional[mx.array] = None,
520
+ use_cache: Optional[bool] = None,
521
+ ):
522
+
523
+ # prepare_inputs_for_generation
524
+ if self.new_position_id is None:
525
+ position_ids = self.get_position_ids(input_ids)
526
+ else:
527
+ position_ids = self.new_position_id
528
+
529
+ new_position_id = position_ids[..., -1:]
530
+ # print(f"== new_position_id:{new_position_id}")
531
+ new_position_id += 1
532
+ # print(f"== new_position_id:{new_position_id}")
533
+ new_position_id = mx.concatenate(
534
+ [position_ids, new_position_id], axis=-1
535
+ )
536
+ # print(f"== new_position_id:{new_position_id}")
537
+ self.new_position_id = new_position_id
538
+
539
+ if past_key_values and past_key_values[0].offset > 0: # TODO: check pre_seq
540
+ position_ids = position_ids[..., -1:]
541
+ input_ids = input_ids[:, -1:]
542
+
543
+ # print(f"== position_ids:{position_ids} input_ids:{input_ids}")
544
+ batch_size, seq_length = input_ids.shape
545
+
546
+ if inputs_embeds is None:
547
+ inputs_embeds = self.embedding(input_ids)
548
+
549
+ # Rotary positional embeddings
550
+ rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
551
+ if position_ids is not None:
552
+ rotary_pos_emb = rotary_pos_emb[position_ids]
553
+ else:
554
+ rotary_pos_emb = rotary_pos_emb[None, :seq_length]
555
+ # print(f"== rotary_pos_emb:{rotary_pos_emb.shape}")
556
+
557
+ # Run encoder.
558
+ hidden_states = self.encoder(
559
+ inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
560
+ kv_caches=past_key_values, use_cache=use_cache
561
+ )
562
+
563
+ return hidden_states
564
+
565
+
566
+ class Model(nn.Module):
567
+ def __init__(self, args: ModelArgs):
568
+ super().__init__()
569
+ self.args = args
570
+ self.model_type = args.model_type
571
+ self.transformer = ChatGLMModel(args)
572
+
573
+ def __call__(
574
+ self,
575
+ inputs: mx.array,
576
+ cache=None,
577
+ ):
578
+ out = self.transformer(inputs, None, None, None, cache, None, True)
579
+ if self.args.tie_word_embeddings:
580
+ out = self.model.embedding.as_linear(out)
581
+ else:
582
+ out = self.model.output_layer(out)
583
+ return out
584
+
585
+ def sanitize(self, weights):
586
+ # Remove unused precomputed rotary freqs
587
+ return {
588
+ k: v for k, v in weights.items() if "transformer.rotary_pos_emb.inv_freq" not in k
589
+ }
590
+ # return weights
591
+
592
+ @property
593
+ def layers(self):
594
+ return self.model.encoder.layers
595
+
596
+ @property
597
+ def head_dim(self):
598
+ return self.args.hidden_size // self.args.num_attention_heads
599
+
600
+ @property
601
+ def n_kv_heads(self):
602
+ return self.args.multi_query_group_num
603
+
604
+ @property
605
+ def model(self):
606
+ return self.transformer
607
+
608
+