Unable to do inference of finetuned lora model

#69
by rajuptvs - opened

I am unable to do inference on finetuned model, please let me know if i am doing anything wrong

Inference Code

import torch
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoProcessor, BitsAndBytesConfig, Idefics2ForConditionalGeneration


peft_model_id = "rajuptvs/img2mkd-idefics2-8b"
config = PeftConfig.from_pretrained(peft_model_id)
model = Idefics2ForConditionalGeneration.from_pretrained(config.base_model_name_or_path, return_dict=True, load_in_8bit=True, device_map='auto')
model = PeftModel.from_pretrained(model, peft_model_id)

eval_dataset = load_dataset("rajuptvs/img2mkd", split="test")
example=eval_dataset[5]
image = example["image"]
# query = example["query"]
processor = AutoProcessor.from_pretrained(
    "HuggingFaceM4/idefics2-8b",
    do_image_splitting=False
)

messages = [
    {
        "role": "user",
        "content": [
            {"type": "text", "text": "Answer the question."},
            {"type": "image"},
            {"type": "text", "text": "Given the image, Extract all the information from image and write it into markdown format"}
        ]
    }
]
text = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(text=[text.strip()], images=[image], return_tensors="pt", padding=True)
generated_ids = model.generate(**inputs, max_new_tokens=10000)
generated_texts = processor.batch_decode(generated_ids[:, inputs["input_ids"].size(1):], skip_special_tokens=True)
print(generated_texts)

Error

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[13], line 1
----> 1 generated_ids = model.generate(**inputs, max_new_tokens=10000)
      2 generated_texts = processor.batch_decode(generated_ids[:, inputs["input_ids"].size(1):], skip_special_tokens=True)
      3 print(generated_texts)

File ~/anaconda3/envs/test/lib/python3.11/site-packages/peft/peft_model.py:647, in PeftModel.generate(self, *args, **kwargs)
    645 with self._enable_peft_forward_hooks(*args, **kwargs):
    646     kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
--> 647     return self.get_base_model().generate(*args, **kwargs)

File ~/anaconda3/envs/test/lib/python3.11/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File ~/anaconda3/envs/test/lib/python3.11/site-packages/transformers/generation/utils.py:1896, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
   1888     input_ids, model_kwargs = self._expand_inputs_for_generation(
   1889         input_ids=input_ids,
   1890         expand_size=generation_config.num_return_sequences,
   1891         is_encoder_decoder=self.config.is_encoder_decoder,
   1892         **model_kwargs,
   1893     )
...
   1541 reshaped_image_hidden_states = image_hidden_states.view(-1, vision_hidden_size)
-> 1542 new_inputs_embeds[special_image_token_mask] = reshaped_image_hidden_states
   1543 return new_inputs_embeds

RuntimeError: shape mismatch: value tensor of shape [64, 4096] cannot be broadcast to indexing result of shape [0, 4096]

Try again with the most recent version of transformers if you haven't already -- this looks like a bug that was fixed yesterday.

setting use_cache=True in the generate() function solved the problem for me.

Try again with the most recent version of transformers if you haven't already -- this looks like a bug that was fixed yesterday.

Thanks reinstalling transformers from source... seems to fix the issue...

rajuptvs changed discussion status to closed

Sign up or log in to comment