|
from dataclasses import dataclass |
|
from typing import Dict, Optional, Tuple, Union |
|
import math |
|
|
|
import mlx.core as mx |
|
import mlx.nn as nn |
|
|
|
from .base import BaseModelArgs |
|
|
|
|
|
@dataclass |
|
class ModelArgs(BaseModelArgs): |
|
model_type: str |
|
add_bias_linear: bool = False |
|
add_qkv_bias: bool = True |
|
apply_query_key_layer_scaling: bool = True |
|
apply_residual_connection_post_layernorm: bool = False |
|
attention_dropout: float = 0.0 |
|
attention_softmax_in_fp32: bool = True |
|
bias_dropout_fusion: bool = True |
|
ffn_hidden_size: int = 13696 |
|
fp32_residual_connection: bool = False |
|
hidden_dropout: float = 0.0 |
|
hidden_size: int = 4096 |
|
kv_channels: int = 128 |
|
layernorm_epsilon: float = 1.5625e-07 |
|
multi_query_attention: bool = True |
|
multi_query_group_num: int = 2 |
|
num_attention_heads: int = 32 |
|
num_hidden_layers: int = 40 |
|
num_layers: int = 40 |
|
rope_ratio: int = 500 |
|
original_rope: bool = True |
|
padded_vocab_size: int = 151552 |
|
post_layer_norm: bool = True |
|
rmsnorm: bool = True |
|
seq_length: int = 131072 |
|
use_cache: bool = True |
|
torch_dtype: str = "bfloat16" |
|
tie_word_embeddings: bool = False |
|
|
|
def __post_init__(self): |
|
pass |
|
|
|
class RotaryEmbedding(nn.Module): |
|
def __init__(self, dim, rope_ratio=1, original_impl=False, dtype=None): |
|
super().__init__() |
|
|
|
|
|
|
|
self.inv_freq_type = dtype |
|
self.dim = dim |
|
self.original_impl = original_impl |
|
self.rope_ratio = rope_ratio |
|
|
|
def forward_impl( |
|
self, seq_len: int, n_elem: int, dtype: mx.Dtype, base: int = 10000 |
|
): |
|
"""Enhanced Transformer with Rotary Position Embedding. |
|
Derived from:https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ |
|
transformers/rope/__init__.py. MIT License: |
|
https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. |
|
""" |
|
|
|
base = base * self.rope_ratio |
|
theta = 1.0 / (base ** (mx.arange(0, n_elem, 2, dtype=mx.float16) / n_elem)) |
|
|
|
|
|
seq_idx = mx.arange(seq_len, dtype=mx.float16) |
|
|
|
|
|
idx_theta = mx.outer(seq_idx, theta).astype(mx.float16) |
|
|
|
cache = mx.stack([mx.cos(idx_theta), mx.sin(idx_theta)], axis=-1) |
|
|
|
|
|
if dtype in (mx.float16, mx.bfloat16, mx.int8): |
|
cache = cache.astype(mx.bfloat16) if dtype == mx.bfloat16 else cache.astype(mx.float16) |
|
return cache |
|
|
|
def __call__(self, max_seq_len, offset=0): |
|
return self.forward_impl( |
|
max_seq_len, self.dim, dtype=self.inv_freq_type, |
|
) |
|
|
|
def apply_rotary_pos_emb(x: mx.array, rope_cache: mx.array) -> mx.array: |
|
|
|
b, np, sq, hn = x.shape[0], x.shape[1], x.shape[2], x.shape[3] |
|
rot_dim = rope_cache.shape[-2] * 2 |
|
x, x_pass = x[..., :rot_dim], x[..., rot_dim:] |
|
|
|
rope_cache = rope_cache[:, :sq] |
|
xshaped = x.reshape(b, np, sq, rot_dim // 2, 2) |
|
rope_cache = rope_cache.reshape(-1, 1, sq, xshaped.shape[3], 2) |
|
x_out2 = mx.stack( |
|
[ |
|
xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1], |
|
xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1], |
|
], |
|
-1, |
|
) |
|
x_out2 = x_out2.flatten(3) |
|
return mx.concatenate((x_out2, x_pass), axis=-1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CoreAttention(nn.Module): |
|
def __init__(self, args: ModelArgs, layer_number): |
|
super().__init__() |
|
|
|
self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling |
|
self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32 |
|
if self.apply_query_key_layer_scaling: |
|
self.attention_softmax_in_fp32 = True |
|
self.layer_number = max(1, layer_number) |
|
|
|
projection_size = args.kv_channels * args.num_attention_heads |
|
|
|
|
|
self.hidden_size_per_partition = projection_size |
|
self.hidden_size_per_attention_head = projection_size // args.num_attention_heads |
|
self.num_attention_heads_per_partition = args.num_attention_heads |
|
|
|
coeff = None |
|
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) |
|
if self.apply_query_key_layer_scaling: |
|
coeff = self.layer_number |
|
self.norm_factor *= coeff |
|
self.coeff = coeff |
|
|
|
self.attention_dropout = nn.Dropout(args.attention_dropout) |
|
|
|
def __call__(self, query_layer, key_layer, value_layer, attention_mask): |
|
|
|
scale_factor = query_layer.shape[-1] ** -0.5 |
|
|
|
|
|
if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: |
|
attention_mask = nn.MultiHeadAttention.create_additive_causal_mask(query_layer.shape[2]).astype(query_layer.dtype) |
|
context_layer = mx.fast.scaled_dot_product_attention(query_layer, key_layer, value_layer, scale=scale_factor,mask=attention_mask) |
|
else: |
|
if attention_mask is not None: |
|
attention_mask = ~attention_mask |
|
context_layer = mx.fast.scaled_dot_product_attention(query_layer, key_layer, value_layer, scale=scale_factor, mask=attention_mask) |
|
context_layer = context_layer.transpose((0,2,1,3)) |
|
new_context_layer_shape = context_layer.shape[:-2] + (self.hidden_size_per_partition,) |
|
context_layer = context_layer.reshape(*new_context_layer_shape) |
|
|
|
return context_layer |
|
|
|
class SelfAttention(nn.Module): |
|
def __init__(self, args: ModelArgs, layer_number): |
|
super(SelfAttention, self).__init__() |
|
self.layer_number = max(1, layer_number) |
|
|
|
self.projection_size = args.kv_channels * args.num_attention_heads |
|
|
|
|
|
self.hidden_size_per_attention_head = self.projection_size // args.num_attention_heads |
|
self.num_attention_heads_per_partition = args.num_attention_heads |
|
self.multi_query_attention = args.multi_query_attention |
|
self.qkv_hidden_size = 3 * self.projection_size |
|
if self.multi_query_attention: |
|
self.num_multi_query_groups_per_partition = args.multi_query_group_num |
|
self.qkv_hidden_size = ( |
|
self.projection_size + 2 * self.hidden_size_per_attention_head * args.multi_query_group_num |
|
) |
|
self.query_key_value = nn.Linear(args.hidden_size, self.qkv_hidden_size, |
|
bias=args.add_bias_linear or args.add_qkv_bias) |
|
|
|
self.core_attention = CoreAttention(args, self.layer_number) |
|
|
|
|
|
self.dense = nn.Linear(self.projection_size, args.hidden_size, bias=args.add_bias_linear) |
|
|
|
def __call__(self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mixed_x_layer = self.query_key_value(hidden_states) |
|
|
|
if self.multi_query_attention: |
|
q_k_v_len = [ |
|
self.num_attention_heads_per_partition * self.hidden_size_per_attention_head, |
|
self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, |
|
self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, |
|
] |
|
mixs = mixed_x_layer.split([ |
|
q_k_v_len[0], |
|
q_k_v_len[0]+q_k_v_len[1], |
|
q_k_v_len[0]+q_k_v_len[1]+q_k_v_len[2], |
|
], |
|
axis=-1, |
|
) |
|
|
|
query_layer, key_layer, value_layer = mixs[0], mixs[1], mixs[2] |
|
query_layer = query_layer.reshape( |
|
query_layer.shape[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) |
|
) |
|
key_layer = key_layer.reshape( key_layer.shape[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)) |
|
value_layer = value_layer.reshape( |
|
value_layer.shape[:-1] |
|
+ (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head) |
|
) |
|
else: |
|
new_tensor_shape = mixed_x_layer.shape[:-1] + \ |
|
(self.num_attention_heads_per_partition, |
|
3 * self.hidden_size_per_attention_head) |
|
mixed_x_layer = mixed_x_layer.reshape(*new_tensor_shape) |
|
|
|
|
|
(query_layer, key_layer, value_layer) = mx.split_along_last_dim(mixed_x_layer, 3) |
|
|
|
|
|
query_layer, key_layer, value_layer = [k.transpose((0,2,1,3)) for k in [query_layer, key_layer, value_layer]] |
|
|
|
|
|
if rotary_pos_emb is not None: |
|
query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb) |
|
key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb) |
|
|
|
|
|
|
|
if use_cache: |
|
key_layer, value_layer = kv_cache.update_and_fetch(key_layer, value_layer) |
|
else: |
|
kv_cache = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) |
|
|
|
|
|
|
|
|
|
|
|
output = self.dense(context_layer) |
|
|
|
return output |
|
|
|
class MLP(nn.Module): |
|
def __init__(self, args: ModelArgs): |
|
super().__init__() |
|
|
|
self.add_bias = args.add_bias_linear |
|
|
|
|
|
self.dense_h_to_4h = nn.Linear( |
|
args.hidden_size, |
|
args.ffn_hidden_size * 2, |
|
bias=self.add_bias, |
|
) |
|
|
|
def swiglu(x): |
|
x = mx.split(x, 2, axis=-1) |
|
return nn.silu(x[0]) * x[1] |
|
|
|
self.activation_func = swiglu |
|
|
|
|
|
self.dense_4h_to_h = nn.Linear( |
|
args.ffn_hidden_size, |
|
args.hidden_size, |
|
bias=self.add_bias, |
|
) |
|
|
|
def __call__(self, hidden_states): |
|
|
|
intermediate_parallel = self.dense_h_to_4h(hidden_states) |
|
intermediate_parallel = self.activation_func(intermediate_parallel) |
|
|
|
output = self.dense_4h_to_h(intermediate_parallel) |
|
return output |
|
|
|
|
|
class GLMBlock(nn.Module): |
|
def __init__(self, args: ModelArgs, layer_number): |
|
super(GLMBlock, self).__init__() |
|
self.layer_number = layer_number |
|
|
|
self.apply_residual_connection_post_layernorm = args.apply_residual_connection_post_layernorm |
|
|
|
self.fp32_residual_connection = args.fp32_residual_connection |
|
|
|
LayerNormFunc = nn.RMSNorm if args.rmsnorm else nn.LayerNorm |
|
|
|
self.input_layernorm = LayerNormFunc(args.hidden_size, eps=args.layernorm_epsilon) |
|
|
|
|
|
self.self_attention = SelfAttention(args, layer_number) |
|
self.hidden_dropout = args.hidden_dropout |
|
|
|
self.dropout = nn.Dropout(self.hidden_dropout) |
|
|
|
|
|
self.post_attention_layernorm = LayerNormFunc(args.hidden_size, eps=args.layernorm_epsilon) |
|
|
|
|
|
self.mlp = MLP(args) |
|
|
|
def __call__( |
|
self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True, |
|
): |
|
|
|
|
|
|
|
layernorm_output = self.input_layernorm(hidden_states) |
|
|
|
attention_output = self.self_attention( |
|
layernorm_output, |
|
attention_mask, |
|
rotary_pos_emb, |
|
kv_cache=kv_cache, |
|
use_cache=use_cache |
|
) |
|
|
|
|
|
if self.apply_residual_connection_post_layernorm: |
|
residual = layernorm_output |
|
else: |
|
residual = hidden_states |
|
|
|
layernorm_input = self.dropout(attention_output) |
|
layernorm_input = residual + layernorm_input |
|
|
|
|
|
layernorm_output = self.post_attention_layernorm(layernorm_input) |
|
|
|
|
|
mlp_output = self.mlp(layernorm_output) |
|
|
|
|
|
if self.apply_residual_connection_post_layernorm: |
|
residual = layernorm_output |
|
else: |
|
residual = layernorm_input |
|
|
|
output = self.dropout(mlp_output) |
|
output = residual + output |
|
|
|
return output |
|
|
|
class GLMTransformer(nn.Module): |
|
def __init__(self, args: ModelArgs): |
|
super().__init__() |
|
|
|
self.fp32_residual_connection = args.fp32_residual_connection |
|
self.post_layer_norm = args.post_layer_norm |
|
|
|
|
|
self.num_layers = args.num_layers |
|
|
|
|
|
def build_layer(layer_number): |
|
return GLMBlock(args, layer_number) |
|
|
|
self.layers = [build_layer(i + 1) for i in range(self.num_layers)] |
|
|
|
if self.post_layer_norm: |
|
LayerNormFunc = nn.RMSNorm if args.rmsnorm else nn.LayerNorm |
|
|
|
self.final_layernorm = LayerNormFunc(args.hidden_size, eps=args.layernorm_epsilon) |
|
|
|
self.gradient_checkpointing = False |
|
|
|
def _get_layer(self, layer_number): |
|
return self.layers[layer_number] |
|
|
|
def __call__( |
|
self, hidden_states, attention_mask, rotary_pos_emb, kv_caches=None, |
|
use_cache: Optional[bool] = True, |
|
): |
|
if not kv_caches: |
|
kv_caches = [None for _ in range(self.num_layers)] |
|
|
|
for index in range(self.num_layers): |
|
layer = self._get_layer(index) |
|
layer_ret = layer( |
|
hidden_states, |
|
attention_mask, |
|
rotary_pos_emb, |
|
kv_cache=kv_caches[index], |
|
use_cache=use_cache |
|
) |
|
hidden_states = layer_ret |
|
|
|
|
|
if self.post_layer_norm: |
|
hidden_states = self.final_layernorm(hidden_states) |
|
|
|
return hidden_states |
|
|
|
class Embedding(nn.Module): |
|
def __init__(self, args: ModelArgs): |
|
super().__init__() |
|
|
|
self.hidden_size = args.hidden_size |
|
|
|
self.word_embeddings = nn.Embedding( |
|
args.padded_vocab_size, |
|
self.hidden_size, |
|
) |
|
self.fp32_residual_connection = args.fp32_residual_connection |
|
|
|
def __call__(self, input_ids): |
|
|
|
words_embeddings = self.word_embeddings(input_ids) |
|
embeddings = words_embeddings |
|
|
|
if self.fp32_residual_connection: |
|
embeddings = embeddings.float() |
|
return embeddings |
|
|
|
|
|
class ChatGLMModel(nn.Module): |
|
def __init__(self, args: ModelArgs): |
|
super().__init__() |
|
|
|
self.embedding = Embedding(args) |
|
self.num_layers = args.num_layers |
|
self.multi_query_group_num = args.multi_query_group_num |
|
|
|
self.kv_channels = args.kv_channels |
|
self.use_cache = args.use_cache |
|
self.use_return_dict = False |
|
self.output_hidden_states = False |
|
|
|
|
|
self.seq_length = args.seq_length |
|
rotary_dim = ( |
|
args.hidden_size // args.num_attention_heads if args.kv_channels is None else args.kv_channels |
|
) |
|
|
|
self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, rope_ratio=args.rope_ratio, original_impl=args.original_rope,dtype=args.torch_dtype) |
|
self.encoder = GLMTransformer(args) |
|
self.output_layer = nn.Linear(args.hidden_size, args.padded_vocab_size, bias=False) |
|
|
|
self.new_position_id = None |
|
self.is_first_forward = True |
|
|
|
def get_input_embeddings(self): |
|
return self.embedding.word_embeddings |
|
|
|
def set_input_embeddings(self, value): |
|
self.embedding.word_embeddings = value |
|
|
|
def get_masks(self, input_ids, past_key_values, padding_mask=None): |
|
batch_size, seq_length = input_ids.shape |
|
full_attention_mask = mx.ones((batch_size, seq_length, seq_length), dtype=input_ids.dtype) |
|
full_attention_mask = mx.tril(full_attention_mask) |
|
past_length = 0 |
|
if past_key_values and past_key_values[0].keys is not None: |
|
past_length = past_key_values[0].offset |
|
if past_length: |
|
full_attention_mask = mx.concatenate((mx.ones((batch_size, seq_length, past_length), dtype=input_ids.dtype), |
|
full_attention_mask), axis=-1) |
|
if padding_mask is not None: |
|
full_attention_mask = full_attention_mask * mx.expand_dims(padding_mask,1) |
|
if not past_length and padding_mask is not None: |
|
full_attention_mask -= mx.expand_dims(padding_mask,-1) - 1 |
|
full_attention_mask = (full_attention_mask < 0.5) |
|
full_attention_mask = mx.expand_dims(full_attention_mask,1) |
|
return full_attention_mask |
|
|
|
def get_position_ids(self, input_ids): |
|
batch_size, seq_length = input_ids.shape |
|
position_ids = mx.arange(seq_length, dtype=mx.int32) |
|
position_ids = mx.broadcast_to(position_ids, (batch_size, seq_length)) |
|
return position_ids |
|
|
|
def __call__( |
|
self, |
|
input_ids, |
|
position_ids: Optional[mx.array] = None, |
|
attention_mask: Optional[mx.array] = None, |
|
full_attention_mask: Optional[mx.array] = None, |
|
past_key_values: Optional[Tuple[Tuple[mx.array, mx.array], ...]] = None, |
|
inputs_embeds: Optional[mx.array] = None, |
|
use_cache: Optional[bool] = None, |
|
): |
|
|
|
|
|
if self.new_position_id is None: |
|
position_ids = self.get_position_ids(input_ids) |
|
else: |
|
position_ids = self.new_position_id |
|
|
|
new_position_id = position_ids[..., -1:] |
|
|
|
new_position_id += 1 |
|
|
|
new_position_id = mx.concatenate( |
|
[position_ids, new_position_id], axis=-1 |
|
) |
|
|
|
self.new_position_id = new_position_id |
|
|
|
if past_key_values and past_key_values[0].offset > 0: |
|
position_ids = position_ids[..., -1:] |
|
input_ids = input_ids[:, -1:] |
|
|
|
|
|
batch_size, seq_length = input_ids.shape |
|
|
|
if inputs_embeds is None: |
|
inputs_embeds = self.embedding(input_ids) |
|
|
|
|
|
rotary_pos_emb = self.rotary_pos_emb(self.seq_length) |
|
if position_ids is not None: |
|
rotary_pos_emb = rotary_pos_emb[position_ids] |
|
else: |
|
rotary_pos_emb = rotary_pos_emb[None, :seq_length] |
|
|
|
|
|
|
|
hidden_states = self.encoder( |
|
inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb, |
|
kv_caches=past_key_values, use_cache=use_cache |
|
) |
|
|
|
return hidden_states |
|
|
|
|
|
class Model(nn.Module): |
|
def __init__(self, args: ModelArgs): |
|
super().__init__() |
|
self.args = args |
|
self.model_type = args.model_type |
|
self.transformer = ChatGLMModel(args) |
|
|
|
def __call__( |
|
self, |
|
inputs: mx.array, |
|
cache=None, |
|
): |
|
out = self.transformer(inputs, None, None, None, cache, None, True) |
|
if self.args.tie_word_embeddings: |
|
out = self.model.embedding.as_linear(out) |
|
else: |
|
out = self.model.output_layer(out) |
|
return out |
|
|
|
def sanitize(self, weights): |
|
|
|
return { |
|
k: v for k, v in weights.items() if "transformer.rotary_pos_emb.inv_freq" not in k |
|
} |
|
|
|
|
|
@property |
|
def layers(self): |
|
return self.model.encoder.layers |
|
|
|
@property |
|
def head_dim(self): |
|
return self.args.hidden_size // self.args.num_attention_heads |
|
|
|
@property |
|
def n_kv_heads(self): |
|
return self.args.multi_query_group_num |
|
|
|
@property |
|
def model(self): |
|
return self.transformer |
|
|
|
|
|
|