Cannot set sequence length higher than 2048 & doesn't support the optimized triton implementation of FlashAttention
#2
by
t83714
- opened
I've tried with the following code and it (any sequence length rather than 2048) doesn't work for me:
The same code works for mosaicml/mpt-7b-instruct
though.
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import torch
device = f'cuda:{torch.cuda.current_device()}' if torch.cuda.is_available() else 'cpu'
print(f'Selected device is: {device}')
model_name = "nomic-ai/gpt4all-mpt"
config = AutoConfig.from_pretrained(
model_name,
trust_remote_code=True
)
# use the optimized triton implementation of FlashAttention, you can load the model with attn_impl='triton' and move the model to bfloat16
#config.attn_config['attn_impl'] = 'triton'
config.init_device = device
# config.max_seq_len = 2048
# update the maximum sequence length during inference to 4096
config.max_seq_len = 3072
print(config)
model = AutoModelForCausalLM.from_pretrained(
model_name,
config=config,
torch_dtype=torch.bfloat16,
trust_remote_code = True
)
model.eval()
I got the following error:
RuntimeError: Error(s) in loading state_dict for MPTForCausalLM:
size mismatch for transformer.wpe.weight: copying a param with shape torch.Size([2048, 4096]) from checkpoint, the shape in current model is torch.Size([3072, 4096]).
You may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method.
Set ignore_mismatched_sizes=True
still won't fix it. Instead, you got a different error:
File /opt/anaconda3/lib/python3.9/site-packages/transformers/modeling_utils.py:3031, in PreTrainedModel._load_pretrained_model.<locals>._find_mismatched_keys(state_dict, model_state_dict, loaded_keys, add_prefix_to_model, remove_prefix_from_model, ignore_mismatched_sizes)
3025 elif add_prefix_to_model:
3026 # The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it.
3027 model_key = ".".join(checkpoint_key.split(".")[1:])
3029 if (
3030 model_key in model_state_dict
-> 3031 and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
3032 ):
3033 mismatched_keys.append(
3034 (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
3035 )
3036 del state_dict[checkpoint_key]
KeyError: 'transformer.blocks.11.ffn.down_proj.weight'
By the way, this model also doesn't support the optimized triton implementation of FlashAttention like mosaicml/mpt-7b-instruct
.
If you turn it on via config.attn_config['attn_impl'] = 'triton'
, you will get the same KeyError: 'transformer.blocks.11.ffn.down_proj.weight'
error.
Would appreciate if anyone could shed some light on the possible cause of this error.