Saving Memory Using Padding-Free Transformer Layers during Finetuning
For long sequence training, attention computation can become a memory bottleneck as a naive implementation requires memory where is the sequence length. However, recently FlashAttention [1,2] has been proposed, which optimizes IO and use online softmax [3] to reduce both data movement [4] from the GPU memory (typically HBM for datacenter GPUs) and GPU cache. The FlashAttention algorithm also reduces the memory requirement for attention computation from to , i.e from quadratic to linear.
The current integration of FlashAttention [1,2] for training in most libraries (including HuggingFace transformers) are based on non-intrusive changes to these libraries and most implementations just replace the Naive Attention module with FlashAttention [1,2]. Though this is quite easy to implement, its suboptimal when using sequences of variable lengths in a batch (i.e when we have padding in the batch). All operations except attention are applied independently to each token position in the transformer. And since FlashAttention [1,2] completely avoids any computation/memory requirements on pad tokens, its possible to drop all reduntant computations and memory needed for padding tokens from the transformer model and essentially create a Padding-Free transformer model when using FlashAttention. This is also done in the original FlashAttention training codebase. It should be noted that this is an exact implementation of the model and has no approximations.
Similar optimizations are done in HuggingFace TGI for inference efficiency. It should be noted that this would not be a problem in cases where padding of the batch is not needed i.e if a batch has all examples of equal length or when using dense packing of examples (as is the case for pretraining models).
In this blog, we give the theoretical memory consumptions for naive attention, FlashAttention with padded transformer blocks (current implementation in HuggingFace transformers library) and the Padding-Free transformer blocks.
Lets assume an input batch of embeddings of shape as input to a transformer layer where , and denote the batch size, unpadded sequence length of the example in the batch and the hidden size for the transformer model respectively. For training the model, each transformer layer needs to cache activations of each operation (computed in the forward pass) for the backward pass. We assume 16-bit precision for training (2 bytes per value in a tensor). We also assume Multi-Head Attention [6] with attention heads here for simplicity. Though same idea also applies to Multi-Query Attention [7] and Grouped-Query Attention [8].
Naive Attention
The input LayerNorm receives an input of shape which needs to be cached for backward pass. The mean and variance also need to be cached which are each of shape . Since, , we can ignore the elements for mean and variance. Total activation memory for this operation is
The input (shared among the Q, K and V projections) to the QKV projection matrices needs to be cached. It also has elements taking bytes. The outputs of each of Q, K and V projection also need to be cached, each of which has elements taking bytes each. Total activation memory for this operation is
The output of softmax which has elements also needs to be cached. Total activation memory is for this operation is
Attention softmax has a dropout which requires saving a mask of elements. Each element takes a byte since PyTorch doesn't allow bit tensors. The reason for this is probably an ease of implementation since GPUs are generally byte-addressable. Total memory for this operation is
The softmax dropout output has elements which also needs to be cached. Total activation memory for this operation is
We cache the output of the above multiplication which is the input to the projection matrix. It has elements. Total activation memory for this operation is
Only the dropout mask needs to be cached. Total memory for this operation is
Same as the previous layernorm. Memory requirement is
We assume here that the feedforward hidden dimension is as is typical for a standard transformer. Inputs to each linear layer and the input to GELU activation function needs to be cached. These take , bytes and bytes respectively. The required memory for the MLP block is
Memory required is same as point (8) above i.e
Summing these up, total activation memory per layer is given by:
Transformer Layer with FlashAttention
FlashAttention [1,2] has been integrated into the HuggingFace transformers API. The current implementation at the time of writing this blog does an unpad operation just before FlashAttention kernel is executed. This operation converts the input Q, K, V of shape to shape (where each example in the batch is concatenated one after the other resulting in a 2D tensor) and launches the FlashAttention kernel. Post attention computation, the output is padded again to the shape .
FlashAttention [1,2] avoids materializing the quadratic matrix in memory and uses online softmax [3], thereby dropping the need to cache activations in point (3). Rather we only need to materialize the output matrix which has shape , the 2 softmax statistics both of which have the same shape and the random number generator state for the dropout which we ignore here. For the algorithm in detail, refer to FlashAttention [1,2] paper. We also need to cache the attention mask of booleans which is used for padding and unpadding. We ignore it in calculations though since its same for every layer and can be cached once for the entire transformer model and doesn't need to be cached on every layer. Thus the memory required for attention becomes
Thus we have the total activation memory per layer with FlashAttention [1,2] as follows:
Padding-Free Transformer Layer
Since all operations (except attention) in the transformer layer are same for each token position, we can avoid the padding and unpadding operation and thus reduce the activation memory required by the transformer layer further, this requires minor changes to the HuggingFace transformers implementation. In this implementation of the transformer, there is no wasted memory for pad token positions at all! In this case, the input to the entire transformer model is of the shape . The memory in this case is given by
It should be noted that when there is no padding i.e when . This optimization is similar to running a transformer model with nested tensors. While there has been significant effort to resolve this problem by taking approches like binning examples by context lengths, these lead to model performance degradation especially during finetuning.
Motivation for using Padding-Free Transformer Layer
Now, we analyze the memory consumptions in the 3 transformer layer implementations. We assume that we have a dataset of sequences of lengths following a discrete uniform distribution i.e , where is the random variable denoting the sequence length of sample in the batch and is the maximum sequence length for the dataset and the model. We sample batches with examples each, with sequences of lengths . We compute the expectation , and under the discrete uniform distribution. To do so, we consider another random variable . The Cumulative Distribution Function for can be derived as: Now, using the fact that examples in a batch are i.i.d, we have and thus we have the Probability Mass Function for as: We can use computational methods or Faulhaber's formula [9] with the aforementioned derived result to calculate the expectations of the memory usage in the 3 methods. We report the theoretical memory consumption derived using the equations for a 20B parameter model in the following table. We find that using a Padding-Free version of the transformer layer saves activation memory and also saves a lot of redundant FLOPs. We leave the analysis of FLOPs out of this blog but they are easily derivable.
Sequence Length | Naive Attention | Flash Attention | Padding-Free Transformer |
---|---|---|---|
512 | 1.085 GB | 0.721 GB | 0.411 GB |
1024 | 2.919 GB | 1.441 GB | 0.821 GB |
2048 | 8.837 GB | 2.882 GB | 1.642 GB |
4096 | 29.674 GB | 5.763 GB | 3.283 GB |
8192 | 107.347 GB | 11.524 GB | 6.566 GB |
16384 | 406.693 GB | 23.048 GB | 13.132 GB |
32768 | 1581.386 GB | 46.096 GB | 26.263 GB |
Table: Memory usage per transformer layer for different attention implementations at different context lengths for a 20B parameter model with context length , hidden size , FFN hidden size , attention heads .
Conclusion
In this blog, we present a way to completely avoid computations and memory requirements of pad tokens during finetuning of transformer models using FlashAttention. Our changes are easily integrable into the HuggingFace transformers ecosystem for finetuning. We also derive equations for theoretical memory consumption for the same in this blog. The method doesn't involve writing any low level device code. The only non-native PyTorch code we use is FlashAttention which is already available.
References
- Dao, Tri, et al. "Flashattention: Fast and memory-efficient exact attention with io-awareness." Advances in Neural Information Processing Systems 35 (2022): 16344-16359.
- Dao, Tri. "Flashattention-2: Faster attention with better parallelism and work partitioning." arXiv preprint arXiv:2307.08691 (2023).
- Milakov, Maxim, and Natalia Gimelshein. "Online normalizer calculation for softmax." arXiv preprint arXiv:1805.02867 (2018).
- Ivanov, Andrei, et al. "Data movement is all you need: A case study on optimizing transformers." Proceedings of Machine Learning and Systems 3 (2021): 711-732.
- Korthikanti, Vijay Anand, et al. "Reducing activation recomputation in large transformer models." Proceedings of Machine Learning and Systems 5 (2023).
- Vaswani, Ashish, et al. "Attention is all you need." Advances in neural information processing systems 30 (2017).
- Shazeer, Noam. "Fast transformer decoding: One write-head is all you need." arXiv preprint arXiv:1911.02150 (2019).
- Ainslie, Joshua, et al. "Gqa: Training generalized multi-query transformer models from multi-head checkpoints." arXiv preprint arXiv:2305.13245 (2023).
- Knuth, Donald E. "Johann Faulhaber and sums of powers." Mathematics of Computation 61.203 (1993): 277-294.