Error with Inference with past_key_values
It seems that logic of using past_key_values
in generation is either not implemented or implemented with error.
I tried to write my own generation loop using past_key_values
I got errors in _convert_to_rw_cache(past)
in modelling_RW.py
with tensor dimensions or "nonsense" in generation if try to skip this method. More details:
In modelling_RW.py
there is this method
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor,
past: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> dict:
# only last token for input_ids if past is not None
if past:
input_ids = input_ids[:, -1].unsqueeze(-1)
# the cache may be in the stardard format (e.g. in contrastive search), convert to our's format if needed
if past[0][0].shape[0] == input_ids.shape[0]:
past = self._convert_to_rw_cache(past)
return {
"input_ids": input_ids,
"past_key_values": past,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
Now, if you debug the default example with pipeline
generation from the model card description https://huggingface.co./tiiuae/falcon-7b , this bit of code from prepare_inputs_for_generation
method will never be called:
...
if past:
input_ids = input_ids[:, -1].unsqueeze(-1)
if past[0][0].shape[0] == input_ids.shape[0]:
past = self._convert_to_rw_cache(past)
...
Because past
on each iteration of the generation loop is None
.
I try to write my own loop, in which I set past
argument values and set past_key_values
. After that I always get an error in dimensions in method _convert_to_rw_cache(past)
There is no error in the generation loop if I manually edit prepare_inputs_for_generation
and skip _convert_to_rw_cache
method and leave the original dimensions of the tensors. But I get "nonsense" when decoding most probable tokens in result.
It seems that logic of using past_key_values
is either not implemented or implemented with error.
I would be very happy to hear from you. Because using past_key_values
speeds up the inference several times.
P.S. Again, when using original generate
or pipeline
methods out of the box of huggingface
with Falcon model, everything works as it should, but debugging shows that in these cases past_key_values
are not actually used.
P.P.S. I also tried changing the logic in the _convert_to_rw(past)
method, there is clearly something wrong with the expected dimensions in code, but this also failed. At best I got "nonsense" when decoding the result tokens
To add more clarity. Here is my generation cycle
device = torch.device("cuda")
model_id = "tiiuae/falcon-7b"
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map="auto",
).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_id)
text = "Hello there. How are"
inputs = tokenizer(text, return_tensors="pt").to(device)
input_ids = inputs["input_ids"]
output = None
step = 0
# generation cycle with 10 steps
while step < 10:
attention_mask = input_ids.new_ones(input_ids.shape)
past_key_values = None
if output is not None:
past_key_values = output["past_key_values"]
ids = model.prepare_inputs_for_generation(input_ids,
past=past_key_values,
use_cache=True,
attention_mask=attention_mask)
output = model(**ids)
# get random of 3 most probable tokens and add to input_ids
top_k = 3
next_token = random.choice(torch.topk(output.logits[:, -1, :], top_k, dim=-1).indices[0])
input_ids = torch.cat([input_ids, torch.tensor([[next_token]]).to(device)], dim=-1)
step += 1
print(tokenizer.decode(input_ids[0]))
Hello there. How are!
,
. I
.<|endoftext|>
P.S.
I commented out this check in modelling_RW.py
in prepare_inputs_for_generation
method
'''
# the cache may be in the stardard format (e.g. in contrastive search), convert to our's format if needed
if past[0][0].shape[0] == input_ids.shape[0]:
past = self._convert_to_rw_cache(past)
'''
Otherwise an error in the tensor dimensions will fall out
Same problem!
@FalconLLM
I would be very grateful if you could tell if past_key_values
is supposed to be used in the generation, or if this logic is not implemented? Perhaps it can be added or there are some fundamental limitations? After all its use significantly speeds up the time of inference
@FalconLLM Or maybe you can suggest a specialist from your team who would help sort out this issue? I will be very grateful!
Same problem, appreciate some suggestions from @FalconLLM !
It appears to me that RotaryEmbeddings obtains "sequence_length" from q input, which will be 1 when using KV cache. This makes embeddings incorrect.
I resolved this by passing in the position_id of the current token I'm generating with the following. Although our embeddings now match what we see without KV cache, our results are still garbage.
def cos_sin(
self,
seq_len: int,
device="cuda",
dtype=torch.bfloat16,
position=None
) -> torch.Tensor:
if seq_len != self.seq_len_cached:
self.seq_len_cached = seq_len
# t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
t = torch.arange(position, device=device).type_as(self.inv_freq)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1).to(device)
if dtype in [torch.float16, torch.bfloat16]:
emb = emb.float()
self.cos_cached = emb.cos()[None, :, :]
self.sin_cached = emb.sin()[None, :, :]
self.cos_cached = self.cos_cached.type(dtype)
self.sin_cached = self.sin_cached.type(dtype)
return (self.cos_cached[:, -1:, :], self.sin_cached[:, -1:, :]) if position != seq_len else (self.cos_cached, self.sin_cached)
def forward(self, q, k, position):
# q: q_new, b*nh x q_len x d
# k: k_new, b*nh x q_len x d
# position: true position index of these tokens
# These aren't the true position ids of the tokens
batch, seq_len, head_dim = q.shape
cos, sin = self.cos_sin(seq_len, q.device, q.dtype, position)
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
@ColmanTT
There was also this hypothesis https://huggingface.co./tiiuae/falcon-40b/discussions/48 (in 40b
discussions) , we discussed and tested it with
@cchudant
. But the results are also garbage
Btw, the way I'm attempting to run with KV cache is like this:
result = model(input_ids=input_ids, past_key_values=past_key_values, attention_mask=None, position_ids=None, use_cache=True, return_dict=True)
First iteration, input_ids contains prompt and past_key_values is None. Subsequent iterations, input_ids contains only the new token, and past_key_values is piped back into the model.
@LevanKvirkvelia
Did you success? Do you propose to replace the Attention
class in falcon
model by FlashRWAttention
from HF
?
@ColmanTT . I got pretty output after I changed code like this:
def forward(self, q, k, seq_len):
# batch, seq_len, head_dim = q.shape
_,q_len,_ = q.shape
cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
cos = cos[:,-q_len:]
sin = sin[:,-q_len:]
cos_np = cos.detach().cpu().float().numpy()
sin_np = sin.detach().cpu().float().numpy()
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
and also change the code as below:
if layer_past is not None:
L = query_layer_.shape[-2]
S = key_layer_.shape[-2]
attn_mask = torch.ones(L, S, dtype=torch.bool, device=query_layer_.device)
attn_output = F.scaled_dot_product_attention(
query_layer_, key_layer_, value_layer_, attn_mask, 0.0, is_causal=False
)
else:
attn_output = F.scaled_dot_product_attention(
query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
)
@Tron2060
Please can you explain how do you pass new arguments to RotaryEmbedding
, forward(self, q, k, seq_len)
The old way:
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
There is not seq_len
in this context. I changed it to:
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, fused_qkv.shape[1])
In fact, I get something more or less readable, but it still seems to be very far from normal model generation. Perhaps I misused RotaryEmbedding
@dimaischenko
I pass the arguments by this way:
_, seq_len, _ = query_layer.shape
if layer_past is not None:
_,seq_len_past,_=layer_past[0].shape
seq_len=seq_len+seq_len_past
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, seq_len)
This will all get fixed eventually in the transformers GitHub code
https://github.com/huggingface/transformers/issues/25151#issuecomment-1654062690