Fix the kv-cache dimensions

#47
by cchudant - opened
Files changed (1) hide show
  1. modelling_RW.py +1 -1
modelling_RW.py CHANGED
@@ -271,7 +271,7 @@ class Attention(nn.Module):
271
  # concatenate along seq_length dimension:
272
  # - key: [batch_size * self.num_heads, head_dim, kv_length]
273
  # - value: [batch_size * self.num_heads, kv_length, head_dim]
274
- key_layer = torch.cat((past_key, key_layer), dim=1)
275
  value_layer = torch.cat((past_value, value_layer), dim=1)
276
 
277
  _, kv_length, _ = key_layer.shape
 
271
  # concatenate along seq_length dimension:
272
  # - key: [batch_size * self.num_heads, head_dim, kv_length]
273
  # - value: [batch_size * self.num_heads, kv_length, head_dim]
274
+ key_layer = torch.cat((past_key, key_layer), dim=2)
275
  value_layer = torch.cat((past_value, value_layer), dim=1)
276
 
277
  _, kv_length, _ = key_layer.shape