gugarosa commited on
Commit
de35f90
1 Parent(s): d38e6f9

Adds support for MQA/GQA and attention mask during training.

Browse files
README.md CHANGED
@@ -127,7 +127,7 @@ with torch.autocast(model.device.type, dtype=torch.float16, enabled=True):
127
  ```
128
 
129
  **Remark.** In the generation function, our model currently does not support beam search (`num_beams` > 1).
130
- Furthermore, in the forward pass of the model, we currently do not support attention mask during training, outputting hidden states or attention values, or using custom input embeddings (instead of the model's).
131
 
132
  ### Citation
133
 
 
127
  ```
128
 
129
  **Remark.** In the generation function, our model currently does not support beam search (`num_beams` > 1).
130
+ Furthermore, in the forward pass of the model, we currently do not support outputting hidden states or attention values, or using custom input embeddings (instead of the model's).
131
 
132
  ### Citation
133
 
configuration_mixformer_sequential.py CHANGED
@@ -2,7 +2,7 @@
2
  # Licensed under the MIT license.
3
 
4
  import math
5
- from typing import Any, Dict, List, Optional, Union
6
 
7
  from transformers import PretrainedConfig
8
 
@@ -27,6 +27,7 @@ class MixFormerSequentialConfig(PretrainedConfig):
27
  n_layer: Optional[int] = 20,
28
  n_inner: Optional[int] = None,
29
  n_head: Optional[int] = 16,
 
30
  rotary_dim: Optional[int] = 32,
31
  activation_function: Optional[str] = "gelu_new",
32
  embd_pdrop: Optional[float] = 0.0,
@@ -43,6 +44,7 @@ class MixFormerSequentialConfig(PretrainedConfig):
43
  self.n_layer = n_layer
44
  self.n_inner = n_inner
45
  self.n_head = n_head
 
46
  self.rotary_dim = min(rotary_dim, n_embd // n_head)
47
  self.activation_function = activation_function
48
  self.embd_pdrop = embd_pdrop
 
2
  # Licensed under the MIT license.
3
 
4
  import math
5
+ from typing import Optional
6
 
7
  from transformers import PretrainedConfig
8
 
 
27
  n_layer: Optional[int] = 20,
28
  n_inner: Optional[int] = None,
29
  n_head: Optional[int] = 16,
30
+ n_head_kv: Optional[int] = None,
31
  rotary_dim: Optional[int] = 32,
32
  activation_function: Optional[str] = "gelu_new",
33
  embd_pdrop: Optional[float] = 0.0,
 
44
  self.n_layer = n_layer
45
  self.n_inner = n_inner
46
  self.n_head = n_head
47
+ self.n_head_kv = n_head_kv
48
  self.rotary_dim = min(rotary_dim, n_embd // n_head)
49
  self.activation_function = activation_function
50
  self.embd_pdrop = embd_pdrop
modeling_mixformer_sequential.py CHANGED
@@ -34,20 +34,20 @@
34
  from __future__ import annotations
35
 
36
  import math
37
- import copy
38
  from typing import Any, Dict, Optional, Tuple, Union
39
  from dataclasses import dataclass, field
40
 
41
  import torch
42
  import torch.nn as nn
43
 
44
- from einops import rearrange
45
  from transformers.activations import ACT2FN
46
  from transformers import PretrainedConfig, PreTrainedModel
47
  from transformers.modeling_outputs import CausalLMOutputWithPast
48
 
49
  from .configuration_mixformer_sequential import MixFormerSequentialConfig
50
 
 
51
  @dataclass
52
  class InferenceParams:
53
  """Inference parameters passed to model to efficiently calculate
@@ -57,21 +57,20 @@ class InferenceParams:
57
  https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/generation.py.
58
 
59
  Args:
60
- max_sequence_len: Maximum sequence length.
61
  max_batch_size: Maximum batch size.
62
- sequence_len_offset: Sequence length offset.
63
  batch_size_offset: Batch size offset.
64
  key_value_memory_dict: Key value memory dictionary.
65
- fused_ft_kernel: Whether to use fused kernel for fast inference.
66
  lengths_per_sample: Lengths per sample.
67
 
68
  """
69
 
70
- max_sequence_len: int = field(metadata={"help": "Maximum sequence length."})
71
 
72
  max_batch_size: int = field(metadata={"help": "Maximum batch size."})
73
 
74
- sequence_len_offset: int = field(default=0, metadata={"help": "Sequence length offset."})
75
 
76
  batch_size_offset: int = field(default=0, metadata={"help": "Batch size offset."})
77
 
@@ -79,8 +78,6 @@ class InferenceParams:
79
  default_factory=dict, metadata={"help": "Key value memory dictionary."}
80
  )
81
 
82
- fused_ft_kernel: bool = field(default=False, metadata={"help": "Whether to use fused kernel for fast inference."})
83
-
84
  lengths_per_sample: torch.Tensor = field(default=None, metadata={"help": "Lengths per sample."})
85
 
86
 
@@ -103,12 +100,112 @@ class Embedding(nn.Module):
103
  return hidden_states
104
 
105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  class RotaryEmbedding(nn.Module):
107
- """Rotary embeddings.
108
 
109
  Reference:
110
- https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py.
111
-
 
112
  """
113
 
114
  def __init__(
@@ -131,14 +228,14 @@ class RotaryEmbedding(nn.Module):
131
  self.device = device
132
 
133
  inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
134
- self.register_buffer("inv_freq", inv_freq)
135
 
136
  scale = (
137
  (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
138
  if scale_base is not None
139
  else None
140
  )
141
- self.register_buffer("scale", scale)
142
 
143
  self._seq_len_cached = 0
144
  self._cos_cached = None
@@ -146,28 +243,26 @@ class RotaryEmbedding(nn.Module):
146
  self._cos_k_cached = None
147
  self._sin_k_cached = None
148
 
149
- def _update_cos_sin_cache(self, x: torch.FloatTensor, seqlen_offset: int = 0) -> None:
150
- # Reset the tables if the sequence length has changed,
151
- # or if we're on a new device (possibly due to tracing for instance)
152
- seqlen = x.shape[1] + seqlen_offset
153
-
154
  # Re-generate the inverse frequency buffer if it's not fp32
155
  # (for instance if model.half() was called)
156
  if self.inv_freq.dtype != "torch.float32":
157
  self.inv_freq = 1.0 / (
158
- self.base ** (torch.arange(0, self.dim, 2, device=self.device, dtype=torch.float32) / self.dim)
159
  )
160
 
161
- if seqlen > self._seq_len_cached or self._cos_cached.device != x.device or self._cos_cached.dtype != x.dtype:
162
  self._seq_len_cached = seqlen
163
- t = torch.arange(seqlen, device=x.device, dtype=torch.float32)
164
 
165
  # Don't do einsum, it converts fp32 to fp16
166
  # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
167
  freqs = torch.outer(t, self.inv_freq.to(device=t.device, dtype=torch.float32))
168
  if self.scale is None:
169
- self._cos_cached = torch.cos(freqs).to(x.dtype)
170
- self._sin_cached = torch.sin(freqs).to(x.dtype)
171
  else:
172
  power = (
173
  torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
@@ -175,62 +270,32 @@ class RotaryEmbedding(nn.Module):
175
  scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
176
 
177
  # We want the multiplication by scale to happen in fp32
178
- self._cos_cached = (torch.cos(freqs) * scale).to(x.dtype)
179
- self._sin_cached = (torch.sin(freqs) * scale).to(x.dtype)
180
- self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype)
181
- self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype)
182
 
183
- def _apply_rotary_emb_qkv(
184
  self,
185
- qkv: torch.FloatTensor,
186
- sin: torch.FloatTensor,
187
- cos: torch.FloatTensor,
188
- sin_k: Optional[torch.FloatTensor] = None,
189
- cos_k: Optional[torch.FloatTensor] = None,
190
- ) -> torch.FloatTensor:
191
- _, seqlen, three, _, headdim = qkv.shape
192
- assert three == 3
193
-
194
- rotary_seqlen, rotary_dim = cos.shape
195
- rotary_dim *= 2
196
- assert rotary_dim <= headdim
197
- assert seqlen <= rotary_seqlen
198
-
199
- cos_k = cos if cos_k is None else cos_k
200
- sin_k = sin if sin_k is None else sin_k
201
- assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2)
202
-
203
- q_rot = qkv[:, :, 0, :, :rotary_dim]
204
- q_pass = qkv[:, :, 0, :, rotary_dim:]
205
-
206
- k_rot = qkv[:, :, 1, :, :rotary_dim]
207
- k_pass = qkv[:, :, 1, :, rotary_dim:]
208
-
209
- # Splits the queries and keys in half
210
- q1, q2 = q_rot.chunk(2, dim=-1)
211
- k1, k2 = k_rot.chunk(2, dim=-1)
212
- c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
213
-
214
- # Casts to fp32 are necessary to prevent fp16 overflow issues
215
- q1, q2, k1, k2, c, s = [t.to(dtype=torch.float32) for t in [q1, q2, k1, k2, c, s]]
216
-
217
- # Computes the new keys and queries, recasting to original dtype
218
- q_rot = torch.cat([q1 * c - q2 * s, q1 * s + q2 * c], axis=-1).to(qkv.dtype)
219
- k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(qkv.dtype)
220
-
221
- return torch.cat(
222
- [
223
- torch.cat([q_rot, q_pass], axis=-1).unsqueeze(2),
224
- torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2),
225
- qkv[:, :, 2:3, :, :],
226
- ],
227
- axis=2,
228
- )
229
 
230
- def forward(self, qkv: torch.Tensor, seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
231
- # `qkv` is of shape (batch, seqlen, 3, nheads, headdim)
232
- self._update_cos_sin_cache(qkv, seqlen_offset)
233
- return self._apply_rotary_emb_qkv(qkv, self._sin_cached[seqlen_offset:], self._cos_cached[seqlen_offset:])
234
 
235
 
236
  class MLP(nn.Module):
@@ -290,21 +355,22 @@ class SelfAttention(nn.Module):
290
  attention_mask: Optional[torch.BoolTensor] = None,
291
  **kwargs,
292
  ) -> torch.FloatTensor:
293
- causal = self.causal if causal is None else causal
294
- batch_size, seq_len = qkv.shape[0], qkv.shape[1]
295
  q, k, v = qkv.unbind(dim=2)
296
 
 
297
  softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
 
298
  scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
299
 
300
  if attention_mask is not None:
301
- padding_mask = torch.full((batch_size, seq_len), -10000.0, dtype=scores.dtype, device=scores.device)
302
  padding_mask.masked_fill_(attention_mask, 0.0)
303
 
304
  scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
305
 
306
  if causal:
307
- causal_mask = torch.triu(torch.full((seq_len, seq_len), -10000.0, device=scores.device), 1)
308
  scores = scores + causal_mask.to(dtype=scores.dtype)
309
 
310
  attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
@@ -343,25 +409,31 @@ class CrossAttention(nn.Module):
343
  attention_mask: Optional[torch.BoolTensor] = None,
344
  **kwargs,
345
  ) -> torch.FloatTensor:
346
- causal = self.causal if causal is None else causal
347
- batch_size, seq_len_q = q.shape[0], q.shape[1]
348
- assert kv.shape[0] == batch_size and kv.shape[3] == q.shape[2] and kv.shape[4] == q.shape[3]
349
 
350
- seq_len_k = kv.shape[1]
 
351
  k, v = kv.unbind(dim=2)
352
 
 
353
  softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
 
354
  scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
355
 
356
  if attention_mask is not None:
357
- padding_mask = torch.full((batch_size, seq_len_k), -10000.0, dtype=scores.dtype, device=scores.device)
358
  padding_mask.masked_fill_(attention_mask, 0.0)
359
 
360
  scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
361
 
362
  if causal:
363
- causal_mask = torch.triu(torch.full((seq_len_q, seq_len_k), -10000.0, device=scores.device), 1)
364
- scores = scores + causal_mask.to(dtype=scores.dtype)
 
 
 
365
 
366
  attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
367
  attention = self.drop(attention)
@@ -371,21 +443,12 @@ class CrossAttention(nn.Module):
371
  return output
372
 
373
 
374
- def find_mha_dims(
375
- config: PretrainedConfig, n_head: Optional[int] = None, head_dim: Optional[int] = None
 
 
 
376
  ) -> Tuple[int, int]:
377
- """Validate and return the number of heads and head dimension for multi-head attention.
378
-
379
- Args:
380
- config: Model configuration.
381
- n_head: Number of heads.
382
- head_dim: Head dimension.
383
-
384
- Returns:
385
- Number of heads and head dimension.
386
-
387
- """
388
-
389
  assert all(
390
  hasattr(config, attr) for attr in ["n_embd", "n_head"]
391
  ), "`config` must have `n_embd` and `n_head` attributes."
@@ -401,31 +464,20 @@ def find_mha_dims(
401
  elif n_head is None or head_dim is None:
402
  raise ValueError("`n_head` and `head_dim` must be both specified or `None`.")
403
 
404
- return n_head, head_dim
405
-
406
-
407
- def update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, layer_idx: int) -> torch.FloatTensor:
408
- """Update the key-value cache for inference.
409
-
410
- Reference:
411
- https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py.
412
-
413
- Args:
414
- kv: Key-value tensor.
415
- inference_params: Inference parameters.
416
- layer_idx: Layer index.
417
 
418
- Returns:
419
- Updated key-value tensor.
420
 
421
- """
422
 
 
423
  num_heads, head_dim = kv.shape[-2:]
424
 
425
  if layer_idx not in inference_params.key_value_memory_dict:
426
  kv_cache = torch.empty(
427
  inference_params.max_batch_size,
428
- inference_params.max_sequence_len,
429
  2,
430
  num_heads,
431
  head_dim,
@@ -434,43 +486,19 @@ def update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, la
434
  )
435
  inference_params.key_value_memory_dict[layer_idx] = kv_cache
436
  else:
437
- if not inference_params.fused_ft_kernel:
438
- kv_cache = inference_params.key_value_memory_dict[layer_idx]
439
- else:
440
- k_cache, v_cache = inference_params.key_value_memory_dict[layer_idx]
441
- kv_cache = None
442
 
443
  batch_start = inference_params.batch_size_offset
444
  batch_end = batch_start + kv.shape[0]
445
- assert batch_end <= (kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0])
446
 
447
- sequence_start = inference_params.sequence_len_offset
448
  sequence_end = sequence_start + kv.shape[1]
449
- assert sequence_end <= (kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2])
450
-
451
- if not inference_params.fused_ft_kernel:
452
- assert kv_cache is not None
453
-
454
- kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
455
- kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
456
 
457
- return kv
458
-
459
- assert inference_params.sequence_len_offset == 0
460
- assert kv.dtype in [torch.float16, torch.bfloat16, torch.float32]
461
-
462
- packsize = 4 if kv.dtype == torch.float32 else 8
463
-
464
- if kv_cache is not None:
465
- kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
466
- k_cache = rearrange(kv_cache[:, :, 0], "b s h (d packsize) -> b h d s packsize", packsize=packsize).contiguous()
467
- v_cache = rearrange(kv_cache[:, :, 1], "b s h d -> b h s d").contiguous()
468
- inference_params.key_value_memory_dict[layer_idx] = (k_cache, v_cache)
469
- else:
470
- k_cache[batch_start:batch_end, :, :, :sequence_end, :] = rearrange(
471
- kv[:, :, 0], "b s h (d packsize) -> b h d s packsize", packsize=packsize
472
- )
473
- v_cache[batch_start:batch_end, :, :sequence_end, :] = rearrange(kv[:, :, 1], "b s h d -> b h s d")
474
 
475
  return kv
476
 
@@ -486,6 +514,7 @@ class MHA(nn.Module):
486
  rotary_dim: Optional[int] = None,
487
  rotary_emb_scale_base: Optional[float] = None,
488
  n_head: Optional[int] = None,
 
489
  head_dim: Optional[int] = None,
490
  bias: bool = True,
491
  causal: bool = True,
@@ -506,12 +535,12 @@ class MHA(nn.Module):
506
  self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, **rotary_kwargs)
507
 
508
  # MLP
509
- self.n_head, self.head_dim = find_mha_dims(config, n_head, head_dim)
510
- op_size = self.n_head * self.head_dim
511
  hidden_size = config.n_embd
512
 
513
- self.Wqkv = nn.Linear(hidden_size, 3 * op_size, bias=bias, device=device, dtype=dtype)
514
- self.out_proj = nn.Linear(op_size, hidden_size, bias=bias, device=device, dtype=dtype)
515
 
516
  # Attention
517
  self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
@@ -521,40 +550,75 @@ class MHA(nn.Module):
521
  self.return_residual = return_residual
522
  self.checkpointing = checkpointing
523
 
524
- def forward(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
525
  self,
526
  x: torch.FloatTensor,
527
- past_key_values: Optional[InferenceParams] = None,
528
- attention_mask: Optional[torch.BoolTensor] = None,
529
- cu_seqlens: Optional[torch.LongTensor] = None,
530
- max_seqlen: Optional[int] = None,
531
- **kwargs,
532
- ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
533
  qkv = self.Wqkv(x)
534
- qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
535
 
536
- seqlen_offset = past_key_values.sequence_len_offset if past_key_values is not None else 0
 
 
 
 
 
 
 
537
  if self.rotary_emb_dim > 0:
538
- qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset)
539
 
540
  if past_key_values is not None:
541
- kv = update_kv_cache(qkv[:, :, 1:], past_key_values, self.layer_idx)
542
 
543
- if attention_mask is not None:
544
- attention_mask = attention_mask[0] if isinstance(attention_mask, tuple) else attention_mask
545
- attention_mask = attention_mask.bool().to(qkv.device)
 
546
 
547
- attention_kwargs = {"attention_mask": attention_mask}
548
 
549
- if past_key_values is None or seqlen_offset == 0:
550
- if self.checkpointing:
551
- attn_output = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **attention_kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
552
  else:
553
- attn_output = self.inner_attn(qkv, **attention_kwargs)
 
 
 
554
  else:
555
- q = qkv[:, :, 0]
556
- causal = None if past_key_values.sequence_len_offset == 0 else False
557
- attn_output = self.inner_cross_attn(q, kv, causal=causal, **attention_kwargs)
558
 
559
  output = rearrange(attn_output, "... h d -> ... (h d)")
560
  output = self.out_proj(output)
@@ -672,38 +736,29 @@ class MixFormerSequentialPreTrainedModel(PreTrainedModel):
672
  if module.padding_idx is not None:
673
  module.weight.data[module.padding_idx].zero_()
674
  elif isinstance(module, nn.LayerNorm):
675
- module.bias.data.zero_()
 
676
  module.weight.data.fill_(1.0)
677
 
678
  def prepare_inputs_for_generation(
679
  self,
680
  input_ids: torch.LongTensor,
681
  past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
682
- attention_mask: Optional[torch.BoolTensor] = None,
683
  **kwargs,
684
  ) -> Dict[str, Any]:
685
- if attention_mask is not None and torch.any(~attention_mask.bool()):
686
- total_seq_len = torch.sum(attention_mask, dim=1)
687
- max_seq_len = torch.max(total_seq_len)
688
-
689
- total_seq_len = torch.cat((torch.tensor([0], device=attention_mask.device), total_seq_len)).unsqueeze(1)
690
- cumulative_seq_len = torch.cumsum(total_seq_len, dim=0).squeeze(1).to(torch.int32)
691
- attention_mask = (attention_mask.bool(), cumulative_seq_len, max_seq_len.item())
692
- else:
693
- attention_mask = None
694
-
695
  if past_key_values is None or not (isinstance(past_key_values, InferenceParams)):
696
  past_key_values = InferenceParams(
 
697
  max_batch_size=input_ids.shape[0],
698
- max_sequence_len=self.config.n_positions,
699
- sequence_len_offset=0,
700
  batch_size_offset=0,
701
- fused_ft_kernel=False,
702
  key_value_memory_dict={},
 
703
  )
704
  else:
705
  # Assume that `past_key_values` has cached all tokens up to the last token in `input_ids`
706
- past_key_values.sequence_len_offset = len(input_ids[0]) - 1
707
  input_ids = input_ids[:, -1].unsqueeze(-1)
708
 
709
  return {
@@ -712,9 +767,9 @@ class MixFormerSequentialPreTrainedModel(PreTrainedModel):
712
  "attention_mask": attention_mask,
713
  }
714
 
715
- def _set_gradient_checkpointing(self, module, value=False):
716
- if isinstance(module, MixFormerSequentialPreTrainedModel):
717
- module.gradient_checkpointing = value
718
 
719
 
720
  class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel):
@@ -756,13 +811,10 @@ class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel):
756
  labels: Optional[torch.LongTensor] = None,
757
  **kwargs,
758
  ) -> CausalLMOutputWithPast:
759
- if past_key_values is None and attention_mask is None:
760
- lm_logits = self.layers(input_ids)
761
- else:
762
- hidden_layer = self.layers[0](input_ids)
763
- for module in self.layers[1:-1]:
764
- hidden_layer = module(hidden_layer, past_key_values=past_key_values, attention_mask=attention_mask)
765
- lm_logits = self.layers[-1](hidden_layer)
766
 
767
  loss = None
768
  if labels is not None:
 
34
  from __future__ import annotations
35
 
36
  import math
 
37
  from typing import Any, Dict, Optional, Tuple, Union
38
  from dataclasses import dataclass, field
39
 
40
  import torch
41
  import torch.nn as nn
42
 
43
+ from einops import rearrange, repeat
44
  from transformers.activations import ACT2FN
45
  from transformers import PretrainedConfig, PreTrainedModel
46
  from transformers.modeling_outputs import CausalLMOutputWithPast
47
 
48
  from .configuration_mixformer_sequential import MixFormerSequentialConfig
49
 
50
+
51
  @dataclass
52
  class InferenceParams:
53
  """Inference parameters passed to model to efficiently calculate
 
57
  https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/generation.py.
58
 
59
  Args:
60
+ max_seqlen: Maximum sequence length.
61
  max_batch_size: Maximum batch size.
62
+ seqlen_offset: Sequence length offset.
63
  batch_size_offset: Batch size offset.
64
  key_value_memory_dict: Key value memory dictionary.
 
65
  lengths_per_sample: Lengths per sample.
66
 
67
  """
68
 
69
+ max_seqlen: int = field(metadata={"help": "Maximum sequence length."})
70
 
71
  max_batch_size: int = field(metadata={"help": "Maximum batch size."})
72
 
73
+ seqlen_offset: int = field(default=0, metadata={"help": "Sequence length offset."})
74
 
75
  batch_size_offset: int = field(default=0, metadata={"help": "Batch size offset."})
76
 
 
78
  default_factory=dict, metadata={"help": "Key value memory dictionary."}
79
  )
80
 
 
 
81
  lengths_per_sample: torch.Tensor = field(default=None, metadata={"help": "Lengths per sample."})
82
 
83
 
 
100
  return hidden_states
101
 
102
 
103
+ def _apply_rotary_emb(
104
+ x: torch.FloatTensor,
105
+ cos: torch.FloatTensor,
106
+ sin: torch.FloatTensor,
107
+ ) -> torch.FloatTensor:
108
+ _, seqlen, _, head_dim = x.shape
109
+ rotary_seqlen, rotary_dim = cos.shape
110
+ rotary_dim *= 2
111
+
112
+ assert rotary_dim <= head_dim
113
+ assert seqlen <= rotary_seqlen
114
+ assert cos.shape == sin.shape == (rotary_seqlen, rotary_dim // 2)
115
+
116
+ x_rot = x[:, :, :, :rotary_dim]
117
+ x_pass = x[:, :, :, rotary_dim:]
118
+
119
+ x1, x2 = x_rot.chunk(2, dim=-1)
120
+ c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
121
+ x1, x2, c, s = [t.to(dtype=torch.float32) for t in [x1, x2, c, s]]
122
+
123
+ x_rot = torch.cat([x1 * c - x2 * s, x1 * s + x2 * c], axis=-1).to(x.dtype)
124
+
125
+ return torch.cat([x_rot, x_pass], axis=-1)
126
+
127
+
128
+ def _apply_rotary_emb_kv(
129
+ kv: torch.FloatTensor,
130
+ cos: torch.FloatTensor,
131
+ sin: torch.FloatTensor,
132
+ cos_k: Optional[torch.FloatTensor] = None,
133
+ sin_k: Optional[torch.FloatTensor] = None,
134
+ ) -> torch.FloatTensor:
135
+ _, seqlen, two, _, head_dim = kv.shape
136
+ assert two == 2
137
+
138
+ rotary_seqlen, rotary_dim = cos.shape
139
+ rotary_dim *= 2
140
+ assert rotary_dim <= head_dim
141
+ assert seqlen <= rotary_seqlen
142
+ assert cos.shape == sin.shape == (rotary_seqlen, rotary_dim // 2)
143
+
144
+ k_rot = kv[:, :, 0, :, :rotary_dim]
145
+ k_pass = kv[:, :, 0, :, rotary_dim:]
146
+
147
+ k1, k2 = k_rot.chunk(2, dim=-1)
148
+ c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
149
+ k1, k2, c, s = [t.to(dtype=torch.float32) for t in [k1, k2, c, s]]
150
+
151
+ k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(kv.dtype)
152
+
153
+ return torch.cat(
154
+ [
155
+ torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2),
156
+ kv[:, :, 1:2, :, :],
157
+ ],
158
+ axis=2,
159
+ )
160
+
161
+
162
+ def _apply_rotary_emb_qkv(
163
+ qkv: torch.FloatTensor,
164
+ cos: torch.FloatTensor,
165
+ sin: torch.FloatTensor,
166
+ cos_k: Optional[torch.FloatTensor] = None,
167
+ sin_k: Optional[torch.FloatTensor] = None,
168
+ ) -> torch.FloatTensor:
169
+ _, seqlen, three, _, head_dim = qkv.shape
170
+ assert three == 3
171
+
172
+ rotary_seqlen, rotary_dim = cos.shape
173
+ rotary_dim *= 2
174
+ assert rotary_dim <= head_dim
175
+ assert seqlen <= rotary_seqlen
176
+ assert cos.shape == sin.shape == (rotary_seqlen, rotary_dim // 2)
177
+
178
+ q_rot = qkv[:, :, 0, :, :rotary_dim]
179
+ q_pass = qkv[:, :, 0, :, rotary_dim:]
180
+
181
+ k_rot = qkv[:, :, 1, :, :rotary_dim]
182
+ k_pass = qkv[:, :, 1, :, rotary_dim:]
183
+
184
+ q1, q2 = q_rot.chunk(2, dim=-1)
185
+ k1, k2 = k_rot.chunk(2, dim=-1)
186
+ c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
187
+ q1, q2, k1, k2, c, s = [t.to(dtype=torch.float32) for t in [q1, q2, k1, k2, c, s]]
188
+
189
+ q_rot = torch.cat([q1 * c - q2 * s, q1 * s + q2 * c], axis=-1).to(qkv.dtype)
190
+ k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(qkv.dtype)
191
+
192
+ return torch.cat(
193
+ [
194
+ torch.cat([q_rot, q_pass], axis=-1).unsqueeze(2),
195
+ torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2),
196
+ qkv[:, :, 2:3, :, :],
197
+ ],
198
+ axis=2,
199
+ )
200
+
201
+
202
  class RotaryEmbedding(nn.Module):
203
+ """Rotary positional embedding (RoPE).
204
 
205
  Reference:
206
+ RoFormer: Enhanced Transformer with Rotary Position Embedding.
207
+ https://arxiv.org/pdf/2104.09864.pdf.
208
+
209
  """
210
 
211
  def __init__(
 
228
  self.device = device
229
 
230
  inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
231
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
232
 
233
  scale = (
234
  (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
235
  if scale_base is not None
236
  else None
237
  )
238
+ self.register_buffer("scale", scale, persistent=False)
239
 
240
  self._seq_len_cached = 0
241
  self._cos_cached = None
 
243
  self._cos_k_cached = None
244
  self._sin_k_cached = None
245
 
246
+ def _update_cos_sin_cache(
247
+ self, seqlen: int, device: Optional[str] = None, dtype: Optional[torch.dtype] = None
248
+ ) -> None:
 
 
249
  # Re-generate the inverse frequency buffer if it's not fp32
250
  # (for instance if model.half() was called)
251
  if self.inv_freq.dtype != "torch.float32":
252
  self.inv_freq = 1.0 / (
253
+ self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)
254
  )
255
 
256
+ if seqlen > self._seq_len_cached or self._cos_cached.device != device or self._cos_cached.dtype != dtype:
257
  self._seq_len_cached = seqlen
258
+ t = torch.arange(seqlen, device=device, dtype=torch.float32)
259
 
260
  # Don't do einsum, it converts fp32 to fp16
261
  # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
262
  freqs = torch.outer(t, self.inv_freq.to(device=t.device, dtype=torch.float32))
263
  if self.scale is None:
264
+ self._cos_cached = torch.cos(freqs).to(dtype)
265
+ self._sin_cached = torch.sin(freqs).to(dtype)
266
  else:
267
  power = (
268
  torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
 
270
  scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
271
 
272
  # We want the multiplication by scale to happen in fp32
273
+ self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
274
+ self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
275
+ self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
276
+ self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
277
 
278
+ def forward(
279
  self,
280
+ qkv: torch.Tensor,
281
+ kv: Optional[torch.Tensor] = None,
282
+ seqlen_offset: int = 0,
283
+ max_seqlen: Optional[int] = None,
284
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
285
+ seqlen = qkv.shape[1]
286
+
287
+ if max_seqlen is not None:
288
+ self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
289
+ else:
290
+ self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
291
+
292
+ if kv is None:
293
+ return _apply_rotary_emb_qkv(qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:])
294
+ else:
295
+ q = _apply_rotary_emb(qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:])
296
+ kv = _apply_rotary_emb_kv(kv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
 
298
+ return q, kv
 
 
 
299
 
300
 
301
  class MLP(nn.Module):
 
355
  attention_mask: Optional[torch.BoolTensor] = None,
356
  **kwargs,
357
  ) -> torch.FloatTensor:
358
+ batch_size, seqlen = qkv.shape[0], qkv.shape[1]
 
359
  q, k, v = qkv.unbind(dim=2)
360
 
361
+ causal = self.causal if causal is None else causal
362
  softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
363
+
364
  scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
365
 
366
  if attention_mask is not None:
367
+ padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device)
368
  padding_mask.masked_fill_(attention_mask, 0.0)
369
 
370
  scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
371
 
372
  if causal:
373
+ causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
374
  scores = scores + causal_mask.to(dtype=scores.dtype)
375
 
376
  attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
 
409
  attention_mask: Optional[torch.BoolTensor] = None,
410
  **kwargs,
411
  ) -> torch.FloatTensor:
412
+ batch_size, seqlen_q = q.shape[0], q.shape[1]
413
+ seqlen_k = kv.shape[1]
414
+ assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
415
 
416
+ if kv.shape[3] != q.shape[2]:
417
+ kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
418
  k, v = kv.unbind(dim=2)
419
 
420
+ causal = self.causal if causal is None else causal
421
  softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
422
+
423
  scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
424
 
425
  if attention_mask is not None:
426
+ padding_mask = torch.full((batch_size, seqlen_k), -10000.0, dtype=scores.dtype, device=scores.device)
427
  padding_mask.masked_fill_(attention_mask, 0.0)
428
 
429
  scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
430
 
431
  if causal:
432
+ rows = rearrange(torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1")
433
+ cols = torch.arange(seqlen_k, device=k.device, dtype=torch.long)
434
+ causal_mask = cols > rows + seqlen_k - seqlen_q
435
+
436
+ scores = scores.masked_fill(causal_mask, -10000.0)
437
 
438
  attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
439
  attention = self.drop(attention)
 
443
  return output
444
 
445
 
446
+ def _find_mha_dims(
447
+ config: PretrainedConfig,
448
+ n_head: Optional[int] = None,
449
+ n_head_kv: Optional[int] = None,
450
+ head_dim: Optional[int] = None,
451
  ) -> Tuple[int, int]:
 
 
 
 
 
 
 
 
 
 
 
 
452
  assert all(
453
  hasattr(config, attr) for attr in ["n_embd", "n_head"]
454
  ), "`config` must have `n_embd` and `n_head` attributes."
 
464
  elif n_head is None or head_dim is None:
465
  raise ValueError("`n_head` and `head_dim` must be both specified or `None`.")
466
 
467
+ if n_head_kv is None:
468
+ n_head_kv = getattr(config, "n_head_kv", None) or n_head
469
+ assert n_head % n_head_kv == 0, "`n_head` must be divisible by `n_head_kv`."
 
 
 
 
 
 
 
 
 
 
470
 
471
+ return n_head, n_head_kv, head_dim
 
472
 
 
473
 
474
+ def _update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, layer_idx: int) -> torch.FloatTensor:
475
  num_heads, head_dim = kv.shape[-2:]
476
 
477
  if layer_idx not in inference_params.key_value_memory_dict:
478
  kv_cache = torch.empty(
479
  inference_params.max_batch_size,
480
+ inference_params.max_seqlen,
481
  2,
482
  num_heads,
483
  head_dim,
 
486
  )
487
  inference_params.key_value_memory_dict[layer_idx] = kv_cache
488
  else:
489
+ kv_cache = inference_params.key_value_memory_dict[layer_idx]
 
 
 
 
490
 
491
  batch_start = inference_params.batch_size_offset
492
  batch_end = batch_start + kv.shape[0]
493
+ assert batch_end <= kv_cache.shape[0]
494
 
495
+ sequence_start = inference_params.seqlen_offset
496
  sequence_end = sequence_start + kv.shape[1]
497
+ assert sequence_end <= kv_cache.shape[1]
 
 
 
 
 
 
498
 
499
+ assert kv_cache is not None
500
+ kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
501
+ kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
502
 
503
  return kv
504
 
 
514
  rotary_dim: Optional[int] = None,
515
  rotary_emb_scale_base: Optional[float] = None,
516
  n_head: Optional[int] = None,
517
+ n_head_kv: Optional[int] = None,
518
  head_dim: Optional[int] = None,
519
  bias: bool = True,
520
  causal: bool = True,
 
535
  self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, **rotary_kwargs)
536
 
537
  # MLP
538
+ self.n_head, self.n_head_kv, self.head_dim = _find_mha_dims(config, n_head=n_head, n_head_kv=n_head_kv, head_dim=head_dim)
539
+ op_size = self.head_dim * (self.n_head + 2 * self.n_head_kv)
540
  hidden_size = config.n_embd
541
 
542
+ self.Wqkv = nn.Linear(hidden_size, op_size, bias=bias, device=device, dtype=dtype)
543
+ self.out_proj = nn.Linear(hidden_size, hidden_size, bias=bias, device=device, dtype=dtype)
544
 
545
  # Attention
546
  self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
 
550
  self.return_residual = return_residual
551
  self.checkpointing = checkpointing
552
 
553
+ def _forward_self_attn(
554
+ self, x: torch.FloatTensor, attention_mask: Optional[torch.BoolTensor]
555
+ ) -> torch.FloatTensor:
556
+ qkv = self.Wqkv(x)
557
+ qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
558
+
559
+ if self.rotary_emb_dim > 0:
560
+ qkv = self.rotary_emb(qkv)
561
+
562
+ if self.checkpointing:
563
+ return torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, attention_mask=attention_mask)
564
+
565
+ return self.inner_attn(qkv, attention_mask=attention_mask)
566
+
567
+ def _forward_cross_attn(
568
  self,
569
  x: torch.FloatTensor,
570
+ past_key_values: Optional[InferenceParams],
571
+ attention_mask: Optional[torch.BoolTensor],
572
+ ) -> torch.FloatTensor:
 
 
 
573
  qkv = self.Wqkv(x)
 
574
 
575
+ q = qkv[..., : self.n_head * self.head_dim]
576
+ q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
577
+
578
+ kv = qkv[..., self.n_head * self.head_dim :]
579
+ kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
580
+
581
+ seqlen_offset = past_key_values.seqlen_offset if past_key_values is not None else 0
582
+ causal = None if seqlen_offset == 0 else False
583
  if self.rotary_emb_dim > 0:
584
+ q, kv = self.rotary_emb(q, kv=kv, seqlen_offset=seqlen_offset)
585
 
586
  if past_key_values is not None:
587
+ kv = _update_kv_cache(kv, past_key_values, self.layer_idx)
588
 
589
+ if self.checkpointing:
590
+ return torch.utils.checkpoint.checkpoint(
591
+ self.inner_cross_attn, q, kv, attention_mask=attention_mask, causal=causal
592
+ )
593
 
594
+ return self.inner_cross_attn(q, kv, attention_mask=attention_mask, causal=causal)
595
 
596
+ def forward(
597
+ self,
598
+ x: torch.FloatTensor,
599
+ past_key_values: Optional[InferenceParams] = None,
600
+ attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
601
+ **kwargs,
602
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
603
+ if attention_mask is not None and torch.any(~attention_mask.bool()):
604
+ attention_mask = attention_mask.bool()
605
+ else:
606
+ attention_mask = None
607
+
608
+ # MHA
609
+ if self.n_head == self.n_head_kv:
610
+ if past_key_values is None:
611
+ # If `past_key_values` are not supplied, we run self-attention
612
+ attn_output = self._forward_self_attn(x, attention_mask)
613
  else:
614
+ # If `past_key_values` are supplied, it means that we might have cached values and
615
+ # could take advantage of cross-attention
616
+ attn_output = self._forward_cross_attn(x, past_key_values, attention_mask)
617
+ # MQA / GQA
618
  else:
619
+ # Regardless of `past_key_values` being supplied or not, it always use cross-attention
620
+ # because `q` and `kv` lengths might be different
621
+ attn_output = self._forward_cross_attn(x, past_key_values, attention_mask)
622
 
623
  output = rearrange(attn_output, "... h d -> ... (h d)")
624
  output = self.out_proj(output)
 
736
  if module.padding_idx is not None:
737
  module.weight.data[module.padding_idx].zero_()
738
  elif isinstance(module, nn.LayerNorm):
739
+ if module.bias is not None:
740
+ module.bias.data.zero_()
741
  module.weight.data.fill_(1.0)
742
 
743
  def prepare_inputs_for_generation(
744
  self,
745
  input_ids: torch.LongTensor,
746
  past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
747
+ attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
748
  **kwargs,
749
  ) -> Dict[str, Any]:
 
 
 
 
 
 
 
 
 
 
750
  if past_key_values is None or not (isinstance(past_key_values, InferenceParams)):
751
  past_key_values = InferenceParams(
752
+ max_seqlen=self.config.n_positions,
753
  max_batch_size=input_ids.shape[0],
754
+ seqlen_offset=0,
 
755
  batch_size_offset=0,
 
756
  key_value_memory_dict={},
757
+ lengths_per_sample=None,
758
  )
759
  else:
760
  # Assume that `past_key_values` has cached all tokens up to the last token in `input_ids`
761
+ past_key_values.seqlen_offset = len(input_ids[0]) - 1
762
  input_ids = input_ids[:, -1].unsqueeze(-1)
763
 
764
  return {
 
767
  "attention_mask": attention_mask,
768
  }
769
 
770
+ def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False) -> None:
771
+ if isinstance(module, MixFormerSequentialPreTrainedModel):
772
+ module.gradient_checkpointing = value
773
 
774
 
775
  class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel):
 
811
  labels: Optional[torch.LongTensor] = None,
812
  **kwargs,
813
  ) -> CausalLMOutputWithPast:
814
+ hidden_layer = self.layers[0](input_ids)
815
+ for module in self.layers[1:-1]:
816
+ hidden_layer = module(hidden_layer, past_key_values=past_key_values, attention_mask=attention_mask)
817
+ lm_logits = self.layers[-1](hidden_layer)
 
 
 
818
 
819
  loss = None
820
  if labels is not None: