deepseek-mla / insights /architecture.md
Yan Wei
Initial commit: DeepSeek Multi-Latent Attention implementation
550eb56
|
raw
history blame
3.36 kB

Advanced Insights: Multi-Latent Attention Architecture

Key Architectural Innovations

Compression-Position Decoupling

# Two parallel pathways with different roles:
[b, s, d] -> [b, s, d_c] -> [b, s, d]     # Compression pathway
[b, s, d] -> [b, s, d_r] -> RoPE()        # Position pathway

Critical insight: Matrix multiplication non-commutativity necessitates pathway separation for efficient inference.

Asymmetric Dimensionality

Q pathway: per-head rotary dimensions [b, s, n_h, d_r]
K pathway: shared rotary dimensions  [b, s, 1, d_r]

Design choice allows computational reuse while maintaining positional awareness.

Cache Optimization Strategy

Two distinct caches with different roles:

cache_kv: [b, max_len, d_c]    # Compressed KV states
cache_rk: [b, max_len, d_r]    # Shared rotary key

Optimization insight: d_c + d_r << d_model, enabling significant memory reduction.

Implementation Subtleties

Matrix Absorption During Inference

Standard: W^Q @ (W^UK @ c^KV)           # Three matrix multiplications
Optimized: (W^Q @ W^UK) @ c^KV          # Two matrix multiplications

Key requirement: Position-agnostic main pathway enables matrix pre-multiplication.

Attention Pattern Evolution

t=1: Pattern[1:1]     # Initial token
t=2: Pattern[1:2]     # One previous token
t=n: Pattern[1:n]     # Full context window

Cache growth introduces subtle position-dependent patterns requiring careful mask handling.

Dimension Flow Control

Critical transitions to monitor:

[b, s, d] -> [b, s, d_c]              # Down projection
[b, s, d_c] -> [b, s+cache, d_c]      # Cache concatenation
[b, s+cache, d_c] -> [b, s+cache, d]  # Up projection

Each transition must preserve both positional and content information.

Edge Cases and Considerations

Cross-Attention Scenarios

q_len != kv_len  # Length mismatch
d_c < d_model    # Compression bottleneck

Compression and position information must be maintained across different sequence lengths.

Position-Aware Cache Updates

# Position-dependent attention mask creation
mask[:, :, i, :(start_pos + i + 1)] = 0       # Can attend
mask[:, :, i, (start_pos + i + 1):] = -inf    # Cannot attend

Mask must evolve with cache to maintain causal attention patterns.

Numerical Stability

  1. Scaling factor accounts for both pathways: 1/sqrt(d_head + d_rotate)
  2. Compression dimensions balance between efficiency and representation capacity
  3. RoPE dimensions impact position encoding granularity

Performance Implications

Memory Complexity

Standard: O(b * s * d_model)
MLA:      O(b * s * (d_c + d_r))

Where d_c + d_r << d_model

Computational Trade-offs

  1. Additional projections for position pathway
  2. Reduced cache size enables longer sequences
  3. Matrix absorption reduces inference compute

Integration Considerations

Initialization Strategy

# Critical hyperparameters
d_c:      Compression dimension
d_rotate: Position encoding dimension

Trade-off between compression efficiency and position encoding capacity.

Cache Management

# Two update patterns
cache_kv[:, pos:pos+s] = current_kv    # Content cache
cache_rk[:, pos:pos+s] = current_rk    # Position cache

Synchronization between caches crucial for correctness.