Need help implementing codestral to mlx-lm

#20
by Goekdeniz-Guelmez - opened

Hey mistral team,
I'm a big fan of the Mamba Architecture, and have added support for the Mamba 1 architecture into apples MLX-LM package, after trying to add the Mamba 2 architecture I got on a bottleneck, so when inferencing with a origional model from state-spaces. It works perfectly fine, but trying the same with the Codestral model from this repo, it generates gibberish. Can it be possible because it doesnt use the tokenizer.v3?

Here is my code:

import math
from dataclasses import dataclass, field
from typing import Tuple, Union
import mlx.core as mx
import mlx.nn as nn

from .base import BaseModelArgs
from .cache import MambaCache

@dataclass
class ModelArgs(BaseModelArgs):
    model_type: str
    num_heads: int
    head_dim: int
    vocab_size: int
    hidden_size: int
    state_size: int
    num_hidden_layers: int
    layer_norm_epsilon: float
    expand: int
    conv_kernel: int
    n_groups: int
    use_bias: bool
    use_conv_bias: bool
    initializer_range: float
    residual_in_fp32: bool
    rescale_prenorm_residual: bool
    rms_norm: bool
    chunk_size: int
    tie_word_embeddings: bool
    dim: int = None
    intermediate_size: int = None
    time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf")))
    time_step_rank: Union[int, str] = "auto"
    time_step_min: float = 0.001
    time_step_max: float = 0.1
    time_step_floor: float = 1e-4
    A_init_min: float = 1.0
    A_init_max: float = 16.0

    def __post_init__(self):
        if not hasattr(self, "intermediate_size"):
            self.intermediate_size = int(self.expand * self.hidden_size)
        if not hasattr(self, "hidden_size"):
            self.hidden_size = self.dim
        if not hasattr(self, "head_dim"):
            self.head_dim = self.hidden_size // self.num_heads
        if self.time_step_rank == "auto":
            self.time_step_rank = math.ceil(self.hidden_size / 16)


class MambaRMSNormGated(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        super().__init__()
        self.weight = mx.ones((hidden_size,))
        self.variance_epsilon = eps

    def __call__(self, hidden_states, gate=None):
        if gate is not None:
            hidden_states = hidden_states * nn.silu(gate)
        variance = mx.mean(hidden_states ** 2, axis=-1, keepdims=True)
        hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states


def silu(x):
    return x * mx.sigmoid(x)


class DepthWiseConv1d(nn.Module):
    def __init__(self, channels, kernel_size, bias=True, padding=0):
        super().__init__()
        self.channels = channels
        self.kernel_size = kernel_size
        self.padding = padding
        self.weight = mx.random.normal((channels, kernel_size, 1))
        self.bias = mx.zeros((channels,)) if bias else None

    def __call__(self, x, cache=None):
        B, L, C = x.shape
        _, K, _ = self.weight.shape

        if cache is not None:
            x = mx.concatenate([cache, x], axis=1)
        else:
            x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)])

        y = mx.conv_general(x, self.weight, groups=C)
        y = y + self.bias
        return y, x[:, -K + 1:, :]


class Mamba2Block(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.args = args
        
        # Same dimensions as before
        self.d_model = args.hidden_size
        self.d_state = args.state_size
        self.d_conv = args.conv_kernel
        self.expand = args.expand
        self.d_inner = int(self.expand * self.d_model)
        self.n_groups = args.n_groups
        self.n_heads = args.num_heads
        self.d_head = self.d_inner // self.n_heads
        
        # Input projection
        d_in_proj = 2 * self.d_inner + 2 * self.n_groups * self.d_state + self.n_heads
        self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=args.use_bias)
        
        # Improved initialization of dt
        dt = mx.exp(
            mx.random.uniform(
                low=math.log(args.time_step_min),
                high=math.log(args.time_step_max),
                shape=(self.n_heads,)
            )
        )
        dt = mx.clip(dt, args.time_step_floor, float('inf'))
        inv_dt = dt + mx.log(-mx.exp(-dt) + 1)  # Inverse softplus
        self.dt_bias = mx.array(inv_dt)

        # Improved A initialization
        A = mx.random.uniform(
            low=args.A_init_min,
            high=args.A_init_max,
            shape=(self.n_heads,)
        )
        self.A_log = mx.log(A)
        
        # Same D initialization
        self.D = mx.random.normal((self.n_heads,)) * args.initializer_range

        # Convolution with proper initialization
        self.conv1d = DepthWiseConv1d(
            channels=self.d_inner + 2 * self.n_groups * self.d_state,
            kernel_size=self.d_conv,
            bias=args.use_conv_bias,
            padding=self.d_conv-1
        )
        
        # Output projections
        self.norm = MambaRMSNormGated(self.d_inner, eps=args.layer_norm_epsilon)
        self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=args.use_bias)

    def __call__(self, u: mx.array, cache=None):
        batch_size, seq_len, _ = u.shape

        # Project input
        zxbcdt = self.in_proj(u) # (B, L, d_in_proj)
        
        # Split projections
        z = zxbcdt[..., :self.d_inner]
        xBC = zxbcdt[..., self.d_inner:self.d_inner + (self.d_inner + 2 * self.n_groups * self.d_state)]
        dt = zxbcdt[..., -self.n_heads:]

        # Process time steps - simplified to match PyTorch
        dt = nn.softplus(dt + self.dt_bias) # (B, L, nheads)
        
        xBC, conv_state = self.conv1d(xBC, cache[0] if cache else None) # (B, L, self.d_inner + 2 * ngroups * d_state)
        if cache is not None:
            cache[0] = conv_state
        xBC = silu(xBC)

        xBC = xBC[:, :seq_len, :]
        
        # Split conv output and reshape
        x = xBC[..., :self.d_inner]
        B = mx.reshape(xBC[..., self.d_inner:self.d_inner + self.n_groups * self.d_state], (batch_size, seq_len, self.n_groups, -1))
        C = mx.reshape(xBC[..., -self.n_groups * self.d_state:], (batch_size, seq_len, self.n_groups, -1))
        
        # Reshape for SSM processing
        x = mx.reshape(x, (batch_size, seq_len, self.n_heads, self.d_head))
        
        # Initialize state
        if cache and cache[1] is not None:
            # State initialization might need proper scaling
            prev_state = cache[1]
        else:
            prev_state = mx.zeros((batch_size, self.n_heads, self.d_head, self.d_state))
                
        # Compute dA - simplified to match PyTorch
        A = -mx.exp(self.A_log)
        dt = mx.reshape(dt, (batch_size, seq_len, self.n_heads))
        dA = mx.exp(dt * mx.expand_dims(A, axis=(0, 1)))
        
        # Process sequence
        next_state = prev_state
        outputs = []
        
        for t in range(seq_len):
            # Get current step tensors
            xt = x[:, t]  # [batch, n_heads, d_head]
            Bt = B[:, t]  # [batch, n_heads, d_state]
            Ct = C[:, t]  # [batch, n_heads, d_state]
            dAt = dA[:, t]  # [batch, n_heads]
            
            # Compute dBx using einsum to match PyTorch
            dBx = mx.einsum('bh,bgd,bhp->bhpd', dAt, Bt, xt) # dAt: (b,h), Bt: (b,g,d), xt: (b,h,p) -> (b,h,p,d)

            # Update state
            next_state = next_state * mx.expand_dims(dAt, axis=(-1, -2)) + dBx

            # Compute output with groups 
            yt = mx.einsum('bhpd,bgd->bhp', next_state, Ct)
            yt = yt + xt * mx.expand_dims(self.D, -1)
            
            # Reshape and normalize
            yt = mx.reshape(yt, (batch_size, 1, self.d_inner))
            yt = self.norm(yt, z[:, t:t+1])
            outputs.append(self.out_proj(yt))
        
        # Update cache
        if cache is not None:
            cache[1] = next_state

        return mx.concatenate(outputs, axis=1)


class ResidualBlock(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.residual_in_fp32 = args.residual_in_fp32
        self.mixer = Mamba2Block(args)
        self.norm = nn.RMSNorm(args.hidden_size)

    def __call__(self, x: mx.array, cache):
        if self.residual_in_fp32:
            x = x.astype(mx.float32)
        normed = self.norm(x)
        output = self.mixer(normed, cache)
        return output + x


class Mamba2(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.args = args
        self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
        self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)]
        self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)

    def __call__(self, x: mx.array, cache):
        x = self.embeddings(x)
        if cache is None:
            cache = [None] * len(self.layers)
        
        hidden = x
        for layer, c in zip(self.layers, cache):
            hidden = layer(hidden, c)
        return self.norm_f(hidden)


class Model(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.args = args
        self.model_type = args.model_type
        self.backbone = Mamba2(args)

        if not args.tie_word_embeddings:
            self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)

    def __call__(self, inputs: mx.array, cache=None):
        hidden = self.backbone(inputs, cache)
        
        if self.args.tie_word_embeddings:
            logits = self.backbone.embeddings.as_linear(hidden)
        else:
            logits = self.lm_head(hidden)
        
        return logits
    
    def sanitize(self, weights):
        for k, v in weights.items():
            if "conv1d.weight" in k and v.shape[-1] != 1:
                weights[k] = v.moveaxis(2, 1)
        return weights

    def make_cache(self):
        return [MambaCache() for _ in range(len(self.layers))]

    @property
    def layers(self):
        return self.backbone.layers

here is a generation from Codestral:

Prompt:
python -m mlx_lm.generate --model /Users/gokdenizgulmez/Desktop/Mamba-Codestral-7B-v0.1-4bit --prompt "Write me a function that computes fibonacci in Rust" -m 120
Output:
==========
U

:%.*

❒



 Станов
:%.*ICENSE Станов Станов❒❒ biologie≮ustr≮❒ Станов❒][<.❒ Станов Squad släktet Arbitro Станов Станов Становништво][< Floren Станов.CLUDING][<.Extra



ExtraFree Станов enthusExtra

.][<.][<amment Станов enthus."— släktet⌁][< /******/American.
 biologie släktetntil /******/.
 biologie.
 AMD.][< Dy. companICENSEml Syd.
 DJCLUDING
 c







 Net
 Net
ityEngine
ityEngine
ityEngine

==========
Goekdeniz-Guelmez changed discussion status to closed
Goekdeniz-Guelmez changed discussion status to open

the new Code:

import math
from dataclasses import dataclass, field
from typing import Tuple, Union
import mlx.core as mx
import mlx.nn as nn

from .base import BaseModelArgs
from .cache import MambaCache

@dataclass
class ModelArgs(BaseModelArgs):
    model_type: str
    num_heads: int
    head_dim: int
    vocab_size: int
    hidden_size: int
    state_size: int
    num_hidden_layers: int
    layer_norm_epsilon: float
    expand: int
    conv_kernel: int
    n_groups: int
    use_bias: bool
    use_conv_bias: bool
    initializer_range: float
    residual_in_fp32: bool
    chunk_size: int
    tie_word_embeddings: bool
    time_step_limit: Tuple[float, float]
    time_step_rank: Union[int, str]
    time_step_min: float
    time_step_max: float
    time_step_floor: float
    norm_before_gate: bool = True

    def __post_init__(self):
        if not hasattr(self, "intermediate_size"):
            self.intermediate_size = int(self.expand * self.hidden_size)
        if not hasattr(self, "head_dim"):
            self.head_dim = self.hidden_size // self.num_heads
        if self.time_step_rank == "auto":
            self.time_step_rank = math.ceil(self.hidden_size / 16)


class DepthWiseConv1d(nn.Module):
    def __init__(self, channels, kernel_size, bias=True, padding=0):
        super().__init__()
        self.channels = channels
        self.kernel_size = kernel_size
        self.padding = padding
        self.weight = mx.random.normal((channels, kernel_size, 1))
        self.bias = mx.zeros((channels,)) if bias else None

    def __call__(self, x, cache=None):
        B, L, C = x.shape
        _, K, _ = self.weight.shape

        if cache is not None:
            x = mx.concatenate([cache, x], axis=1)
        else:
            x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)])

        y = mx.conv_general(x, self.weight, groups=C)
        y = y + self.bias
        return y, x[:, -K + 1:, :]


def segsum(x):
    # x shape: [b, h, l]
    b, h, l = x.shape
    
    # Create a lower triangular mask
    mask = mx.tril(mx.ones((l, l)), 0)
    
    # Reshape x for broadcasting
    x_expanded = x.reshape(b, h, l, 1)  # [b, h, l, 1]
    
    # Apply mask
    masked_x = x_expanded * mask.reshape(1, 1, l, l)  # [b, h, l, l]
    
    # Sum along the appropriate dimension
    result = mx.sum(masked_x, axis=2)  # [b, h, l]
    
    return result.reshape(b, h, 1, l)  # Return in shape [b, h, 1, l]


def ssd_forward_attn(
    x: mx.array,
    dt: mx.array,
    A: mx.array,
    B: mx.array,
    C: mx.array,
    D: mx.array,
    dt_bias: mx.array,
    dt_min: float,
    dt_max: float,
    prev_state=None,
) -> Tuple[mx.array, mx.array]:
    b, l, h, dh = x.shape
    _, _, g, _ = B.shape

    # Process dt
    if dt_bias is not None:
        dt = dt + dt_bias.reshape(1, 1, -1)
    dt = nn.softplus(dt)
    dt = mx.clip(dt, a_min=dt_min, a_max=dt_max)

    # Reshape tensors
    B_reshaped = mx.swapaxes(mx.swapaxes(B, 1, 3), 1, 2)
    C_reshaped = mx.swapaxes(C, 1, 2)

    # Compute CB
    CB = C_reshaped @ B_reshaped
    CB = mx.repeat(CB, repeats=h // g, axis=1)

    # Compute decay terms
    dtA = dt * A.reshape(1, 1, -1)
    dtA = mx.swapaxes(dtA, 1, 2)
    decay = mx.exp(-mx.abs(segsum(dtA)))

    # Create attention matrix
    surrogate_attention_matrix = mx.tril(CB * decay, 0)

    # Apply attention
    dtx = dt.reshape(b, l, h, 1) * x
    y = surrogate_attention_matrix @ dtx.swapaxes(1, 2)
    y = mx.swapaxes(y, 1, 2)

    # Compute next state
    decay_last = decay[:, :, -1, :].reshape(b, h, l).swapaxes(1, 2).reshape(b, l, h, 1)
    B_for_state = mx.repeat(B_reshaped, h // g, axis=1).swapaxes(2, 3)
    dtxdecay = dtx * decay_last
    dtxdecay = dtxdecay.swapaxes(1, 2).swapaxes(2, 3)
    
    # Calculate new state contribution
    new_state_contribution = dtxdecay @ B_for_state
    
    # Initialize or update state
    if prev_state is not None:
        decayed_prev_state = prev_state * decay[:, :, -1, :].reshape(b, h, 1, 1)
        next_state = decayed_prev_state + new_state_contribution
    else:
        next_state = new_state_contribution

    # Add skip connection if D is provided
    if D is not None:
        y += x * D.reshape(1, 1, h, 1)

    # Reshape output
    y = y.reshape(b, l, h * dh)

    return y, next_state


class Mamba2Block(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.args = args
        self.d_model = args.hidden_size
        self.d_state = args.state_size
        self.d_conv = args.conv_kernel
        self.expand = args.expand
        self.d_inner = int(self.expand * self.d_model)
        self.n_groups = args.n_groups
        self.n_heads = args.num_heads
        self.d_head = self.d_inner // self.n_heads
        self.chunk_size = args.chunk_size
        
        d_in_proj = 2 * self.d_inner + 2 * self.n_groups * self.d_state + self.n_heads
        self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=args.use_bias)

        self.dt_bias = mx.random.normal((self.n_heads,)) * args.initializer_range
        self.A_log = mx.random.normal((self.n_heads,)) * args.initializer_range
        self.D = mx.random.normal((self.n_heads,)) * args.initializer_range

        self.conv1d = DepthWiseConv1d(
            channels=self.d_inner + 2 * self.n_groups * self.d_state,
            kernel_size=self.d_conv,
            bias=args.use_conv_bias,
            padding=self.d_conv-1
        )
        
        self.norm = nn.RMSNorm(self.d_inner, eps=args.layer_norm_epsilon)
        self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=args.use_bias)

    def __call__(self, u: mx.array, cache=None):
        batch_size, seq_len, _ = u.shape
        
        # Get cache states
        conv_state = None
        ssm_state = None
        if cache is not None:
            conv_state = cache[0]  # Access using index
            ssm_state = cache[1]   # Access using index

        zxBCdt = self.in_proj(u)

        z, xBC, dt = mx.split(
            zxBCdt,
            [self.d_inner, 2 * self.d_inner + 2 * self.n_groups * self.d_state], 
            axis=-1
        )

        xBC, conv_state = self.conv1d(xBC, conv_state)
        xBC = xBC * mx.sigmoid(xBC)
        xBC = xBC[:, :seq_len, :]

        x, B, C = mx.split(
            xBC, 
            [self.d_inner, self.d_inner + self.d_state * self.n_groups], 
            axis=-1
        )

        x = mx.reshape(x, (batch_size, seq_len, self.n_heads, self.d_head))
        B = mx.reshape(B, (batch_size, seq_len, self.n_groups, -1))
        C = mx.reshape(C, (batch_size, seq_len, self.n_groups, -1))

        y, ssm_state = ssd_forward_attn(
            x=x,
            dt=dt,
            A=-mx.exp(self.A_log),
            B=B,
            C=C,
            D=self.D,
            dt_bias=self.dt_bias,
            dt_min=self.args.time_step_min,
            dt_max=self.args.time_step_max,
            prev_state=ssm_state
        )

        if self.args.norm_before_gate:
            y = self.norm(y)
            y = y * nn.silu(z)
        else:
            y = y * nn.silu(z)
            y = self.norm(y)

        y = self.out_proj(y)

        # Update cache using indexing
        if cache is not None:
            cache[0] = conv_state
            cache[1] = ssm_state
            
        return y


class ResidualBlock(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.residual_in_fp32 = args.residual_in_fp32
        self.mixer = Mamba2Block(args)
        self.norm = nn.RMSNorm(args.hidden_size)

    def __call__(self, x: mx.array, cache):
        if self.residual_in_fp32:
            x = x.astype(mx.float32)
        normed = self.norm(x)
        output = self.mixer(normed, cache)
        return output + x


class Mamba2(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.args = args
        self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
        self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)]
        self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)

    def __call__(self, x: mx.array, cache):
        x = self.embeddings(x)
        if cache is None:
            cache = [None] * len(self.layers)
        
        hidden = x
        for layer, c in zip(self.layers, cache):
            hidden = layer(hidden, c)
        return self.norm_f(hidden)


class Model(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.args = args
        self.model_type = args.model_type
        self.backbone = Mamba2(args)

        if not args.tie_word_embeddings:
            self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)

    def __call__(self, inputs: mx.array, cache=None):
        hidden = self.backbone(inputs, cache)
        
        if self.args.tie_word_embeddings:
            logits = self.backbone.embeddings.as_linear(hidden)
        else:
            logits = self.lm_head(hidden)
        
        return logits

    def make_cache(self):
        return [MambaCache() for _ in range(len(self.layers))]

    @property
    def layers(self):
        return self.backbone.layers

Example generation:

python -m mlx_lm.generate --model mlx-community/Mamba-Codestral-7B-v0.1-4bit --prompt "<s>[INST] Write me a function that computes fibonacci in Rust [/INST] " -m 64 --extra-eos-token "</s>"
Fetching 7 files: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 90061.74it/s]
==========
100% natural minerals, vitamins, minerals, vitamins, minerals, vitamins, minerals, vitamins, minerals, vitamins, minerals, vitamins, minerals, vitamins, minerals, vitamins, minerals,
==========
Prompt: 19 tokens, 37.486 tokens-per-sec
Generation: 64 tokens, 5.202 tokens-per-sec
Peak memory: 4.679 GB

Is the prompt format also correct?

Sign up or log in to comment