Need help implementing codestral to mlx-lm
#20
by
Goekdeniz-Guelmez
- opened
Hey mistral team,
I'm a big fan of the Mamba Architecture, and have added support for the Mamba 1 architecture into apples MLX-LM package, after trying to add the Mamba 2 architecture I got on a bottleneck, so when inferencing with a origional model from state-spaces. It works perfectly fine, but trying the same with the Codestral model from this repo, it generates gibberish. Can it be possible because it doesnt use the tokenizer.v3?
Here is my code:
import math
from dataclasses import dataclass, field
from typing import Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .cache import MambaCache
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
num_heads: int
head_dim: int
vocab_size: int
hidden_size: int
state_size: int
num_hidden_layers: int
layer_norm_epsilon: float
expand: int
conv_kernel: int
n_groups: int
use_bias: bool
use_conv_bias: bool
initializer_range: float
residual_in_fp32: bool
rescale_prenorm_residual: bool
rms_norm: bool
chunk_size: int
tie_word_embeddings: bool
dim: int = None
intermediate_size: int = None
time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf")))
time_step_rank: Union[int, str] = "auto"
time_step_min: float = 0.001
time_step_max: float = 0.1
time_step_floor: float = 1e-4
A_init_min: float = 1.0
A_init_max: float = 16.0
def __post_init__(self):
if not hasattr(self, "intermediate_size"):
self.intermediate_size = int(self.expand * self.hidden_size)
if not hasattr(self, "hidden_size"):
self.hidden_size = self.dim
if not hasattr(self, "head_dim"):
self.head_dim = self.hidden_size // self.num_heads
if self.time_step_rank == "auto":
self.time_step_rank = math.ceil(self.hidden_size / 16)
class MambaRMSNormGated(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = mx.ones((hidden_size,))
self.variance_epsilon = eps
def __call__(self, hidden_states, gate=None):
if gate is not None:
hidden_states = hidden_states * nn.silu(gate)
variance = mx.mean(hidden_states ** 2, axis=-1, keepdims=True)
hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states
def silu(x):
return x * mx.sigmoid(x)
class DepthWiseConv1d(nn.Module):
def __init__(self, channels, kernel_size, bias=True, padding=0):
super().__init__()
self.channels = channels
self.kernel_size = kernel_size
self.padding = padding
self.weight = mx.random.normal((channels, kernel_size, 1))
self.bias = mx.zeros((channels,)) if bias else None
def __call__(self, x, cache=None):
B, L, C = x.shape
_, K, _ = self.weight.shape
if cache is not None:
x = mx.concatenate([cache, x], axis=1)
else:
x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)])
y = mx.conv_general(x, self.weight, groups=C)
y = y + self.bias
return y, x[:, -K + 1:, :]
class Mamba2Block(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
# Same dimensions as before
self.d_model = args.hidden_size
self.d_state = args.state_size
self.d_conv = args.conv_kernel
self.expand = args.expand
self.d_inner = int(self.expand * self.d_model)
self.n_groups = args.n_groups
self.n_heads = args.num_heads
self.d_head = self.d_inner // self.n_heads
# Input projection
d_in_proj = 2 * self.d_inner + 2 * self.n_groups * self.d_state + self.n_heads
self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=args.use_bias)
# Improved initialization of dt
dt = mx.exp(
mx.random.uniform(
low=math.log(args.time_step_min),
high=math.log(args.time_step_max),
shape=(self.n_heads,)
)
)
dt = mx.clip(dt, args.time_step_floor, float('inf'))
inv_dt = dt + mx.log(-mx.exp(-dt) + 1) # Inverse softplus
self.dt_bias = mx.array(inv_dt)
# Improved A initialization
A = mx.random.uniform(
low=args.A_init_min,
high=args.A_init_max,
shape=(self.n_heads,)
)
self.A_log = mx.log(A)
# Same D initialization
self.D = mx.random.normal((self.n_heads,)) * args.initializer_range
# Convolution with proper initialization
self.conv1d = DepthWiseConv1d(
channels=self.d_inner + 2 * self.n_groups * self.d_state,
kernel_size=self.d_conv,
bias=args.use_conv_bias,
padding=self.d_conv-1
)
# Output projections
self.norm = MambaRMSNormGated(self.d_inner, eps=args.layer_norm_epsilon)
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=args.use_bias)
def __call__(self, u: mx.array, cache=None):
batch_size, seq_len, _ = u.shape
# Project input
zxbcdt = self.in_proj(u) # (B, L, d_in_proj)
# Split projections
z = zxbcdt[..., :self.d_inner]
xBC = zxbcdt[..., self.d_inner:self.d_inner + (self.d_inner + 2 * self.n_groups * self.d_state)]
dt = zxbcdt[..., -self.n_heads:]
# Process time steps - simplified to match PyTorch
dt = nn.softplus(dt + self.dt_bias) # (B, L, nheads)
xBC, conv_state = self.conv1d(xBC, cache[0] if cache else None) # (B, L, self.d_inner + 2 * ngroups * d_state)
if cache is not None:
cache[0] = conv_state
xBC = silu(xBC)
xBC = xBC[:, :seq_len, :]
# Split conv output and reshape
x = xBC[..., :self.d_inner]
B = mx.reshape(xBC[..., self.d_inner:self.d_inner + self.n_groups * self.d_state], (batch_size, seq_len, self.n_groups, -1))
C = mx.reshape(xBC[..., -self.n_groups * self.d_state:], (batch_size, seq_len, self.n_groups, -1))
# Reshape for SSM processing
x = mx.reshape(x, (batch_size, seq_len, self.n_heads, self.d_head))
# Initialize state
if cache and cache[1] is not None:
# State initialization might need proper scaling
prev_state = cache[1]
else:
prev_state = mx.zeros((batch_size, self.n_heads, self.d_head, self.d_state))
# Compute dA - simplified to match PyTorch
A = -mx.exp(self.A_log)
dt = mx.reshape(dt, (batch_size, seq_len, self.n_heads))
dA = mx.exp(dt * mx.expand_dims(A, axis=(0, 1)))
# Process sequence
next_state = prev_state
outputs = []
for t in range(seq_len):
# Get current step tensors
xt = x[:, t] # [batch, n_heads, d_head]
Bt = B[:, t] # [batch, n_heads, d_state]
Ct = C[:, t] # [batch, n_heads, d_state]
dAt = dA[:, t] # [batch, n_heads]
# Compute dBx using einsum to match PyTorch
dBx = mx.einsum('bh,bgd,bhp->bhpd', dAt, Bt, xt) # dAt: (b,h), Bt: (b,g,d), xt: (b,h,p) -> (b,h,p,d)
# Update state
next_state = next_state * mx.expand_dims(dAt, axis=(-1, -2)) + dBx
# Compute output with groups
yt = mx.einsum('bhpd,bgd->bhp', next_state, Ct)
yt = yt + xt * mx.expand_dims(self.D, -1)
# Reshape and normalize
yt = mx.reshape(yt, (batch_size, 1, self.d_inner))
yt = self.norm(yt, z[:, t:t+1])
outputs.append(self.out_proj(yt))
# Update cache
if cache is not None:
cache[1] = next_state
return mx.concatenate(outputs, axis=1)
class ResidualBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.residual_in_fp32 = args.residual_in_fp32
self.mixer = Mamba2Block(args)
self.norm = nn.RMSNorm(args.hidden_size)
def __call__(self, x: mx.array, cache):
if self.residual_in_fp32:
x = x.astype(mx.float32)
normed = self.norm(x)
output = self.mixer(normed, cache)
return output + x
class Mamba2(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)]
self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
def __call__(self, x: mx.array, cache):
x = self.embeddings(x)
if cache is None:
cache = [None] * len(self.layers)
hidden = x
for layer, c in zip(self.layers, cache):
hidden = layer(hidden, c)
return self.norm_f(hidden)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.backbone = Mamba2(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(self, inputs: mx.array, cache=None):
hidden = self.backbone(inputs, cache)
if self.args.tie_word_embeddings:
logits = self.backbone.embeddings.as_linear(hidden)
else:
logits = self.lm_head(hidden)
return logits
def sanitize(self, weights):
for k, v in weights.items():
if "conv1d.weight" in k and v.shape[-1] != 1:
weights[k] = v.moveaxis(2, 1)
return weights
def make_cache(self):
return [MambaCache() for _ in range(len(self.layers))]
@property
def layers(self):
return self.backbone.layers
here is a generation from Codestral:
Prompt:
python -m mlx_lm.generate --model /Users/gokdenizgulmez/Desktop/Mamba-Codestral-7B-v0.1-4bit --prompt "Write me a function that computes fibonacci in Rust" -m 120
Output:
==========
U
:%.*
❒
Станов
:%.*ICENSE Станов Станов❒❒ biologie≮ustr≮❒ Станов❒][<.❒ Станов Squad släktet Arbitro Станов Станов Становништво][< Floren Станов.CLUDING][<.Extra
ExtraFree Станов enthusExtra
.][<.][<amment Станов enthus."— släktet⌁][< /******/American.
biologie släktetntil /******/.
biologie.
AMD.][< Dy. companICENSEml Syd.
DJCLUDING
c
Net
Net
ityEngine
ityEngine
ityEngine
==========
Goekdeniz-Guelmez
changed discussion status to
closed
Goekdeniz-Guelmez
changed discussion status to
open
the new Code:
import math
from dataclasses import dataclass, field
from typing import Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .cache import MambaCache
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
num_heads: int
head_dim: int
vocab_size: int
hidden_size: int
state_size: int
num_hidden_layers: int
layer_norm_epsilon: float
expand: int
conv_kernel: int
n_groups: int
use_bias: bool
use_conv_bias: bool
initializer_range: float
residual_in_fp32: bool
chunk_size: int
tie_word_embeddings: bool
time_step_limit: Tuple[float, float]
time_step_rank: Union[int, str]
time_step_min: float
time_step_max: float
time_step_floor: float
norm_before_gate: bool = True
def __post_init__(self):
if not hasattr(self, "intermediate_size"):
self.intermediate_size = int(self.expand * self.hidden_size)
if not hasattr(self, "head_dim"):
self.head_dim = self.hidden_size // self.num_heads
if self.time_step_rank == "auto":
self.time_step_rank = math.ceil(self.hidden_size / 16)
class DepthWiseConv1d(nn.Module):
def __init__(self, channels, kernel_size, bias=True, padding=0):
super().__init__()
self.channels = channels
self.kernel_size = kernel_size
self.padding = padding
self.weight = mx.random.normal((channels, kernel_size, 1))
self.bias = mx.zeros((channels,)) if bias else None
def __call__(self, x, cache=None):
B, L, C = x.shape
_, K, _ = self.weight.shape
if cache is not None:
x = mx.concatenate([cache, x], axis=1)
else:
x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)])
y = mx.conv_general(x, self.weight, groups=C)
y = y + self.bias
return y, x[:, -K + 1:, :]
def segsum(x):
# x shape: [b, h, l]
b, h, l = x.shape
# Create a lower triangular mask
mask = mx.tril(mx.ones((l, l)), 0)
# Reshape x for broadcasting
x_expanded = x.reshape(b, h, l, 1) # [b, h, l, 1]
# Apply mask
masked_x = x_expanded * mask.reshape(1, 1, l, l) # [b, h, l, l]
# Sum along the appropriate dimension
result = mx.sum(masked_x, axis=2) # [b, h, l]
return result.reshape(b, h, 1, l) # Return in shape [b, h, 1, l]
def ssd_forward_attn(
x: mx.array,
dt: mx.array,
A: mx.array,
B: mx.array,
C: mx.array,
D: mx.array,
dt_bias: mx.array,
dt_min: float,
dt_max: float,
prev_state=None,
) -> Tuple[mx.array, mx.array]:
b, l, h, dh = x.shape
_, _, g, _ = B.shape
# Process dt
if dt_bias is not None:
dt = dt + dt_bias.reshape(1, 1, -1)
dt = nn.softplus(dt)
dt = mx.clip(dt, a_min=dt_min, a_max=dt_max)
# Reshape tensors
B_reshaped = mx.swapaxes(mx.swapaxes(B, 1, 3), 1, 2)
C_reshaped = mx.swapaxes(C, 1, 2)
# Compute CB
CB = C_reshaped @ B_reshaped
CB = mx.repeat(CB, repeats=h // g, axis=1)
# Compute decay terms
dtA = dt * A.reshape(1, 1, -1)
dtA = mx.swapaxes(dtA, 1, 2)
decay = mx.exp(-mx.abs(segsum(dtA)))
# Create attention matrix
surrogate_attention_matrix = mx.tril(CB * decay, 0)
# Apply attention
dtx = dt.reshape(b, l, h, 1) * x
y = surrogate_attention_matrix @ dtx.swapaxes(1, 2)
y = mx.swapaxes(y, 1, 2)
# Compute next state
decay_last = decay[:, :, -1, :].reshape(b, h, l).swapaxes(1, 2).reshape(b, l, h, 1)
B_for_state = mx.repeat(B_reshaped, h // g, axis=1).swapaxes(2, 3)
dtxdecay = dtx * decay_last
dtxdecay = dtxdecay.swapaxes(1, 2).swapaxes(2, 3)
# Calculate new state contribution
new_state_contribution = dtxdecay @ B_for_state
# Initialize or update state
if prev_state is not None:
decayed_prev_state = prev_state * decay[:, :, -1, :].reshape(b, h, 1, 1)
next_state = decayed_prev_state + new_state_contribution
else:
next_state = new_state_contribution
# Add skip connection if D is provided
if D is not None:
y += x * D.reshape(1, 1, h, 1)
# Reshape output
y = y.reshape(b, l, h * dh)
return y, next_state
class Mamba2Block(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.d_model = args.hidden_size
self.d_state = args.state_size
self.d_conv = args.conv_kernel
self.expand = args.expand
self.d_inner = int(self.expand * self.d_model)
self.n_groups = args.n_groups
self.n_heads = args.num_heads
self.d_head = self.d_inner // self.n_heads
self.chunk_size = args.chunk_size
d_in_proj = 2 * self.d_inner + 2 * self.n_groups * self.d_state + self.n_heads
self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=args.use_bias)
self.dt_bias = mx.random.normal((self.n_heads,)) * args.initializer_range
self.A_log = mx.random.normal((self.n_heads,)) * args.initializer_range
self.D = mx.random.normal((self.n_heads,)) * args.initializer_range
self.conv1d = DepthWiseConv1d(
channels=self.d_inner + 2 * self.n_groups * self.d_state,
kernel_size=self.d_conv,
bias=args.use_conv_bias,
padding=self.d_conv-1
)
self.norm = nn.RMSNorm(self.d_inner, eps=args.layer_norm_epsilon)
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=args.use_bias)
def __call__(self, u: mx.array, cache=None):
batch_size, seq_len, _ = u.shape
# Get cache states
conv_state = None
ssm_state = None
if cache is not None:
conv_state = cache[0] # Access using index
ssm_state = cache[1] # Access using index
zxBCdt = self.in_proj(u)
z, xBC, dt = mx.split(
zxBCdt,
[self.d_inner, 2 * self.d_inner + 2 * self.n_groups * self.d_state],
axis=-1
)
xBC, conv_state = self.conv1d(xBC, conv_state)
xBC = xBC * mx.sigmoid(xBC)
xBC = xBC[:, :seq_len, :]
x, B, C = mx.split(
xBC,
[self.d_inner, self.d_inner + self.d_state * self.n_groups],
axis=-1
)
x = mx.reshape(x, (batch_size, seq_len, self.n_heads, self.d_head))
B = mx.reshape(B, (batch_size, seq_len, self.n_groups, -1))
C = mx.reshape(C, (batch_size, seq_len, self.n_groups, -1))
y, ssm_state = ssd_forward_attn(
x=x,
dt=dt,
A=-mx.exp(self.A_log),
B=B,
C=C,
D=self.D,
dt_bias=self.dt_bias,
dt_min=self.args.time_step_min,
dt_max=self.args.time_step_max,
prev_state=ssm_state
)
if self.args.norm_before_gate:
y = self.norm(y)
y = y * nn.silu(z)
else:
y = y * nn.silu(z)
y = self.norm(y)
y = self.out_proj(y)
# Update cache using indexing
if cache is not None:
cache[0] = conv_state
cache[1] = ssm_state
return y
class ResidualBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.residual_in_fp32 = args.residual_in_fp32
self.mixer = Mamba2Block(args)
self.norm = nn.RMSNorm(args.hidden_size)
def __call__(self, x: mx.array, cache):
if self.residual_in_fp32:
x = x.astype(mx.float32)
normed = self.norm(x)
output = self.mixer(normed, cache)
return output + x
class Mamba2(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)]
self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
def __call__(self, x: mx.array, cache):
x = self.embeddings(x)
if cache is None:
cache = [None] * len(self.layers)
hidden = x
for layer, c in zip(self.layers, cache):
hidden = layer(hidden, c)
return self.norm_f(hidden)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.backbone = Mamba2(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(self, inputs: mx.array, cache=None):
hidden = self.backbone(inputs, cache)
if self.args.tie_word_embeddings:
logits = self.backbone.embeddings.as_linear(hidden)
else:
logits = self.lm_head(hidden)
return logits
def make_cache(self):
return [MambaCache() for _ in range(len(self.layers))]
@property
def layers(self):
return self.backbone.layers
Example generation:
python -m mlx_lm.generate --model mlx-community/Mamba-Codestral-7B-v0.1-4bit --prompt "<s>[INST] Write me a function that computes fibonacci in Rust [/INST] " -m 64 --extra-eos-token "</s>"
Fetching 7 files: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 90061.74it/s]
==========
100% natural minerals, vitamins, minerals, vitamins, minerals, vitamins, minerals, vitamins, minerals, vitamins, minerals, vitamins, minerals, vitamins, minerals, vitamins, minerals,
==========
Prompt: 19 tokens, 37.486 tokens-per-sec
Generation: 64 tokens, 5.202 tokens-per-sec
Peak memory: 4.679 GB
Is the prompt format also correct?