# Advanced Insights: Attention Masks with KV-Caching ## Key Pitfalls in Complex Attention Implementations ### Dimension Evolution with Caching ```python # Crucial dimension transitions in cached attention: [b, s, d_model] -> [b, s+cache, d_c] -> [b, s+cache, d_model] -> [b, num_h, s, d_head] ``` The non-obvious trap: even with growing K/V cache, attention output dimensions must match query length, not cached length. ### Mask Causality with Growing Cache Standard causal masks break with KV-caching - they don't account for position-dependent attention patterns across cached sequences. Critical edge cases: - Token at position `i` must attend to `[0:start_pos+i]` - Naive mask extension leads to incorrect causality preservation - Performance impact of position-wise mask generation ### Optimization Considerations 1. Memory vs Compute tradeoff: Precomputing extended masks vs generating per position 2. Batch dimension handling: Mask broadcasting impacts memory usage 3. Fused attention patterns may break with custom mask handling ## Debugging Strategy for Non-Obvious Cases Monitor these dimension transitions for subtle bugs: ```python C_KV.shape # Should grow: [b, s₁, d_c] -> [b, s₁+s₂, d_c] K_state.shape # Post-projection growth affects attention patterns att_output.shape # Must maintain query dimensions despite K/V growth ``` ## Practical Example: DeepSeek's MLA Edge Case In Multi-Latent Attention, the compressed KV cache introduces subtle interactions with attention masks due to: 1. Joint compression affecting position-dependent patterns 2. Non-standard dimension flow through compression/decompression 3. Mask causality preservation across cached compressed states