linzheng commited on
Commit
474addc
·
verified ·
1 Parent(s): ceadb38

Upload EvaByteForCausalLM

Browse files
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": null,
3
+ "architectures": [
4
+ "EvaByteForCausalLM"
5
+ ],
6
+ "attention_bias": false,
7
+ "attention_class": "eva",
8
+ "attention_dropout": 0.0,
9
+ "auto_map": {
10
+ "AutoConfig": "configuration_evabyte.EvaByteConfig",
11
+ "AutoModelForCausalLM": "modeling_evabyte.EvaByteForCausalLM"
12
+ },
13
+ "bos_token_id": 1,
14
+ "chunk_size": 16,
15
+ "eos_token_id": 11,
16
+ "fp32_ln": false,
17
+ "fp32_logits": true,
18
+ "fp32_skip_add": true,
19
+ "hidden_act": "silu",
20
+ "hidden_size": 4096,
21
+ "init_cutoff_factor": null,
22
+ "init_fn": "v2",
23
+ "init_std": 0.01275,
24
+ "initializer_range": 0.01275,
25
+ "intermediate_size": 11008,
26
+ "lazy_init": true,
27
+ "max_position_embeddings": 32768,
28
+ "max_seq_length": 32768,
29
+ "mixedp_attn": true,
30
+ "model_type": "evabyte",
31
+ "norm_add_unit_offset": true,
32
+ "num_attention_heads": 32,
33
+ "num_chunks": null,
34
+ "num_hidden_layers": 32,
35
+ "num_key_value_heads": 32,
36
+ "num_pred_heads": 8,
37
+ "pad_token_id": 0,
38
+ "return_dict": false,
39
+ "rms_norm_eps": 1e-05,
40
+ "rope_scaling": null,
41
+ "rope_theta": 100000,
42
+ "tie_word_embeddings": false,
43
+ "torch_dtype": "bfloat16",
44
+ "transformers_version": "4.47.1",
45
+ "use_cache": true,
46
+ "vocab_size": 320,
47
+ "window_size": 2048
48
+ }
configuration_evabyte.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ EvaByte configuration"""
2
+
3
+ from transformers.configuration_utils import PretrainedConfig
4
+
5
+ class EvaByteConfig(PretrainedConfig):
6
+ model_type = "evabyte"
7
+ keys_to_ignore_at_inference = ["past_key_values"]
8
+
9
+ def __init__(
10
+ self,
11
+ vocab_size=320,
12
+ hidden_size=4096,
13
+ intermediate_size=11008,
14
+ num_hidden_layers=32,
15
+ num_attention_heads=32,
16
+ num_key_value_heads=None,
17
+ hidden_act="silu",
18
+ max_position_embeddings=2048,
19
+ initializer_range=0.02,
20
+ rms_norm_eps=1e-6,
21
+ use_cache=True,
22
+ pad_token_id=None,
23
+ bos_token_id=1,
24
+ eos_token_id=2,
25
+ tie_word_embeddings=False,
26
+ rope_theta=10000.0,
27
+ rope_scaling=None,
28
+ attention_bias=False,
29
+ attention_dropout=0.0,
30
+ norm_add_unit_offset=False,
31
+ init_fn="mitchell",
32
+ init_std=0.006,
33
+ init_cutoff_factor=None,
34
+ attention_class="mha",
35
+ window_size=512,
36
+ num_chunks=None,
37
+ chunk_size=256,
38
+ **kwargs,
39
+ ):
40
+ self.vocab_size = vocab_size
41
+ self.max_position_embeddings = max_position_embeddings
42
+ self.hidden_size = hidden_size
43
+ self.intermediate_size = intermediate_size
44
+ self.num_hidden_layers = num_hidden_layers
45
+ self.num_attention_heads = num_attention_heads
46
+
47
+ # for backward compatibility
48
+ if num_key_value_heads is None:
49
+ num_key_value_heads = num_attention_heads
50
+
51
+ self.num_key_value_heads = num_key_value_heads
52
+ self.hidden_act = hidden_act
53
+ self.initializer_range = initializer_range
54
+ self.rms_norm_eps = rms_norm_eps
55
+ self.use_cache = use_cache
56
+ self.rope_theta = rope_theta
57
+ self.rope_scaling = rope_scaling
58
+ self._rope_scaling_validation()
59
+ self.attention_bias = attention_bias
60
+ self.attention_dropout = attention_dropout
61
+
62
+ self.norm_add_unit_offset = norm_add_unit_offset
63
+ self.init_fn = init_fn
64
+ self.init_std = init_std
65
+ self.init_cutoff_factor = init_cutoff_factor
66
+
67
+ # Attention-specific paramters
68
+ self.attention_class = attention_class
69
+ self.window_size = window_size
70
+ self.num_chunks = num_chunks
71
+ self.chunk_size = chunk_size
72
+
73
+ super().__init__(
74
+ pad_token_id=pad_token_id,
75
+ bos_token_id=bos_token_id,
76
+ eos_token_id=eos_token_id,
77
+ tie_word_embeddings=tie_word_embeddings,
78
+ **kwargs,
79
+ )
80
+
81
+ def _rope_scaling_validation(self):
82
+ """
83
+ Validate the `rope_scaling` configuration.
84
+ """
85
+ if self.rope_scaling is None:
86
+ return
87
+
88
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
89
+ raise ValueError(
90
+ "`rope_scaling` must be a dictionary with two fields, `type` and `factor`, " f"got {self.rope_scaling}"
91
+ )
92
+ rope_scaling_type = self.rope_scaling.get("type", None)
93
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
94
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
95
+ raise ValueError(
96
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
97
+ )
98
+ if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
99
+ raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
eva.py ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional, Tuple, List, Any, Union
2
+ import torch
3
+ from torch import nn
4
+ import torch.nn.functional as F
5
+ from .eva_agg_kernel import triton_eva_agg_fwd
6
+ from .eva_prep_kv_kernel import triton_eva_prep_kv_fwd
7
+ try:
8
+ import triton
9
+ USE_TRITON_IMPL = True
10
+ except ImportError:
11
+ USE_TRITON_IMPL = False
12
+ raise ImportError("Triton is not installed. Please install it by running `pip install triton`.")
13
+
14
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
15
+ """
16
+ Rotates half the hidden dims (last dim) of the input.
17
+ Args:
18
+ x: Rotary embedded tensor
19
+ Return:
20
+ Tensor with half of last dim negated and rotated to the front.
21
+ """
22
+ x1, x2 = x.split(x.shape[-1] // 2, dim=-1)
23
+ return torch.cat((-x2, x1), dim=-1)
24
+
25
+ def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor,
26
+ position_ids: torch.Tensor) -> torch.Tensor:
27
+ """
28
+ Apply rotary embedding (cos, sin) to the query and key tensor on the sequence dimension.
29
+
30
+ The legends for dimensions are defined as:
31
+ num_heads: number of attention heads
32
+ current_seq_len: the current batch's sequence length, should be either 1 or max_seq_len
33
+ max_seq_len: the static sequence length, different from current_seq_len in cached inference case where it is always
34
+ maximum lenghth, e.g. the length of static sequence length of KV cache
35
+
36
+
37
+ Args:
38
+ q: Query tensor, of size (batch_size, num_heads, current_seq_len, head_dim)
39
+ k: Key tensor, of size (batch_size, num_key_value_heads, current_seq_len, head_dim)
40
+ cos: Cosine base of rotary embedding, of size (max_seq_len, head_dim)
41
+ sin: Sine base of rotary embedding, of size (max_seq_len, head_dim)
42
+ position_ids: The position indices of the tokens corresponding to the query and key tensors. It has a size of
43
+ (batch_size, current_seq_len).
44
+
45
+ Returns:
46
+ Embedded query and key tensor of same size as input.
47
+
48
+ """
49
+ bs, nheads, cur_seq_len, head_dim = q.shape
50
+ assert len(
51
+ k.shape) == 4, f"k should be of shape (batch_size, num_heads, current_seq_len, head_dim), got {k.shape} instead"
52
+ assert k.shape[0] == bs, f"k has a different batch_size {k.shape[0]} compared to q {bs}"
53
+ assert list(k.shape[2:]) == [cur_seq_len,
54
+ head_dim], f"k has different current_seq_len and/or head_dim compared to q"
55
+ assert cos.shape[3] == head_dim, f"cos should have dim of head dim {head_dim}, got {cos.shape[3]} instead"
56
+ assert list(position_ids.shape) in [[bs, cur_seq_len], [1, cur_seq_len]],\
57
+ f"position_ids should be of shape {[bs, cur_seq_len]} or {[1, cur_seq_len]}, got {position_ids.shape} instead"
58
+
59
+ q_embed = (q * cos) + (rotate_half(q) * sin)
60
+ k_embed = (k * cos) + (rotate_half(k) * sin)
61
+ return q_embed, k_embed
62
+
63
+ class EvaAttention(nn.Module):
64
+ """
65
+ Causal EVA for language modeling.
66
+ """
67
+
68
+ def __init__(self, config, layer_idx: Optional[int] = None):
69
+ super().__init__()
70
+ self.config = config
71
+ self.layer_idx = layer_idx
72
+ self.hidden_size = config.hidden_size
73
+ self.num_heads = config.num_attention_heads
74
+ self.head_dim = self.hidden_size // self.num_heads
75
+ self.head_dim_scaling = self.head_dim ** -0.5
76
+
77
+ self.max_position_embeddings = config.max_position_embeddings
78
+
79
+ if (self.head_dim * self.num_heads) != self.hidden_size:
80
+ raise ValueError(
81
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
82
+ f" and `num_heads`: {self.num_heads})."
83
+ )
84
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
85
+ self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
86
+ self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
87
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
88
+
89
+ self.window_size = config.window_size
90
+
91
+ self.num_chunks = config.num_chunks
92
+ self.chunk_size = config.chunk_size
93
+ if self.chunk_size is not None:
94
+ assert self.window_size >= self.chunk_size and self.window_size % self.chunk_size == 0
95
+ # chunk_size overrides the number of landmarks
96
+ self.num_chunks = None
97
+
98
+ self.chunks_per_window = int(self.window_size // self.chunk_size)
99
+ self.adaptive_phi = nn.Parameter(
100
+ torch.randn(
101
+ 1,
102
+ self.num_heads,
103
+ 1,
104
+ 1,
105
+ self.head_dim
106
+ ).clamp(-1., 1.) * self.head_dim_scaling
107
+ )
108
+ self.adaptive_mu_k = nn.Parameter(
109
+ torch.randn(
110
+ 1,
111
+ self.num_heads,
112
+ 1,
113
+ 1,
114
+ self.head_dim
115
+ ).clamp(-1., 1.) * self.head_dim_scaling
116
+ )
117
+
118
+ def _triton_forward(
119
+ self,
120
+ hidden_states: torch.Tensor,
121
+ attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
122
+ position_ids: Optional[torch.LongTensor] = None,
123
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
124
+ output_attentions: bool = False,
125
+ use_cache: bool = False,
126
+ cos: Optional[torch.Tensor] = None,
127
+ sin: Optional[torch.Tensor] = None,
128
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
129
+ assert not output_attentions
130
+ bsz, q_len, _ = hidden_states.size()
131
+
132
+ if use_cache and past_key_value is None:
133
+ raise ValueError
134
+
135
+ assert isinstance(attention_mask, tuple)
136
+
137
+ # infer the model's running mode
138
+ is_prefilling = use_cache and past_key_value.get_seq_length(self.layer_idx) == 0
139
+ is_decoding = use_cache and past_key_value.get_seq_length(self.layer_idx) > 0
140
+
141
+ if is_prefilling:
142
+ assert len(attention_mask) == 2
143
+ window_mask, intra_chunk_mask = attention_mask
144
+ chunk_dummpy_mask = None
145
+ elif is_decoding:
146
+ assert len(attention_mask) == 3
147
+ window_mask, intra_chunk_mask, chunk_dummpy_mask = attention_mask
148
+ else:
149
+ window_mask, intra_chunk_mask = attention_mask
150
+ chunk_dummpy_mask = None
151
+
152
+ ############################################
153
+ # compute q, k, v from hidden states
154
+ ############################################
155
+ # [b, h, q_len, d]
156
+ q = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
157
+ # [b, h, kv_len, d]
158
+ k = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
159
+ # [b, h, kv_len, d]
160
+ v = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
161
+
162
+ if use_cache:
163
+ past_key_value.update_past_len(q.shape[-2], self.layer_idx)
164
+
165
+ ############################################
166
+ # apply rotary positional embeddings to q, k
167
+ ############################################
168
+ q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids)
169
+
170
+ ############################################
171
+ # update and get cached singleton tokens
172
+ # update and cache k and v for calculating chunk-level RFAs
173
+ ############################################
174
+ if use_cache:
175
+ s_k, s_v, dump_k, dump_v = past_key_value.update_singletons_and_chunks(
176
+ k,
177
+ v,
178
+ self.layer_idx,
179
+ self.window_size,
180
+ )
181
+ else:
182
+ s_k, s_v = k, v
183
+ dump_k, dump_v = k, v
184
+
185
+ if use_cache:
186
+ singleton_mask, dump_rf_mask = past_key_value.update_mask(
187
+ s_mask=window_mask,
188
+ rf_mask=intra_chunk_mask,
189
+ layer_idx=self.layer_idx,
190
+ window_size=self.window_size,
191
+ )
192
+ else:
193
+ singleton_mask = window_mask
194
+ dump_rf_mask = intra_chunk_mask
195
+
196
+ if dump_k is not None and dump_v is not None:
197
+ # 1. in prefilling, the input shape is
198
+ # dump_k/dump_v: [b, h, n, d]
199
+ # rfa_k/rfa_v: [b, h, n // c, d]
200
+ # 2. in decoding, the input shape is
201
+ # k/v: [b, h, w, d]
202
+ # rfa_k/rfa_v: [b, h, w//c, d]
203
+ # 3. in forward inference; the seq_len is already divisible
204
+ rfa_k, rfa_v = triton_eva_prep_kv_fwd(
205
+ dump_k, dump_v,
206
+ self.adaptive_mu_k, self.adaptive_phi,
207
+ dump_rf_mask, self.head_dim_scaling, self.chunk_size
208
+ )
209
+ # rfa_mask = get_rfa_chunk_mask(dump_rf_mask)
210
+ if use_cache:
211
+ rfa_k, rfa_v = past_key_value.update_chunk_rfas(
212
+ rfa_k, rfa_v, self.layer_idx
213
+ )
214
+ elif use_cache:
215
+ # if there are not enough elements within the last chunk,
216
+ # we will only use the cached chunk-level RFAs
217
+ rfa_k, rfa_v = past_key_value.get_chunk_rfas(self.layer_idx)
218
+ else:
219
+ rfa_k, rfa_v = None, None
220
+
221
+ ############################################
222
+ # compute the full attention output
223
+ ############################################
224
+ if is_prefilling:
225
+ # prefilling
226
+ # 1. in prefilling, the input shape is
227
+ # q: [b, h, n, d]
228
+ # k/v: [b, h, n, d]
229
+ # rfa_k/rfa_v: [b, h, n // c, d]
230
+ attn_output = triton_eva_agg_fwd(
231
+ q, s_k, s_v,
232
+ rfa_k, rfa_v,
233
+ singleton_mask, self.head_dim_scaling, self.window_size, self.chunks_per_window
234
+ )
235
+ elif is_decoding:
236
+ # 2. in decoding, the input shape is
237
+ # q: [b, h, 1, d] or [b, h, z, d] (for multi-byte prediction)
238
+ # k/v: [b, h, 1 + s, d]
239
+ # rfa_k/rfa_v: [b, h, n // c, d]
240
+ if rfa_k is not None and rfa_v is not None:
241
+ # we only take the chunk-level RFAs not in the current window
242
+ seen_seq_len = past_key_value.get_seq_length(self.layer_idx)
243
+ if seen_seq_len <= self.window_size:
244
+ agg_k = s_k
245
+ agg_v = s_v
246
+ attn_mask = singleton_mask
247
+ else:
248
+ # NOTE: we already updated the cache so the length now
249
+ # includes the current token
250
+ # we subtract 1 from seen_seq_len because we want
251
+ # if seen_seq_len = 2048 -> num_windows_seen_so_far = 0
252
+ # if seen_seq_len = 4096 -> num_windows_seen_so_far = 1
253
+ # if seen_seq_len = 4097 -> num_windows_seen_so_far = 2
254
+ # NOTE the cat order should be taken care of;
255
+ # should align with the order based on which
256
+ # the attention mask is constructed
257
+ num_windows_seen_so_far = (seen_seq_len - 1) // self.window_size
258
+ agg_k = torch.cat([s_k, rfa_k[..., :num_windows_seen_so_far * self.chunks_per_window, :]], dim=-2)
259
+ agg_v = torch.cat([s_v, rfa_v[..., :num_windows_seen_so_far * self.chunks_per_window, :]], dim=-2)
260
+ if singleton_mask is not None:
261
+ assert chunk_dummpy_mask is not None
262
+ attn_mask = torch.cat([singleton_mask, chunk_dummpy_mask], dim=-1)
263
+ else:
264
+ attn_mask = singleton_mask
265
+ else:
266
+ agg_k = s_k
267
+ agg_v = s_v
268
+ attn_mask = singleton_mask
269
+ attn_output = F.scaled_dot_product_attention(
270
+ q, agg_k, agg_v,
271
+ attn_mask=attn_mask,
272
+ is_causal=False,
273
+ dropout_p=0.0,
274
+ scale=self.head_dim_scaling
275
+ )
276
+ else:
277
+ # 3. in single-forward inference
278
+ attn_output = triton_eva_agg_fwd(
279
+ q, s_k, s_v,
280
+ rfa_k, rfa_v,
281
+ singleton_mask, self.head_dim_scaling, self.window_size, self.chunks_per_window
282
+ )
283
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
284
+ raise ValueError(
285
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
286
+ f" {attn_output.size()}"
287
+ )
288
+ attn_output = attn_output.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
289
+ attn_output = self.o_proj(attn_output)
290
+ attn_weights = None
291
+ return attn_output, attn_weights, past_key_value
292
+
293
+ def _multibyte_decoding_forward(
294
+ self,
295
+ hidden_states: torch.Tensor,
296
+ attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
297
+ position_ids: Optional[torch.LongTensor] = None,
298
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
299
+ output_attentions: bool = False,
300
+ use_cache: bool = False,
301
+ cos: Optional[torch.Tensor] = None,
302
+ sin: Optional[torch.Tensor] = None,
303
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
304
+ # during multi-byte forwarding, we only read caches and do not update them
305
+ assert not output_attentions
306
+ bsz, q_len, _ = hidden_states.size()
307
+
308
+ if use_cache and past_key_value is None:
309
+ raise ValueError
310
+
311
+ assert USE_TRITON_IMPL
312
+ assert isinstance(attention_mask, torch.Tensor) and attention_mask.dtype == torch.bool
313
+
314
+ assert use_cache and past_key_value.get_seq_length(self.layer_idx) > 0
315
+
316
+ ############################################
317
+ # compute q, k, v from hidden states
318
+ ############################################
319
+ # [b, h, q_len, d]
320
+ q = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
321
+ # [b, h, kv_len, d]
322
+ k = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
323
+ # [b, h, kv_len, d]
324
+ v = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
325
+
326
+ ############################################
327
+ # apply rotary positional embeddings to q, k
328
+ ############################################
329
+ q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids)
330
+
331
+ ############################################
332
+ # update and get cached singleton tokens
333
+ ############################################
334
+ input_len = k.shape[-2]
335
+ window_pos = past_key_value.past_window_pos[self.layer_idx]
336
+ new_window_pos = window_pos + input_len
337
+
338
+ past_key_value.past_window_k[self.layer_idx][:, :, window_pos : new_window_pos, :] = k
339
+ past_key_value.past_window_v[self.layer_idx][:, :, window_pos : new_window_pos, :] = v
340
+ s_k = past_key_value.past_window_k[self.layer_idx][:, :, : new_window_pos, :]
341
+ s_v = past_key_value.past_window_v[self.layer_idx][:, :, : new_window_pos, :]
342
+
343
+ rfa_k, rfa_v = past_key_value.get_chunk_rfas(self.layer_idx)
344
+
345
+ ############################################
346
+ # compute the full attention output
347
+ ############################################
348
+ # 2. in decoding, the input shape is
349
+ # q: [b, h, 1, d] or [b, h, z, d] (for multi-byte prediction)
350
+ # k/v: [b, h, 1 + s, d]
351
+ # rfa_k/rfa_v: [b, h, n // c, d]
352
+ if rfa_k is not None and rfa_v is not None:
353
+ # NOTE the cat order should be taken care of;
354
+ # should align with the order based on which
355
+ # the attention mask is constructed
356
+ # agg_k = torch.cat([s_k, rfa_k], dim=-2)
357
+ # agg_v = torch.cat([s_v, rfa_v], dim=-2)
358
+ agg_k = torch.cat([rfa_k, s_k], dim=-2)
359
+ agg_v = torch.cat([rfa_v, s_v], dim=-2)
360
+ else:
361
+ agg_k = s_k
362
+ agg_v = s_v
363
+ attn_output = F.scaled_dot_product_attention(
364
+ q, agg_k, agg_v,
365
+ attn_mask=attention_mask,
366
+ is_causal=False,
367
+ dropout_p=0.0,
368
+ scale=self.head_dim_scaling
369
+ )
370
+
371
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
372
+ raise ValueError(
373
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
374
+ f" {attn_output.size()}"
375
+ )
376
+ attn_output = attn_output.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
377
+ attn_output = self.o_proj(attn_output)
378
+ attn_weights = None
379
+ return attn_output, attn_weights, past_key_value
380
+
381
+ def forward(
382
+ self,
383
+ hidden_states: torch.Tensor,
384
+ attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
385
+ position_ids: Optional[torch.LongTensor] = None,
386
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
387
+ output_attentions: bool = False,
388
+ use_cache: bool = False,
389
+ cos: Optional[torch.Tensor] = None,
390
+ sin: Optional[torch.Tensor] = None,
391
+ multibyte_decoding: Optional[bool] = False,
392
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
393
+ assert not output_attentions
394
+ if use_cache and past_key_value is None:
395
+ raise ValueError
396
+
397
+ assert USE_TRITON_IMPL
398
+ if use_cache and multibyte_decoding:
399
+ return self._multibyte_decoding_forward(
400
+ hidden_states,
401
+ attention_mask=attention_mask,
402
+ position_ids=position_ids,
403
+ past_key_value=past_key_value,
404
+ output_attentions=output_attentions,
405
+ use_cache=use_cache,
406
+ cos=cos,
407
+ sin=sin,
408
+ )
409
+ else:
410
+ return self._triton_forward(
411
+ hidden_states,
412
+ attention_mask=attention_mask,
413
+ position_ids=position_ids,
414
+ past_key_value=past_key_value,
415
+ output_attentions=output_attentions,
416
+ use_cache=use_cache,
417
+ cos=cos,
418
+ sin=sin,
419
+ )
eva_agg_kernel.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import math
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ # Disabling autotune for now, set num_warps=4 if headdim=64 and num_warps=8 if headdim=128
8
+ # @triton.autotune(
9
+ # configs=[
10
+ # triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, num_stages=1),
11
+ # # This config has a race condition when EVEN_M == False, disabling it for now.
12
+ # # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1),
13
+ # ],
14
+ # key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM']
15
+ # )
16
+ @triton.heuristics(
17
+ {
18
+ "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
19
+ "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
20
+ "EVEN_C": lambda args: args["nchunks"] % args["BLOCK_N"] == 0,
21
+ "EVEN_W": lambda args: args["WINDOW_SIZE"] % args["BLOCK_N"] == 0,
22
+ "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
23
+ }
24
+ )
25
+ @triton.jit
26
+ def _fwd_eva_agg_kernel(
27
+ Q,
28
+ K,
29
+ V,
30
+ RFA_K,
31
+ RFA_V,
32
+ WindowMask,
33
+ Out,
34
+ softmax_scale,
35
+ stride_qb, stride_qh, stride_qm,
36
+ stride_kb, stride_kh, stride_kn,
37
+ stride_vb, stride_vh, stride_vn,
38
+ stride_rfa_kb, stride_rfa_kh, stride_rfa_kc,
39
+ stride_rfa_vb, stride_rfa_vh, stride_rfa_vc,
40
+ stride_mb, stride_mm,
41
+ stride_ob, stride_oh, stride_om,
42
+ nheads,
43
+ seqlen_q,
44
+ seqlen_k,
45
+ nchunks,
46
+ headdim,
47
+ CACHE_KEY_SEQLEN_Q, # TODO: why keeping this
48
+ CACHE_KEY_SEQLEN_K, # TODO: why keeping this
49
+ CACHE_KEY_NCHUNKS, # TODO: why keeping this
50
+ CHUNKS_PER_WINDOW: tl.constexpr,
51
+ WINDOW_SIZE: tl.constexpr,
52
+ MASK_TYPE: tl.constexpr,
53
+ EMPTY_RFA_KV: tl.constexpr,
54
+ BLOCK_HEADDIM: tl.constexpr,
55
+ EVEN_M: tl.constexpr,
56
+ EVEN_N: tl.constexpr,
57
+ EVEN_W: tl.constexpr,
58
+ EVEN_C: tl.constexpr,
59
+ EVEN_HEADDIM: tl.constexpr,
60
+ BLOCK_M: tl.constexpr,
61
+ BLOCK_N: tl.constexpr,
62
+ ):
63
+ start_m = tl.program_id(0)
64
+ off_bh = tl.program_id(1)
65
+ off_h = off_bh % nheads
66
+ off_b = off_bh // nheads
67
+ # initialize offsets
68
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
69
+ offs_w = (start_m * BLOCK_M) // WINDOW_SIZE
70
+ offs_n = tl.arange(0, BLOCK_N)
71
+ offs_c = tl.arange(0, BLOCK_N)
72
+ offs_d = tl.arange(0, BLOCK_HEADDIM)
73
+ # TODO: add paratheses or not
74
+ q_ptrs = (
75
+ Q +
76
+ off_b * stride_qb +
77
+ off_h * stride_qh +
78
+ (offs_m[:, None] * stride_qm + offs_d[None, :])
79
+ )
80
+ k_ptrs = (
81
+ K +
82
+ off_b * stride_kb +
83
+ off_h * stride_kh +
84
+ (offs_n[:, None] * stride_kn + offs_d[None, :])
85
+ )
86
+ v_ptrs = (
87
+ V +
88
+ off_b * stride_vb +
89
+ off_h * stride_vh +
90
+ (offs_n[:, None] * stride_vn + offs_d[None, :])
91
+ )
92
+ if EMPTY_RFA_KV == 0:
93
+ rfa_k_ptrs = (
94
+ RFA_K +
95
+ off_b * stride_rfa_kb +
96
+ off_h * stride_rfa_kh +
97
+ (offs_c[:, None] * stride_rfa_kc + offs_d[None, :])
98
+ )
99
+ rfa_v_ptrs = (
100
+ RFA_V +
101
+ off_b * stride_rfa_vb +
102
+ off_h * stride_rfa_vh +
103
+ (offs_c[:, None] * stride_rfa_vc + offs_d[None, :])
104
+ )
105
+
106
+ qk_scale = softmax_scale
107
+ qk_scale *= 1.4426950408889634 # log2(e)
108
+ if MASK_TYPE == 1:
109
+ m_ptrs = (
110
+ WindowMask +
111
+ off_b * stride_mb +
112
+ (offs_m[:, None] * stride_mm + offs_n[None, :])
113
+ )
114
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
115
+ d_i = tl.zeros([BLOCK_M], dtype=tl.float32)
116
+ acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
117
+ # load q: it will stay in SRAM throughout
118
+ # [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call
119
+ # tl.load(q_ptrs), we get the wrong output!
120
+ if EVEN_M & EVEN_N:
121
+ if EVEN_HEADDIM:
122
+ q = tl.load(
123
+ q_ptrs
124
+ )
125
+ else:
126
+ q = tl.load(
127
+ q_ptrs,
128
+ mask=offs_d[None, :] < headdim,
129
+ other=0.0
130
+ )
131
+ else:
132
+ if EVEN_HEADDIM:
133
+ q = tl.load(
134
+ q_ptrs,
135
+ mask=offs_m[:, None] < seqlen_q,
136
+ other=0.0
137
+ )
138
+ else:
139
+ q = tl.load(
140
+ q_ptrs,
141
+ mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
142
+ other=0.0
143
+ )
144
+ # loop over k, v and update accumulator
145
+ # Iterate over local singletons;
146
+ # so we only iterate over blocks within the current window
147
+ start_idx_n = offs_w * WINDOW_SIZE
148
+ end_idx_n = tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
149
+ for start_n in range(start_idx_n, end_idx_n, BLOCK_N):
150
+ start_n = tl.multiple_of(start_n, BLOCK_N)
151
+ # -- compute qk ----
152
+ if EVEN_N & EVEN_M:
153
+ if EVEN_HEADDIM:
154
+ k = tl.load(
155
+ k_ptrs + start_n * stride_kn
156
+ )
157
+ else:
158
+ k = tl.load(
159
+ k_ptrs + start_n * stride_kn,
160
+ mask=offs_d[None, :] < headdim,
161
+ other=0.0
162
+ )
163
+ else:
164
+ if EVEN_HEADDIM:
165
+ k = tl.load(
166
+ k_ptrs + start_n * stride_kn,
167
+ mask=(start_n + offs_n)[:, None] < seqlen_k,
168
+ other=0.0,
169
+ )
170
+ else:
171
+ k = tl.load(
172
+ k_ptrs + start_n * stride_kn,
173
+ mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
174
+ other=0.0,
175
+ )
176
+ qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
177
+ qk += tl.dot(q, tl.trans(k))
178
+ # Trying to combine the two masks seem to make the result wrong
179
+ if not EVEN_N: # Need to mask out otherwise the softmax is wrong
180
+ qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
181
+
182
+ if MASK_TYPE == 1:
183
+ if EVEN_M & EVEN_W:
184
+ mask = tl.load(
185
+ m_ptrs + start_n - start_idx_n
186
+ ).to(tl.float32)
187
+ else:
188
+ mask = tl.load(
189
+ m_ptrs + start_n - start_idx_n,
190
+ mask=(offs_m[:, None] < seqlen_q)
191
+ & ((start_n - start_idx_n + offs_n)[None, :] < WINDOW_SIZE),
192
+ other=0.0,
193
+ ).to(tl.float32)
194
+ # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
195
+ # can then fuse the mult and add into an fma instruction. But if we have bias we need to
196
+ # to multiply with softmax_scale here.
197
+ # we assume mask already implies the causal masking
198
+ qk = qk * qk_scale + mask
199
+ m_ij = tl.maximum(tl.max(qk, 1), m_i)
200
+ p = tl.exp2(qk - m_ij[:, None])
201
+ else:
202
+ qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
203
+ m_ij = tl.maximum(tl.max(qk, 1) * qk_scale, m_i)
204
+ p = tl.exp2(qk * qk_scale - m_ij[:, None])
205
+
206
+ d_ij = tl.sum(p, 1)
207
+
208
+ # scale acc_o
209
+ prev_scale = tl.exp2(m_i - m_ij)
210
+ # # -- update output accumulator --
211
+ acc_o = acc_o * prev_scale[:, None]
212
+ # update acc_o
213
+ if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition
214
+ if EVEN_HEADDIM:
215
+ v = tl.load(
216
+ v_ptrs + start_n * stride_vn
217
+ )
218
+ else:
219
+ v = tl.load(
220
+ v_ptrs + start_n * stride_vn,
221
+ mask=offs_d[None, :] < headdim,
222
+ other=0.0
223
+ )
224
+ else:
225
+ if EVEN_HEADDIM:
226
+ v = tl.load(
227
+ v_ptrs + start_n * stride_vn,
228
+ mask=(start_n + offs_n)[:, None] < seqlen_k,
229
+ other=0.0,
230
+ )
231
+ else:
232
+ v = tl.load(
233
+ v_ptrs + start_n * stride_vn,
234
+ mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
235
+ other=0.0,
236
+ )
237
+ p = p.to(v.dtype)
238
+ acc_o = tl.dot(p, v, acc_o)
239
+
240
+ # -- update statistics
241
+ d_i = d_i * prev_scale + d_ij
242
+ m_i = m_ij
243
+
244
+ if EMPTY_RFA_KV == 0:
245
+ # Iterate over RFA chunks
246
+ # we only iterate over chunks before the current local singleton window
247
+ end_idx_c = tl.minimum(offs_w * CHUNKS_PER_WINDOW, nchunks)
248
+ for start_c in range(0, end_idx_c, BLOCK_N):
249
+ start_c = tl.multiple_of(start_c, BLOCK_N)
250
+ # -- compute qk ----
251
+ if EVEN_C & EVEN_M:
252
+ if EVEN_HEADDIM:
253
+ rfa_k = tl.load(
254
+ rfa_k_ptrs + start_c * stride_rfa_kc
255
+ )
256
+ else:
257
+ rfa_k = tl.load(
258
+ rfa_k_ptrs + start_c * stride_rfa_kc,
259
+ mask=offs_d[None, :] < headdim,
260
+ other=0.0
261
+ )
262
+ else:
263
+ if EVEN_HEADDIM:
264
+ rfa_k = tl.load(
265
+ rfa_k_ptrs + start_c * stride_rfa_kc,
266
+ mask=(start_c + offs_c)[:, None] < nchunks,
267
+ other=0.0,
268
+ )
269
+ else:
270
+ rfa_k = tl.load(
271
+ rfa_k_ptrs + start_c * stride_rfa_kc,
272
+ mask=((start_c + offs_c)[:, None] < nchunks) & (offs_d[None, :] < headdim),
273
+ other=0.0,
274
+ )
275
+ qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
276
+ qk += tl.dot(q, tl.trans(rfa_k))
277
+ # Trying to combine the two masks seem to make the result wrong
278
+ if not EVEN_C: # Need to mask out otherwise the softmax is wrong
279
+ qk += tl.where((start_c + offs_c)[None, :] < nchunks, 0, float("-inf"))
280
+
281
+ m_ij = tl.maximum(tl.max(qk, 1) * qk_scale, m_i)
282
+ p = tl.exp2(qk * qk_scale - m_ij[:, None])
283
+
284
+ d_ij = tl.sum(p, 1)
285
+
286
+ # scale acc_o
287
+ prev_scale = tl.exp2(m_i - m_ij)
288
+ # # -- update output accumulator --
289
+ acc_o = acc_o * prev_scale[:, None]
290
+ # update acc_o
291
+ # TODO: If we just do "if EVEN_N", there seems to be some race condition ?
292
+ if EVEN_C & EVEN_M:
293
+ if EVEN_HEADDIM:
294
+ rfa_v = tl.load(
295
+ rfa_v_ptrs + start_c * stride_rfa_vc
296
+ )
297
+ else:
298
+ rfa_v = tl.load(
299
+ rfa_v_ptrs + start_c * stride_rfa_vc,
300
+ mask=offs_d[None, :] < headdim,
301
+ other=0.0
302
+ )
303
+ else:
304
+ if EVEN_HEADDIM:
305
+ rfa_v = tl.load(
306
+ rfa_v_ptrs + start_c * stride_rfa_vc,
307
+ mask=(start_c + offs_n)[:, None] < nchunks,
308
+ other=0.0,
309
+ )
310
+ else:
311
+ rfa_v = tl.load(
312
+ rfa_v_ptrs + start_c * stride_rfa_vc,
313
+ mask=((start_c + offs_n)[:, None] < nchunks) & (offs_d[None, :] < headdim),
314
+ other=0.0,
315
+ )
316
+ p = p.to(rfa_v.dtype)
317
+ acc_o = tl.dot(p, rfa_v, acc_o)
318
+
319
+ # -- update statistics
320
+ d_i = d_i * prev_scale + d_ij
321
+ m_i = m_ij
322
+
323
+ # BUG: have to store and immediately load
324
+ acc_o = acc_o / d_i[:, None]
325
+ # TODO: understand why rematerialize offsets to save registers?
326
+ start_m = tl.program_id(0)
327
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
328
+ offs_d = tl.arange(0, BLOCK_HEADDIM)
329
+ out_ptrs = (
330
+ Out +
331
+ off_b * stride_ob +
332
+ off_h * stride_oh +
333
+ (offs_m[:, None] * stride_om + offs_d[None, :])
334
+ )
335
+ if EVEN_M:
336
+ if EVEN_HEADDIM:
337
+ tl.store(
338
+ out_ptrs, acc_o
339
+ )
340
+ else:
341
+ tl.store(
342
+ out_ptrs, acc_o,
343
+ mask=offs_d[None, :] < headdim
344
+ )
345
+ else:
346
+ if EVEN_HEADDIM:
347
+ tl.store(
348
+ out_ptrs, acc_o,
349
+ mask=offs_m[:, None] < seqlen_q
350
+ )
351
+ else:
352
+ tl.store(
353
+ out_ptrs, acc_o,
354
+ mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)
355
+ )
356
+
357
+ def triton_eva_agg_fwd(q, k, v, rfa_k, rfa_v, window_mask, softmax_scale, window_size, chunks_per_window):
358
+ if rfa_k is None and rfa_v is None:
359
+ empty_rfa_kv = 1
360
+
361
+ q, k, v = [
362
+ x if x.stride(-1) == 1 else x.contiguous()
363
+ for x in [q, k, v]
364
+ ]
365
+ else:
366
+ assert rfa_k is not None and rfa_v is not None, "Both rfa_k and rfa_v must either be None or have values at the same time."
367
+ empty_rfa_kv = 0
368
+
369
+ q, k, v, rfa_k, rfa_v = [
370
+ x if x.stride(-1) == 1 else x.contiguous()
371
+ for x in [q, k, v, rfa_k, rfa_v]
372
+ ]
373
+
374
+ # shape constraints
375
+ batch, nheads, seqlen_q, head_dim = q.shape
376
+ _, _, seqlen_k, _ = k.shape
377
+ if empty_rfa_kv == 0:
378
+ nchunks = rfa_k.shape[-2]
379
+ assert rfa_k.shape == (batch, nheads, nchunks, head_dim)
380
+ assert rfa_v.shape == (batch, nheads, nchunks, head_dim)
381
+ assert q.dtype == k.dtype == v.dtype == rfa_k.dtype == rfa_v.dtype
382
+ else:
383
+ nchunks = 0
384
+ assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type"
385
+ assert k.shape == (batch, nheads, seqlen_k, head_dim)
386
+ assert v.shape == (batch, nheads, seqlen_k, head_dim)
387
+
388
+ assert head_dim <= 128, "We only test head dimensions up to 128"
389
+ # assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16"
390
+ assert q.dtype in [torch.bfloat16, torch.float], "Only support bf16 and fp32 for now"
391
+ assert q.is_cuda and k.is_cuda and v.is_cuda
392
+ softmax_scale = softmax_scale or 1.0 / math.sqrt(head_dim)
393
+
394
+ mask_type = 0
395
+ if window_mask is not None:
396
+ mask_type = 1
397
+ assert window_mask.dtype == q.dtype, torch.float
398
+ assert window_mask.is_cuda
399
+ assert window_mask.dim() == 4
400
+ assert window_mask.shape == (batch, 1, seqlen_q, window_size)
401
+ if window_mask.stride(-1) != 1:
402
+ window_mask = window_mask.contiguous()
403
+ mask_strides = (
404
+ (window_mask.stride(0), window_mask.stride(2))
405
+ if mask_type == 1 else
406
+ (0, 0)
407
+ )
408
+
409
+ rfa_k_strides = (
410
+ (rfa_k.stride(0), rfa_k.stride(1), rfa_k.stride(2))
411
+ if empty_rfa_kv == 0 else
412
+ (0, 0, 0)
413
+ )
414
+ rfa_v_strides = (
415
+ (rfa_v.stride(0), rfa_v.stride(1), rfa_v.stride(2))
416
+ if empty_rfa_kv == 0 else
417
+ (0, 0, 0)
418
+ )
419
+ assert chunks_per_window > 0, "chunks_per_window must be greater than 0"
420
+
421
+ o = torch.empty_like(q)
422
+
423
+ BLOCK_HEADDIM = max(triton.next_power_of_2(head_dim), 16)
424
+ if q.dtype == torch.float:
425
+ BLOCK = 64
426
+ else:
427
+ BLOCK = 128
428
+ num_warps = 4 if head_dim <= 64 else 8
429
+ assert chunks_per_window >= BLOCK, "chunks_per_window must be greater than BLOCK"
430
+ # WINDOW_MASK_TYPE:
431
+ # - 0: regular causal mask, simply None
432
+ # - 1: the shape must be B, 1, W, I, J
433
+
434
+ grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
435
+ _fwd_eva_agg_kernel[grid](
436
+ q,
437
+ k,
438
+ v,
439
+ rfa_k,
440
+ rfa_v,
441
+ window_mask,
442
+ o,
443
+ softmax_scale,
444
+ q.stride(0), q.stride(1), q.stride(2),
445
+ k.stride(0), k.stride(1), k.stride(2),
446
+ v.stride(0), v.stride(1), v.stride(2),
447
+ rfa_k_strides[0], rfa_k_strides[1], rfa_k_strides[2],
448
+ rfa_v_strides[0], rfa_v_strides[1], rfa_v_strides[2],
449
+ mask_strides[0], mask_strides[1],
450
+ o.stride(0), o.stride(1), o.stride(2),
451
+ nheads,
452
+ seqlen_q,
453
+ seqlen_k,
454
+ nchunks,
455
+ head_dim,
456
+ seqlen_q // 32,
457
+ seqlen_k // 32,
458
+ nchunks // 32,
459
+ chunks_per_window,
460
+ window_size,
461
+ mask_type,
462
+ empty_rfa_kv,
463
+ BLOCK_HEADDIM,
464
+ BLOCK_M=BLOCK,
465
+ BLOCK_N=BLOCK,
466
+ num_warps=num_warps,
467
+ num_stages=1,
468
+ )
469
+ return o
eva_cache.py ADDED
@@ -0,0 +1,761 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional, Tuple, List, Any, Union
2
+ import torch
3
+ from transformers.cache_utils import Cache
4
+
5
+ class EvaCache(Cache):
6
+ """
7
+ A cache that grows dynamically as more tokens are generated. This is the default for generative models.
8
+
9
+ It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
10
+ `[batch_size, num_heads, seq_len, head_dim]`.
11
+ """
12
+
13
+ def __init__(self) -> None:
14
+ self.w_k: List[torch.Tensor] = []
15
+ self.w_v: List[torch.Tensor] = []
16
+
17
+ self.rf_q: List[torch.Tensor] = []
18
+ self.rf_k: List[torch.Tensor] = []
19
+ self.rf_v: List[torch.Tensor] = []
20
+
21
+ self.softmax_phi_k_v: List[torch.Tensor] = []
22
+ self.log_sum_phi_k: List[torch.Tensor] = []
23
+ self.rf_k_bar: List[torch.Tensor] = []
24
+ self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
25
+
26
+ # attention masks temporary buffer
27
+ self.rf_mask: List[Optional[torch.Tensor]] = []
28
+ self.s_mask: List[torch.Tensor] = []
29
+ self.chunk_mask: List[torch.Tensor] = []
30
+
31
+ def __len__(self):
32
+ """
33
+ Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
34
+ to the number of layers in the model.
35
+ """
36
+ return len(self.w_k)
37
+
38
+ def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
39
+ """Given the sequence length of the new inputs, returns the usable length of the cache."""
40
+ # Cache without size limit -> all cache is usable
41
+ # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
42
+ # length, we will need to evict part of the cache (and thus not all cache is usable)
43
+ max_length = self.get_max_length()
44
+ previous_seq_length = self.get_seq_length(layer_idx)
45
+ if max_length is not None and previous_seq_length + new_seq_length > max_length:
46
+ return max_length - new_seq_length
47
+ return previous_seq_length
48
+
49
+ def reorder_cache(self, beam_idx: torch.LongTensor):
50
+ """Reorders the cache for beam search, given the selected beam indices."""
51
+ for layer_idx in range(len(self.w_k)):
52
+ device = self.w_k[layer_idx].device
53
+ self.w_k[layer_idx] = self.w_k[layer_idx].index_select(0, beam_idx.to(device))
54
+
55
+ device = self.w_v[layer_idx].device
56
+ self.w_v[layer_idx] = self.w_v[layer_idx].index_select(0, beam_idx.to(device))
57
+
58
+ device = self.rf_q[layer_idx].device
59
+ self.rf_q[layer_idx] = self.rf_q[layer_idx].index_select(0, beam_idx.to(device))
60
+
61
+ device = self.rf_k[layer_idx].device
62
+ self.rf_k[layer_idx] = self.rf_k[layer_idx].index_select(0, beam_idx.to(device))
63
+
64
+ device = self.rf_v[layer_idx].device
65
+ self.rf_v[layer_idx] = self.rf_v[layer_idx].index_select(0, beam_idx.to(device))
66
+
67
+ device = self.softmax_phi_k_v[layer_idx].device
68
+ self.softmax_phi_k_v[layer_idx] = self.softmax_phi_k_v[layer_idx].index_select(0, beam_idx.to(device))
69
+
70
+ device = self.log_sum_phi_k[layer_idx].device
71
+ self.log_sum_phi_k[layer_idx] = self.log_sum_phi_k[layer_idx].index_select(0, beam_idx.to(device))
72
+
73
+ device = self.rf_k_bar[layer_idx].device
74
+ self.rf_k_bar[layer_idx] = self.rf_k_bar[layer_idx].index_select(0, beam_idx.to(device))
75
+
76
+ device = self.rf_mask[layer_idx].device
77
+ self.rf_mask[layer_idx] = self.rf_mask[layer_idx].index_select(0, beam_idx.to(device))
78
+
79
+ device = self.s_mask[layer_idx].device
80
+ self.s_mask[layer_idx] = self.s_mask[layer_idx].index_select(0, beam_idx.to(device))
81
+
82
+ device = self.chunk_mask[layer_idx].device
83
+ self.chunk_mask[layer_idx] = self.chunk_mask[layer_idx].index_select(0, beam_idx.to(device))
84
+ @property
85
+ def seen_tokens(self):
86
+ if hasattr(self, "_seen_tokens"):
87
+ return self._seen_tokens
88
+ else:
89
+ return None
90
+
91
+ def update_past_len(
92
+ self,
93
+ cur_q_len: int,
94
+ layer_idx: int
95
+ ):
96
+ # Update the number of seen tokens
97
+ if layer_idx == 0:
98
+ self._seen_tokens += cur_q_len
99
+ return self._seen_tokens
100
+
101
+ def update_mask(
102
+ self,
103
+ prev_s_mask,
104
+ cur_s_mask,
105
+ chunk_mask,
106
+ rf_mask,
107
+ layer_idx,
108
+ window_size,
109
+ chunk_size,
110
+ ):
111
+ ############################################
112
+ # compute masks for singletons
113
+ ############################################
114
+ q_len = None
115
+ if len(self.s_mask) <= layer_idx:
116
+ q_len = chunk_mask.shape[-2]
117
+ # prefill stage
118
+ # q is of shape [b, h, n, d]
119
+ if q_len < window_size:
120
+ assert prev_s_mask is None
121
+
122
+ # w_v = # [b, h, 1, j, d]
123
+ # store the past window-wise key-value pairs
124
+ self.s_mask.append(cur_s_mask[..., -1:, :] if cur_s_mask is not None else prev_s_mask[..., -1, -1:, :])
125
+ else:
126
+ # decoding stage
127
+ prev_s_mask = None
128
+
129
+ cached_s_mask = self.s_mask[layer_idx]
130
+ assert cached_s_mask is not None
131
+ if cached_s_mask.shape[-1] == window_size:
132
+ cur_s_mask = cur_s_mask
133
+ else:
134
+ cur_s_mask = torch.cat([cached_s_mask, cur_s_mask], dim=-1)
135
+
136
+ # store the past window-wise key-value pairs
137
+ self.s_mask[layer_idx] = cur_s_mask
138
+
139
+ ############################################
140
+ # compute masks for intra-chunks
141
+ ############################################
142
+ dump_rf_mask = None
143
+ if len(self.rf_mask) <= layer_idx:
144
+ # initialize chunk stats
145
+ # prefill stage
146
+ if q_len < chunk_size:
147
+ cur_rf_mask = rf_mask
148
+ else:
149
+ if q_len % chunk_size == 0:
150
+ dump_rf_mask = rf_mask
151
+ cur_rf_mask = None
152
+ else:
153
+ remainder_tokens = q_len % chunk_size
154
+ if rf_mask is not None:
155
+ dump_rf_mask, cur_rf_mask = torch.split(rf_mask, [q_len - remainder_tokens, remainder_tokens], dim=-2)
156
+ else:
157
+ dump_rf_mask = None
158
+ cur_rf_mask = None
159
+ self.rf_mask.append(cur_rf_mask)
160
+ else:
161
+ past_rf_mask = self.rf_mask[layer_idx]
162
+ if past_rf_mask is not None:
163
+ # when decoding tokens, we always assume the
164
+ # incoming token mask is 0 (not masked)
165
+ cur_rf_mask = torch.cat([past_rf_mask, rf_mask], dim=-2)
166
+ else:
167
+ # we do not need to use rf_mask anymore after we receive generated tokens
168
+ cur_rf_mask = None
169
+ # We need to store rf_k_bar and RFA-results that
170
+ # compute the per-chunk RFA.
171
+
172
+ # Dump the chunk if the len of current chunk reaches <chunk_size>.
173
+ if cur_rf_mask is not None and cur_rf_mask.shape[-2] == chunk_size:
174
+ dump_rf_mask = cur_rf_mask
175
+ cur_rf_mask = None
176
+
177
+ self.rf_mask[layer_idx] = cur_rf_mask
178
+
179
+ ############################################
180
+ # compute masks for inter chunks
181
+ ############################################
182
+ if len(self.chunk_mask) <= layer_idx:
183
+ # prefill stage
184
+ # q is of shape [b, h, n, d]
185
+ if q_len < window_size:
186
+ cur_chunk_mask = chunk_mask
187
+ prev_chunk_mask = None
188
+ else:
189
+ if q_len % window_size == 0:
190
+ cur_chunk_mask = None
191
+ prev_chunk_mask = chunk_mask
192
+ else:
193
+ remainder_tokens = q_len % window_size
194
+ # [b, h, n-r, d] [b, h, r, d]
195
+ prev_chunk_mask, cur_chunk_mask = torch.split(chunk_mask, [q_len - remainder_tokens, remainder_tokens], dim=-2)
196
+ bsz, num_heads, _, head_dim = prev_chunk_mask.shape
197
+ prev_chunk_mask = prev_chunk_mask.reshape(bsz, num_heads, -1, window_size, head_dim)
198
+
199
+ assert prev_s_mask is not None
200
+ if prev_s_mask.shape[-3] == 1 and prev_chunk_mask.shape[-3] > 1:
201
+ # need to expand
202
+ prev_s_mask = prev_s_mask.expand(-1, -1, prev_chunk_mask.shape[-3], -1, -1)
203
+ # w_v = # [b, h, 1, j, d]
204
+ # store the past window-wise key-value pairs
205
+ self.chunk_mask.append(cur_chunk_mask[..., -1:, :] if cur_chunk_mask is not None else prev_chunk_mask[..., -1, -1:, :])
206
+ else:
207
+ # decoding stage
208
+ prev_chunk_mask = None
209
+ cur_chunk_mask = self.chunk_mask[layer_idx]
210
+
211
+ # if the current sequence length reaches <chunk_size>,
212
+ # we append a new 1 to the end of chunk_mask
213
+ seen_seq_len = self.get_seq_length(layer_idx)
214
+ if seen_seq_len > 0 and seen_seq_len % chunk_size == 0:
215
+ past_chunk_mask = self.chunk_mask[layer_idx]
216
+ if past_chunk_mask is not None:
217
+ # when decoding tokens, we always assume the
218
+ # incoming token mask is 0 (not masked)
219
+ cur_chunk_mask = torch.cat([past_chunk_mask, chunk_mask], dim=-1)
220
+ else:
221
+ cur_chunk_mask = chunk_mask
222
+ self.chunk_mask[layer_idx] = cur_chunk_mask
223
+
224
+ # if the len of current sequence reaches <window_size> + 1,
225
+ # we turn on the mask for most recent chunks
226
+ if seen_seq_len > 0 and seen_seq_len % window_size == 1:
227
+ cur_chunk_mask = self.chunk_mask[layer_idx]
228
+ # we do not need to use rf_mask anymore after we receive generated tokens
229
+ num_chunks_per_window = window_size // chunk_size
230
+ cur_chunk_mask[..., -num_chunks_per_window:] = False
231
+ self.chunk_mask[layer_idx] = cur_chunk_mask
232
+
233
+ return (prev_s_mask, cur_s_mask, prev_chunk_mask, cur_chunk_mask, dump_rf_mask)
234
+
235
+ def update_singletons(
236
+ self,
237
+ q,
238
+ k,
239
+ v,
240
+ layer_idx,
241
+ window_size,
242
+ ):
243
+ if len(self.w_k) <= layer_idx:
244
+ # prefill stage
245
+ # q is of shape [b, h, n, d]
246
+ q_len = q.shape[-2]
247
+ if q_len < window_size:
248
+ w_q = q
249
+ w_k = k
250
+ w_v = v
251
+ past_w_q = past_w_k = past_w_v = None
252
+ else:
253
+ if q_len % window_size == 0:
254
+ w_q = None
255
+ w_k = None
256
+ w_v = None
257
+ past_w_q = q
258
+ past_w_k = k
259
+ past_w_v = v
260
+ else:
261
+ remainder_tokens = q_len % window_size
262
+ # [b, h, n-r, d] [b, h, r, d]
263
+ past_w_q, w_q = torch.split(q, [q_len - remainder_tokens, remainder_tokens], dim=-2)
264
+ past_w_k, w_k = torch.split(k, [q_len - remainder_tokens, remainder_tokens], dim=-2)
265
+ past_w_v, w_v = torch.split(v, [q_len - remainder_tokens, remainder_tokens], dim=-2)
266
+ bsz, num_heads, _, head_dim = past_w_q.shape
267
+ past_w_q = past_w_q.reshape(bsz, num_heads, -1, window_size, head_dim)
268
+ past_w_k = past_w_k.reshape(bsz, num_heads, -1, window_size, head_dim)
269
+ past_w_v = past_w_v.reshape(bsz, num_heads, -1, window_size, head_dim)
270
+ # w_q = q[..., None, -window_size:, :] # [b, h, 1, j, d]
271
+ # w_k = # [b, h, 1, j, d]
272
+ # w_v = # [b, h, 1, j, d]
273
+ # store the past window-wise key-value pairs
274
+ # if w_k is None, it means we happen to pass in a sqeuence that is divisible by window_size
275
+ # we leave the cache with window_size-sized kv pairs to be cleared next iteration
276
+ self.w_k.append(w_k if w_k is not None else past_w_k[..., -1, :, :])
277
+ self.w_v.append(w_v if w_v is not None else past_w_v[..., -1, :, :])
278
+ else:
279
+ # decoding stage
280
+ past_w_q = past_w_k = past_w_v = None
281
+ # this is implemented as either a sliding window or fixed window
282
+ w_q = q # [b, h, 1, d]
283
+ w_k = k # [b, h, 1, d]
284
+ w_v = v # [b, h, 1, d]
285
+
286
+ cached_w_k = self.w_k[layer_idx]
287
+ assert cached_w_k is not None # [b, h, j, d]
288
+ if cached_w_k.shape[-2] == window_size:
289
+ w_k = w_k
290
+ else:
291
+ w_k = torch.cat([cached_w_k, w_k], dim=-2)
292
+
293
+ cached_w_v = self.w_v[layer_idx]
294
+ assert cached_w_v is not None
295
+ if cached_w_v.shape[-2] == window_size:
296
+ w_v = w_v
297
+ else:
298
+ w_v = torch.cat([cached_w_v, w_v], dim=-2)
299
+
300
+ # store the past window-wise key-value pairs
301
+ self.w_k[layer_idx] = w_k
302
+ self.w_v[layer_idx] = w_v
303
+ return (past_w_q, past_w_k, past_w_v), (w_q, w_k, w_v)
304
+
305
+ def update_chunks(
306
+ self,
307
+ q,
308
+ k,
309
+ v,
310
+ layer_idx,
311
+ chunk_size
312
+ ):
313
+ q_len = q.shape[-2]
314
+ dump_q = None
315
+ dump_k = None
316
+ dump_v = None
317
+ if len(self.rf_q) <= layer_idx:
318
+ # initialize chunk stats
319
+ # prefill stage
320
+ if q_len < chunk_size:
321
+ rf_q = q
322
+ rf_k = k
323
+ rf_v = v
324
+ else:
325
+ if q_len % chunk_size == 0:
326
+ rf_q = None
327
+ rf_k = None
328
+ rf_v = None
329
+ dump_q = q
330
+ dump_k = k
331
+ dump_v = v
332
+ else:
333
+ remainder_tokens = q_len % chunk_size
334
+ # [b, h, n-r, d] [b, h, r, d]
335
+ dump_q, rf_q = torch.split(q, [q_len - remainder_tokens, remainder_tokens], dim=-2)
336
+ dump_k, rf_k = torch.split(k, [q_len - remainder_tokens, remainder_tokens], dim=-2)
337
+ dump_v, rf_v = torch.split(v, [q_len - remainder_tokens, remainder_tokens], dim=-2)
338
+ self.rf_q.append(rf_q)
339
+ self.rf_k.append(rf_k)
340
+ self.rf_v.append(rf_v)
341
+ else:
342
+ # decode tokens
343
+ # add query, key & value to the current chunk.
344
+ past_rf_q = self.rf_q[layer_idx]
345
+ if past_rf_q is not None:
346
+ rf_q = torch.cat([past_rf_q, q], dim=-2)
347
+ else:
348
+ rf_q = q
349
+
350
+ past_rf_k = self.rf_k[layer_idx]
351
+ if past_rf_k is not None:
352
+ rf_k = torch.cat([past_rf_k, k], dim=-2)
353
+ else:
354
+ rf_k = k
355
+
356
+ past_rf_v = self.rf_v[layer_idx]
357
+ if past_rf_v is not None:
358
+ rf_v = torch.cat([past_rf_v, v], dim=-2)
359
+ else:
360
+ rf_v = v
361
+
362
+ # We need to store rf_k_bar and RFA-results that
363
+ # compute the per-chunk RFA.
364
+
365
+ # Dump the chunk if the len of current chunk reaches <chunk_size>.
366
+ if rf_q.shape[-2] == chunk_size:
367
+ dump_q = rf_q
368
+ dump_k = rf_k
369
+ dump_v = rf_v
370
+ # clear the chunk
371
+ rf_q = None
372
+ rf_k = None
373
+ rf_v = None
374
+
375
+ self.rf_q[layer_idx] = rf_q
376
+ self.rf_k[layer_idx] = rf_k
377
+ self.rf_v[layer_idx] = rf_v
378
+
379
+ return dump_q, dump_k, dump_v
380
+
381
+ def update_chunk_rfas(
382
+ self,
383
+ softmax_phi_k_v,
384
+ log_sum_phi_k,
385
+ rf_k_bar,
386
+ layer_idx,
387
+ random_feature_dim
388
+ ):
389
+ if len(self.softmax_phi_k_v) <= layer_idx:
390
+ # prefill stage
391
+ self.softmax_phi_k_v.append(softmax_phi_k_v)
392
+ self.log_sum_phi_k.append(log_sum_phi_k)
393
+ self.rf_k_bar.append(rf_k_bar)
394
+ else:
395
+ # token decoding
396
+ past_softmax_phi_k_v = self.softmax_phi_k_v[layer_idx]
397
+ past_log_sum_phi_k = self.log_sum_phi_k[layer_idx]
398
+ past_rf_k_bar = self.rf_k_bar[layer_idx]
399
+
400
+ if past_softmax_phi_k_v is not None:
401
+ if random_feature_dim == 1:
402
+ dim = -2
403
+ else:
404
+ dim = -3
405
+ softmax_phi_k_v = torch.cat([past_softmax_phi_k_v, softmax_phi_k_v], dim=dim)
406
+
407
+ if past_log_sum_phi_k is not None:
408
+ if random_feature_dim == 1:
409
+ dim = -2
410
+ else:
411
+ dim = -3
412
+ log_sum_phi_k = torch.cat([past_log_sum_phi_k, log_sum_phi_k], dim=dim)
413
+
414
+ if past_rf_k_bar is not None:
415
+ rf_k_bar = torch.cat([past_rf_k_bar, rf_k_bar], dim=-2)
416
+
417
+ self.softmax_phi_k_v[layer_idx] = softmax_phi_k_v
418
+ self.log_sum_phi_k[layer_idx] = log_sum_phi_k
419
+ self.rf_k_bar[layer_idx] = rf_k_bar
420
+
421
+ return softmax_phi_k_v, log_sum_phi_k, rf_k_bar
422
+
423
+ def get_chunk_rfas(self, layer_idx):
424
+ if len(self.softmax_phi_k_v) <= layer_idx:
425
+ return (
426
+ None,
427
+ None,
428
+ None
429
+ )
430
+ else:
431
+ return (
432
+ self.softmax_phi_k_v[layer_idx],
433
+ self.log_sum_phi_k[layer_idx],
434
+ self.rf_k_bar[layer_idx]
435
+ )
436
+
437
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
438
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
439
+ if len(self.w_k) <= layer_idx:
440
+ return 0
441
+ return self._seen_tokens
442
+
443
+ def get_max_length(self) -> Optional[int]:
444
+ """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
445
+ return None
446
+
447
+ def update(
448
+ self,
449
+ layer_idx: int,
450
+ cache_kwargs: Optional[Dict[str, Any]] = None,
451
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
452
+ raise NotImplementedError("`update` is not used in Eva Cache.")
453
+
454
+ class EvaStaticCacheForTriton(Cache):
455
+ """
456
+ A variant of EvaCache for eva's triton kernels
457
+ """
458
+
459
+ def __init__(
460
+ self,
461
+ batch_size,
462
+ num_key_value_heads,
463
+ window_size,
464
+ head_dim,
465
+ num_layers,
466
+ dtype,
467
+ device
468
+ ) -> None:
469
+ self.past_window_k: List[torch.Tensor] = []
470
+ self.past_window_v: List[torch.Tensor] = []
471
+
472
+ cache_shape = (batch_size, num_key_value_heads, window_size, head_dim)
473
+ for idx in range(num_layers):
474
+ new_window_k = torch.zeros(cache_shape, dtype=dtype, device=device)
475
+ new_window_v = torch.zeros(cache_shape, dtype=dtype, device=device)
476
+ self.past_window_k.append(new_window_k)
477
+ self.past_window_v.append(new_window_v)
478
+
479
+ self.past_window_pos: List[int] = []
480
+
481
+ self.rfa_k: List[torch.Tensor] = []
482
+ self.rfa_v: List[torch.Tensor] = []
483
+ # self.rfa_mask: List[torch.Tensor] = []
484
+
485
+ self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
486
+
487
+ # attention masks temporary buffer
488
+ self.rf_mask: List[Optional[torch.Tensor]] = []
489
+ self.s_mask: List[torch.Tensor] = []
490
+
491
+ def __len__(self):
492
+ """
493
+ Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
494
+ to the number of layers in the model.
495
+ """
496
+ return len(self.past_window_pos)
497
+
498
+ def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
499
+ """Given the sequence length of the new inputs, returns the usable length of the cache."""
500
+ # Cache without size limit -> all cache is usable
501
+ # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
502
+ # length, we will need to evict part of the cache (and thus not all cache is usable)
503
+ max_length = self.get_max_length()
504
+ previous_seq_length = self.get_seq_length(layer_idx)
505
+ if max_length is not None and previous_seq_length + new_seq_length > max_length:
506
+ return max_length - new_seq_length
507
+ return previous_seq_length
508
+
509
+ def reorder_cache(self, beam_idx: torch.LongTensor):
510
+ """Reorders the cache for beam search, given the selected beam indices."""
511
+ for layer_idx in range(len(self.past_window_k)):
512
+ device = self.past_window_k[layer_idx].device
513
+ self.past_window_k[layer_idx] = self.past_window_k[layer_idx].index_select(0, beam_idx.to(device))
514
+
515
+ device = self.past_window_v[layer_idx].device
516
+ self.past_window_v[layer_idx] = self.past_window_v[layer_idx].index_select(0, beam_idx.to(device))
517
+
518
+ device = self.rfa_k[layer_idx].device
519
+ self.rfa_k[layer_idx] = self.rfa_k[layer_idx].index_select(0, beam_idx.to(device))
520
+
521
+ device = self.rfa_v[layer_idx].device
522
+ self.rfa_v[layer_idx] = self.rfa_v[layer_idx].index_select(0, beam_idx.to(device))
523
+
524
+ # device = self.rfa_mask[layer_idx].device
525
+ # self.rfa_mask[layer_idx] = self.rfa_mask[layer_idx].index_select(0, beam_idx.to(device))
526
+
527
+ device = self.rf_mask[layer_idx].device
528
+ self.rf_mask[layer_idx] = self.rf_mask[layer_idx].index_select(0, beam_idx.to(device))
529
+
530
+ device = self.s_mask[layer_idx].device
531
+ self.s_mask[layer_idx] = self.s_mask[layer_idx].index_select(0, beam_idx.to(device))
532
+
533
+ @property
534
+ def seen_tokens(self):
535
+ if hasattr(self, "_seen_tokens"):
536
+ return self._seen_tokens
537
+ else:
538
+ return None
539
+
540
+ def update_past_len(
541
+ self,
542
+ cur_q_len: int,
543
+ layer_idx: int
544
+ ):
545
+ # Update the number of seen tokens
546
+ if layer_idx == 0:
547
+ self._seen_tokens += cur_q_len
548
+ return self._seen_tokens
549
+
550
+ def update_mask(
551
+ self,
552
+ s_mask,
553
+ rf_mask,
554
+ layer_idx,
555
+ window_size,
556
+ ):
557
+ ############################################
558
+ # compute masks for singletons
559
+ ############################################
560
+ if len(self.s_mask) <= layer_idx:
561
+ # prefill stage
562
+ # q is of shape [b, h, n, d]
563
+ # s_v = # [b, h, 1, j, d]
564
+ # store the past window-wise key-value pairs
565
+ if s_mask is None:
566
+ cur_s_mask = None
567
+ else:
568
+ q_len = s_mask.shape[-2]
569
+ # s_mask is of shape [b, h, n, w]
570
+ # let r = q_len % window_size
571
+ # if r == 0, the mask to be appended is of shape [b, h, 1, w]
572
+ # otherwise, r < w, the mask to be appended is of shape [b, h, 1, r]
573
+ remainder_tokens = q_len % window_size
574
+ if remainder_tokens == 0:
575
+ cur_s_mask = None
576
+ else:
577
+ cur_s_mask = s_mask[..., -1:, :remainder_tokens]
578
+ self.s_mask.append(cur_s_mask)
579
+ # we use the passed s_mask for subsequent computations
580
+ dump_s_mask = s_mask
581
+ else:
582
+ # decoding stage
583
+ past_s_mask = self.s_mask[layer_idx]
584
+ if past_s_mask is None:
585
+ assert s_mask is None
586
+ cur_s_mask = None
587
+ else:
588
+ assert s_mask is not None
589
+ cur_s_mask = torch.cat([past_s_mask, s_mask], dim=-1)
590
+
591
+ dump_s_mask = cur_s_mask
592
+ if cur_s_mask is not None and cur_s_mask.shape[-1] == window_size:
593
+ cur_s_mask = None
594
+ # store the past window-wise key-value pairs
595
+ self.s_mask[layer_idx] = cur_s_mask
596
+
597
+ ############################################
598
+ # compute masks for intra-chunks
599
+ ############################################
600
+ dump_rf_mask = None
601
+ if len(self.rf_mask) <= layer_idx:
602
+ # initialize chunk stats
603
+ # prefill stage
604
+ if rf_mask is None:
605
+ cur_rf_mask = None
606
+ else:
607
+ q_len = rf_mask.shape[-2]
608
+ if q_len < window_size:
609
+ dump_rf_mask = None
610
+ cur_rf_mask = rf_mask
611
+ else:
612
+ if q_len % window_size == 0:
613
+ dump_rf_mask = rf_mask
614
+ cur_rf_mask = None
615
+ else:
616
+ remainder_tokens = q_len % window_size
617
+ dump_rf_mask, cur_rf_mask = torch.split(rf_mask, [q_len - remainder_tokens, remainder_tokens], dim=-2)
618
+ self.rf_mask.append(cur_rf_mask)
619
+ else:
620
+ past_rf_mask = self.rf_mask[layer_idx]
621
+ if past_rf_mask is not None:
622
+ # when decoding tokens, we always assume the
623
+ # incoming token mask is 0 (not masked)
624
+ cur_rf_mask = torch.cat([past_rf_mask, rf_mask], dim=-2)
625
+ else:
626
+ cur_rf_mask = None
627
+
628
+ if cur_rf_mask is not None and cur_rf_mask.shape[-2] == window_size:
629
+ dump_rf_mask = cur_rf_mask
630
+ cur_rf_mask = None
631
+
632
+ self.rf_mask[layer_idx] = cur_rf_mask
633
+
634
+ return dump_s_mask, dump_rf_mask
635
+
636
+ def update_singletons_and_chunks(
637
+ self,
638
+ k,
639
+ v,
640
+ layer_idx,
641
+ window_size,
642
+ ):
643
+ if len(self.past_window_pos) <= layer_idx:
644
+ # prefill stage
645
+ s_k = k
646
+ s_v = v
647
+ input_len = k.shape[-2]
648
+ window_pos = 0
649
+ if input_len <= window_size:
650
+ new_window_pos = window_pos + input_len
651
+
652
+ cached_window_k = k
653
+ cached_window_v = v
654
+ dump_k = None
655
+ dump_v = None
656
+ else:
657
+ remainder_tokens = input_len % window_size
658
+ if remainder_tokens == 0:
659
+ remainder_tokens = window_size
660
+ new_window_pos = window_pos + remainder_tokens
661
+
662
+ # [b, h, n-r, d] [b, h, r, d]
663
+ cached_window_k = k[..., -remainder_tokens:, :]
664
+ cached_window_v = v[..., -remainder_tokens:, :]
665
+ dump_k = k[..., :-remainder_tokens, :]
666
+ dump_v = v[..., :-remainder_tokens, :]
667
+ # store the past window-wise key-value pairs
668
+ self.past_window_k[layer_idx][:, :, window_pos : new_window_pos, :] = cached_window_k
669
+ self.past_window_v[layer_idx][:, :, window_pos : new_window_pos, :] = cached_window_v
670
+ self.past_window_pos.append(new_window_pos)
671
+ else:
672
+ # decoding stage
673
+ # if the previous cache has full tokens,
674
+ # roll back to the first elements
675
+ if self.past_window_pos[layer_idx] == window_size:
676
+ self.past_window_pos[layer_idx] = 0
677
+ dump_k = self.past_window_k[layer_idx].clone()
678
+ dump_v = self.past_window_v[layer_idx].clone()
679
+ else:
680
+ dump_k = None
681
+ dump_v = None
682
+
683
+ input_len = k.shape[-2]
684
+ window_pos = self.past_window_pos[layer_idx]
685
+ new_window_pos = window_pos + input_len
686
+
687
+ self.past_window_k[layer_idx][:, :, window_pos : new_window_pos, :] = k
688
+ self.past_window_v[layer_idx][:, :, window_pos : new_window_pos, :] = v
689
+
690
+ s_k = self.past_window_k[layer_idx][:, :, : new_window_pos, :]
691
+ s_v = self.past_window_v[layer_idx][:, :, : new_window_pos, :]
692
+
693
+ self.past_window_pos[layer_idx] = new_window_pos
694
+
695
+ return s_k, s_v, dump_k, dump_v
696
+
697
+ def update_chunk_rfas(
698
+ self,
699
+ rfa_k,
700
+ rfa_v,
701
+ layer_idx,
702
+ ):
703
+ if len(self.rfa_k) <= layer_idx:
704
+ # prefill stage
705
+ self.rfa_k.append(rfa_k)
706
+ self.rfa_v.append(rfa_v)
707
+ else:
708
+ # token decoding
709
+ past_rfa_k = self.rfa_k[layer_idx]
710
+ past_rfa_v = self.rfa_v[layer_idx]
711
+
712
+ if past_rfa_k is not None:
713
+ rfa_k = torch.cat([past_rfa_k, rfa_k], dim=-2)
714
+
715
+ if past_rfa_v is not None:
716
+ rfa_v = torch.cat([past_rfa_v, rfa_v], dim=-2)
717
+
718
+ self.rfa_k[layer_idx] = rfa_k
719
+ self.rfa_v[layer_idx] = rfa_v
720
+
721
+ return rfa_k, rfa_v
722
+
723
+ def get_past_window_pos(self, layer_idx):
724
+ if len(self.past_window_pos) <= layer_idx:
725
+ return None
726
+ else:
727
+ return self.past_window_pos[layer_idx]
728
+
729
+ def get_past_window_kv(self, layer_idx):
730
+ if len(self.past_window_pos) <= layer_idx:
731
+ return None, None
732
+ else:
733
+ return (
734
+ self.past_window_k[layer_idx][:, :, : self.past_window_pos[layer_idx], :],
735
+ self.past_window_v[layer_idx][:, :, : self.past_window_pos[layer_idx], :]
736
+ )
737
+
738
+ def get_chunk_rfas(self, layer_idx):
739
+ if len(self.rfa_k) <= layer_idx:
740
+ return None, None
741
+ else:
742
+ return self.rfa_k[layer_idx], self.rfa_v[layer_idx]
743
+
744
+ def get_seq_length(self, layer_idx = 0) -> int:
745
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
746
+ # layer_idx must be provided since otherwise
747
+ # any layer > 0 can only get the updated _seen_tokens
748
+ if len(self.past_window_pos) <= layer_idx:
749
+ return 0
750
+ return self._seen_tokens
751
+
752
+ def get_max_length(self) -> Optional[int]:
753
+ """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
754
+ return None
755
+
756
+ def update(
757
+ self,
758
+ layer_idx: int,
759
+ cache_kwargs: Optional[Dict[str, Any]] = None,
760
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
761
+ raise NotImplementedError("`update` is not used in Eva Cache.")
eva_prep_kv_kernel.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import math
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ @triton.heuristics(
8
+ {
9
+ "EVEN_N": lambda args: args["seqlen"] % args["BLOCK_N"] == 0,
10
+ "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
11
+ }
12
+ )
13
+ @triton.jit
14
+ def _fwd_eva_prep_kv_kernel(
15
+ K, # [b, h, n, d]
16
+ V, # [b, h, n, d]
17
+ PARAM_MU, # [1, h, 1, 1, d]
18
+ PARAM_PHI, # [1, h, 1, 1, d]
19
+ ChunkMask, # [b, h, n, 1]
20
+ Out_RFA_K, # [b, h, c, d]
21
+ Out_RFA_V, # [b, h, c, d]
22
+ softmax_scale,
23
+ stride_kb, stride_kh, stride_kn,
24
+ stride_vb, stride_vh, stride_vn,
25
+ stride_mu_h,
26
+ stride_phi_h,
27
+ stride_mb, stride_mn,
28
+ stride_ok_b, stride_ok_h, stride_ok_c,
29
+ stride_ov_b, stride_ov_h, stride_ov_c,
30
+ nheads,
31
+ seqlen,
32
+ nchunks,
33
+ headdim,
34
+ CACHE_KEY_SEQLEN, # TODO: why keeping this
35
+ CACHE_KEY_NCHUNKS, # TODO: why keeping this
36
+ CHUNKS_PER_BLOCK: tl.constexpr,
37
+ CHUNK_SIZE: tl.constexpr,
38
+ MASK_TYPE: tl.constexpr,
39
+ BLOCK_HEADDIM: tl.constexpr,
40
+ EVEN_N: tl.constexpr,
41
+ EVEN_HEADDIM: tl.constexpr,
42
+ BLOCK_N: tl.constexpr,
43
+ ):
44
+ start_n = tl.program_id(0)
45
+ offs_bh = tl.program_id(1)
46
+ offs_h = offs_bh % nheads
47
+ offs_b = offs_bh // nheads
48
+ # initialize offsets
49
+ # we load BLOCK_N keys and values each time, and
50
+ # reshape it to [CHUNKS_PER_BLOCK, CHUNK_SIZE]
51
+ offs_c = tl.arange(0, CHUNKS_PER_BLOCK)
52
+ offs_m = tl.arange(0, CHUNK_SIZE)
53
+ offs_d = tl.arange(0, BLOCK_HEADDIM)
54
+
55
+ k_ptrs = (
56
+ K +
57
+ offs_b * stride_kb +
58
+ offs_h * stride_kh +
59
+ (
60
+ (
61
+ start_n * BLOCK_N +
62
+ offs_c[:, None, None] * CHUNK_SIZE +
63
+ offs_m[None, :, None]
64
+ ) * stride_kn +
65
+ offs_d[None, None, :]
66
+ )
67
+ )
68
+ v_ptrs = (
69
+ V +
70
+ offs_b * stride_vb +
71
+ offs_h * stride_vh +
72
+ (
73
+ (
74
+ start_n * BLOCK_N +
75
+ offs_c[:, None, None] * CHUNK_SIZE +
76
+ offs_m[None, :, None]
77
+ ) * stride_vn +
78
+ offs_d[None, None, :]
79
+ )
80
+ )
81
+ param_mu_ptrs = (
82
+ PARAM_MU +
83
+ offs_h * stride_mu_h +
84
+ offs_d[None, None, :]
85
+ )
86
+ param_phi_ptrs = (
87
+ PARAM_PHI +
88
+ offs_h * stride_phi_h +
89
+ offs_d[None, None, :]
90
+ )
91
+ log2e = 1.4426950408889634
92
+ if MASK_TYPE == 1:
93
+ m_ptrs = (
94
+ ChunkMask +
95
+ offs_b * stride_mb +
96
+ (
97
+ (
98
+ start_n * BLOCK_N +
99
+ offs_c[:, None] * CHUNK_SIZE +
100
+ offs_m[None, :]
101
+ ) * stride_mn
102
+ )
103
+ )
104
+ if EVEN_N:
105
+ if EVEN_HEADDIM:
106
+ k = tl.load(
107
+ k_ptrs
108
+ )
109
+ else:
110
+ k = tl.load(
111
+ k_ptrs,
112
+ mask=offs_d[None, None, :] < headdim,
113
+ other=0.0
114
+ )
115
+ else:
116
+ if EVEN_HEADDIM:
117
+ k = tl.load(
118
+ k_ptrs,
119
+ mask=(
120
+ start_n * BLOCK_N +
121
+ offs_c[:, None, None] * CHUNK_SIZE +
122
+ offs_m[None, :, None]
123
+ ) < seqlen,
124
+ other=0.0
125
+ )
126
+ else:
127
+ k = tl.load(
128
+ k_ptrs,
129
+ mask=(
130
+ (
131
+ start_n * BLOCK_N +
132
+ offs_c[:, None, None] * CHUNK_SIZE +
133
+ offs_m[None, :, None]
134
+ ) < seqlen
135
+ ) & (offs_d[None, None, :] < headdim),
136
+ other=0.0
137
+ )
138
+
139
+ param_mu = tl.load(param_mu_ptrs).to(k.dtype)
140
+ rfa_k_c_w = tl.zeros([CHUNKS_PER_BLOCK, CHUNK_SIZE], dtype=tl.float32)
141
+ rfa_k_c_w += tl.sum(k * param_mu, axis=-1)
142
+ rfa_k_c_w *= log2e
143
+ if MASK_TYPE == 1:
144
+ if EVEN_N:
145
+ mask = tl.load(
146
+ m_ptrs
147
+ ).to(tl.float32)
148
+ else:
149
+ mask = tl.load(
150
+ m_ptrs,
151
+ mask=(
152
+ start_n * BLOCK_N +
153
+ offs_c[:, None] * CHUNK_SIZE +
154
+ offs_m[None, :]
155
+ ) < seqlen,
156
+ other=0.0,
157
+ ).to(tl.float32)
158
+ rfa_k_c_w = rfa_k_c_w + mask
159
+
160
+ rfa_k_c_w = tl.exp2(rfa_k_c_w - tl.max(rfa_k_c_w, axis=-1)[:, None])
161
+ rfa_k_c_w = rfa_k_c_w / tl.sum(rfa_k_c_w, axis=-1)[:, None]
162
+ rfa_k_c = tl.sum(k * rfa_k_c_w[:, :, None].to(k.dtype), axis=-2)
163
+ # TODO: understand why rematerialize offsets to save registers?
164
+ offs_out_c = start_n * CHUNKS_PER_BLOCK + tl.arange(0, CHUNKS_PER_BLOCK)
165
+ out_rfa_k_ptrs = (
166
+ Out_RFA_K +
167
+ offs_b * stride_ok_b +
168
+ offs_h * stride_ok_h +
169
+ (offs_out_c[:, None] * stride_ok_c + offs_d[None, :])
170
+ )
171
+
172
+ if EVEN_N:
173
+ if EVEN_HEADDIM:
174
+ tl.store(
175
+ out_rfa_k_ptrs, rfa_k_c
176
+ )
177
+ else:
178
+ tl.store(
179
+ out_rfa_k_ptrs, rfa_k_c,
180
+ mask=offs_d[None, :] < headdim
181
+ )
182
+ else:
183
+ if EVEN_HEADDIM:
184
+ tl.store(
185
+ out_rfa_k_ptrs, rfa_k_c,
186
+ mask=offs_out_c[:, None] < nchunks
187
+ )
188
+ else:
189
+ tl.store(
190
+ out_rfa_k_ptrs, rfa_k_c,
191
+ mask=(offs_out_c[:, None] < nchunks) & (offs_d[None, :] < headdim)
192
+ )
193
+
194
+
195
+ param_phi = tl.load(param_phi_ptrs).to(k.dtype)
196
+ rfa_v_c_w = tl.zeros([CHUNKS_PER_BLOCK, CHUNK_SIZE], dtype=tl.float32)
197
+ rfa_v_c_w += tl.sum(k * param_phi, axis=-1)
198
+ rfa_v_c_w -= (0.5 * tl.sum(k * k, axis=-1))
199
+ rfa_v_c_w *= log2e * softmax_scale
200
+ if not EVEN_N: # Need to mask out otherwise the softmax is wrong
201
+ rfa_v_c_w += tl.where(
202
+ (
203
+ start_n * BLOCK_N +
204
+ offs_c[:, None] * CHUNK_SIZE +
205
+ offs_m[None, :]
206
+ ) < seqlen,
207
+ 0,
208
+ float("-inf")
209
+ )
210
+
211
+ if MASK_TYPE == 1:
212
+ rfa_v_c_w = rfa_v_c_w + mask
213
+
214
+ if EVEN_N:
215
+ if EVEN_HEADDIM:
216
+ v = tl.load(
217
+ v_ptrs
218
+ )
219
+ else:
220
+ v = tl.load(
221
+ v_ptrs,
222
+ mask=offs_d[None, None, :] < headdim,
223
+ other=0.0
224
+ )
225
+ else:
226
+ if EVEN_HEADDIM:
227
+ v = tl.load(
228
+ v_ptrs,
229
+ mask=(
230
+ start_n * BLOCK_N +
231
+ offs_c[:, None, None] * CHUNK_SIZE +
232
+ offs_m[None, :, None]
233
+ ) < seqlen,
234
+ other=0.0
235
+ )
236
+ else:
237
+ v = tl.load(
238
+ v_ptrs,
239
+ mask=(
240
+ (
241
+ start_n * BLOCK_N +
242
+ offs_c[:, None, None] * CHUNK_SIZE +
243
+ offs_m[None, :, None]
244
+ ) < seqlen
245
+ ) & (offs_d[None, None, :] < headdim),
246
+ other=0.0
247
+ )
248
+
249
+ rfa_v_c_w = tl.exp2(rfa_v_c_w - tl.max(rfa_v_c_w, axis=-1)[:, None])
250
+ rfa_v_c_w = rfa_v_c_w / tl.sum(rfa_v_c_w, axis=-1)[:, None]
251
+ rfa_v_c = tl.sum(v * rfa_v_c_w[:, :, None].to(v.dtype), axis=-2)
252
+
253
+ offs_out_c = start_n * CHUNKS_PER_BLOCK + tl.arange(0, CHUNKS_PER_BLOCK)
254
+ out_rfa_v_ptrs = (
255
+ Out_RFA_V +
256
+ offs_b * stride_ov_b +
257
+ offs_h * stride_ov_h +
258
+ (offs_out_c[:, None] * stride_ov_c + offs_d[None, :])
259
+ )
260
+ if EVEN_N:
261
+ if EVEN_HEADDIM:
262
+ tl.store(
263
+ out_rfa_v_ptrs, rfa_v_c
264
+ )
265
+ else:
266
+ tl.store(
267
+ out_rfa_v_ptrs, rfa_v_c,
268
+ mask=offs_d[None, :] < headdim
269
+ )
270
+ else:
271
+ if EVEN_HEADDIM:
272
+ tl.store(
273
+ out_rfa_v_ptrs, rfa_v_c,
274
+ mask=offs_out_c[:, None] < nchunks
275
+ )
276
+ else:
277
+ tl.store(
278
+ out_rfa_v_ptrs, rfa_v_c,
279
+ mask=(offs_out_c[:, None] < nchunks) & (offs_d[None, :] < headdim)
280
+ )
281
+
282
+ def triton_eva_prep_kv_fwd(k, v, param_mu, param_phi, chunk_mask, softmax_scale, chunksize):
283
+ k, v, param_mu, param_phi = [
284
+ x if x.stride(-1) == 1 else x.contiguous()
285
+ for x in [k, v, param_mu, param_phi]
286
+ ]
287
+
288
+ # shape constraints
289
+ batch, nheads, seqlen, head_dim = k.shape
290
+ assert seqlen % chunksize == 0, "seqlen must be divisible by chunksize"
291
+ nchunks = seqlen // chunksize
292
+ assert k.shape == (batch, nheads, seqlen, head_dim)
293
+ assert v.shape == (batch, nheads, seqlen, head_dim)
294
+ assert param_mu.shape == (1, nheads, 1, 1, head_dim)
295
+ assert param_phi.shape == (1, nheads, 1, 1, head_dim)
296
+ assert head_dim <= 128, "We only test head dimensions up to 128"
297
+ assert k.dtype == v.dtype == param_mu.dtype == param_phi.dtype, "All tensors must have the same type"
298
+ assert k.dtype in [torch.bfloat16, torch.float], "Only support bf16 and fp32 for now"
299
+ assert k.is_cuda and v.is_cuda
300
+ softmax_scale = softmax_scale or 1.0 / math.sqrt(head_dim)
301
+
302
+ mask_type = 0
303
+ if chunk_mask is not None:
304
+ mask_type = 1
305
+ assert chunk_mask.dtype == k.dtype
306
+ assert chunk_mask.is_cuda
307
+ assert chunk_mask.dim() == 4
308
+ assert chunk_mask.shape == (batch, 1, seqlen, 1)
309
+ if chunk_mask.stride(-1) != 1:
310
+ chunk_mask = chunk_mask.contiguous()
311
+ mask_strides = (
312
+ (chunk_mask.stride(0), chunk_mask.stride(2))
313
+ if mask_type == 1 else
314
+ (0, 0)
315
+ )
316
+ out_rfa_k = torch.empty((batch, nheads, nchunks, head_dim), dtype=k.dtype, device=k.device)
317
+ out_rfa_v = torch.empty((batch, nheads, nchunks, head_dim), dtype=v.dtype, device=v.device)
318
+
319
+ BLOCK_HEADDIM = max(triton.next_power_of_2(head_dim), 16)
320
+ BLOCK = 128
321
+ num_warps = 4 if head_dim <= 64 else 8
322
+
323
+ assert (BLOCK > chunksize) & (BLOCK % chunksize) == 0, "BLOCK must be divisible by chunksize"
324
+ chunks_per_block = BLOCK // chunksize
325
+
326
+ grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_N"]), batch * nheads)
327
+ _fwd_eva_prep_kv_kernel[grid](
328
+ k,
329
+ v,
330
+ param_mu,
331
+ param_phi,
332
+ chunk_mask,
333
+ out_rfa_k,
334
+ out_rfa_v,
335
+ softmax_scale,
336
+ k.stride(0), k.stride(1), k.stride(2),
337
+ v.stride(0), v.stride(1), v.stride(2),
338
+ param_mu.stride(1),
339
+ param_phi.stride(1),
340
+ mask_strides[0], mask_strides[1],
341
+ out_rfa_k.stride(0), out_rfa_k.stride(1), out_rfa_k.stride(2),
342
+ out_rfa_v.stride(0), out_rfa_v.stride(1), out_rfa_v.stride(2),
343
+ nheads,
344
+ seqlen,
345
+ nchunks,
346
+ head_dim,
347
+ seqlen // 32,
348
+ nchunks // 32,
349
+ chunks_per_block,
350
+ chunksize,
351
+ mask_type,
352
+ BLOCK_HEADDIM,
353
+ BLOCK_N=BLOCK,
354
+ num_warps=num_warps,
355
+ num_stages=1,
356
+ )
357
+ return out_rfa_k, out_rfa_v
eva_pt_ref.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Union
2
+ import torch
3
+ from torch import nn
4
+
5
+ MASK_MIN_VALUE = -10e10
6
+
7
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
8
+ """
9
+ Rotates half the hidden dims (last dim) of the input.
10
+ Args:
11
+ x: Rotary embedded tensor
12
+ Return:
13
+ Tensor with half of last dim negated and rotated to the front.
14
+ """
15
+ x1, x2 = x.split(x.shape[-1] // 2, dim=-1)
16
+ return torch.cat((-x2, x1), dim=-1)
17
+
18
+ def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor,
19
+ position_ids: torch.Tensor) -> torch.Tensor:
20
+ """
21
+ Apply rotary embedding (cos, sin) to the query and key tensor on the sequence dimension.
22
+
23
+ The legends for dimensions are defined as:
24
+ num_heads: number of attention heads
25
+ current_seq_len: the current batch's sequence length, should be either 1 or max_seq_len
26
+ max_seq_len: the static sequence length, different from current_seq_len in cached inference case where it is always
27
+ maximum lenghth, e.g. the length of static sequence length of KV cache
28
+
29
+
30
+ Args:
31
+ q: Query tensor, of size (batch_size, num_heads, current_seq_len, head_dim)
32
+ k: Key tensor, of size (batch_size, num_key_value_heads, current_seq_len, head_dim)
33
+ cos: Cosine base of rotary embedding, of size (max_seq_len, head_dim)
34
+ sin: Sine base of rotary embedding, of size (max_seq_len, head_dim)
35
+ position_ids: The position indices of the tokens corresponding to the query and key tensors. It has a size of
36
+ (batch_size, current_seq_len).
37
+
38
+ Returns:
39
+ Embedded query and key tensor of same size as input.
40
+
41
+ """
42
+ bs, nheads, cur_seq_len, head_dim = q.shape
43
+ assert len(
44
+ k.shape) == 4, f"k should be of shape (batch_size, num_heads, current_seq_len, head_dim), got {k.shape} instead"
45
+ assert k.shape[0] == bs, f"k has a different batch_size {k.shape[0]} compared to q {bs}"
46
+ assert list(k.shape[2:]) == [cur_seq_len,
47
+ head_dim], f"k has different current_seq_len and/or head_dim compared to q"
48
+ assert cos.shape[3] == head_dim, f"cos should have dim of head dim {head_dim}, got {cos.shape[3]} instead"
49
+ assert list(position_ids.shape) in [[bs, cur_seq_len], [1, cur_seq_len]],\
50
+ f"position_ids should be of shape {[bs, cur_seq_len]} or {[1, cur_seq_len]}, got {position_ids.shape} instead"
51
+
52
+ q_embed = (q * cos) + (rotate_half(q) * sin)
53
+ k_embed = (k * cos) + (rotate_half(k) * sin)
54
+ return q_embed, k_embed
55
+
56
+ def attention_op(
57
+ q,
58
+ k,
59
+ v,
60
+ attn_mask,
61
+ mixedp_attn,
62
+ head_dim_scaling
63
+ ):
64
+ attn = torch.matmul(q, k.transpose(-2, -1))
65
+ if mixedp_attn:
66
+ attn = attn.to(torch.float)
67
+ attn = attn * head_dim_scaling
68
+ if attn_mask is not None:
69
+ attn = attn.masked_fill(attn_mask, MASK_MIN_VALUE)
70
+
71
+ attn_weights = torch.softmax(attn, dim=-1).to(q.dtype)
72
+ attn_output = torch.matmul(attn_weights, v)
73
+ return attn_output
74
+
75
+ def prm_projection(
76
+ x: torch.Tensor,
77
+ projection_matrix: torch.Tensor,
78
+ mixedp_attn: bool = False
79
+ ):
80
+ """
81
+ Constructs nonnegative kernel features for fast softmax attention.
82
+ Args:
83
+ x: input for which features are computed
84
+ projection_matrix: random matrix used to compute features
85
+ Returns:
86
+ Random features for fast attention.
87
+ """
88
+ # x : [..., m, d]
89
+ # proj : [..., r, d]
90
+ scaling_factor = (x.shape[-1] ** -0.5)
91
+ proj_x = torch.matmul(projection_matrix, x.transpose(-1, -2)) # [..., r, m]
92
+ norm = torch.sum(x ** 2, dim=-1).unsqueeze(-2) * 0.5 # [..., 1]
93
+ if mixedp_attn:
94
+ proj_x = proj_x.to(torch.float)
95
+ norm = norm.to(torch.float)
96
+ phi_x = scaling_factor * (proj_x - norm)
97
+ return phi_x
98
+
99
+ class EvaAttention(nn.Module):
100
+ def __init__(self, config, layer_idx: Optional[int] = None):
101
+ super().__init__()
102
+ self.config = config
103
+ self.layer_idx = layer_idx
104
+ self.hidden_size = config.hidden_size
105
+ self.num_heads = config.num_attention_heads
106
+ self.head_dim = self.hidden_size // self.num_heads
107
+ self.head_dim_scaling = self.head_dim ** -0.5
108
+
109
+ self.max_position_embeddings = config.max_position_embeddings
110
+
111
+ if (self.head_dim * self.num_heads) != self.hidden_size:
112
+ raise ValueError(
113
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
114
+ f" and `num_heads`: {self.num_heads})."
115
+ )
116
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
117
+ self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
118
+ self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
119
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
120
+
121
+ self.window_size = config.window_size
122
+
123
+ self.num_chunks = config.num_chunks
124
+ self.chunk_size = config.chunk_size
125
+ if self.chunk_size is not None:
126
+ assert self.window_size >= self.chunk_size and self.window_size % self.chunk_size == 0
127
+ # chunk_size overrides the number of landmarks
128
+ self.num_chunks = None
129
+
130
+ self.chunks_per_window = int(self.window_size // self.chunk_size)
131
+ self.random_feature_dim = 1
132
+ self.adaptive_phi = nn.Parameter(
133
+ torch.randn(
134
+ 1,
135
+ self.num_heads,
136
+ 1,
137
+ 1,
138
+ self.head_dim
139
+ ).clamp(-1., 1.) * self.head_dim_scaling
140
+ )
141
+ self.adaptive_mu_k = nn.Parameter(
142
+ torch.randn(
143
+ 1,
144
+ self.num_heads,
145
+ 1,
146
+ 1,
147
+ self.head_dim
148
+ ).clamp(-1., 1.) * self.head_dim_scaling
149
+ )
150
+
151
+ def _generate_feature_map(self, rf_q, rf_k, rf_v):
152
+ rf_k_logits = torch.sum(self.adaptive_mu_k.to(rf_k.dtype) * rf_k, dim=-1, keepdim=True) # b h c m 1
153
+ if self.config.mixedp_attn:
154
+ rf_k_logits = rf_k_logits.to(torch.float)
155
+ rf_k_weights = torch.softmax(rf_k_logits, dim=-2).to(rf_k.dtype)
156
+ rf_k_bar = torch.sum(rf_k_weights * rf_k, dim=-2)
157
+ weights = self.adaptive_phi.to(rf_k.dtype)
158
+ return weights, rf_k_bar
159
+
160
+ def _calculate_chunk_rfa_cache(self, rf_q, rf_k, rf_v, weights, rf_mask=None):
161
+ proj_x = torch.sum(weights * rf_k, dim=-1, keepdim=True)
162
+ norm = torch.sum(rf_k ** 2, dim=-1, keepdim=True) * 0.5 # [..., 1]
163
+ if self.config.mixedp_attn:
164
+ proj_x = proj_x.to(torch.float)
165
+ norm = norm.to(torch.float)
166
+ log_phi_k = self.head_dim_scaling * (proj_x - norm)
167
+
168
+ if rf_mask is not None:
169
+ log_phi_k = log_phi_k.masked_fill(rf_mask, MASK_MIN_VALUE)
170
+
171
+ # [b, h, c, m, r]
172
+ softmax_phi_k = torch.softmax(log_phi_k, dim=-2).to(rf_k.dtype)
173
+ softmax_phi_k_v = torch.sum(softmax_phi_k * rf_v, dim=-2)
174
+ # [b, h, c, r, m] [b, h, c, m, d] -> [b, h, c, r, d]
175
+ # softmax_phi_k_v = torch.matmul(softmax_phi_k.transpose(-1, -2), rf_v).squeeze(-2)
176
+ log_sum_phi_k = None
177
+ return softmax_phi_k_v, log_sum_phi_k
178
+
179
+ def _calculate_chunk_rfa(self, q, softmax_phi_k_v, log_sum_phi_k, weights):
180
+ if self.random_feature_dim == 1:
181
+ # when r = 1, the snis weights becomes 1, so this takes no effect
182
+ # [b, h, c, r, d] -> [b, h, c, d]
183
+ return softmax_phi_k_v
184
+ else:
185
+ # [b, h, c, r, d] [b, h, 1, s, d] -> [b, h, c, r, s]
186
+ log_phi_q = prm_projection(q.unsqueeze(-3), weights, self.config.mixedp_attn)
187
+ # [b, h, c, r, s] [b, h, c, r, 1] -> [b, h, c, r, s]
188
+ sniw = torch.softmax(log_phi_q + log_sum_phi_k, dim=-1).to(q.dtype)
189
+ # [b, h, c, r, s] [b, h, c, r, d] -> [b, h, c, s, d] -> [b, h, s, c, d]
190
+ rfa_per_chunk = torch.matmul(sniw.transpose(-1, -2), softmax_phi_k_v).transpose(-3, -2)
191
+ return rfa_per_chunk
192
+
193
+ def window_partition(self, x, window_size=None):
194
+ window_size = window_size if window_size is not None else self.window_size
195
+
196
+ gw, d = x.shape[-2:]
197
+ leading_dims = x.shape[:-2]
198
+ n_groups = gw // window_size
199
+ return x.reshape(*leading_dims, n_groups, window_size, d)
200
+
201
+ def window_merge(self, x, window_size=None):
202
+ g, w, d = x.shape[-3:]
203
+ leading_dims = x.shape[:-3]
204
+ return x.reshape(*leading_dims, g * w, d)
205
+
206
+ def forward(
207
+ self,
208
+ hidden_states: torch.Tensor,
209
+ attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
210
+ position_ids: Optional[torch.LongTensor] = None,
211
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
212
+ output_attentions: bool = False,
213
+ use_cache: bool = False,
214
+ cos: Optional[torch.Tensor] = None,
215
+ sin: Optional[torch.Tensor] = None,
216
+ multibyte_decoding: Optional[bool] = False,
217
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
218
+ assert not output_attentions
219
+ bsz, q_len, _ = hidden_states.size()
220
+
221
+ ############################################
222
+ # initialize past states if not provided
223
+ ############################################
224
+ if use_cache and past_key_value is None:
225
+ raise ValueError
226
+ if use_cache and multibyte_decoding:
227
+ raise NotImplementedError("Multibyte decoding is not supported for PyTorch native implementation")
228
+ # assert isinstance(attention_mask, tuple)
229
+ if len(attention_mask) == 4:
230
+ assert use_cache
231
+ prev_causal_mask, cur_causal_mask, chunk_causal_mask, intra_chunk_mask = attention_mask
232
+ elif len(attention_mask) == 3:
233
+ assert not use_cache
234
+ window_causal_mask, chunk_causal_mask, intra_chunk_mask = attention_mask
235
+ else:
236
+ raise NotImplementedError("Only attention-mask tuple with length 2 or 3 is supported")
237
+
238
+ ############################################
239
+ # compute q, k, v from hidden states
240
+ ############################################
241
+ # [b, h, q_len, d]
242
+ q = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
243
+ # [b, h, kv_len, d]
244
+ k = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
245
+ # [b, h, kv_len, d]
246
+ v = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
247
+
248
+ if use_cache:
249
+ past_key_value.update_past_len(q.shape[-2], self.layer_idx)
250
+
251
+ ############################################
252
+ # apply rotary positional embeddings to q, k
253
+ ############################################
254
+ q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids)
255
+
256
+ ############################################
257
+ # compute q, k, v stats for the local window
258
+ ############################################
259
+ if use_cache:
260
+ (prev_w_q, prev_w_k, prev_w_v), (cur_w_q, cur_w_k, cur_w_v) = past_key_value.update_singletons(
261
+ q,
262
+ k,
263
+ v,
264
+ self.layer_idx,
265
+ self.window_size,
266
+ self.singleton_update
267
+ )
268
+ else:
269
+ prev_w_q = self.window_partition(q) # [b, h, w, i, d]
270
+ prev_w_k = self.window_partition(k) # [b, h, w, j, d]
271
+ prev_w_v = self.window_partition(v) # [b, h, w, j, d]
272
+ # during training, we assume window_size divides seq_len so no remainders
273
+ cur_w_q = cur_w_k = cur_w_v = None
274
+
275
+ ############################################
276
+ # compute q, k, v stats for chunk-level RFAs
277
+ ############################################
278
+ if use_cache:
279
+ dump_q, dump_k, dump_v = past_key_value.update_chunks(q, k, v, self.layer_idx, self.chunk_size)
280
+ else:
281
+ dump_q, dump_k, dump_v = q, k, v
282
+
283
+ if use_cache:
284
+ prev_s_mask, cur_s_mask, prev_chunk_mask, cur_chunk_mask, dump_rf_mask = past_key_value.update_mask(
285
+ prev_s_mask=prev_causal_mask,
286
+ cur_s_mask=cur_causal_mask,
287
+ chunk_mask=chunk_causal_mask,
288
+ rf_mask=intra_chunk_mask,
289
+ layer_idx=self.layer_idx,
290
+ window_size=self.window_size,
291
+ chunk_size=self.chunk_size,
292
+ singleton_update=self.singleton_update
293
+ )
294
+ else:
295
+ prev_s_mask = window_causal_mask # [1, 1, w, i, j]
296
+ cur_s_mask = None
297
+ prev_chunk_mask = self.window_partition(chunk_causal_mask)
298
+ cur_chunk_mask = None
299
+ dump_rf_mask = intra_chunk_mask
300
+ if prev_s_mask.shape[-3] == 1:
301
+ # need to expand
302
+ prev_s_mask = prev_s_mask.expand(-1, -1, prev_chunk_mask.shape[-3], -1, -1)
303
+
304
+ if (
305
+ dump_q is not None and
306
+ dump_k is not None and
307
+ dump_v is not None
308
+ ):
309
+ # [b, h, c, j, d]
310
+ rf_q = self.window_partition(dump_q, window_size=self.chunk_size)
311
+ # [b, h, c, j, d]
312
+ rf_k = self.window_partition(dump_k, window_size=self.chunk_size)
313
+ # [b, h, c, j, d]
314
+ rf_v = self.window_partition(dump_v, window_size=self.chunk_size)
315
+
316
+ if dump_rf_mask is not None:
317
+ rf_mask = self.window_partition(dump_rf_mask, window_size=self.chunk_size)
318
+ rf_q = rf_q.masked_fill(rf_mask, 0.)
319
+ rf_k = rf_k.masked_fill(rf_mask, 0.)
320
+ rf_v = rf_v.masked_fill(rf_mask, 0.)
321
+ else:
322
+ rf_mask = None
323
+ else:
324
+ rf_q = None
325
+ rf_k = None
326
+ rf_v = None
327
+ rf_mask = None
328
+
329
+
330
+ if rf_q is not None:
331
+ # import pdb; pdb.set_trace()
332
+ weights, rf_k_bar = self._generate_feature_map(rf_q, rf_k, rf_v)
333
+ softmax_phi_k_v, log_sum_phi_k = self._calculate_chunk_rfa_cache(rf_q, rf_k, rf_v, weights, rf_mask=rf_mask)
334
+ if use_cache:
335
+ softmax_phi_k_v, log_sum_phi_k, rf_k_bar = past_key_value.update_chunk_rfas(
336
+ softmax_phi_k_v, log_sum_phi_k, rf_k_bar, self.layer_idx, 1
337
+ )
338
+ elif use_cache:
339
+ weights = None
340
+ softmax_phi_k_v, log_sum_phi_k, rf_k_bar = past_key_value.get_chunk_rfas(self.layer_idx)
341
+ else:
342
+ weights = None
343
+ softmax_phi_k_v = None
344
+ log_sum_phi_k = None
345
+ rf_k_bar = None
346
+
347
+ if rf_k_bar is not None:
348
+ rfa_per_chunk = self._calculate_chunk_rfa(q, softmax_phi_k_v, log_sum_phi_k, weights)
349
+ ############################################
350
+ # compute meta-attention weights for
351
+ # - group-wise RFAs and
352
+ # - singletons (equivalent to exact local attention)
353
+ ############################################
354
+ if prev_w_k is not None:
355
+ if rf_k_bar is not None:
356
+ num_windows = prev_w_k.shape[-3]
357
+ # rf_k_bar and rfa_per_chunk take the shape [b, h, c, d]
358
+ # -> [b, h, 1, c, d] -> [b, h, w, c, d]
359
+ prev_rf_k_bar = rf_k_bar.unsqueeze(-3).expand(-1, -1, num_windows, -1, -1)
360
+ prev_rfa_per_chunk = rfa_per_chunk.unsqueeze(-3).expand(-1, -1, num_windows, -1, -1)
361
+ prev_agg_k = torch.cat([prev_w_k, prev_rf_k_bar], dim=-2)
362
+ prev_agg_v = torch.cat([prev_w_v, prev_rfa_per_chunk], dim=-2)
363
+
364
+ prev_attn_mask = torch.cat([prev_s_mask, prev_chunk_mask], dim=-1)
365
+ else:
366
+ prev_agg_k = prev_w_k
367
+ prev_agg_v = prev_w_v
368
+ prev_attn_mask = prev_s_mask
369
+
370
+ prev_attn_output = attention_op(
371
+ q=prev_w_q,
372
+ k=prev_agg_k,
373
+ v=prev_agg_v,
374
+ attn_mask=prev_attn_mask,
375
+ mixedp_attn=self.config.mixedp_attn,
376
+ head_dim_scaling=self.head_dim_scaling
377
+ )
378
+ prev_attn_output = self.window_merge(prev_attn_output)
379
+
380
+ if cur_w_k is not None:
381
+ if rf_k_bar is not None:
382
+ # rf_k_bar and rfa_per_chunk take the shape [b, h, c, d]
383
+ # cur_w_k and cur_w_v also has shape [b, h, m, d]
384
+ cur_agg_k = torch.cat([cur_w_k, rf_k_bar], dim=-2)
385
+ cur_agg_v = torch.cat([cur_w_v, rfa_per_chunk], dim=-2)
386
+
387
+ cur_attn_mask = torch.cat([cur_s_mask, cur_chunk_mask], dim=-1)
388
+ else:
389
+ cur_agg_k = cur_w_k
390
+ cur_agg_v = cur_w_v
391
+ cur_attn_mask = cur_s_mask
392
+
393
+ cur_attn_output = attention_op(
394
+ q=cur_w_q,
395
+ k=cur_agg_k,
396
+ v=cur_agg_v,
397
+ attn_mask=cur_attn_mask,
398
+ mixedp_attn=self.config.mixedp_attn,
399
+ head_dim_scaling=self.head_dim_scaling
400
+ )
401
+
402
+ if prev_w_k is not None and cur_w_k is not None:
403
+ attn_output = torch.cat([prev_attn_output, cur_attn_output], dim=-2)
404
+ elif prev_w_k is not None:
405
+ attn_output = prev_attn_output
406
+ elif cur_w_k is not None:
407
+ attn_output = cur_attn_output
408
+ else:
409
+ raise ValueError("There must be some bug")
410
+
411
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
412
+ raise ValueError(
413
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
414
+ f" {attn_output.size()}"
415
+ )
416
+
417
+ attn_output = attn_output.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
418
+ attn_output = self.o_proj(attn_output)
419
+
420
+ attn_weights = None
421
+
422
+ return attn_output, attn_weights, past_key_value
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 11,
5
+ "pad_token_id": 0,
6
+ "transformers_version": "4.47.1"
7
+ }
model-00001-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b8f009b49caa79bbd15766b5b29b2dcf4a74a030af1a2043e53c3f15971cf33d
3
+ size 4994268984
model-00002-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b6131455aa5c8cfb07ecf35a9d7cb99f2f7e3a5e0fc9aa90c0c2e00f7e13d583
3
+ size 4947590376
model-00003-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c497c85b8d257c9b78a1db9caccf6f07718b9b0aa06351fc357f64c9ca896b79
3
+ size 3034842568
model.safetensors.index.json ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 12976660480
4
+ },
5
+ "weight_map": {
6
+ "lm_head.weight": "model-00003-of-00003.safetensors",
7
+ "model.embed_tokens.weight": "model-00001-of-00003.safetensors",
8
+ "model.layers.0.input_layernorm.weight": "model-00001-of-00003.safetensors",
9
+ "model.layers.0.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
10
+ "model.layers.0.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
11
+ "model.layers.0.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
12
+ "model.layers.0.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
13
+ "model.layers.0.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
14
+ "model.layers.0.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
15
+ "model.layers.0.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
16
+ "model.layers.0.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
17
+ "model.layers.0.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
18
+ "model.layers.0.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
19
+ "model.layers.1.input_layernorm.weight": "model-00001-of-00003.safetensors",
20
+ "model.layers.1.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
21
+ "model.layers.1.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
22
+ "model.layers.1.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
23
+ "model.layers.1.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
24
+ "model.layers.1.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
25
+ "model.layers.1.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
26
+ "model.layers.1.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
27
+ "model.layers.1.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
28
+ "model.layers.1.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
29
+ "model.layers.1.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
30
+ "model.layers.10.input_layernorm.weight": "model-00001-of-00003.safetensors",
31
+ "model.layers.10.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
32
+ "model.layers.10.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
33
+ "model.layers.10.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
34
+ "model.layers.10.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
35
+ "model.layers.10.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
36
+ "model.layers.10.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
37
+ "model.layers.10.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
38
+ "model.layers.10.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
39
+ "model.layers.10.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
40
+ "model.layers.10.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
41
+ "model.layers.11.input_layernorm.weight": "model-00001-of-00003.safetensors",
42
+ "model.layers.11.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
43
+ "model.layers.11.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
44
+ "model.layers.11.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
45
+ "model.layers.11.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
46
+ "model.layers.11.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
47
+ "model.layers.11.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
48
+ "model.layers.11.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
49
+ "model.layers.11.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
50
+ "model.layers.11.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
51
+ "model.layers.11.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
52
+ "model.layers.12.input_layernorm.weight": "model-00002-of-00003.safetensors",
53
+ "model.layers.12.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
54
+ "model.layers.12.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
55
+ "model.layers.12.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
56
+ "model.layers.12.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
57
+ "model.layers.12.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
58
+ "model.layers.12.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
59
+ "model.layers.12.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
60
+ "model.layers.12.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
61
+ "model.layers.12.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
62
+ "model.layers.12.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
63
+ "model.layers.13.input_layernorm.weight": "model-00002-of-00003.safetensors",
64
+ "model.layers.13.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
65
+ "model.layers.13.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
66
+ "model.layers.13.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
67
+ "model.layers.13.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
68
+ "model.layers.13.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
69
+ "model.layers.13.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
70
+ "model.layers.13.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
71
+ "model.layers.13.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
72
+ "model.layers.13.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
73
+ "model.layers.13.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
74
+ "model.layers.14.input_layernorm.weight": "model-00002-of-00003.safetensors",
75
+ "model.layers.14.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
76
+ "model.layers.14.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
77
+ "model.layers.14.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
78
+ "model.layers.14.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
79
+ "model.layers.14.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
80
+ "model.layers.14.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
81
+ "model.layers.14.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
82
+ "model.layers.14.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
83
+ "model.layers.14.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
84
+ "model.layers.14.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
85
+ "model.layers.15.input_layernorm.weight": "model-00002-of-00003.safetensors",
86
+ "model.layers.15.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
87
+ "model.layers.15.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
88
+ "model.layers.15.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
89
+ "model.layers.15.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
90
+ "model.layers.15.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
91
+ "model.layers.15.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
92
+ "model.layers.15.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
93
+ "model.layers.15.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
94
+ "model.layers.15.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
95
+ "model.layers.15.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
96
+ "model.layers.16.input_layernorm.weight": "model-00002-of-00003.safetensors",
97
+ "model.layers.16.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
98
+ "model.layers.16.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
99
+ "model.layers.16.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
100
+ "model.layers.16.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
101
+ "model.layers.16.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
102
+ "model.layers.16.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
103
+ "model.layers.16.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
104
+ "model.layers.16.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
105
+ "model.layers.16.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
106
+ "model.layers.16.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
107
+ "model.layers.17.input_layernorm.weight": "model-00002-of-00003.safetensors",
108
+ "model.layers.17.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
109
+ "model.layers.17.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
110
+ "model.layers.17.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
111
+ "model.layers.17.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
112
+ "model.layers.17.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
113
+ "model.layers.17.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
114
+ "model.layers.17.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
115
+ "model.layers.17.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
116
+ "model.layers.17.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
117
+ "model.layers.17.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
118
+ "model.layers.18.input_layernorm.weight": "model-00002-of-00003.safetensors",
119
+ "model.layers.18.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
120
+ "model.layers.18.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
121
+ "model.layers.18.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
122
+ "model.layers.18.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
123
+ "model.layers.18.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
124
+ "model.layers.18.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
125
+ "model.layers.18.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
126
+ "model.layers.18.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
127
+ "model.layers.18.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
128
+ "model.layers.18.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
129
+ "model.layers.19.input_layernorm.weight": "model-00002-of-00003.safetensors",
130
+ "model.layers.19.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
131
+ "model.layers.19.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
132
+ "model.layers.19.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
133
+ "model.layers.19.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
134
+ "model.layers.19.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
135
+ "model.layers.19.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
136
+ "model.layers.19.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
137
+ "model.layers.19.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
138
+ "model.layers.19.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
139
+ "model.layers.19.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
140
+ "model.layers.2.input_layernorm.weight": "model-00001-of-00003.safetensors",
141
+ "model.layers.2.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
142
+ "model.layers.2.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
143
+ "model.layers.2.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
144
+ "model.layers.2.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
145
+ "model.layers.2.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
146
+ "model.layers.2.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
147
+ "model.layers.2.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
148
+ "model.layers.2.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
149
+ "model.layers.2.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
150
+ "model.layers.2.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
151
+ "model.layers.20.input_layernorm.weight": "model-00002-of-00003.safetensors",
152
+ "model.layers.20.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
153
+ "model.layers.20.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
154
+ "model.layers.20.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
155
+ "model.layers.20.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
156
+ "model.layers.20.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
157
+ "model.layers.20.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
158
+ "model.layers.20.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
159
+ "model.layers.20.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
160
+ "model.layers.20.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
161
+ "model.layers.20.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
162
+ "model.layers.21.input_layernorm.weight": "model-00002-of-00003.safetensors",
163
+ "model.layers.21.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
164
+ "model.layers.21.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
165
+ "model.layers.21.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
166
+ "model.layers.21.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
167
+ "model.layers.21.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
168
+ "model.layers.21.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
169
+ "model.layers.21.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
170
+ "model.layers.21.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
171
+ "model.layers.21.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
172
+ "model.layers.21.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
173
+ "model.layers.22.input_layernorm.weight": "model-00002-of-00003.safetensors",
174
+ "model.layers.22.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
175
+ "model.layers.22.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
176
+ "model.layers.22.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
177
+ "model.layers.22.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
178
+ "model.layers.22.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
179
+ "model.layers.22.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
180
+ "model.layers.22.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
181
+ "model.layers.22.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
182
+ "model.layers.22.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
183
+ "model.layers.22.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
184
+ "model.layers.23.input_layernorm.weight": "model-00002-of-00003.safetensors",
185
+ "model.layers.23.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
186
+ "model.layers.23.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
187
+ "model.layers.23.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
188
+ "model.layers.23.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
189
+ "model.layers.23.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
190
+ "model.layers.23.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
191
+ "model.layers.23.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
192
+ "model.layers.23.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
193
+ "model.layers.23.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
194
+ "model.layers.23.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
195
+ "model.layers.24.input_layernorm.weight": "model-00003-of-00003.safetensors",
196
+ "model.layers.24.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
197
+ "model.layers.24.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
198
+ "model.layers.24.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
199
+ "model.layers.24.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
200
+ "model.layers.24.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
201
+ "model.layers.24.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
202
+ "model.layers.24.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
203
+ "model.layers.24.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
204
+ "model.layers.24.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
205
+ "model.layers.24.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
206
+ "model.layers.25.input_layernorm.weight": "model-00003-of-00003.safetensors",
207
+ "model.layers.25.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
208
+ "model.layers.25.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
209
+ "model.layers.25.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
210
+ "model.layers.25.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
211
+ "model.layers.25.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
212
+ "model.layers.25.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
213
+ "model.layers.25.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
214
+ "model.layers.25.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
215
+ "model.layers.25.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
216
+ "model.layers.25.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
217
+ "model.layers.26.input_layernorm.weight": "model-00003-of-00003.safetensors",
218
+ "model.layers.26.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
219
+ "model.layers.26.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
220
+ "model.layers.26.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
221
+ "model.layers.26.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
222
+ "model.layers.26.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
223
+ "model.layers.26.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
224
+ "model.layers.26.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
225
+ "model.layers.26.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
226
+ "model.layers.26.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
227
+ "model.layers.26.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
228
+ "model.layers.27.input_layernorm.weight": "model-00003-of-00003.safetensors",
229
+ "model.layers.27.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
230
+ "model.layers.27.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
231
+ "model.layers.27.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
232
+ "model.layers.27.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
233
+ "model.layers.27.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
234
+ "model.layers.27.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
235
+ "model.layers.27.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
236
+ "model.layers.27.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
237
+ "model.layers.27.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
238
+ "model.layers.27.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
239
+ "model.layers.28.input_layernorm.weight": "model-00003-of-00003.safetensors",
240
+ "model.layers.28.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
241
+ "model.layers.28.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
242
+ "model.layers.28.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
243
+ "model.layers.28.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
244
+ "model.layers.28.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
245
+ "model.layers.28.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
246
+ "model.layers.28.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
247
+ "model.layers.28.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
248
+ "model.layers.28.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
249
+ "model.layers.28.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
250
+ "model.layers.29.input_layernorm.weight": "model-00003-of-00003.safetensors",
251
+ "model.layers.29.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
252
+ "model.layers.29.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
253
+ "model.layers.29.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
254
+ "model.layers.29.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
255
+ "model.layers.29.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
256
+ "model.layers.29.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
257
+ "model.layers.29.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
258
+ "model.layers.29.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
259
+ "model.layers.29.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
260
+ "model.layers.29.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
261
+ "model.layers.3.input_layernorm.weight": "model-00001-of-00003.safetensors",
262
+ "model.layers.3.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
263
+ "model.layers.3.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
264
+ "model.layers.3.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
265
+ "model.layers.3.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
266
+ "model.layers.3.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
267
+ "model.layers.3.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
268
+ "model.layers.3.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
269
+ "model.layers.3.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
270
+ "model.layers.3.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
271
+ "model.layers.3.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
272
+ "model.layers.30.input_layernorm.weight": "model-00003-of-00003.safetensors",
273
+ "model.layers.30.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
274
+ "model.layers.30.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
275
+ "model.layers.30.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
276
+ "model.layers.30.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
277
+ "model.layers.30.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
278
+ "model.layers.30.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
279
+ "model.layers.30.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
280
+ "model.layers.30.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
281
+ "model.layers.30.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
282
+ "model.layers.30.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
283
+ "model.layers.31.input_layernorm.weight": "model-00003-of-00003.safetensors",
284
+ "model.layers.31.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
285
+ "model.layers.31.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
286
+ "model.layers.31.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
287
+ "model.layers.31.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
288
+ "model.layers.31.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
289
+ "model.layers.31.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
290
+ "model.layers.31.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
291
+ "model.layers.31.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
292
+ "model.layers.31.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
293
+ "model.layers.31.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
294
+ "model.layers.4.input_layernorm.weight": "model-00001-of-00003.safetensors",
295
+ "model.layers.4.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
296
+ "model.layers.4.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
297
+ "model.layers.4.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
298
+ "model.layers.4.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
299
+ "model.layers.4.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
300
+ "model.layers.4.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
301
+ "model.layers.4.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
302
+ "model.layers.4.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
303
+ "model.layers.4.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
304
+ "model.layers.4.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
305
+ "model.layers.5.input_layernorm.weight": "model-00001-of-00003.safetensors",
306
+ "model.layers.5.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
307
+ "model.layers.5.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
308
+ "model.layers.5.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
309
+ "model.layers.5.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
310
+ "model.layers.5.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
311
+ "model.layers.5.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
312
+ "model.layers.5.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
313
+ "model.layers.5.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
314
+ "model.layers.5.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
315
+ "model.layers.5.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
316
+ "model.layers.6.input_layernorm.weight": "model-00001-of-00003.safetensors",
317
+ "model.layers.6.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
318
+ "model.layers.6.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
319
+ "model.layers.6.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
320
+ "model.layers.6.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
321
+ "model.layers.6.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
322
+ "model.layers.6.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
323
+ "model.layers.6.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
324
+ "model.layers.6.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
325
+ "model.layers.6.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
326
+ "model.layers.6.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
327
+ "model.layers.7.input_layernorm.weight": "model-00001-of-00003.safetensors",
328
+ "model.layers.7.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
329
+ "model.layers.7.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
330
+ "model.layers.7.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
331
+ "model.layers.7.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
332
+ "model.layers.7.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
333
+ "model.layers.7.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
334
+ "model.layers.7.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
335
+ "model.layers.7.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
336
+ "model.layers.7.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
337
+ "model.layers.7.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
338
+ "model.layers.8.input_layernorm.weight": "model-00001-of-00003.safetensors",
339
+ "model.layers.8.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
340
+ "model.layers.8.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
341
+ "model.layers.8.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
342
+ "model.layers.8.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
343
+ "model.layers.8.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
344
+ "model.layers.8.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
345
+ "model.layers.8.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
346
+ "model.layers.8.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
347
+ "model.layers.8.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
348
+ "model.layers.8.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
349
+ "model.layers.9.input_layernorm.weight": "model-00001-of-00003.safetensors",
350
+ "model.layers.9.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
351
+ "model.layers.9.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
352
+ "model.layers.9.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
353
+ "model.layers.9.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
354
+ "model.layers.9.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
355
+ "model.layers.9.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
356
+ "model.layers.9.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
357
+ "model.layers.9.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
358
+ "model.layers.9.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
359
+ "model.layers.9.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
360
+ "model.norm.weight": "model-00003-of-00003.safetensors"
361
+ }
362
+ }
modeling_evabyte.py ADDED
@@ -0,0 +1,1092 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+ import math
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import torch.utils.checkpoint
6
+ from torch import nn
7
+ from torch.nn import CrossEntropyLoss
8
+ from transformers.activations import ACT2FN
9
+ from transformers.cache_utils import Cache
10
+ from transformers.modeling_outputs import (
11
+ BaseModelOutputWithPast,
12
+ CausalLMOutputWithPast,
13
+ )
14
+ from transformers.modeling_utils import PreTrainedModel
15
+
16
+ from .configuration_evabyte import EvaByteConfig
17
+ from .multibyte_decoding_evabyte import MultiByteDecodingMixin
18
+ try:
19
+ import triton
20
+ USE_TRITON_IMPL = True
21
+ from .eva import EvaAttention
22
+ from .eva_agg_kernel import triton_eva_agg_fwd
23
+ from .eva_prep_kv_kernel import triton_eva_prep_kv_fwd
24
+ except ImportError:
25
+ USE_TRITON_IMPL = False
26
+ print("WARNING: triton is not installed, using fallback EVA which might be slow and throw errors")
27
+ from .eva_pt_ref import EvaAttention
28
+ from .eva_cache import EvaCache, EvaStaticCacheForTriton
29
+
30
+ MASK_MIN_VALUE = -10e10
31
+
32
+ def prepare_eva_attention_mask(
33
+ seq_len,
34
+ device,
35
+ chunk_size,
36
+ window_size,
37
+ use_cache=False,
38
+ cache=None
39
+ ):
40
+ """
41
+ Prepare attention masks for EVA.
42
+
43
+ """
44
+ chunk_causal_mask = None
45
+ window_causal_mask = None
46
+ if use_cache:
47
+ cached_seq_len = cache.get_seq_length()
48
+ total_seq_len = seq_len + cached_seq_len
49
+ # cached_seq_len will be 0 during prefilling
50
+ # padded_seq_len = chunk_size * math.ceil(total_seq_len / chunk_size)
51
+ padded_seq_len = window_size * math.ceil(total_seq_len / window_size)
52
+ num_chunks = padded_seq_len // chunk_size
53
+ else:
54
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
55
+ assert seq_len % chunk_size == 0
56
+ num_chunks = seq_len // chunk_size
57
+
58
+ assert seq_len % window_size == 0
59
+
60
+ # create causal mask
61
+ ################################
62
+ # generate chunked causal masks
63
+ ################################
64
+ # [b, h, j, c, c]
65
+ chunks_per_window = window_size // chunk_size
66
+ if num_chunks >= chunks_per_window:
67
+ chunk_causal_mask = torch.ones(
68
+ (chunk_size, num_chunks, num_chunks),
69
+ device=device,
70
+ dtype=torch.bool
71
+ ).triu(0)
72
+
73
+ num_blocks = num_chunks // chunks_per_window
74
+ chunk_causal_mask = chunk_causal_mask.reshape(
75
+ chunk_size,
76
+ num_blocks,
77
+ chunks_per_window,
78
+ num_blocks,
79
+ chunks_per_window
80
+ ).transpose(-2, -3)
81
+
82
+ block_diag_zero = (
83
+ torch.eye(num_blocks, device=device, dtype=torch.bool)
84
+ .unsqueeze(-1)
85
+ .unsqueeze(-1)
86
+ .unsqueeze(0)
87
+ )
88
+
89
+ # Set diagonal blocks to zero
90
+ chunk_causal_mask = chunk_causal_mask.masked_fill(block_diag_zero, True)
91
+
92
+ # Reshape back to original size
93
+ chunk_causal_mask = (
94
+ chunk_causal_mask
95
+ .transpose(-2, -3)
96
+ .reshape(chunk_size, num_chunks, num_chunks)
97
+ .transpose(-2, -3)
98
+ .reshape(chunk_size * num_chunks, num_chunks)
99
+ .unsqueeze(0)
100
+ .unsqueeze(0)
101
+ )
102
+ else:
103
+ chunk_causal_mask = torch.ones(
104
+ (1, 1, chunk_size, num_chunks, num_chunks),
105
+ device=device,
106
+ dtype=torch.bool,
107
+ ).triu(0).transpose(-2, -3) # [1, 1, c, j, c]
108
+ chunk_causal_mask = chunk_causal_mask.reshape(
109
+ 1, 1, chunk_size * num_chunks, num_chunks
110
+ ) # [1, 1, n, c]
111
+
112
+ if use_cache:
113
+ chunk_causal_mask = chunk_causal_mask[..., cached_seq_len : cached_seq_len + seq_len, :]
114
+
115
+ window_causal_mask = torch.ones(
116
+ (1, 1, 1, window_size, window_size),
117
+ device=device
118
+ ).triu(1).to(torch.bool)
119
+ return (chunk_causal_mask, window_causal_mask)
120
+
121
+ def pad_to_multiple(tensor, multiple, dim=-2, value=0, create_mask=False, left_padding=False):
122
+ assert dim < 0 # only accept ``dim'' index in a reverse manner
123
+ seqlen = int(tensor.shape[dim])
124
+ m = seqlen / multiple
125
+ if m.is_integer():
126
+ if create_mask:
127
+ return tensor, torch.ones(size=(tensor.shape[0], tensor.shape[dim]), dtype=torch.bool, device=tensor.device)
128
+ else:
129
+ return tensor
130
+ remainder = math.ceil(m) * multiple - seqlen
131
+ pad_offset = (0,) * (-1 - dim) * 2
132
+ if left_padding:
133
+ padded_res = F.pad(tensor, (*pad_offset, remainder, 0), value=value)
134
+ else:
135
+ padded_res = F.pad(tensor, (*pad_offset, 0, remainder), value=value)
136
+ if create_mask:
137
+ # assume dim 0 is the batch size
138
+ padding_mask = torch.ones(size=(padded_res.shape[0], padded_res.shape[dim]), dtype=torch.bool, device=padded_res.device)
139
+ if left_padding:
140
+ padding_mask[:, :remainder] = False
141
+ else:
142
+ padding_mask[:, -remainder:] = False
143
+ return padded_res, padding_mask
144
+ else:
145
+ return padded_res
146
+
147
+ class EvaByteRMSNorm(nn.Module):
148
+ def __init__(self, config):
149
+ super().__init__()
150
+ self.config = config
151
+ self.fp32_ln = config.fp32_ln
152
+ self.variance_epsilon = config.rms_norm_eps
153
+ self.add_unit_offset = config.norm_add_unit_offset
154
+ if self.add_unit_offset:
155
+ self.weight = nn.Parameter(torch.zeros(config.hidden_size))
156
+ else:
157
+ self.weight = nn.Parameter(torch.ones(config.hidden_size))
158
+
159
+ def forward(self, hidden_states):
160
+ if hasattr(self, 'config'):
161
+ fp32_ln = self.config.fp32_ln
162
+ else:
163
+ fp32_ln = self.fp32_ln
164
+ hidden_states = hidden_states.to(torch.float32 if fp32_ln else torch.bfloat16)
165
+
166
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
167
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
168
+ if self.add_unit_offset:
169
+ return (1 + self.weight) * hidden_states
170
+ else:
171
+ return self.weight * hidden_states
172
+
173
+ class EvaByteRotaryEmbedding(torch.nn.Module):
174
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
175
+ super().__init__()
176
+
177
+ self.dim = dim
178
+ self.max_position_embeddings = max_position_embeddings
179
+ self.base = base
180
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
181
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
182
+
183
+ self._set_cos_sin_cache(seq_len=max_position_embeddings,
184
+ device=self.inv_freq.device,
185
+ dtype=torch.get_default_dtype())
186
+
187
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
188
+ self.max_seq_len_cached = seq_len
189
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
190
+
191
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
192
+ emb = torch.cat((freqs, freqs), dim=-1)
193
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
194
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
195
+
196
+
197
+ def forward(self, x, seq_len=None):
198
+ # x: [bs, num_attention_heads, seq_len, head_size]
199
+ if seq_len > self.max_seq_len_cached:
200
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
201
+
202
+ # return (
203
+ # self.cos_cached[:seq_len].to(dtype=x.dtype),
204
+ # self.sin_cached[:seq_len].to(dtype=x.dtype),
205
+ # )
206
+ if seq_len < self.max_seq_len_cached:
207
+ cos_slice = self.cos_cached.split(seq_len, dim=0)[0]
208
+ sin_slice = self.sin_cached.split(seq_len, dim=0)[0]
209
+ else:
210
+ cos_slice = self.cos_cached
211
+ sin_slice = self.sin_cached
212
+
213
+ return (
214
+ cos_slice.to(dtype=x.dtype),
215
+ sin_slice.to(dtype=x.dtype),
216
+ )
217
+
218
+
219
+
220
+ class EvaByteLinearScalingRotaryEmbedding(EvaByteRotaryEmbedding):
221
+ """EvaByteRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
222
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
223
+ self.scaling_factor = scaling_factor
224
+ super().__init__(dim, max_position_embeddings, base, device)
225
+
226
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
227
+ self.max_seq_len_cached = seq_len
228
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
229
+ t = t / self.scaling_factor
230
+
231
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
232
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
233
+ emb = torch.cat((freqs, freqs), dim=-1)
234
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
235
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
236
+
237
+
238
+ class EvaByteDynamicNTKScalingRotaryEmbedding(EvaByteRotaryEmbedding):
239
+ """EvaByteRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
240
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
241
+ self.scaling_factor = scaling_factor
242
+ super().__init__(dim, max_position_embeddings, base, device)
243
+
244
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
245
+ self.max_seq_len_cached = seq_len
246
+
247
+ if seq_len > self.max_position_embeddings:
248
+ base = self.base * ((self.scaling_factor * seq_len / self.max_position_embeddings) -
249
+ (self.scaling_factor - 1))**(self.dim / (self.dim - 2))
250
+ inv_freq = 1.0 / (base**(torch.arange(0, self.dim, 2).float().to(device) / self.dim))
251
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
252
+
253
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
254
+
255
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
256
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
257
+ emb = torch.cat((freqs, freqs), dim=-1)
258
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
259
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
260
+
261
+
262
+ class EvaByteMLP(nn.Module):
263
+ def __init__(self, config, layer_idx: int = None):
264
+ super().__init__()
265
+ self.hidden_size = config.hidden_size
266
+ self.intermediate_size = config.intermediate_size
267
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
268
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
269
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
270
+ self.act_fn = ACT2FN[config.hidden_act]
271
+ self.layer_idx = layer_idx
272
+ self.config = config
273
+
274
+ def forward(self, x):
275
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
276
+ return down_proj
277
+
278
+ class EvaByteDecoderLayer(nn.Module):
279
+ def __init__(self, config: EvaByteConfig, layer_idx: int = None):
280
+ super().__init__()
281
+ self.config = config
282
+ self.hidden_size = config.hidden_size
283
+ self.self_attn = EvaAttention(config=config, layer_idx=layer_idx)
284
+ self.mlp = EvaByteMLP(config, layer_idx=layer_idx)
285
+ self.input_layernorm = EvaByteRMSNorm(config)
286
+ self.post_attention_layernorm = EvaByteRMSNorm(config)
287
+
288
+ def forward(
289
+ self,
290
+ hidden_states: torch.Tensor,
291
+ attention_mask: Optional[torch.Tensor] = None,
292
+ position_ids: Optional[torch.LongTensor] = None,
293
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
294
+ output_attentions: Optional[bool] = False,
295
+ use_cache: Optional[bool] = False,
296
+ cos: Optional[torch.Tensor] = None,
297
+ sin: Optional[torch.Tensor] = None,
298
+ multibyte_decoding: Optional[bool] = False,
299
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
300
+ residual = hidden_states
301
+ if self.config.fp32_skip_add:
302
+ residual = residual.float()
303
+
304
+ hidden_states = self.input_layernorm(hidden_states)
305
+
306
+ # Self Attention
307
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(hidden_states=hidden_states,
308
+ attention_mask=attention_mask,
309
+ position_ids=position_ids,
310
+ past_key_value=past_key_value,
311
+ output_attentions=output_attentions,
312
+ use_cache=use_cache,
313
+ cos=cos,
314
+ sin=sin,
315
+ multibyte_decoding=multibyte_decoding)
316
+ hidden_states = residual + hidden_states
317
+
318
+ # Fully Connected
319
+ residual = hidden_states
320
+ if self.config.fp32_skip_add:
321
+ residual = residual.float()
322
+ hidden_states = self.post_attention_layernorm(hidden_states)
323
+ hidden_states = self.mlp(hidden_states)
324
+ hidden_states = residual + hidden_states
325
+
326
+ outputs = (hidden_states, )
327
+
328
+ if output_attentions:
329
+ outputs += (self_attn_weights, )
330
+
331
+ if use_cache:
332
+ outputs += (present_key_value, )
333
+ return outputs
334
+
335
+ class EvaBytePreTrainedModel(PreTrainedModel):
336
+ config_class = EvaByteConfig
337
+ base_model_prefix = "model"
338
+ supports_gradient_checkpointing = True
339
+ _no_split_modules = ["EvaByteDecoderLayer"]
340
+ _skip_keys_device_placement = "past_key_values"
341
+
342
+ def _init_weights(self, module):
343
+ std = getattr(self.config, "initializer_range", 0.02)
344
+ if isinstance(module, nn.Linear):
345
+ module.weight.data.normal_(mean=0.0, std=std)
346
+ if module.bias is not None:
347
+ module.bias.data.zero_()
348
+ elif isinstance(module, nn.Embedding):
349
+ module.weight.data.normal_(mean=0.0, std=std)
350
+ if module.padding_idx is not None:
351
+ module.weight.data[module.padding_idx].zero_()
352
+
353
+ def _set_gradient_checkpointing(self, module, value=False):
354
+ if isinstance(module, EvaByteModel):
355
+ module.gradient_checkpointing = value
356
+
357
+ class EvaByteModel(EvaBytePreTrainedModel):
358
+ """
359
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`EvaByteDecoderLayer`]
360
+
361
+ Args:
362
+ config: EvaByteConfig
363
+ """
364
+ def __init__(self, config: EvaByteConfig):
365
+ super().__init__(config)
366
+ self.padding_idx = config.pad_token_id
367
+ self.vocab_size = config.vocab_size
368
+ self.hidden_size = config.hidden_size
369
+ self.num_heads = config.num_attention_heads
370
+ self.head_dim = self.hidden_size // self.num_heads
371
+ self.max_position_embeddings = self.config.max_position_embeddings
372
+
373
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
374
+ self.layers = nn.ModuleList([EvaByteDecoderLayer(config, layer_idx=layer_idx) for layer_idx in range(config.num_hidden_layers)])
375
+ self.norm = EvaByteRMSNorm(config)
376
+
377
+ self.gradient_checkpointing = False
378
+ self.rope = config.rope_theta
379
+ # Initialize weights and apply final processing
380
+ self.post_init()
381
+ self._init_rope()
382
+
383
+ def _init_rope(self):
384
+ if self.config.rope_scaling is None:
385
+ self.rotary_emb = EvaByteRotaryEmbedding(self.head_dim,
386
+ max_position_embeddings=self.max_position_embeddings,
387
+ base=self.rope)
388
+ else:
389
+ scaling_type = self.config.rope_scaling["type"]
390
+ scaling_factor = self.config.rope_scaling["factor"]
391
+ if scaling_type == "linear":
392
+ self.rotary_emb = EvaByteLinearScalingRotaryEmbedding(
393
+ self.head_dim,
394
+ max_position_embeddings=self.max_position_embeddings,
395
+ scaling_factor=scaling_factor,
396
+ base=self.rope)
397
+ elif scaling_type == "dynamic":
398
+ self.rotary_emb = EvaByteDynamicNTKScalingRotaryEmbedding(
399
+ self.head_dim,
400
+ max_position_embeddings=self.max_position_embeddings,
401
+ scaling_factor=scaling_factor,
402
+ base=self.rope)
403
+ else:
404
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
405
+
406
+ def get_input_embeddings(self):
407
+ return self.embed_tokens
408
+
409
+ def set_input_embeddings(self, value):
410
+ self.embed_tokens = value
411
+
412
+ def _helper_padding_mask(
413
+ self,
414
+ padding_mask,
415
+ causal_mask
416
+ ):
417
+ padding_mask = torch.logical_or(padding_mask, padding_mask.transpose(-1, -2))
418
+ return torch.logical_or(padding_mask, causal_mask)
419
+
420
+ def _prepare_eva_generation_attn_mask_triton(
421
+ self,
422
+ attention_mask,
423
+ input_ids,
424
+ use_cache,
425
+ past_key_values
426
+ ):
427
+ batch_size, seq_len = input_ids.shape
428
+ if use_cache and past_key_values.get_seq_length() > 0:
429
+ # decoding phase
430
+ if past_key_values.rf_mask[0] is not None:
431
+ cur_rf_mask = torch.zeros(
432
+ (batch_size, 1, seq_len, 1),
433
+ dtype=past_key_values.rf_mask[0].dtype,
434
+ device=past_key_values.rf_mask[0].device
435
+ )
436
+ else:
437
+ cur_rf_mask = None
438
+
439
+ if past_key_values.s_mask[0] is not None:
440
+ cur_s_mask = torch.zeros(
441
+ (batch_size, 1, seq_len, 1),
442
+ dtype=past_key_values.s_mask[0].dtype,
443
+ device=past_key_values.s_mask[0].device
444
+ )
445
+ else:
446
+ cur_s_mask = None
447
+
448
+ seen_tokens = past_key_values.get_seq_length()
449
+ if seen_tokens <= self.config.window_size:
450
+ rfa_chunks_dummy_mask = None
451
+ else:
452
+ if cur_s_mask is not None:
453
+ chunks_per_window = int(self.config.window_size // self.config.chunk_size)
454
+ # the ongoing decoding step would be (seen_seq_len + 1)-th token
455
+ num_windows_seen_so_far = seen_tokens // self.config.window_size
456
+ rfa_chunks_dummy_mask = torch.zeros(
457
+ (batch_size, 1, seq_len, num_windows_seen_so_far * chunks_per_window),
458
+ dtype=past_key_values.s_mask[0].dtype,
459
+ device=past_key_values.s_mask[0].device
460
+ )
461
+ else:
462
+ rfa_chunks_dummy_mask = None
463
+ # rf_mask and cur_mask are 0s because we do not want to mask them
464
+ return (cur_s_mask, cur_rf_mask, rfa_chunks_dummy_mask)
465
+
466
+ if attention_mask is not None and torch.any(attention_mask == 0.0):
467
+ # convert 0 -> padding to 1 -> padding
468
+ padded_attention_mask = pad_to_multiple(
469
+ attention_mask,
470
+ self.config.window_size,
471
+ dim=-1,
472
+ value=0,
473
+ create_mask=False,
474
+ left_padding=False
475
+ )
476
+ # convert 0 -> padding to 1 -> padding
477
+ padded_rf_mask = ~padded_attention_mask.unsqueeze(1).unsqueeze(-1).to(torch.bool) # [b, 1, n, 1]
478
+ # [b, 1, w, j, 1]
479
+ padded_w_attn_mask = padded_rf_mask.reshape(batch_size, 1, -1, self.config.window_size, 1).to(torch.bool)
480
+ # [b, 1, w, j, 1] [b, 1, w, 1, j] -> [b, 1, w, j, j]
481
+ w_padding_mask = torch.logical_or(padded_w_attn_mask, padded_w_attn_mask.transpose(-1, -2))
482
+ w_causal_mask = torch.ones(
483
+ (1, 1, 1, self.config.window_size, self.config.window_size),
484
+ device=input_ids.device
485
+ ).triu(1).to(torch.bool)
486
+ s_mask = torch.logical_or(w_padding_mask, w_causal_mask)
487
+ s_mask = s_mask.reshape(batch_size, 1, -1, self.config.window_size)
488
+ s_mask = s_mask[..., :seq_len, :]
489
+ # negate the attention mask to get the padding mask
490
+ rf_mask = ~attention_mask.unsqueeze(1).unsqueeze(-1).to(torch.bool) # [b, 1, n, 1]
491
+ return (s_mask, rf_mask)
492
+ else:
493
+ return (None, None)
494
+
495
+ def _prepare_eva_generation_attn_mask(
496
+ self,
497
+ attention_mask,
498
+ input_ids,
499
+ use_cache,
500
+ past_key_values
501
+ ):
502
+ batch_size, seq_len = input_ids.shape
503
+ if use_cache and past_key_values.get_seq_length() > 0:
504
+ # decoding phase
505
+ if past_key_values.rf_mask[0] is not None:
506
+ rf_mask = torch.zeros(
507
+ (batch_size, 1, seq_len, 1),
508
+ dtype=past_key_values.rf_mask[0].dtype,
509
+ device=past_key_values.rf_mask[0].device
510
+ )
511
+ else:
512
+ rf_mask = None
513
+
514
+ cur_causal_mask = torch.zeros(
515
+ (batch_size, 1, seq_len, 1),
516
+ dtype=torch.bool,
517
+ device=input_ids.device
518
+ )
519
+
520
+ chunk_causal_mask = torch.ones(
521
+ (batch_size, 1, seq_len, 1),
522
+ dtype=torch.bool,
523
+ device=input_ids.device
524
+ )
525
+ # chunk_causal_mask are 1s because we will mask them by default and
526
+ # will be unmasked when the current singleton attention is processed over
527
+ return (None, cur_causal_mask, chunk_causal_mask, rf_mask)
528
+
529
+ true_num_chunks = seq_len // self.config.chunk_size
530
+ chunk_causal_mask, _ = prepare_eva_attention_mask(
531
+ seq_len,
532
+ input_ids.device,
533
+ self.config.chunk_size,
534
+ self.config.window_size,
535
+ use_cache=use_cache,
536
+ cache=past_key_values
537
+ )
538
+ chunk_causal_mask = chunk_causal_mask[..., :seq_len, :true_num_chunks]
539
+ if attention_mask is not None and torch.any(attention_mask == 0.0):
540
+ # convert 0 -> padding to 1 -> padding
541
+ rf_mask = ~attention_mask.unsqueeze(1).unsqueeze(-1).to(torch.bool) # [b, 1, n, 1]
542
+ else:
543
+ rf_mask = None
544
+
545
+ if seq_len < self.config.window_size:
546
+ cur_window_mask = torch.ones(
547
+ (1, 1, seq_len, seq_len),
548
+ device=input_ids.device
549
+ ).triu(1).to(torch.bool)
550
+ if rf_mask is not None:
551
+ cur_window_mask = self._helper_padding_mask(rf_mask, cur_window_mask)
552
+ prev_window_mask = None
553
+ else:
554
+ if seq_len % self.config.window_size == 0:
555
+ num_windows = seq_len // self.config.window_size
556
+ cur_window_mask = None
557
+ prev_window_mask = torch.ones(
558
+ (1, 1, num_windows, self.config.window_size, self.config.window_size),
559
+ device=input_ids.device
560
+ ).triu(1).to(torch.bool)
561
+ if rf_mask is not None:
562
+ prev_rf_mask = rf_mask.reshape(batch_size, 1, -1, self.config.window_size, 1)
563
+ prev_window_mask = self._helper_padding_mask(prev_rf_mask, prev_window_mask)
564
+ else:
565
+ num_windows = seq_len // self.config.window_size
566
+ remainder_tokens = seq_len % self.config.window_size
567
+ cur_window_mask = torch.ones(
568
+ (1, 1, remainder_tokens, remainder_tokens),
569
+ device=input_ids.device
570
+ ).triu(1).to(torch.bool)
571
+ prev_window_mask = torch.ones(
572
+ (1, 1, num_windows, self.config.window_size, self.config.window_size),
573
+ device=input_ids.device
574
+ ).triu(1).to(torch.bool)
575
+ if rf_mask is not None:
576
+ prev_rf_mask, cur_rf_mask = torch.split(rf_mask, [seq_len - remainder_tokens, remainder_tokens], dim=-2)
577
+ cur_window_mask = self._helper_padding_mask(cur_rf_mask, cur_window_mask)
578
+ prev_rf_mask = prev_rf_mask.reshape(batch_size, 1, -1, self.config.window_size, 1)
579
+ prev_window_mask = self._helper_padding_mask(prev_rf_mask, prev_window_mask)
580
+
581
+ return (prev_window_mask, cur_window_mask, chunk_causal_mask, rf_mask)
582
+
583
+ def forward(
584
+ self,
585
+ input_ids: torch.LongTensor = None,
586
+ attention_mask: Optional[torch.Tensor] = None,
587
+ position_ids: Optional[torch.LongTensor] = None,
588
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
589
+ inputs_embeds: Optional[torch.FloatTensor] = None,
590
+ use_cache: Optional[bool] = None,
591
+ output_attentions: Optional[bool] = None,
592
+ output_hidden_states: Optional[bool] = None,
593
+ return_dict: Optional[bool] = None,
594
+ multibyte_decoding: Optional[bool] = None,
595
+ ) -> Tuple:
596
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
597
+ output_hidden_states = (output_hidden_states
598
+ if output_hidden_states is not None else self.config.output_hidden_states)
599
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
600
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
601
+
602
+ if (input_ids is None) ^ (inputs_embeds is not None):
603
+ raise ValueError(
604
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
605
+ )
606
+
607
+ if self.gradient_checkpointing and self.training and use_cache:
608
+ raise ValueError("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
609
+
610
+ batch_size, seq_len = input_ids.shape
611
+ #### Step 0. Hack
612
+ if (not self.training) and (not use_cache) and (not multibyte_decoding):
613
+ # forward-only inference mode.
614
+ # We tweak use_cache to be True to reuse code for generation
615
+ use_cache = True
616
+ device = input_ids.device if input_ids is not None else None
617
+ if position_ids is None:
618
+ position_ids = torch.arange(0, seq_len, device=device, dtype=int).reshape(1, -1).expand(batch_size, -1)
619
+
620
+ #### Step 1. Prepare caches if in inference mode
621
+ if use_cache:
622
+ if past_key_values is not None:
623
+ assert isinstance(past_key_values, Cache)
624
+ else:
625
+ if not USE_TRITON_IMPL:
626
+ past_key_values = EvaCache()
627
+ else:
628
+ past_key_values = EvaStaticCacheForTriton(
629
+ input_ids.shape[0],
630
+ self.config.num_attention_heads,
631
+ self.config.window_size,
632
+ self.config.hidden_size // self.config.num_attention_heads,
633
+ self.config.num_hidden_layers,
634
+ self.embed_tokens.weight.dtype,
635
+ self.embed_tokens.weight.device,
636
+ )
637
+
638
+ if not multibyte_decoding:
639
+ if use_cache:
640
+ if USE_TRITON_IMPL:
641
+ causal_mask = self._prepare_eva_generation_attn_mask_triton(
642
+ attention_mask,
643
+ input_ids,
644
+ use_cache,
645
+ past_key_values
646
+ )
647
+ else:
648
+ causal_mask = self._prepare_eva_generation_attn_mask(
649
+ attention_mask,
650
+ input_ids,
651
+ use_cache,
652
+ past_key_values
653
+ )
654
+ else:
655
+ assert self.training
656
+ assert seq_len % self.config.window_size == 0
657
+ # for training, we need to pass in the attention mask
658
+ # usually calculated by _prepare_training_attn_mask()
659
+ causal_mask = attention_mask
660
+ else:
661
+ assert use_cache
662
+ causal_mask = attention_mask
663
+
664
+ if inputs_embeds is None:
665
+ inputs_embeds = self.embed_tokens(input_ids)
666
+
667
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
668
+ max_seq_length = past_seen_tokens + inputs_embeds.shape[1]
669
+
670
+ hidden_states = inputs_embeds
671
+
672
+ if position_ids is None:
673
+ assert not use_cache, "during decoding we must explicitly pass position_ids to the model call"
674
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
675
+ position_ids = torch.arange(past_seen_tokens, max_seq_length, device=device, dtype=int).reshape(1, -1).expand(batch_size, -1)
676
+
677
+ cos, sin = self.rotary_emb(hidden_states, seq_len=max_seq_length)
678
+ assert len(cos.shape) == 2, f"cos should be of shape (max_seq_len, head_dim), got {cos.shape} instead"
679
+ assert sin.shape == cos.shape, f"sin should be of shape (max_seq_len, head_dim), got {sin.shape} instead"
680
+ assert len(position_ids.shape) == 2, f"position_ids should be of 2D, got {position_ids.shape} instead"
681
+ cos = cos[position_ids, :]
682
+ sin = sin[position_ids, :]
683
+ cos = cos.unsqueeze(1)
684
+ sin = sin.unsqueeze(1)
685
+
686
+ if USE_TRITON_IMPL and (not multibyte_decoding):
687
+ # the masks generated above for triton kernels are boolean. Convert them to floats
688
+ if (
689
+ (not use_cache) or
690
+ (use_cache and past_seen_tokens == 0)
691
+ ):
692
+ window_mask, intra_chunk_mask = causal_mask
693
+
694
+ if window_mask is not None:
695
+ assert window_mask.dtype == torch.bool
696
+ window_mask_float = window_mask.to(torch.float)
697
+ window_mask_float = window_mask_float.masked_fill(window_mask.to(torch.bool), MASK_MIN_VALUE)
698
+ window_mask_float = window_mask_float.reshape(batch_size, 1, -1, self.config.window_size)
699
+ window_mask = window_mask_float.to(hidden_states.dtype)
700
+
701
+ if intra_chunk_mask is not None:
702
+ assert intra_chunk_mask.dtype == torch.bool
703
+ intra_chunk_mask_float = intra_chunk_mask.to(torch.float)
704
+ intra_chunk_mask_float = intra_chunk_mask_float.masked_fill(intra_chunk_mask.to(torch.bool), MASK_MIN_VALUE)
705
+ intra_chunk_mask = intra_chunk_mask_float.to(hidden_states.dtype)
706
+ causal_mask = (window_mask, intra_chunk_mask)
707
+
708
+ if self.config.fp32_skip_add:
709
+ hidden_states = hidden_states.float()
710
+
711
+ # decoder layers
712
+ all_hidden_states = () if output_hidden_states else None
713
+ all_self_attns = () if output_attentions else None
714
+ next_decoder_cache = None
715
+
716
+ for decoder_layer in self.layers:
717
+ if output_hidden_states:
718
+ all_hidden_states += (hidden_states, )
719
+
720
+ if self.gradient_checkpointing and self.training:
721
+
722
+ def create_custom_forward(module):
723
+ def custom_forward(*inputs):
724
+ # None for past_key_value
725
+ return module(*inputs, output_attentions, use_cache=None)
726
+
727
+ return custom_forward
728
+
729
+ layer_outputs = torch.utils.checkpoint.checkpoint(
730
+ create_custom_forward(decoder_layer),
731
+ hidden_states,
732
+ causal_mask,
733
+ position_ids,
734
+ None,
735
+ )
736
+ else:
737
+ layer_outputs = decoder_layer(
738
+ hidden_states,
739
+ attention_mask=causal_mask,
740
+ position_ids=position_ids,
741
+ past_key_value=past_key_values,
742
+ output_attentions=output_attentions,
743
+ use_cache=use_cache,
744
+ cos=cos,
745
+ sin=sin,
746
+ multibyte_decoding=multibyte_decoding,
747
+ )
748
+
749
+ hidden_states = layer_outputs[0]
750
+
751
+ if use_cache:
752
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
753
+
754
+ if output_attentions:
755
+ all_self_attns += (layer_outputs[1], )
756
+
757
+ hidden_states = self.norm(hidden_states)
758
+
759
+ # add hidden states from the last decoder layer
760
+ if output_hidden_states:
761
+ all_hidden_states += (hidden_states, )
762
+
763
+ next_cache = next_decoder_cache if use_cache else None
764
+ if not return_dict:
765
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
766
+
767
+ return BaseModelOutputWithPast(
768
+ last_hidden_state=hidden_states,
769
+ past_key_values=next_cache,
770
+ hidden_states=all_hidden_states,
771
+ attentions=all_self_attns,
772
+ )
773
+
774
+
775
+ class EvaByteForCausalLM(EvaBytePreTrainedModel, MultiByteDecodingMixin):
776
+ _tied_weights_keys = ["lm_head.weight"]
777
+
778
+ def __init__(self, config):
779
+ EvaBytePreTrainedModel.__init__(self, config)
780
+
781
+ self.model = EvaByteModel(config)
782
+ self.vocab_size = config.vocab_size
783
+ # define multibyte prediction heads
784
+ if hasattr(config, "num_pred_heads") and config.num_pred_heads > 1:
785
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size * config.num_pred_heads, bias=False)
786
+ else:
787
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
788
+
789
+ self.post_init()
790
+
791
+ def get_input_embeddings(self):
792
+ return self.model.embed_tokens
793
+
794
+ def set_input_embeddings(self, value):
795
+ self.model.embed_tokens = value
796
+
797
+ def get_output_embeddings(self):
798
+ return self.lm_head
799
+
800
+ def set_output_embeddings(self, new_embeddings):
801
+ self.lm_head = new_embeddings
802
+
803
+ def set_decoder(self, decoder):
804
+ self.model = decoder
805
+
806
+ def get_decoder(self):
807
+ return self.model
808
+
809
+ def _prepare_training_attn_mask(
810
+ self,
811
+ target_token_type_ids,
812
+ use_doc_boundary_attention,
813
+ EOS_TOKEN_TYPE_ID=None,
814
+ PAD_TOKEN_TYPE_ID=None,
815
+ ):
816
+ '''
817
+ This function prepares the attention mask for training byte models.
818
+ target_token_type_ids:
819
+ Tensor of shape (batch_size, seq_len), marking the token type ids
820
+ for the target sequence. In particular, we should have
821
+ - target_token_type_ids[i, j] = EOS_TOKEN_TYPE_ID
822
+ if the j-th token in the i-th sequence is the end of an article.
823
+ - target_token_type_ids[i, j] = PAD_TOKEN_TYPE_ID
824
+ if the j-th token in the i-th sequence is the padding token.
825
+ use_doc_boundary_attention: bool,
826
+ whether to enable doc boundary attention.
827
+ EOS_TOKEN_TYPE_ID: int,
828
+ the token type id for the end of an article.
829
+ PAD_TOKEN_TYPE_ID: int,
830
+ the token type id for the padding token.
831
+ '''
832
+ assert self.training
833
+ batch_size, num_tokens = target_token_type_ids.shape
834
+
835
+ chunk_causal_mask, window_causal_mask = prepare_eva_attention_mask(
836
+ num_tokens,
837
+ target_token_type_ids.device,
838
+ chunk_size=self.config.chunk_size,
839
+ window_size=self.config.window_size,
840
+ use_cache=False,
841
+ cache=None
842
+ )
843
+ if use_doc_boundary_attention:
844
+ #### step 1: mark each document with a unique id
845
+ end_token_ids = {EOS_TOKEN_TYPE_ID, PAD_TOKEN_TYPE_ID}
846
+ token_types = torch.zeros(batch_size, num_tokens)
847
+ for sequence_idx, sequence in enumerate(target_token_type_ids):
848
+ num_articles = 0
849
+ start_index = 0
850
+ # for each sample in the batch, the collapsed attention mask looks like:
851
+ # [1, 1, .... 1, 0, 2, 2, ... 2, 0, ... n, n ..... n], assuming there are n articles in the sequence.
852
+ # Each of the n articles are separated by 0.
853
+ for token_idx, token_type_id in enumerate(sequence):
854
+ if start_index is not None and token_type_id.item() in end_token_ids:
855
+ num_articles += 1
856
+ end_index = token_idx if token_type_id == PAD_TOKEN_TYPE_ID else token_idx + 1
857
+ token_types[sequence_idx][start_index:end_index] = num_articles
858
+ start_index = None
859
+ elif start_index is None and token_type_id not in end_token_ids:
860
+ start_index = token_idx + 1
861
+
862
+ assert num_tokens % self.config.chunk_size == 0, "Number of tokens must be divisible by chunk size"
863
+ assert num_tokens % self.config.window_size == 0, "Number of tokens must be divisible by window size"
864
+ num_chunks = num_tokens // self.config.chunk_size
865
+ num_windows = num_tokens // self.config.window_size
866
+
867
+ article_separator = 0
868
+
869
+ #### step 2: generate attention masks for each window
870
+ #### NOTE: we perform exact attention within each window,
871
+ #### so we only need to mask out different documents
872
+ #### for each window.
873
+ token_types_windows = token_types.reshape(batch_size, num_windows, self.config.window_size, 1)
874
+ token_types_windows_t = token_types_windows.transpose(-1, -2)
875
+ # replace all elements in TOKEN_SEPS with -1
876
+ token_types_windows = torch.where(token_types_windows == article_separator, -1, token_types_windows)
877
+ window_3d_mask = (token_types_windows == token_types_windows_t)
878
+ window_3d_mask = ~window_3d_mask
879
+
880
+ #### step 3: generate chunk-level 3D masks
881
+ #### NOTE: this is a bit tricky, as we aim to mask out different
882
+ #### documents to avoid cross-doc attention across chunks.
883
+ #### Example: suppose we have a sequence of length 12 with 3 documents:
884
+ #### [1, 1, 1, 1, 1, 2, 2, 3, 3, 3, 3, 3].
885
+ #### The chunk-size and window-size are both 4.
886
+ #### The chunk-level mask of shape (batch_size, seq_len, num_chunks) is:
887
+ #### [
888
+ #### [0, 0, 0],
889
+ #### [0, 0, 0],
890
+ #### [0, 0, 0],
891
+ #### [0, 0, 0],
892
+ ####
893
+ #### [1, 0, 0],
894
+ #### [0, 0, 0],
895
+ #### [0, 0, 0],
896
+ #### [0, 0, 0],
897
+ ####
898
+ #### [0, 1, 0],
899
+ #### [0, 1, 0],
900
+ #### [0, 1, 0],
901
+ #### [0, 1, 0],
902
+ #### ]
903
+ #### Explanation:
904
+ #### - Tokens will not attend to their own and future chunks.
905
+ #### (as tokens within a chunk are captured by the window-level exact attention)
906
+ #### - Tokens will attend to a chunk only if there are tokens
907
+ #### from the same document in that chunk.
908
+ #### The mask within each chunk of shape (batch_size, num_chunks, chunk_size) is:
909
+ #### [
910
+ #### [1, 1, 1, 1],
911
+ #### [0, 0, 0, 1],
912
+ #### [1, 1, 1, 1],
913
+ #### ]
914
+ #### Explanation:
915
+ #### - If all tokens in a chunk are from the same document,
916
+ #### no tokens will be masked out.
917
+ #### - If there are tokens from different documents in a chunk,
918
+ #### only tokens from the rightmost document will be kept.
919
+ #### (b/c the future chunks might contain tokens from the rightmost document,
920
+ #### but all the remaining docs will never get attended by other docs)
921
+ token_types_chunks = token_types.reshape(batch_size, num_chunks, self.config.chunk_size)
922
+ inter_chunk_mask = torch.zeros((batch_size, num_tokens, num_chunks), dtype=torch.bool)
923
+ intra_chunk_mask = torch.ones_like(token_types_chunks, dtype=torch.bool)
924
+
925
+ for chunk_idx in range(num_chunks):
926
+ for batch_idx in range(batch_size):
927
+ # Identify tokens in the current chunk belonging to each sequence
928
+ chunk = token_types_chunks[batch_idx, chunk_idx]
929
+ unique_elements = torch.unique(chunk, sorted=True).tolist()
930
+
931
+ # Create a mask for whether each token can attend to the current chunk
932
+ for token_type in unique_elements:
933
+ if token_type == article_separator:
934
+ continue
935
+ token_mask = (token_types[batch_idx] == token_type)
936
+ inter_chunk_mask[batch_idx, :, chunk_idx] |= token_mask
937
+
938
+ # Create a mask within each chunk
939
+ unique_elements = [x for x in unique_elements if x != article_separator]
940
+ if len(unique_elements) > 1 and chunk[-1] != article_separator:
941
+ intra_chunk_mask[batch_idx, chunk_idx] = (chunk == unique_elements[-1])
942
+
943
+ inter_chunk_mask = ~inter_chunk_mask
944
+ intra_chunk_mask = ~intra_chunk_mask
945
+
946
+ window_mask = torch.logical_or(window_causal_mask, window_3d_mask.unsqueeze(1))
947
+ inter_chunk_mask = torch.logical_or(chunk_causal_mask, inter_chunk_mask.unsqueeze(1))
948
+ intra_chunk_mask = intra_chunk_mask.unsqueeze(1).unsqueeze(-1)
949
+
950
+ joint_mask = torch.cat([window_mask, inter_chunk_mask.reshape(*window_mask.shape)], dim=-1)
951
+ attention_mask = (joint_mask, intra_chunk_mask)
952
+ else:
953
+ joint_mask = torch.cat([window_causal_mask, chunk_causal_mask.reshape(*window_causal_mask.shape)], dim=-1)
954
+ attention_mask = (joint_mask, None)
955
+ return attention_mask
956
+
957
+ def forward(
958
+ self,
959
+ input_ids: torch.LongTensor = None,
960
+ attention_mask: Optional[torch.Tensor] = None,
961
+ position_ids: Optional[torch.LongTensor] = None,
962
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
963
+ inputs_embeds: Optional[torch.FloatTensor] = None,
964
+ labels: Optional[torch.LongTensor] = None,
965
+ use_cache: Optional[bool] = None,
966
+ output_attentions: Optional[bool] = None,
967
+ output_hidden_states: Optional[bool] = None,
968
+ return_dict: Optional[bool] = None,
969
+ return_all_pred_logits: Optional[bool] = None,
970
+ multibyte_decoding: Optional[bool] = None) -> Union[Tuple, CausalLMOutputWithPast]:
971
+
972
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
973
+ output_hidden_states = (output_hidden_states
974
+ if output_hidden_states is not None else self.config.output_hidden_states)
975
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
976
+
977
+ if input_ids is None:
978
+ assert past_key_values is None
979
+
980
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
981
+ outputs = self.model(
982
+ input_ids=input_ids,
983
+ attention_mask=attention_mask,
984
+ position_ids=position_ids,
985
+ past_key_values=past_key_values,
986
+ inputs_embeds=inputs_embeds,
987
+ use_cache=use_cache,
988
+ output_attentions=output_attentions,
989
+ output_hidden_states=output_hidden_states,
990
+ return_dict=return_dict,
991
+ multibyte_decoding=multibyte_decoding,
992
+ )
993
+
994
+ hidden_states = outputs[0]
995
+
996
+ logits = self.lm_head(hidden_states)
997
+ if self.config.fp32_logits:
998
+ logits = logits.float()
999
+
1000
+ loss = None
1001
+ if labels is not None:
1002
+ loss_fct = CrossEntropyLoss(reduction="none")
1003
+ if hasattr(self.config, "num_pred_heads") and self.config.num_pred_heads > 1:
1004
+ shift_logits = logits.view(logits.shape[0], logits.shape[1], self.config.num_pred_heads, self.config.vocab_size)
1005
+ # shift_logits = shift_logits.view(-1, logits.shape[1] * self.config.num_pred_heads, self.config.vocab_size)
1006
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1007
+ else:
1008
+ shift_logits = logits.view(-1, self.config.vocab_size)
1009
+ shift_labels = labels.view(-1)
1010
+ # Enable model parallelism
1011
+ shift_labels = shift_labels.to(shift_logits.device)
1012
+ loss = loss_fct(shift_logits, shift_labels)
1013
+
1014
+ if hasattr(self.config, "num_pred_heads") and self.config.num_pred_heads > 1:
1015
+ all_pred_logits = logits.reshape(logits.shape[0], logits.shape[1], self.config.num_pred_heads, self.config.vocab_size)
1016
+
1017
+ if return_all_pred_logits:
1018
+ logits = all_pred_logits
1019
+ else:
1020
+ logits = all_pred_logits[..., 0, :]
1021
+
1022
+ if not return_dict:
1023
+ output = (logits, ) + outputs[1:]
1024
+ return (loss, ) + output if loss is not None else output
1025
+
1026
+ return CausalLMOutputWithPast(
1027
+ loss=loss,
1028
+ logits=logits,
1029
+ past_key_values=outputs.past_key_values,
1030
+ hidden_states=outputs.hidden_states,
1031
+ attentions=outputs.attentions,
1032
+ )
1033
+
1034
+
1035
+ def prepare_inputs_for_generation(self,
1036
+ input_ids,
1037
+ past_key_values=None,
1038
+ attention_mask=None,
1039
+ inputs_embeds=None,
1040
+ use_cache=True,
1041
+ **kwargs):
1042
+ # prefill phase:
1043
+ # input_ids: b x s
1044
+ # attention_mask: None if no padding or b x s
1045
+ # position_ids : b x s
1046
+
1047
+ # token gen phase:
1048
+ # input_ids : b x 1
1049
+ # attention_mask: b x 1 x s
1050
+ # position_ids: b x 1
1051
+ past_length = 0
1052
+ if past_key_values is not None:
1053
+ assert isinstance(past_key_values, Cache)
1054
+ past_length = past_key_values.get_seq_length()
1055
+
1056
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1057
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length):]
1058
+ elif past_length < input_ids.shape[1]:
1059
+ input_ids = input_ids[:, past_length:]
1060
+
1061
+ position_ids = kwargs.get("position_ids", None)
1062
+ if attention_mask is not None and position_ids is None:
1063
+ position_ids = attention_mask.long().cumsum(-1) - 1
1064
+ position_ids.masked_fill_(attention_mask == 0, 1)
1065
+ if past_key_values:
1066
+ position_ids = position_ids[:, -input_ids.shape[1]:]
1067
+
1068
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1069
+ if inputs_embeds is not None and past_key_values is None:
1070
+ model_inputs = {"inputs_embeds": inputs_embeds}
1071
+ else:
1072
+ model_inputs = {"input_ids": input_ids}
1073
+
1074
+ # must initialize position_ids at each step during GPU inference
1075
+ assert position_ids is not None
1076
+ model_inputs.update(
1077
+ {
1078
+ "position_ids": position_ids,
1079
+ "past_key_values": past_key_values,
1080
+ "use_cache": use_cache,
1081
+ "attention_mask": attention_mask,
1082
+ }
1083
+ )
1084
+ return model_inputs
1085
+
1086
+ @staticmethod
1087
+ def _reorder_cache(past_key_values, beam_idx):
1088
+ reordered_past = ()
1089
+ for layer_past in past_key_values:
1090
+ reordered_past += (tuple(
1091
+ past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), )
1092
+ return reordered_past
multibyte_decoding_evabyte.py ADDED
@@ -0,0 +1,881 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # The implementation of multibyte deocidng is largely adapted from
3
+ # Medusa decoding: https://github.com/FasterDecoding/Medusa
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from transformers.generation.stopping_criteria import (
7
+ MaxLengthCriteria,
8
+ StoppingCriteriaList,
9
+ )
10
+ from typing import Union, List
11
+ from .eva_cache import EvaStaticCacheForTriton
12
+ from .eva_prep_kv_kernel import triton_eva_prep_kv_fwd
13
+
14
+ class MultibyteEosTokenCriteria:
15
+ """
16
+ This class implements a simple stopping criteria to stop generation whenever
17
+ the "end-of-sequence" token is generated in the last `new_tokens` tokens.
18
+
19
+ Adapted from
20
+ https://github.com/huggingface/transformers/blob/main/src/transformers/generation/stopping_criteria.py#L446
21
+ By default, it uses the `model.generation_config.eos_token_id`.
22
+
23
+ Args:
24
+ eos_token_id (`Union[int, List[int]]`):
25
+ The id(s) of the *end-of-sequence* token.
26
+ """
27
+
28
+ def __init__(self, eos_token_ids: Union[int, List[int]]):
29
+ if isinstance(eos_token_ids, int):
30
+ eos_token_ids = [eos_token_ids]
31
+ self.eos_token_ids = eos_token_ids
32
+
33
+ def __call__(self, input_ids: torch.LongTensor, new_tokens: int) -> bool:
34
+ current_input_len = input_ids.shape[-1]
35
+ new_token_ids = input_ids[:, current_input_len - new_tokens:]
36
+ for eos_token_id in self.eos_token_ids:
37
+ if torch.any(new_token_ids == eos_token_id):
38
+ return True
39
+ return False
40
+
41
+ def build_tree(spec):
42
+ nodes_at_depth = []
43
+ nodes_at_depth.append([()]) # Root at depth 1
44
+
45
+ for d in range(1, len(spec) + 1):
46
+ prev_nodes = nodes_at_depth[d - 1]
47
+ spec_list = spec[d - 1]
48
+ current_nodes = []
49
+ for node_idx, node in enumerate(prev_nodes):
50
+ if node_idx < len(spec_list):
51
+ num_children = spec_list[node_idx]
52
+ else:
53
+ num_children = 0
54
+ for child_idx in range(num_children):
55
+ new_node = node + (child_idx,)
56
+ current_nodes.append(new_node)
57
+ nodes_at_depth.append(current_nodes)
58
+
59
+ # Flatten the list of nodes, excluding the root node if desired
60
+ all_nodes = [node for depth_nodes in nodes_at_depth for node in depth_nodes if node]
61
+ return all_nodes
62
+
63
+ evabyte_7b_95 = build_tree(
64
+ [
65
+ [10],
66
+ [10, 8, 2, 2, 1, 1],
67
+ [10, 4, 2, 1, 0, 0, 0, 0, 0, 0, 2, 1, 1, 0, 0, 0, 0, 0, 1],
68
+ [8, 2, 2, 1, 0, 0, 0, 0, 0, 0, 1],
69
+ [6, 2, 1, 1],
70
+ [4, 2, 1, 1],
71
+ [4, 2, 1],
72
+ ]
73
+ )
74
+ evabyte_7b_31 = build_tree(
75
+ [
76
+ [4],
77
+ [3, 2, 1, 1],
78
+ [3, 2, 1, 1],
79
+ [2, 1, 1],
80
+ [2, 1],
81
+ [2, 1],
82
+ [2, 1],
83
+ ]
84
+ )
85
+ TOPK = 10 # topk for sparse tree (10 is a placeholder and it is sufficient)
86
+
87
+ def pad_path(path, length, pad_value=-2):
88
+ """
89
+ Pad the given path list with a specific value up to a specified length.
90
+
91
+ Parameters:
92
+ - path (list): The original list that needs padding.
93
+ - length (int): The desired length of the padded list.
94
+ - pad_value (optional, default=-2): The value to use for padding.
95
+
96
+ Returns:
97
+ - list: A new list based on the original path but padded to the desired length.
98
+
99
+ Example:
100
+ >>> pad_path([1,2,3], 5)
101
+ [1, 2, 3, -2, -2]
102
+
103
+ Note:
104
+ If the given path is already longer than the specified length,
105
+ then no padding occurs, and the original path is returned.
106
+ """
107
+ return path + [pad_value] * (length - len(path))
108
+
109
+ def reset_past_key_values(passed_key_values):
110
+ """
111
+ Resets the current lengths in the passed key-values to zero.
112
+
113
+ This function is designed to be used during the evaluation of a baseline model.
114
+ It iterates through each layer's key-values and sets their current lengths to zero,
115
+ effectively resetting their state.
116
+
117
+ Args:
118
+ - passed_key_values (list of torch.Tensor): Contains past hidden states and past attention values for each layer.
119
+
120
+ Returns:
121
+ - passed_key_values (list of torch.Tensor): Updated past hidden states and past attention values with reset lengths.
122
+ """
123
+ for i in range(len(passed_key_values)):
124
+ for j in range(2):
125
+ passed_key_values[i][j].current_length.fill_(0)
126
+ return passed_key_values
127
+
128
+ def get_nucleus_one_token(logit, temperature, top_p):
129
+ """
130
+ Performs token sampling based on the nucleus (top-p) sampling method.
131
+
132
+ This function selects a token from a given logit distribution using the nucleus sampling strategy.
133
+ It allows for more controlled and diverse generation compared to traditional top-k sampling.
134
+
135
+ Args:
136
+ logit (torch.Tensor): The logits from a language model output, expected to be a 2D tensor (BxC).
137
+ temperature (float): A temperature parameter to control the randomness in sampling.
138
+ Higher values increase diversity, lower values make selections more deterministic.
139
+ top_p (float): The cumulative probability threshold for nucleus sampling.
140
+ It controls the size of the set of high-probability tokens to consider for sampling.
141
+
142
+ Returns:
143
+ torch.Tensor: A tensor containing the indices of the sampled tokens.
144
+ """
145
+ if top_p >= 1:
146
+ return torch.multinomial(F.softmax(logit / temperature, dim=-1), 1)
147
+ logit = logit / temperature
148
+ probs = torch.softmax(logit, dim=-1)
149
+ sorted_logits, sorted_indices = torch.sort(probs, descending=True)
150
+ cum_probs = torch.cumsum(sorted_logits, dim=-1)
151
+ sorted_indices_to_remove = cum_probs > top_p
152
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
153
+ sorted_indices_to_remove[..., 0] = 0
154
+ indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
155
+ logit[indices_to_remove] = float('-inf')
156
+ sampled_tokens = torch.multinomial(F.softmax(logit, dim=-1), 1)
157
+ return sampled_tokens
158
+
159
+ def get_typical_one_token(logit, temperature, posterior_threshold, posterior_alpha):
160
+ """
161
+ Implements token sampling based on the typical sampling method.
162
+
163
+ This function selects a token from a given logit distribution using the typical sampling strategy,
164
+ aiming to balance between diversity and likelihood in a more nuanced way compared to traditional methods.
165
+
166
+ Args:
167
+ logit (torch.Tensor): The logits from a language model output, expected to be a 2D tensor.
168
+ temperature (float): A parameter to control the randomness in sampling.
169
+ Higher values increase diversity, lower values make selections more deterministic.
170
+ posterior_threshold (float): A threshold to decide the lower bound of probabilities to be considered for sampling.
171
+ posterior_alpha (float): A scaling factor applied to the entropy-based adaptive threshold.
172
+
173
+ Returns:
174
+ torch.Tensor: A tensor containing the indices of the sampled tokens.
175
+ """
176
+ logit = logit / temperature
177
+ probs = torch.softmax(logit, dim=-1)
178
+ entropy = -torch.sum(
179
+ probs * torch.log(probs + 1e-5), dim=-1
180
+ )
181
+ threshold = torch.minimum(
182
+ torch.ones_like(entropy) * posterior_threshold,
183
+ torch.exp(-entropy) * posterior_alpha,
184
+ )
185
+ indices_to_remove = probs < threshold.unsqueeze(-1)
186
+ logit[indices_to_remove] = float('-inf')
187
+ sampled_tokens = torch.multinomial(F.softmax(logit, dim=-1), 1)
188
+ return sampled_tokens
189
+
190
+
191
+
192
+ def generate_medusa_buffers(medusa_choices, device="cuda"):
193
+ """
194
+ Generate buffers for the Medusa structure based on the provided choices.
195
+
196
+ Parameters:
197
+ - medusa_choices (list): A nested list representing tree in the Medusa structure.
198
+ - device (str): Device to which the tensors should be moved. Default is "cuda".
199
+
200
+ Returns:
201
+ - dict: A dictionary containing buffers related to the Medusa structure.
202
+ """
203
+
204
+ # Sort the medusa_choices based on their lengths and then their values
205
+ sorted_medusa_choices = sorted(medusa_choices, key=lambda x: (len(x), x))
206
+ medusa_len = len(sorted_medusa_choices) + 1
207
+
208
+ # Initialize depth_counts to keep track of how many choices have a particular depth
209
+ depth_counts = [0] * max([len(path) for path in sorted_medusa_choices])
210
+ for path in sorted_medusa_choices:
211
+ depth_counts[len(path) - 1] += 1
212
+
213
+ # Create the attention mask for Medusa
214
+ medusa_attn_mask = torch.eye(medusa_len, medusa_len)
215
+ medusa_attn_mask[:, 0] = 1
216
+ start = 0
217
+ for i in range(len(depth_counts)):
218
+ for j in range(depth_counts[i]):
219
+ cur_medusa_choice = sorted_medusa_choices[start + j]
220
+ # retrieve ancestor position
221
+ if len(cur_medusa_choice) == 1:
222
+ continue
223
+ ancestor_idx = []
224
+ for c in range(len(cur_medusa_choice) - 1):
225
+ ancestor_idx.append(sorted_medusa_choices.index(cur_medusa_choice[:c+1]) + 1)
226
+ medusa_attn_mask[j + start + 1, ancestor_idx] = 1
227
+ start += depth_counts[i]
228
+
229
+ # Generate tree indices for the Medusa structure
230
+ medusa_tree_indices = torch.zeros(medusa_len, dtype=torch.long)
231
+ medusa_tree_indices[0] = 0
232
+ start = 0
233
+ for i in range(len(depth_counts)):
234
+ for j in range(depth_counts[i]):
235
+ cur_medusa_choice = sorted_medusa_choices[start + j]
236
+ medusa_tree_indices[start + j + 1] = cur_medusa_choice[-1] + TOPK * i + 1
237
+ start += depth_counts[i]
238
+
239
+ # Generate position IDs for the Medusa structure
240
+ medusa_position_ids = torch.zeros(medusa_len, dtype=torch.long)
241
+ start = 0
242
+ for i in range(len(depth_counts)):
243
+ medusa_position_ids[start + 1: start + depth_counts[i] + 1] = i + 1
244
+ start += depth_counts[i]
245
+
246
+ # Generate retrieval indices for Medusa structure verification
247
+ retrieve_indices_nest = []
248
+ retrieve_paths = []
249
+ for i in range(len(sorted_medusa_choices)):
250
+ cur_medusa_choice = sorted_medusa_choices[-i-1]
251
+ retrieve_indice = []
252
+ if cur_medusa_choice in retrieve_paths:
253
+ continue
254
+ else:
255
+ for c in range(len(cur_medusa_choice)):
256
+ retrieve_indice.append(sorted_medusa_choices.index(cur_medusa_choice[:c+1]))
257
+ retrieve_paths.append(cur_medusa_choice[:c+1])
258
+ retrieve_indices_nest.append(retrieve_indice)
259
+ max_length = max([len(x) for x in retrieve_indices_nest])
260
+ retrieve_indices = [pad_path(path, max_length) for path in retrieve_indices_nest]
261
+ retrieve_indices = torch.tensor(retrieve_indices, dtype=torch.long)
262
+ retrieve_indices = retrieve_indices + 1
263
+ retrieve_indices = torch.cat([torch.zeros((retrieve_indices.shape[0], 1), dtype=torch.long), retrieve_indices], dim=1)
264
+
265
+ # Aggregate the generated buffers into a dictionary
266
+ medusa_buffers = {
267
+ "medusa_attn_mask": medusa_attn_mask.unsqueeze(0).unsqueeze(0),
268
+ "tree_indices": medusa_tree_indices,
269
+ "medusa_position_ids": medusa_position_ids.unsqueeze(0),
270
+ "retrieve_indices": retrieve_indices,
271
+ }
272
+
273
+ # Move the tensors in the dictionary to the specified device
274
+ medusa_buffers = {
275
+ k: v.clone().to(device)
276
+ if isinstance(v, torch.Tensor)
277
+ else torch.tensor(v, device=device)
278
+ for k, v in medusa_buffers.items()
279
+ }
280
+ return medusa_buffers
281
+
282
+ def generate_candidates(
283
+ medusa_logits,
284
+ logits,
285
+ tree_indices,
286
+ retrieve_indices,
287
+ temperature = 0,
288
+ posterior_threshold=0.3,
289
+ posterior_alpha = 0.09,
290
+ top_p=0.8,
291
+ sampling = 'typical',
292
+ fast = False
293
+ ):
294
+ # Say we have 3 heads, and the top-4 for each head are:
295
+ # [10, 3, 8, 4]
296
+ # [9, 5, 1, 6]
297
+ # [7, 16, 3, 2]
298
+
299
+ # candidates_id = 10
300
+ if temperature == 0 or fast:
301
+ candidates_ids = torch.argmax(logits[:, -1]).unsqueeze(0)
302
+ else:
303
+ if sampling == 'typical':
304
+ candidates_ids = get_typical_one_token(logits[:, -1], temperature, posterior_threshold, posterior_alpha).squeeze(0)
305
+ elif sampling == 'nucleus':
306
+ candidates_ids = get_nucleus_one_token(logits[:, -1], temperature, top_p).squeeze(0)
307
+ else:
308
+ raise NotImplementedError
309
+
310
+ # this calculates the top-k medusa logits
311
+ # candidates_medusa_id = [
312
+ # [9, 5, 1, 6]
313
+ # [7, 16, 3, 2]
314
+ # ]
315
+ candidates_medusa_ids = torch.topk(medusa_logits[:, 0, -1], TOPK, dim=-1).indices
316
+
317
+ # [10, 9, 5, 1, 6, 7, 16, 3, 2]
318
+ candidate_ids = torch.cat([candidates_ids, candidates_medusa_ids.view(-1)], dim=-1)
319
+
320
+ # based on the pre-defined tree_indices, select the corresponding candidates
321
+ # if we select top-2 and top-3 for the two heads (we select top-1 for the first head):
322
+ # tree_candidates = [10, 9, 5, 7, 16, 3, 7, 16, 3]
323
+ tree_candidate_ids = candidate_ids[tree_indices]
324
+
325
+ # tree_candidate_ids = [10, 9, 5, 7, 16, 3, 7, 16, 3, 0]
326
+ # Sometimes the tree_indices are padded, so we append a zero here
327
+ # so that all padded indices select the appended zero.
328
+ tree_candidate_ids_ext = torch.cat(
329
+ [
330
+ tree_candidate_ids,
331
+ torch.zeros((1), dtype=torch.long, device=tree_candidate_ids.device)
332
+ ],
333
+ dim=0
334
+ )
335
+ # [[10, 9, 7], [10, 9, 16], [10, 9, 3], [10, 5, 7], [10, 5, 16], [10, 5, 3]]
336
+ unflattened_candidate_ids = tree_candidate_ids_ext[retrieve_indices]
337
+
338
+ tree_candidate_ids = tree_candidate_ids.unsqueeze(0)
339
+
340
+ return tree_candidate_ids, unflattened_candidate_ids
341
+
342
+ def get_nucleus_posterior_mask(logits, candidates, temperature, top_p):
343
+ """
344
+ Generates a posterior mask for token candidates using nucleus (top-p) sampling.
345
+
346
+ This function applies nucleus sampling to a set of logits, and then generates a mask indicating
347
+ which candidate tokens are selected. It adapts the sampling strategy to accommodate for
348
+ temperature scaling and cumulative probability thresholding.
349
+
350
+ Args:
351
+ logits (torch.Tensor): A tensor of logits from a language model output.
352
+ candidates (torch.Tensor): A tensor of candidate tokens to compare against sampled tokens.
353
+ temperature (float): A parameter to scale the logits, controlling randomness in sampling.
354
+ top_p (float): The cumulative probability threshold for nucleus sampling.
355
+
356
+ Returns:
357
+ torch.Tensor: A posterior mask indicating which candidate tokens match the sampled tokens.
358
+ """
359
+ # adapted from https://github.com/huggingface/transformers/blob/18a879f47576822aa1a5c49aecb27d89bfa5fa69/examples/run_generation.py#L79
360
+
361
+ # Apply temperature
362
+ logits = logits[:, :-1] / temperature
363
+ n_samples, n_tokens = logits.shape[0], logits.shape[1]
364
+ logits = logits.view(n_samples*n_tokens, -1)
365
+ if top_p >= 1:
366
+ sampled_tokens = torch.multinomial(F.softmax(logits, dim=-1), 1)
367
+ sampled_tokens = sampled_tokens.view(n_samples, n_tokens)
368
+ posterior_mask = (candidates[:, 1:] == sampled_tokens).int()
369
+ return posterior_mask
370
+ # Convert to probabilities (softmax)
371
+ probs = F.softmax(logits, dim=-1)
372
+ # Sort the probabilities
373
+ sorted_logits, sorted_indices = torch.sort(probs, descending=True)
374
+
375
+ # Compute cumulative probabilities
376
+ cum_probs = torch.cumsum(sorted_logits, dim=-1)
377
+
378
+ # Create mask for the top-p nucleus
379
+ sorted_indices_to_remove = cum_probs > top_p
380
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
381
+ sorted_indices_to_remove[..., 0] = 0
382
+
383
+ indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
384
+
385
+
386
+ # Remove low-probability tokens
387
+ logits[indices_to_remove] = float('-inf')
388
+ # Sample from the remaining tokens
389
+ sampled_tokens = torch.multinomial(F.softmax(logits, dim=-1), 1)
390
+ sampled_tokens = sampled_tokens.view(n_samples, n_tokens)
391
+ # Create a mask for selected tokens
392
+ posterior_mask = (candidates[:, 1:] == sampled_tokens).int()
393
+
394
+ return posterior_mask
395
+
396
+ def get_typical_posterior_mask(logits, candidates, temperature, posterior_threshold, posterior_alpha):
397
+ """
398
+ Args:
399
+ logits (torch.Tensor): A tensor of logits from a language model output.
400
+ candidates (torch.Tensor): A tensor of candidate tokens to compare against sampled tokens.
401
+ temperature (float): A parameter to scale the logits, controlling randomness in sampling.
402
+ posterior_threshold (float): The minimum threshold for probabilities to be considered in sampling.
403
+ posterior_alpha (float): A scaling factor applied to the entropy-based adaptive threshold.
404
+
405
+ Returns:
406
+ torch.Tensor: A posterior mask indicating which candidate tokens match the sampled tokens.
407
+ """
408
+ logits = logits[:, :-1] / temperature
409
+ n_samples, n_tokens = logits.shape[0], logits.shape[1]
410
+ logits = logits.view(n_samples*n_tokens, -1)
411
+ probs = F.softmax(logits, dim=-1)
412
+ entropy = -torch.sum(
413
+ probs * torch.log(probs + 1e-5), dim=-1
414
+ )
415
+ threshold = torch.minimum(
416
+ torch.ones_like(entropy) * posterior_threshold,
417
+ torch.exp(-entropy) * posterior_alpha,
418
+ )
419
+ indices_to_remove = probs < threshold.unsqueeze(-1)
420
+ logits[indices_to_remove] = float('-inf')
421
+ sampled_tokens = torch.multinomial(F.softmax(logits, dim=-1), 1)
422
+ sampled_tokens = sampled_tokens.view(n_samples, n_tokens)
423
+ posterior_mask = (candidates[:, 1:] == sampled_tokens).int()
424
+ return posterior_mask
425
+
426
+
427
+
428
+ def evaluate_posterior(
429
+ logits,
430
+ candidates,
431
+ temperature,
432
+ posterior_threshold=0.3,
433
+ posterior_alpha = 0.09,
434
+ top_p=0.8,
435
+ sampling = 'typical',
436
+ fast = True
437
+ ):
438
+ if logits.shape[1] <= 1:
439
+ return torch.tensor(0, dtype=torch.long, device=candidates.device), 0
440
+ # Greedy decoding based on temperature value
441
+ if temperature == 0:
442
+ # Find the tokens that match the maximum logits for each position in the sequence
443
+ posterior_mask = (
444
+ candidates[:, 1:] == torch.argmax(logits[:, :-1], dim=-1)
445
+ ).int()
446
+ candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1)
447
+ accept_length = candidates_accept_length.max().item()
448
+ # Choose the best candidate
449
+ if accept_length == 0:
450
+ # Default to the first candidate if none are accepted
451
+ best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device)
452
+ else:
453
+ best_candidate = torch.argmax(candidates_accept_length).to(torch.long)
454
+ return best_candidate, accept_length
455
+ elif sampling == 'typical':
456
+ if fast:
457
+ posterior_prob = torch.softmax(logits[:, :-1] / temperature, dim=-1)
458
+ candidates_prob = torch.gather(
459
+ posterior_prob, dim=-1, index=candidates[:, 1:].unsqueeze(-1)
460
+ ).squeeze(-1)
461
+ posterior_entropy = -torch.sum(
462
+ posterior_prob * torch.log(posterior_prob + 1e-5), dim=-1
463
+ ) # torch.sum(torch.log(*)) is faster than torch.prod
464
+ threshold = torch.minimum(
465
+ torch.ones_like(posterior_entropy) * posterior_threshold,
466
+ torch.exp(-posterior_entropy) * posterior_alpha,
467
+ )
468
+ posterior_mask = candidates_prob > threshold
469
+ candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1)
470
+
471
+ # Choose the best candidate based on the evaluated posterior probabilities
472
+ accept_length = candidates_accept_length.max().item()
473
+ if accept_length == 0:
474
+ # If no candidates are accepted, just choose the first one
475
+ best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device)
476
+ else:
477
+ best_candidates = torch.where(candidates_accept_length == accept_length)[0]
478
+ # Accept the best one according to likelihood
479
+ likelihood = torch.sum(
480
+ torch.log(candidates_prob[best_candidates, :accept_length]), dim=-1
481
+ )
482
+ best_candidate = best_candidates[torch.argmax(likelihood)]
483
+ return best_candidate, accept_length
484
+ # Calculate posterior probabilities and thresholds for candidate selection
485
+ posterior_mask = get_typical_posterior_mask(logits, candidates, temperature, posterior_threshold, posterior_alpha)
486
+ candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1)
487
+ # Choose the best candidate based on the evaluated posterior probabilities
488
+ accept_length = candidates_accept_length.max().item()
489
+
490
+ if accept_length == 0:
491
+ # If no candidates are accepted, just choose the first one
492
+ best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device)
493
+ else:
494
+ best_candidate = torch.argmax(candidates_accept_length).to(torch.long)
495
+ # Accept the best one according to likelihood
496
+ return best_candidate, accept_length
497
+ elif sampling == 'nucleus':
498
+ assert top_p < 1.0 + 1e-6, "top_p should between 0 and 1"
499
+ posterior_mask = get_nucleus_posterior_mask(logits, candidates, temperature, top_p)
500
+ candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1)
501
+ accept_length = candidates_accept_length.max().item()
502
+ # Choose the best candidate
503
+ if accept_length == 0:
504
+ # Default to the first candidate if none are accepted
505
+ best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device)
506
+ else:
507
+ best_candidate = torch.argmax(candidates_accept_length).to(torch.long)
508
+ return best_candidate, accept_length
509
+ else:
510
+ raise NotImplementedError
511
+
512
+ def update_inference_inputs(
513
+ input_ids,
514
+ medusa_logits,
515
+ logits,
516
+ candidate_ids,
517
+ best_candidate,
518
+ accept_length,
519
+ ):
520
+ input_ids = torch.cat(
521
+ [
522
+ input_ids,
523
+ candidate_ids[None, best_candidate, : accept_length + 1]
524
+ ],
525
+ dim=-1
526
+ )
527
+ logits = logits[
528
+ None, best_candidate, accept_length : accept_length + 1
529
+ ]
530
+ medusa_logits = medusa_logits[
531
+ :, None, best_candidate, accept_length : accept_length + 1
532
+ ]
533
+ # Update the new token counter
534
+ new_token = accept_length + 1
535
+ return input_ids, medusa_logits, logits, new_token
536
+
537
+ def split_logits(full_logits):
538
+ # logits has shape [b, n, heads, vocab_size]
539
+ logits = full_logits[..., 0, :]
540
+ medusa_logits = full_logits[..., 1:, :].permute(2, 0, 1, 3)
541
+ return medusa_logits, logits
542
+
543
+ class MultiByteDecodingMixin:
544
+ def multi_byte_pred_update_cache(
545
+ self,
546
+ past_key_values,
547
+ retrieve_indices,
548
+ best_candidate,
549
+ new_tokens,
550
+ ):
551
+ prev_window_len = past_key_values.get_past_window_pos(0)
552
+ select_indices = (
553
+ retrieve_indices[best_candidate, : new_tokens] + prev_window_len
554
+ )
555
+ for layer_idx in range(self.config.num_hidden_layers):
556
+
557
+ past_key_values.update_past_len(new_tokens, layer_idx)
558
+
559
+ past_window_k = past_key_values.past_window_k[layer_idx]
560
+ past_window_v = past_key_values.past_window_v[layer_idx]
561
+
562
+ tgt_window_k = past_window_k[..., select_indices, :]
563
+ tgt_window_v = past_window_v[..., select_indices, :]
564
+
565
+ dst_window_k = past_window_k[..., prev_window_len : prev_window_len + new_tokens, :]
566
+ dst_window_v = past_window_v[..., prev_window_len : prev_window_len + new_tokens, :]
567
+
568
+ dst_window_k.copy_(tgt_window_k, non_blocking=True)
569
+ dst_window_v.copy_(tgt_window_v, non_blocking=True)
570
+
571
+ new_window_len = prev_window_len + new_tokens
572
+ if new_window_len >= self.config.window_size:
573
+ assert new_window_len < 2 * self.config.window_size
574
+
575
+ dump_k = past_window_k[..., :self.config.window_size, :].clone()
576
+ dump_v = past_window_v[..., :self.config.window_size, :].clone()
577
+
578
+ _window_len = new_window_len - self.config.window_size
579
+
580
+ if _window_len > 0:
581
+ new_window_k = past_window_k[..., self.config.window_size : new_window_len, :]
582
+ new_window_v = past_window_v[..., self.config.window_size : new_window_len, :]
583
+
584
+ _dst_window_k = past_window_k[..., : _window_len, :]
585
+ _dst_window_v = past_window_v[..., : _window_len, :]
586
+
587
+ _dst_window_k.copy_(new_window_k, non_blocking=True)
588
+ _dst_window_v.copy_(new_window_v, non_blocking=True)
589
+
590
+ past_key_values.past_window_pos[layer_idx] = _window_len
591
+ else:
592
+ dump_k = None
593
+ dump_v = None
594
+ past_key_values.past_window_pos[layer_idx] = new_window_len
595
+
596
+ if dump_k is not None and dump_v is not None:
597
+ rfa_k, rfa_v = triton_eva_prep_kv_fwd(
598
+ dump_k, dump_v,
599
+ self.model.layers[layer_idx].self_attn.adaptive_mu_k,
600
+ self.model.layers[layer_idx].self_attn.adaptive_phi,
601
+ None,
602
+ self.model.layers[layer_idx].self_attn.head_dim_scaling,
603
+ self.model.layers[layer_idx].self_attn.chunk_size
604
+ )
605
+ rfa_k, rfa_v = past_key_values.update_chunk_rfas(
606
+ rfa_k, rfa_v, layer_idx
607
+ )
608
+ return past_key_values
609
+
610
+ def _multi_byte_pred_update_cache_when_prefil_len_eq_window_size(
611
+ self,
612
+ past_key_values,
613
+ ):
614
+ prev_window_len = past_key_values.get_past_window_pos(0)
615
+ for layer_idx in range(self.config.num_hidden_layers):
616
+
617
+ past_window_k = past_key_values.past_window_k[layer_idx]
618
+ past_window_v = past_key_values.past_window_v[layer_idx]
619
+
620
+ new_window_len = prev_window_len
621
+ if new_window_len == self.config.window_size:
622
+ dump_k = past_window_k[..., :self.config.window_size, :].clone()
623
+ dump_v = past_window_v[..., :self.config.window_size, :].clone()
624
+ past_key_values.past_window_pos[layer_idx] = 0
625
+
626
+ if dump_k is not None and dump_v is not None:
627
+ rfa_k, rfa_v = triton_eva_prep_kv_fwd(
628
+ dump_k, dump_v,
629
+ self.model.layers[layer_idx].self_attn.adaptive_mu_k,
630
+ self.model.layers[layer_idx].self_attn.adaptive_phi,
631
+ None,
632
+ self.model.layers[layer_idx].self_attn.head_dim_scaling,
633
+ self.model.layers[layer_idx].self_attn.chunk_size
634
+ )
635
+ rfa_k, rfa_v = past_key_values.update_chunk_rfas(
636
+ rfa_k, rfa_v, layer_idx
637
+ )
638
+ return past_key_values
639
+
640
+ def multi_byte_pred_update_attn_mask(
641
+ self,
642
+ last_iter_new_tokens,
643
+ tree_candidate_ids,
644
+ past_attn_mask,
645
+ medusa_attn_mask,
646
+ past_key_values,
647
+ ):
648
+ batch_size, tree_candidate_len = tree_candidate_ids.shape
649
+ seen_tokens = past_key_values.get_seq_length()
650
+ # NOTE: past_key_values has been updated so now
651
+ # seen_tokens incldues new tokens from the last tree iteration
652
+ assert seen_tokens > 0
653
+ # so one iteration would not cross two windows
654
+ assert last_iter_new_tokens < self.config.window_size
655
+
656
+ if past_attn_mask is not None and seen_tokens < self.config.window_size:
657
+ past_attn_mask = torch.cat(
658
+ [
659
+ past_attn_mask,
660
+ torch.ones(
661
+ [batch_size, 1, tree_candidate_len, last_iter_new_tokens],
662
+ dtype=torch.bool,
663
+ device=self.device
664
+ )
665
+ ],
666
+ dim=-1
667
+ )
668
+ else:
669
+ # we initialize attn mask each time when
670
+ # 1. the model crosses the window bounary, or
671
+ # 2. after prefilling
672
+ chunks_per_window = int(self.config.window_size // self.config.chunk_size)
673
+
674
+ window_tokens = seen_tokens % self.config.window_size
675
+ num_windows_seen_so_far = seen_tokens // self.config.window_size
676
+ attn_mask_len = num_windows_seen_so_far * chunks_per_window + window_tokens
677
+ past_attn_mask = torch.ones(
678
+ (batch_size, 1, tree_candidate_len, attn_mask_len),
679
+ dtype=torch.bool,
680
+ device=self.device
681
+ )
682
+
683
+ # note that 1 indicates the position is not masked
684
+ tree_attn_mask = torch.cat(
685
+ [
686
+ past_attn_mask,
687
+ medusa_attn_mask.to(torch.bool)
688
+ ],
689
+ dim=-1
690
+ )
691
+ return tree_attn_mask, past_attn_mask
692
+
693
+ @torch.no_grad()
694
+ def multi_byte_generate(
695
+ self,
696
+ input_ids,
697
+ attention_mask=None,
698
+ temperature=0.0,
699
+ max_length=None,
700
+ max_new_tokens=None,
701
+ stopping_criteria=None,
702
+ posterior_threshold=0.09,
703
+ posterior_alpha=0.3,
704
+ top_p=0.8,
705
+ sampling='typical',
706
+ fast=True,
707
+ do_sample=False,
708
+ medusa_choices=None,
709
+ return_acc_lengths=False
710
+ ):
711
+ if do_sample or temperature > 0.0:
712
+ fast = False
713
+
714
+ ### Prepare `max_length` depending on other stopping criteria.
715
+ if max_new_tokens is not None:
716
+ max_length = max_new_tokens + input_ids.shape[-1]
717
+ elif max_new_tokens is None and max_length is None:
718
+ max_length = getattr(self.config, "max_position_embeddings", 32768)
719
+
720
+ ### Set up stopping criteria
721
+ eos_stop_criteria = MultibyteEosTokenCriteria(self.generation_config.eos_token_id)
722
+ stop_criteria = StoppingCriteriaList()
723
+ if max_length is not None:
724
+ max_position_embeddings = getattr(self.config, "max_position_embeddings", None)
725
+ stop_criteria.append(
726
+ MaxLengthCriteria(
727
+ max_length=max_length,
728
+ max_position_embeddings=max_position_embeddings,
729
+ )
730
+ )
731
+ if stopping_criteria is not None and len(stopping_criteria) > 0:
732
+ stop_criteria.extend(stopping_criteria)
733
+
734
+ assert input_ids.shape[0] == 1, "Only support batch size 1 for now"
735
+ assert attention_mask is None, "Only support attention mask None for now"
736
+ # Avoid modifying the input_ids in-place
737
+ input_ids = input_ids.clone()
738
+ position_ids = torch.arange(0, input_ids.shape[1], device=self.device, dtype=int).reshape(1, -1)
739
+
740
+ ####################################################
741
+ # 0. initialize the medusa buffers
742
+ ####################################################
743
+ if medusa_choices is None:
744
+ medusa_choices = evabyte_7b_95
745
+ medusa_buffers = generate_medusa_buffers(
746
+ medusa_choices, device=self.device
747
+ )
748
+
749
+ past_key_values = EvaStaticCacheForTriton(
750
+ input_ids.shape[0],
751
+ self.config.num_attention_heads,
752
+ # we add 256 to allow tree ids
753
+ self.config.window_size + 256,
754
+ self.config.hidden_size // self.config.num_attention_heads,
755
+ self.config.num_hidden_layers,
756
+ self.lm_head.weight.dtype,
757
+ self.lm_head.weight.device,
758
+ )
759
+ # prefill to get medusa logits and logits
760
+ full_logits, past_key_values = self.forward(
761
+ input_ids,
762
+ attention_mask=attention_mask,
763
+ position_ids=position_ids,
764
+ use_cache=True,
765
+ past_key_values=past_key_values,
766
+ return_all_pred_logits=True,
767
+ multibyte_decoding=False,
768
+ )
769
+ # handles an edge case where the prefill length == window_size
770
+ # we force the previous window to be dumped into RFA chunks
771
+ past_key_values = self._multi_byte_pred_update_cache_when_prefil_len_eq_window_size(
772
+ past_key_values
773
+ )
774
+ medusa_logits, logits = split_logits(full_logits)
775
+
776
+ past_attn_mask = None
777
+ last_iter_new_tokens = 0
778
+ max_iters = 32768
779
+ if return_acc_lengths:
780
+ acc_lengths = []
781
+ for _ in range(max_iters):
782
+ ####################################################
783
+ # 1. generate candidate_ids with topk predictions from Medusa heads
784
+ ####################################################
785
+ tree_candidate_ids, unflattened_candidate_ids = generate_candidates(
786
+ medusa_logits,
787
+ logits,
788
+ medusa_buffers["tree_indices"],
789
+ medusa_buffers["retrieve_indices"],
790
+ temperature=temperature,
791
+ posterior_alpha=posterior_alpha,
792
+ posterior_threshold=posterior_threshold,
793
+ top_p=top_p,
794
+ sampling=sampling,
795
+ fast=fast,
796
+ )
797
+
798
+ ####################################################
799
+ # 2. Build the medusa attention mask and position ids
800
+ ####################################################
801
+ # NOTE: 1 indicates the position is not masked
802
+ medusa_attn_mask, past_attn_mask = self.multi_byte_pred_update_attn_mask(
803
+ last_iter_new_tokens,
804
+ tree_candidate_ids,
805
+ past_attn_mask,
806
+ medusa_buffers["medusa_attn_mask"],
807
+ past_key_values,
808
+ )
809
+ medusa_position_ids = medusa_buffers["medusa_position_ids"] + input_ids.shape[1]
810
+
811
+ ####################################################
812
+ # 3. tree decoding
813
+ ####################################################
814
+ tree_full_logits, past_key_values = self.forward(
815
+ tree_candidate_ids,
816
+ past_key_values=past_key_values,
817
+ attention_mask=medusa_attn_mask,
818
+ position_ids=medusa_position_ids,
819
+ return_all_pred_logits=True,
820
+ multibyte_decoding=True,
821
+ )
822
+ _medusa_logits, _logits = split_logits(tree_full_logits)
823
+ medusa_logits = _medusa_logits[..., 0, medusa_buffers["retrieve_indices"], :]
824
+ logits = _logits[..., 0, medusa_buffers["retrieve_indices"], :]
825
+
826
+ ####################################################
827
+ # 4. candidate selection
828
+ ####################################################
829
+ # if the current iteration, with tree tokens, crosses window
830
+ # boundaries, trim the condidate_ids to be within the window
831
+ # so that those exceeded tokens (which will be inaccurate)
832
+ # will not be considered
833
+ tree_depth = unflattened_candidate_ids.shape[-1]
834
+ if tree_depth + past_key_values.get_past_window_pos(0) > self.config.window_size:
835
+ max_acc_len = self.config.window_size - past_key_values.get_past_window_pos(0)
836
+ _trimmed_unflattened_candidate_ids = unflattened_candidate_ids[:, :max_acc_len]
837
+ _trimmed_logits = logits[:, :max_acc_len]
838
+ else:
839
+ _trimmed_unflattened_candidate_ids = unflattened_candidate_ids
840
+ _trimmed_logits = logits
841
+ best_candidate, accept_length = evaluate_posterior(
842
+ _trimmed_logits,
843
+ _trimmed_unflattened_candidate_ids,
844
+ temperature,
845
+ posterior_threshold,
846
+ posterior_alpha,
847
+ top_p=top_p,
848
+ sampling=sampling,
849
+ fast=fast
850
+ )
851
+
852
+ ####################################################
853
+ # 5. update model inputs and caches
854
+ ####################################################
855
+ input_ids, medusa_logits, logits, last_iter_new_tokens = update_inference_inputs(
856
+ input_ids,
857
+ medusa_logits,
858
+ logits,
859
+ unflattened_candidate_ids,
860
+ best_candidate,
861
+ accept_length,
862
+ )
863
+
864
+ past_key_values = self.multi_byte_pred_update_cache(
865
+ past_key_values,
866
+ medusa_buffers["retrieve_indices"],
867
+ best_candidate,
868
+ last_iter_new_tokens,
869
+ )
870
+
871
+ if return_acc_lengths:
872
+ acc_lengths.append(last_iter_new_tokens)
873
+ if stop_criteria(input_ids, None) or eos_stop_criteria(input_ids, last_iter_new_tokens):
874
+ if return_acc_lengths:
875
+ return input_ids, acc_lengths
876
+ else:
877
+ return input_ids
878
+ if return_acc_lengths:
879
+ return input_ids, acc_lengths
880
+ else:
881
+ return input_ids