Vocab size vs. LM head size mismatch
#46
by
harshil-shah
- opened
Hi,
It seems there is a mismatch between the vocab size in the MllamaProcessor
and the size of the lm_head
weight matrix. Trying to call resize_token_embeddings
doesn't fix this. This means that it is not possible to do training. Minimal example:
import requests
from PIL import Image
from transformers import MllamaForConditionalGeneration, MllamaProcessor
MODEL_NAME = "meta-llama/Llama-3.2-11B-Vision-Instruct"
processor = MllamaProcessor.from_pretrained(MODEL_NAME)
model = MllamaForConditionalGeneration.from_pretrained(MODEL_NAME)
print(f"{len(processor.tokenizer) = }")
print(f"Before resize: {model.language_model.lm_head.weight.shape = }")
model.resize_token_embeddings(len(processor.tokenizer))
print(f"After resize: {model.language_model.lm_head.weight.shape = }")
url = "https://huggingface.co./datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg"
image = Image.open(requests.get(url, stream=True).raw)
messages = [
{"role": "user", "content": [
{"type": "image"},
{"type": "text", "text": "If I had to write a haiku for this one, it would be: "}
]}
]
input_text = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(
image,
input_text,
add_special_tokens=False,
return_tensors="pt",
).to(model.device)
output = model(**inputs, labels=inputs.input_ids)
This outputs:
len(processor.tokenizer) = 128257
Before resize: model.language_model.lm_head.weight.shape = torch.Size([128256, 4096])
After resize: model.language_model.lm_head.weight.shape = torch.Size([128256, 4096])
And then errors with:
File ~/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)
File ~/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
1515 # If we don't have any hooks, we want to skip the rest of the logic in
1516 # this function, and just call forward.
1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1522 try:
1523 result = None
File ~/.venv/lib/python3.11/site-packages/transformers/models/mllama/modeling_mllama.py:2188, in MllamaForConditionalGeneration.forward(self, input_ids, pixel_values, aspect_ratio_mask, aspect_ratio_ids, attention_mask, cross_attention_mask, cross_attention_states, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, num_logits_to_keep)
2185 cross_attention_mask = cross_attention_mask[:, :, cache_position]
2186 full_text_row_masked_out_mask = full_text_row_masked_out_mask[:, :, cache_position]
-> 2188 outputs = self.language_model(
2189 input_ids=input_ids,
2190 attention_mask=attention_mask,
2191 position_ids=position_ids,
2192 cross_attention_states=cross_attention_states,
2193 cross_attention_mask=cross_attention_mask,
2194 full_text_row_masked_out_mask=full_text_row_masked_out_mask,
2195 past_key_values=past_key_values,
2196 use_cache=use_cache,
2197 inputs_embeds=inputs_embeds,
2198 labels=labels,
2199 output_hidden_states=output_hidden_states,
2200 output_attentions=output_attentions,
2201 return_dict=return_dict,
2202 cache_position=cache_position,
2203 num_logits_to_keep=num_logits_to_keep,
2204 )
2206 return outputs
File ~/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)
File ~/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
1515 # If we don't have any hooks, we want to skip the rest of the logic in
1516 # this function, and just call forward.
1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1522 try:
1523 result = None
File ~/.venv/lib/python3.11/site-packages/transformers/models/mllama/modeling_mllama.py:1961, in MllamaForCausalLM.forward(self, input_ids, attention_mask, position_ids, cross_attention_states, cross_attention_mask, full_text_row_masked_out_mask, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, num_logits_to_keep)
1959 # Enable model parallelism
1960 shift_labels = shift_labels.to(shift_logits.device)
-> 1961 loss = loss_fct(shift_logits, shift_labels)
1963 if not return_dict:
1964 output = (logits,) + outputs[1:]
File ~/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)
File ~/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
1515 # If we don't have any hooks, we want to skip the rest of the logic in
1516 # this function, and just call forward.
1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1522 try:
1523 result = None
File ~/.venv/lib/python3.11/site-packages/torch/nn/modules/loss.py:1179, in CrossEntropyLoss.forward(self, input, target)
1178 def forward(self, input: Tensor, target: Tensor) -> Tensor:
-> 1179 return F.cross_entropy(input, target, weight=self.weight,
1180 ignore_index=self.ignore_index, reduction=self.reduction,
1181 label_smoothing=self.label_smoothing)
File ~/.venv/lib/python3.11/site-packages/torch/nn/functional.py:3059, in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing)
3057 if size_average is not None or reduce is not None:
3058 reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 3059 return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
IndexError: Target 128256 is out of bounds.
The <|image|>
token, which is token ID 128256, is not intended to be trained on. You should take care of replacing / masking out that token for the forward pass in training. Another thread has contributed some training code here: https://huggingface.co./meta-llama/Llama-3.2-11B-Vision-Instruct/discussions/31