Advanced Insights: Attention Masks with KV-Caching
Key Pitfalls in Complex Attention Implementations
Dimension Evolution with Caching
# Crucial dimension transitions in cached attention:
[b, s, d_model] -> [b, s+cache, d_c] -> [b, s+cache, d_model] -> [b, num_h, s, d_head]
The non-obvious trap: even with growing K/V cache, attention output dimensions must match query length, not cached length.
Mask Causality with Growing Cache
Standard causal masks break with KV-caching - they don't account for position-dependent attention patterns across cached sequences. Critical edge cases:
- Token at position
i
must attend to[0:start_pos+i]
- Naive mask extension leads to incorrect causality preservation
- Performance impact of position-wise mask generation
Optimization Considerations
- Memory vs Compute tradeoff: Precomputing extended masks vs generating per position
- Batch dimension handling: Mask broadcasting impacts memory usage
- Fused attention patterns may break with custom mask handling
Debugging Strategy for Non-Obvious Cases
Monitor these dimension transitions for subtle bugs:
C_KV.shape # Should grow: [b, s₁, d_c] -> [b, s₁+s₂, d_c]
K_state.shape # Post-projection growth affects attention patterns
att_output.shape # Must maintain query dimensions despite K/V growth
Practical Example: DeepSeek's MLA Edge Case
In Multi-Latent Attention, the compressed KV cache introduces subtle interactions with attention masks due to:
- Joint compression affecting position-dependent patterns
- Non-standard dimension flow through compression/decompression
- Mask causality preservation across cached compressed states