purunfer22 commited on
Commit
eca0280
1 Parent(s): c7f670a

Changes in modelling_RW.py to be able to handle past_key_values for faster model generations

Browse files

The current code has missed out passing past_key_values in every forward pass for fast generation of tokens. This results in lot of recompute. This "modelling_RW.py" I am uploading deals with this in the way pytorch huggingface transformers package generation/utils.py wants. All the changes are basically around including past_key_values everywhere. I think this will apply on all falcon models These are the changes specifically

1) Class RotaryEmbedding forward method
Include past_seq_length in forward pass and apply rotary embedding according to the position of the query token ---- if else condition added (line number 100-103)

2) _make_causal_mask function
to give masking according to the way F.scaled dot product attention behaves. F.scaled_dot_product attention treats the attention_mask matrix as receiving attentions. For example if attention_mask is
[[True, False], [True, True]]. It would mean the first token is "receiving" attentions from first token and not second token. This is unlike what we generally end up thinking which is first token is giving attention to itself and not to the second one. Due to reason the past_key_values attentions are all True in make_causal mask function. Also I have reversed the inequality above that due to the same reason. ---- (line number 114 inequality, line number 117 attention mask to be True)

3) Class Attention forward method
a) past_key_value length is passed in rotary function ---- if,else loop added (line number 271-277)
b) concatenation of past key and current key is done after permuting the past key shape to match the current key shape ---- (line number 280-284)
c) to keep key_layer shape consistent with the output expectation which is (batch_size, head_dim, seq_length), another permutation done before creating "present" to return in the output ---- (line number 289-293)

4) RW Model prepare_attn_mask
Have removed src_length > 1 criteria for making causal mask (line number 554).

5) RW causal LM prepare inputs for generation
Read pastkey values from the input coming from huggingface generate method and dont call convert_to_rw_cache method (line number 740-748)

Files changed (1) hide show
  1. modelling_RW.py +72 -36
modelling_RW.py CHANGED
@@ -11,7 +11,9 @@ import torch.utils.checkpoint
11
  from torch import nn
12
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
13
  from torch.nn import functional as F
14
-
 
 
15
  from transformers.modeling_outputs import (
16
  BaseModelOutputWithPastAndCrossAttentions,
17
  CausalLMOutputWithCrossAttentions,
@@ -87,10 +89,19 @@ class RotaryEmbedding(torch.nn.Module):
87
 
88
  return self.cos_cached, self.sin_cached
89
 
90
- def forward(self, q, k):
91
- batch, seq_len, head_dim = q.shape
 
 
 
 
 
92
  cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
93
- return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
 
 
 
 
94
 
95
 
96
  def _make_causal_mask(
@@ -100,10 +111,10 @@ def _make_causal_mask(
100
  mask = torch.empty((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device)
101
  # ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround
102
  seq_ids = torch.arange(target_length, device=device)
103
- mask[:, past_key_values_length:] = seq_ids[:, None] < seq_ids[None, :]
104
 
105
  if past_key_values_length > 0:
106
- mask[:, :past_key_values_length] = False
107
 
108
  expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
109
  return expanded_mask
@@ -150,6 +161,10 @@ def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training:
150
  out = residual + out
151
  return out
152
 
 
 
 
 
153
 
154
  class Attention(nn.Module):
155
  def __init__(self, config: RWConfig):
@@ -239,9 +254,8 @@ class Attention(nn.Module):
239
  use_cache: bool = False,
240
  output_attentions: bool = False,
241
  ):
 
242
  fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
243
-
244
- # 3 x [batch_size, seq_length, num_heads, head_dim]
245
  (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
246
 
247
  batch_size, q_length, _, _ = query_layer.shape
@@ -254,20 +268,27 @@ class Attention(nn.Module):
254
  )
255
  value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_kv, q_length, self.head_dim)
256
 
257
- query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
 
 
 
 
 
 
 
258
 
259
  if layer_past is not None:
260
  past_key, past_value = layer_past
261
- # concatenate along seq_length dimension:
262
- # - key: [batch_size * self.num_heads, head_dim, kv_length]
263
- # - value: [batch_size * self.num_heads, kv_length, head_dim]
264
  key_layer = torch.cat((past_key, key_layer), dim=1)
265
  value_layer = torch.cat((past_value, value_layer), dim=1)
 
266
 
267
  _, kv_length, _ = key_layer.shape
268
 
269
  if use_cache is True:
270
- present = (key_layer, value_layer)
 
271
  else:
272
  present = None
273
 
@@ -275,10 +296,16 @@ class Attention(nn.Module):
275
  query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
276
  key_layer_ = key_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
277
  value_layer_ = value_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
 
278
 
279
- attn_output = F.scaled_dot_product_attention(
280
- query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
281
- )
 
 
 
 
 
282
 
283
  x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
284
  x = x.permute(0, 2, 1, 3)
@@ -475,8 +502,8 @@ class RWPreTrainedModel(PreTrainedModel):
475
  def _convert_to_rw_cache(
476
  past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
477
  ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
478
- batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape
479
- batch_size_times_num_heads = batch_size * num_heads
480
  # key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
481
  # value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
482
  return tuple(
@@ -488,6 +515,7 @@ class RWPreTrainedModel(PreTrainedModel):
488
  )
489
 
490
 
 
491
  class RWModel(RWPreTrainedModel):
492
  def __init__(self, config: RWConfig):
493
  super().__init__(config)
@@ -522,10 +550,11 @@ class RWModel(RWPreTrainedModel):
522
  device = attention_mask.device
523
  _, src_length = input_shape
524
 
525
- if src_length > 1:
526
- combined_attention_mask = _make_causal_mask(
527
- input_shape, device=device, past_key_values_length=past_key_values_length
528
- )
 
529
 
530
  # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
531
  expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
@@ -560,7 +589,7 @@ class RWModel(RWPreTrainedModel):
560
  )
561
  if len(deprecated_arguments) > 0:
562
  raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
563
-
564
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
565
  output_hidden_states = (
566
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -616,6 +645,7 @@ class RWModel(RWPreTrainedModel):
616
  input_shape=(batch_size, seq_length),
617
  past_key_values_length=past_key_values_length,
618
  )
 
619
 
620
  for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
621
 
@@ -646,16 +676,18 @@ class RWModel(RWPreTrainedModel):
646
  )
647
  else:
648
  outputs = block(
649
- hidden_states,
650
- layer_past=layer_past,
651
- attention_mask=causal_mask,
652
- head_mask=head_mask[i],
653
- use_cache=use_cache,
654
- output_attentions=output_attentions,
655
- alibi=alibi,
656
- )
 
657
 
658
  hidden_states = outputs[0]
 
659
  if use_cache is True:
660
  presents = presents + (outputs[1],)
661
 
@@ -704,16 +736,20 @@ class RWForCausalLM(RWPreTrainedModel):
704
  **kwargs,
705
  ) -> dict:
706
  # only last token for input_ids if past is not None
707
- if past:
 
708
  input_ids = input_ids[:, -1].unsqueeze(-1)
709
-
710
  # the cache may be in the stardard format (e.g. in contrastive search), convert to our's format if needed
711
- if past[0][0].shape[0] == input_ids.shape[0]:
712
- past = self._convert_to_rw_cache(past)
 
 
 
713
 
714
  return {
715
  "input_ids": input_ids,
716
- "past_key_values": past,
717
  "use_cache": kwargs.get("use_cache"),
718
  "attention_mask": attention_mask,
719
  }
 
11
  from torch import nn
12
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
13
  from torch.nn import functional as F
14
+ import pdb
15
+ import os
16
+ import pickle
17
  from transformers.modeling_outputs import (
18
  BaseModelOutputWithPastAndCrossAttentions,
19
  CausalLMOutputWithCrossAttentions,
 
89
 
90
  return self.cos_cached, self.sin_cached
91
 
92
+ def forward(self, q, k, past_seq_length=None):
93
+ if past_seq_length == None :
94
+ batch, seq_len, head_dim = q.shape
95
+ else :
96
+ # print("past_seq_length", past_seq_length)
97
+ batch, input_seq_len, head_dim = q.shape
98
+ seq_len = past_seq_length + input_seq_len
99
  cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
100
+ if past_seq_length != None :
101
+ return (q * cos[:, past_seq_length:, :]) + (rotate_half(q) * sin[:, past_seq_length:, :]), (k * cos[:, past_seq_length:, :]) + (rotate_half(k) * sin[:, past_seq_length:, :])
102
+ else :
103
+ return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
104
+
105
 
106
 
107
  def _make_causal_mask(
 
111
  mask = torch.empty((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device)
112
  # ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround
113
  seq_ids = torch.arange(target_length, device=device)
114
+ mask[:, past_key_values_length:] = seq_ids[:, None] >= seq_ids[None, :]
115
 
116
  if past_key_values_length > 0:
117
+ mask[:, :past_key_values_length] = True
118
 
119
  expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
120
  return expanded_mask
 
161
  out = residual + out
162
  return out
163
 
164
+ def dump_value(name, tensor) :
165
+ with open("/home/purushottam/inspect_falcon/{}".format(name), "wb") as f :
166
+ pickle.dump(tensor, f)
167
+
168
 
169
  class Attention(nn.Module):
170
  def __init__(self, config: RWConfig):
 
254
  use_cache: bool = False,
255
  output_attentions: bool = False,
256
  ):
257
+
258
  fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
 
 
259
  (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
260
 
261
  batch_size, q_length, _, _ = query_layer.shape
 
268
  )
269
  value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_kv, q_length, self.head_dim)
270
 
271
+ if layer_past is not None :
272
+ past_key, past_value = layer_past
273
+ past_kv_length = past_key.shape[2]
274
+ query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
275
+ else :
276
+ query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
277
+
278
+
279
 
280
  if layer_past is not None:
281
  past_key, past_value = layer_past
282
+ past_key = past_key.permute(0, 2, 1)
 
 
283
  key_layer = torch.cat((past_key, key_layer), dim=1)
284
  value_layer = torch.cat((past_value, value_layer), dim=1)
285
+
286
 
287
  _, kv_length, _ = key_layer.shape
288
 
289
  if use_cache is True:
290
+ key_layer_permute = key_layer.permute(0, 2, 1)
291
+ present = (key_layer_permute, value_layer)
292
  else:
293
  present = None
294
 
 
296
  query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
297
  key_layer_ = key_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
298
  value_layer_ = value_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
299
+
300
 
301
+ if attention_mask is not None :
302
+ attn_output = F.scaled_dot_product_attention(
303
+ query_layer_, key_layer_, value_layer_, attention_mask, 0.0, is_causal=False
304
+ )
305
+ else :
306
+ attn_output = F.scaled_dot_product_attention(
307
+ query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
308
+ )
309
 
310
  x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
311
  x = x.permute(0, 2, 1, 3)
 
502
  def _convert_to_rw_cache(
503
  past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
504
  ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
505
+ batch_size, seq_length, head_dim = past_key_value[0][0].shape
506
+ batch_size_times_num_heads = batch_size
507
  # key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
508
  # value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
509
  return tuple(
 
515
  )
516
 
517
 
518
+
519
  class RWModel(RWPreTrainedModel):
520
  def __init__(self, config: RWConfig):
521
  super().__init__(config)
 
550
  device = attention_mask.device
551
  _, src_length = input_shape
552
 
553
+
554
+ # if src_length > 1:
555
+ combined_attention_mask = _make_causal_mask(
556
+ input_shape, device=device, past_key_values_length=past_key_values_length
557
+ )
558
 
559
  # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
560
  expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
 
589
  )
590
  if len(deprecated_arguments) > 0:
591
  raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
592
+ # pdb.set_trace()
593
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
594
  output_hidden_states = (
595
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
645
  input_shape=(batch_size, seq_length),
646
  past_key_values_length=past_key_values_length,
647
  )
648
+ # print("causal_mask", causal_mask)
649
 
650
  for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
651
 
 
676
  )
677
  else:
678
  outputs = block(
679
+ hidden_states,
680
+ layer_past=layer_past,
681
+ attention_mask=causal_mask,
682
+ head_mask=head_mask[i],
683
+ use_cache=use_cache,
684
+ output_attentions=output_attentions,
685
+ alibi=alibi,
686
+ )
687
+
688
 
689
  hidden_states = outputs[0]
690
+
691
  if use_cache is True:
692
  presents = presents + (outputs[1],)
693
 
 
736
  **kwargs,
737
  ) -> dict:
738
  # only last token for input_ids if past is not None
739
+ # only last token for input_ids if past is not None
740
+ if kwargs.get("past_key_values", None) :
741
  input_ids = input_ids[:, -1].unsqueeze(-1)
742
+ past_key_values = kwargs["past_key_values"]
743
  # the cache may be in the stardard format (e.g. in contrastive search), convert to our's format if needed
744
+ # if kwargs["past_key_values"][0][0].shape[0] == input_ids.shape[0]:
745
+ # past_key_values = self._convert_to_rw_cache(kwargs["past_key_values"])
746
+ # past_key_values = kwargs["past_key_values"]
747
+ else :
748
+ past_key_values = None
749
 
750
  return {
751
  "input_ids": input_ids,
752
+ "past_key_values": past_key_values,
753
  "use_cache": kwargs.get("use_cache"),
754
  "attention_mask": attention_mask,
755
  }