Understanding memory consumption during inference
Hello!
Is there a good way to quantify the amount of memory that this model will consume during inference based on the input token count of the data I'm generating embeddings for?
When I start the model, its initial consumption is approx 14-15GB of vRAM. Example provided below was run on an AWS EC2 g5.12xlarge with 4 A10s:
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| 0 N/A N/A 44696 C .pyenv/versions/3.11.8/bin/python3 3580MiB |
| 1 N/A N/A 44696 C .pyenv/versions/3.11.8/bin/python3 4210MiB |
| 2 N/A N/A 44696 C .pyenv/versions/3.11.8/bin/python3 4210MiB |
| 3 N/A N/A 44696 C .pyenv/versions/3.11.8/bin/python3 3330MiB |
+---------------------------------------------------------------------------------------+
However, when sending input through, the size seems to balloon significantly. For example, the following is memory consumption after sending a 4096 token input through:
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| 0 N/A N/A 44696 C .pyenv/versions/3.11.8/bin/python3 10762MiB |
| 1 N/A N/A 44696 C .pyenv/versions/3.11.8/bin/python3 13334MiB |
| 2 N/A N/A 44696 C .pyenv/versions/3.11.8/bin/python3 13334MiB |
| 3 N/A N/A 44696 C .pyenv/versions/3.11.8/bin/python3 10620MiB |
+---------------------------------------------------------------------------------------+
Is it expected that the model memory consumption would scale so significantly based on input size? Or am I doing something wrong in my hosting config? Is there a good way to cap the memory consumption while still allowing embeddings of larger text sequences?
I am afraid this is indeed an expected phenomenon. The memory consumption of self-attention layers on long sequences is huge.
A few things you can try:
1.Use FlashAttention-2 to reduce GPU memory consumption as in https://huggingface.co./docs/transformers/main/model_doc/mistral#speeding-up-mistral-by-using-flash-attention . It also speeds up inference.
2.Make sure you have turned on the torch.no_grad()
context and use fp16 / bf16
if possible.
I've been having a hell of a time getting this running on all 4 gpus on a g5.12xlarge (leveraging all the gpu memory).
Do you by chance have an example from your code that you managed to achieve that (based off the numbers above)?
@andrew-kirfman
?