Setting num_return_sequences results in shape mismatch error.

#28
by Watarungurunnn - opened

hf_args:
do_sample: true
temperature: 0.8
top_k: 50
top_p: 0.95
num_return_sequences: 30

model = AutoModelForCausalLM.from_pretrained(
            model_path,
            trust_remote_code=True,
            torch_dtype="auto",
            device_map="auto",
            **model_args,
        )
generated_tokens = model.generate(
                    inputs=input_ids,
                    pad_token_id=tokenizer.pad_token_id,
                    **hf_args,
                )

Error:

  File "/home/user_2942/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/user_2942/.local/lib/python3.10/site-packages/transformers/models/gemma2/modeling_gemma2.py", line 1068, in forward
    outputs = self.model(
  File "/home/user_2942/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/user_2942/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/user_2942/.local/lib/python3.10/site-packages/transformers/models/gemma2/modeling_gemma2.py", line 908, in forward
    layer_outputs = decoder_layer(
  File "/home/user_2942/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/user_2942/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/user_2942/.local/lib/python3.10/site-packages/transformers/models/gemma2/modeling_gemma2.py", line 650, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/home/user_2942/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/user_2942/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/user_2942/.local/lib/python3.10/site-packages/transformers/models/gemma2/modeling_gemma2.py", line 252, in forward
    key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
  File "/home/user_2942/.local/lib/python3.10/site-packages/transformers/cache_utils.py", line 1227, in update
    return update_fn(
  File "/home/user_2942/.local/lib/python3.10/site-packages/transformers/cache_utils.py", line 1202, in _static_update
    k_out[:, :, cache_position] = key_states
RuntimeError: shape mismatch: value tensor of shape [30, 16, 942, 128] cannot be broadcast to indexing result of shape [1, 16, 942, 128]
time="2024-07-09T04:45:52 UTC" level=info msg="sub-process exited" argo=true error="<nil>"
Error: exit status 1

Hi,
Changing line 1767 in generation/utils.py to getattr(generation_config, "num_beams", 1) * getattr(generation_config, "num_return_sequences", 1) * batch_size , fixed the problem for me. Hope you find that helpful :)

Any update on this issue? I cannot use a fork of transformers on my project.

It was fixed and released already, just make sure to update transformers 😄

Sign up or log in to comment