Cannot set sequence length higher than 2048 & doesn't support the optimized triton implementation of FlashAttention

#2
by t83714 - opened

I've tried with the following code and it (any sequence length rather than 2048) doesn't work for me:

The same code works for mosaicml/mpt-7b-instruct though.

from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import torch

device = f'cuda:{torch.cuda.current_device()}' if torch.cuda.is_available() else 'cpu'

print(f'Selected device is: {device}')

model_name = "nomic-ai/gpt4all-mpt"

config = AutoConfig.from_pretrained(
  model_name,
  trust_remote_code=True
)
# use the optimized triton implementation of FlashAttention, you can load the model with attn_impl='triton' and move the model to bfloat16
#config.attn_config['attn_impl'] = 'triton'
config.init_device = device
# config.max_seq_len = 2048
# update the maximum sequence length during inference to 4096
config.max_seq_len = 3072

print(config)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    config=config,
    torch_dtype=torch.bfloat16,
    trust_remote_code = True
)

model.eval()

I got the following error:

RuntimeError: Error(s) in loading state_dict for MPTForCausalLM:
    size mismatch for transformer.wpe.weight: copying a param with shape torch.Size([2048, 4096]) from checkpoint, the shape in current model is torch.Size([3072, 4096]).
    You may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method.

Set ignore_mismatched_sizes=True still won't fix it. Instead, you got a different error:

File /opt/anaconda3/lib/python3.9/site-packages/transformers/modeling_utils.py:3031, in PreTrainedModel._load_pretrained_model.<locals>._find_mismatched_keys(state_dict, model_state_dict, loaded_keys, add_prefix_to_model, remove_prefix_from_model, ignore_mismatched_sizes)
   3025 elif add_prefix_to_model:
   3026     # The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it.
   3027     model_key = ".".join(checkpoint_key.split(".")[1:])
   3029 if (
   3030     model_key in model_state_dict
-> 3031     and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
   3032 ):
   3033     mismatched_keys.append(
   3034         (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
   3035     )
   3036     del state_dict[checkpoint_key]

KeyError: 'transformer.blocks.11.ffn.down_proj.weight'

By the way, this model also doesn't support the optimized triton implementation of FlashAttention like mosaicml/mpt-7b-instruct.
If you turn it on via config.attn_config['attn_impl'] = 'triton', you will get the same KeyError: 'transformer.blocks.11.ffn.down_proj.weight' error.

Would appreciate if anyone could shed some light on the possible cause of this error.

Sign up or log in to comment