deepseek-mla / insights /architecture.md
Yan Wei
Initial commit: DeepSeek Multi-Latent Attention implementation
550eb56
|
raw
history blame
3.36 kB
# Advanced Insights: Multi-Latent Attention Architecture
## Key Architectural Innovations
### Compression-Position Decoupling
```python
# Two parallel pathways with different roles:
[b, s, d] -> [b, s, d_c] -> [b, s, d] # Compression pathway
[b, s, d] -> [b, s, d_r] -> RoPE() # Position pathway
```
Critical insight: Matrix multiplication non-commutativity necessitates pathway separation for efficient inference.
### Asymmetric Dimensionality
```
Q pathway: per-head rotary dimensions [b, s, n_h, d_r]
K pathway: shared rotary dimensions [b, s, 1, d_r]
```
Design choice allows computational reuse while maintaining positional awareness.
### Cache Optimization Strategy
Two distinct caches with different roles:
```python
cache_kv: [b, max_len, d_c] # Compressed KV states
cache_rk: [b, max_len, d_r] # Shared rotary key
```
Optimization insight: `d_c + d_r << d_model`, enabling significant memory reduction.
## Implementation Subtleties
### Matrix Absorption During Inference
```
Standard: W^Q @ (W^UK @ c^KV) # Three matrix multiplications
Optimized: (W^Q @ W^UK) @ c^KV # Two matrix multiplications
```
Key requirement: Position-agnostic main pathway enables matrix pre-multiplication.
### Attention Pattern Evolution
```
t=1: Pattern[1:1] # Initial token
t=2: Pattern[1:2] # One previous token
t=n: Pattern[1:n] # Full context window
```
Cache growth introduces subtle position-dependent patterns requiring careful mask handling.
### Dimension Flow Control
Critical transitions to monitor:
```
[b, s, d] -> [b, s, d_c] # Down projection
[b, s, d_c] -> [b, s+cache, d_c] # Cache concatenation
[b, s+cache, d_c] -> [b, s+cache, d] # Up projection
```
Each transition must preserve both positional and content information.
## Edge Cases and Considerations
### Cross-Attention Scenarios
```python
q_len != kv_len # Length mismatch
d_c < d_model # Compression bottleneck
```
Compression and position information must be maintained across different sequence lengths.
### Position-Aware Cache Updates
```python
# Position-dependent attention mask creation
mask[:, :, i, :(start_pos + i + 1)] = 0 # Can attend
mask[:, :, i, (start_pos + i + 1):] = -inf # Cannot attend
```
Mask must evolve with cache to maintain causal attention patterns.
### Numerical Stability
1. Scaling factor accounts for both pathways: `1/sqrt(d_head + d_rotate)`
2. Compression dimensions balance between efficiency and representation capacity
3. RoPE dimensions impact position encoding granularity
## Performance Implications
### Memory Complexity
```
Standard: O(b * s * d_model)
MLA: O(b * s * (d_c + d_r))
```
Where `d_c + d_r << d_model`
### Computational Trade-offs
1. Additional projections for position pathway
2. Reduced cache size enables longer sequences
3. Matrix absorption reduces inference compute
## Integration Considerations
### Initialization Strategy
```python
# Critical hyperparameters
d_c: Compression dimension
d_rotate: Position encoding dimension
```
Trade-off between compression efficiency and position encoding capacity.
### Cache Management
```python
# Two update patterns
cache_kv[:, pos:pos+s] = current_kv # Content cache
cache_rk[:, pos:pos+s] = current_rk # Position cache
```
Synchronization between caches crucial for correctness.