Considerable speed loss after Lora Finetuning
I used
@merve
's notebook to Lora Finetune Idefics3. After finally finding the right commit to merge from transformers this worked. Now for inference, i just apply the adapter but what is very surprising, is that the model loses most of its superb speed on my hardware (4x RTXA6000).
A prompt+image that took 4s before now takes 20s. Is that something you were able to observe yourselves?
Furthermore: The whole merge_and_unload() stuff in Peft also is a bit fishy, since it somehow doubles the models VRAM usage and the same on disk if saved to disk.
Anybody have a similar expirience or some pointers to get back the speed one had before?
I can confirm the same: Inference takes a long time on the model that has the LoRA adapter applied, about 20s, as you state. Whereas inference takes just 1-2 seconds on the model without the adapter. (I'm on an A100 with 80GB RAM. )
Also, I load the model like so:
peft_path="/somedir"
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForVision2Seq.from_pretrained(model_id, torch_dtype=torch.bfloat16)
peft_model = PeftModel.from_pretrained(model, peft_path, torch_dtype=torch.bfloat16).to(DEVICE)
The above works. When I try to apply the adaptor to the model via load_adaptor(), it somehow doesn't work. I haven't figured out why, yet:model.load_adaptor(peft_model_id)
Also, when I do the inference, it repeats the output many times. I guess this is because it tries to generate up to 500 tokens? Not sure.
This is how I do the inference:
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(text=prompt, images=image4, return_tensors="pt")
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
generated_ids = peft_model.generate(**inputs, max_new_tokens=500)
generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
When I pass the prompt and image that @merve provided in her notebook, I get the following output:
['User: Answer briefly.<image>Which country is this located in?\nAssistant: thailand\nAssistant: thailand\nAssistant: thailand\nassistant: thailand\nassistant: thailand\nassistant: thailand\nassistant: thailand\nassistant: thailand\nassistant: thailand\nassistant: thailand\nassistant: thailand\nassistant: thailand\nassistant: thailand....]
Etc
@ayyylemao
@fsommers
I have built the scripts mid transformers integration where it works on a specific PR commit. I will try to reproduce error and change things or fix things on transformers side if necessary. in the meantime, can you tell me your transformers versions? (also did you do LoRA or QLoRA? QLoRA slows down a lot)
also thanks for reporting, I'll try to be on it soon as possible, likely today.
Thank you, @merve ! Really appreciate it! I'm using 4.45.0.dev0. I fetched it from here, as in your notebook: https://github.com/andimarafioti/transformers.git@idefics3.
Btw, the other thing I noticed is that I had to specify use_cache=False
when loading Idefics3ForConditionalGeneration
:
https://github.com/huggingface/transformers/pull/32473#discussion_r1740209997
Btw, I could not do the training with just LoRA, as I ran out of memory, only QLoRA. Should it be possible to use LoRA only on an 80GB GPU? I cleared my GPU cache, etc, but keeps running out of memory.
OK, one more piece of info: If I load the base model quantized (8 bits) then the inference is even slower. Btw, I execute merge_and_upload(), as recommended by the documentation. The docs make sense, and merge_and_upload() takes quite some time, but then the inference also takes very long. Thanks in advance for any suggestions on how the inference can be speeded up. (Again, without the LoRA adapter, inference is fast on the base model.)
https://huggingface.co./docs/peft/main/en/developer_guides/lora
Thanks for getting back to us
@merve
.
The transformers commit i used was
commit a72b30fe06bba77d9df4c72fcea48bbdc0d812a5 (HEAD)
Author: Andres Marafioti <[email protected]>
Date: Thu Aug 8 14:20:46 2024 +0000
hot fix for merve
I used this one since I assumed this was the one you used for your tutorial script.
I used Lora (with the setup you presented in your script) since we have the hardware to support the VRAM requirements but the loss of inference speed was crushing for our application.
Were you able to replicate or pinpoint the issue?
Very sorry folks,
@ayyylemao
@fsommers
I have uploaded new version of notebook along with the training script I used (I often use scripts and not notebooks)
currently fine-tuning the model, will try to reproduce your issues!
@ayyylemao
@fsommers
I have attempted to reproduce the issue here, although I cannot find a difference between original model and the LoRA fine-tuned model: https://colab.research.google.com/drive/17TuOBxcGH0EEE5d48FN-NiQNdkBa_SsV?usp=sharing
I have uploaded my training script here: https://github.com/merveenoyan/smol-vision/blob/main/train_idefics2.py and this model is based on that script.
Thank you so much,
@merve
for uploading the updated notebook. I reproduced it, and I think I found the cause of the slow inference: It's the max_new_tokens
param in the model.generate()
function.
Redoing the LoRA training with your changes didn't seem to make a difference. Also, merging the LoRA weights into the base model via merge_and_unload()
didn't make a difference either.
What did make a difference was to specify a lower number for max_new_tokens
:
DEVICE = "cuda:0"
peft_model_id = ... // my adapter id
base_model_id = "HuggingFaceM4/Idefics3-8B-Llama3"
processor = AutoProcessor.from_pretrained(base_model_id)
peft_model = Idefics3ForConditionalGeneration.from_pretrained(base_model_id)
peft_model = peft_model.to(DEVICE)
peft_model.load_adapter(peft_model_id)
Then:
generated_ids = peft_model.generate(**inputs, max_new_tokens=500)
generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
print(generated_texts)
Result:
'User: Answer briefly.<image>Which country is this located in?\nAssistant: thailand\nAssistant: thailand\nAssistant: thailand\nAssistant: thailand\nAssistant: thailand\nassistant: thailand\nassistant: thailand\nassistant: thailand\nassistant: thailand\nassistant: thailand\nassistant: thailand\nassistant: thailand\nassistant: thailand\nassistant: thailand\nassistant: thailand\nassistant: thailand\nassistant....
CPU times: user 27 s, sys: 846 ms, total: 27.8 s
Wall time: 27.8 s
However, specifying 32 for max_new_tokens
results in:
%%time
generated_ids = peft_model.generate(**inputs, max_new_tokens=32)
generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
print(generated_texts)
...
['User: Answer briefly.<image>Which country is this located in?\nAssistant: thailand\nAssistant: thailand\nAssistant: thailand\nAssistant: thailand\nAssistant: thailand\nAssistant: thailand']
CPU times: user 3.63 s, sys: 895 ms, total: 4.53 s
Wall time: 4.52 s
This makes sense, since the prompt says, "Answer briefly." So 500 new tokens is definitely not "brief," and the the generation takes about 6 times longer.
I wonder if this should be added to the documentation, along with some advice as to how one is supposed to determine the right max_new_tokens
. For example, what is the recommended value for a document AI use-case, where one wishes to generate structured data given a document image? Or, in this case, the expected response is a single word. (I changed the prompt to say, "Answer with a single word," but this made no difference.)
@fsommers this model is fine-tuned on VQAv2 which is to come up with short answers from MCQA, so it makes sense and is ok to pass a lower number of max new tokens. again it depends on your dataset, if you have a conversational dataset I think you would have to set it to a higher number. I think "answer briefly" is what the model is conditioned to generate the short answers, i.e. from the notebook
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "Answer briefly."},
{"type": "image"},
{"type": "text", "text": question}
]
},
{
"role": "assistant",
"content": [
{"type": "text", "text": answer}
]
}
]
So it's best to stick to that to generate the VQAv2 completion.
WDYT?
@fsommers I apparently forgot to add eos token at the end of the target, hence why this is happening.
normally when we train it's done automatically from what I remember but since implementation is still not ready I have missed it
"I apparently forgot to add eos token at the end of the target,"
in the dataset? in a prompt somewhere?
Thanks
@pierre-catie at the end of the labels.
normally for other VLMs we have a label or suffix arg in the processors which sort of indicates that we are training, and it automatically adds the EOS token. for this case I was prepping the inputs myself as completed text with chat template and I missed this.