deepseek-mla / insights /attention_mask.md
Yan Wei
Initial commit: DeepSeek Multi-Latent Attention implementation
550eb56
|
raw
history blame
1.7 kB

Advanced Insights: Attention Masks with KV-Caching

Key Pitfalls in Complex Attention Implementations

Dimension Evolution with Caching

# Crucial dimension transitions in cached attention:
[b, s, d_model] -> [b, s+cache, d_c] -> [b, s+cache, d_model] -> [b, num_h, s, d_head]

The non-obvious trap: even with growing K/V cache, attention output dimensions must match query length, not cached length.

Mask Causality with Growing Cache

Standard causal masks break with KV-caching - they don't account for position-dependent attention patterns across cached sequences. Critical edge cases:

  • Token at position i must attend to [0:start_pos+i]
  • Naive mask extension leads to incorrect causality preservation
  • Performance impact of position-wise mask generation

Optimization Considerations

  1. Memory vs Compute tradeoff: Precomputing extended masks vs generating per position
  2. Batch dimension handling: Mask broadcasting impacts memory usage
  3. Fused attention patterns may break with custom mask handling

Debugging Strategy for Non-Obvious Cases

Monitor these dimension transitions for subtle bugs:

C_KV.shape      # Should grow: [b, s₁, d_c] -> [b, s₁+s₂, d_c]
K_state.shape   # Post-projection growth affects attention patterns
att_output.shape # Must maintain query dimensions despite K/V growth

Practical Example: DeepSeek's MLA Edge Case

In Multi-Latent Attention, the compressed KV cache introduces subtle interactions with attention masks due to:

  1. Joint compression affecting position-dependent patterns
  2. Non-standard dimension flow through compression/decompression
  3. Mask causality preservation across cached compressed states