Pclanglais commited on
Commit
a4ca909
·
verified ·
1 Parent(s): bd57b59

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. config.json +1 -1
  2. model.py +1198 -0
config.json CHANGED
@@ -1,6 +1,6 @@
1
  {
2
  "architectures": [
3
- "LlamaForCausalLM"
4
  ],
5
  "attention_bias": false,
6
  "attention_dropout": 0.0,
 
1
  {
2
  "architectures": [
3
+ "LlamaForTrainingFromOurNanotron"
4
  ],
5
  "attention_bias": false,
6
  "attention_dropout": 0.0,
model.py ADDED
@@ -0,0 +1,1198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch LLaMa model."""
16
+
17
+ from typing import Dict, List, Optional, Union
18
+
19
+ import torch
20
+ from torch import nn
21
+ from torch.utils.checkpoint import CheckpointFunction
22
+
23
+ from nanotron import distributed as dist
24
+ from nanotron import logging
25
+ from nanotron.config import Config, LlamaConfig, ParallelismArgs
26
+ from nanotron.config.models_config import RandomInit, SpectralMupInit
27
+ from nanotron.generation.generate_store import AttachableStore
28
+ from nanotron.logging import log_rank
29
+ from nanotron.models import NanotronModel
30
+ from nanotron.nn.activations import ACT2FN
31
+ from nanotron.nn.layer_norm import TritonRMSNorm
32
+ from nanotron.parallel import ParallelContext
33
+ from nanotron.parallel.parameters import NanotronParameter
34
+ from nanotron.parallel.pipeline_parallel.block import PipelineBlock, TensorPointer
35
+ from nanotron.parallel.pipeline_parallel.p2p import P2P
36
+ from nanotron.parallel.tensor_parallel.functional import sharded_cross_entropy
37
+ from nanotron.parallel.tensor_parallel.nn import (
38
+ TensorParallelColumnLinear,
39
+ TensorParallelEmbedding,
40
+ TensorParallelLinearMode,
41
+ TensorParallelRowLinear,
42
+ )
43
+ from nanotron.random import RandomStates
44
+ from nanotron.scaling.parametrization import SpectralMupParametrizator, StandardParametrizator
45
+ from nanotron.utils import checkpoint_method
46
+
47
+ logger = logging.get_logger(__name__)
48
+
49
+
50
+ class RotaryEmbedding(nn.Module):
51
+ def __init__(self, dim: int, end: int, theta: float = 10000.0):
52
+ super().__init__()
53
+ assert dim % 2 == 0
54
+ self.dim = dim
55
+ self.end = end
56
+ self.theta = theta
57
+ # TODO @nouamane: Figure out why we can't set `DTypeInvariantTensor` ...
58
+ # TODO @thomasw21: Complex buffers break DDP, instead we store float and view them as complex
59
+ self.freqs_cis: torch.Tensor
60
+ self._initialized_buffer = False
61
+
62
+ def init_rotary_embeddings(self):
63
+ if self._initialized_buffer is True:
64
+ # Buffer if already initialized
65
+ return
66
+ self.register_buffer(
67
+ "freqs_cis",
68
+ torch.empty(self.end, self.dim // 2, 2, dtype=torch.float, device="cuda"),
69
+ persistent=False,
70
+ )
71
+ assert self.freqs_cis.device.type == "cuda"
72
+ # TODO @nouamane: One we figure out how to do the DTypeInvariantTensor, this can be removed and changed to an assert
73
+ if self.freqs_cis.dtype != torch.float:
74
+ self.freqs_cis = self.freqs_cis.to(torch.float)
75
+ assert self.freqs_cis.dtype == torch.float
76
+ freqs = 1.0 / (
77
+ self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float, device="cpu")[: (self.dim // 2)] / self.dim)
78
+ ).to(
79
+ "cuda"
80
+ ) # should be computed on CPU, otherwise different results with Transformers.
81
+ t = torch.arange(self.end, device="cuda")
82
+ freqs = torch.outer(t, freqs).float()
83
+ complex_freqs = torch.polar(torch.ones_like(freqs), freqs)
84
+ freqs = torch.view_as_real(complex_freqs)
85
+ self.freqs_cis.copy_(freqs)
86
+ self._initialized_buffer = True
87
+
88
+ def forward(
89
+ self,
90
+ x: torch.Tensor, # [batch_size, seq_length, num_heads, d_qk]
91
+ position_ids: Optional[torch.LongTensor], # [batch_size, seq_length]
92
+ ):
93
+ batch_size, seq_length, num_heads, inner_dim = x.shape
94
+ while (
95
+ position_ids is not None and position_ids[-1, -1] >= self.end
96
+ ) or seq_length >= self.end: # TODO @nouamane: check if this causes cpu-gpu sync
97
+ self.end *= 2
98
+ self._initialized_buffer = False
99
+ if self._initialized_buffer is False:
100
+ print(f"Initializing rotary embeddings with end={self.end}")
101
+ self.init_rotary_embeddings()
102
+ dtype = x.dtype
103
+ assert inner_dim % 2 == 0
104
+ x = x.view(
105
+ batch_size, seq_length, num_heads, inner_dim // 2, 2
106
+ ) # [batch_size, q_length, num_heads, inner_dim]
107
+ if x.dtype == torch.bfloat16:
108
+ x = x.float()
109
+ complex_x = torch.view_as_complex(x) # [batch_size, q_length, num_heads, inner_dim // 2]
110
+ if position_ids is None:
111
+ freqs_cis = self.freqs_cis[None, :seq_length, None, :]
112
+ else:
113
+ # TODO(kunhao): Should None follow the num_heads dimension?
114
+ if position_ids[-1, -1] < 0 or position_ids[-1, -1] >= self.end: # Quick test hopefully
115
+ raise ValueError(f"Position ids must be in the range [0, {self.end}), but got {position_ids}")
116
+ freqs_cis = self.freqs_cis[position_ids][:, :, None, :]
117
+ complex_freqs = torch.view_as_complex(freqs_cis)
118
+ x_out = torch.view_as_real(complex_x * complex_freqs).view(batch_size, seq_length, num_heads, inner_dim)
119
+ return x_out.type(dtype)
120
+
121
+
122
+ ## Copy from transformers. Non interleaved version of RoPE. Will be refactored later
123
+ def rotate_half(x):
124
+ """Rotates half the hidden dims of the input."""
125
+ x1 = x[..., : x.shape[-1] // 2]
126
+ x2 = x[..., x.shape[-1] // 2 :]
127
+ return torch.cat((-x2, x1), dim=-1)
128
+
129
+
130
+ class LlamaRotaryEmbedding(nn.Module):
131
+ def __init__(self, dim: int, end: int, theta: float = 500000.0):
132
+ super().__init__()
133
+ self.dim = dim
134
+ self.end = end
135
+ self.theta = theta
136
+ self.init_rotary_embeddings()
137
+
138
+ def init_rotary_embeddings(self):
139
+ inv_freq = 1.0 / (
140
+ self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float, device="cpu") / self.dim)
141
+ ) # important to compute on CPU
142
+ self.register_buffer(
143
+ "inv_freq", torch.empty(self.dim // 2, dtype=torch.float, device="cuda"), persistent=False
144
+ )
145
+ self.inv_freq = self.inv_freq.to(
146
+ torch.float
147
+ ) # make it float32 before copy to avoid precision loss during copy_
148
+ self.inv_freq.copy_(inv_freq)
149
+
150
+ @torch.no_grad()
151
+ def forward(
152
+ self,
153
+ x: torch.Tensor, # [batch_size, seq_length, num_heads, d_qk]
154
+ position_ids: Optional[torch.LongTensor], # [batch_size, seq_length]
155
+ ):
156
+ # x: [bs, num_attention_heads, seq_len, head_size]
157
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
158
+ position_ids_expanded = position_ids[:, None, :].float()
159
+ # Force float32 since bfloat16 loses precision on long contexts
160
+ # See https://github.com/huggingface/transformers/pull/29285
161
+ device_type = x.device.type
162
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
163
+ with torch.autocast(device_type=device_type, enabled=False):
164
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
165
+ emb = torch.cat((freqs, freqs), dim=-1)
166
+ cos = emb.cos()
167
+ sin = emb.sin()
168
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
169
+
170
+ def rotate_half(self, x):
171
+ """Rotates half the hidden dims of the input."""
172
+ x1 = x[..., : x.shape[-1] // 2]
173
+ x2 = x[..., x.shape[-1] // 2 :]
174
+ return torch.cat((-x2, x1), dim=-1)
175
+
176
+ def apply_rotary_pos_emb(self, q, k, cos, sin, unsqueeze_dim=2):
177
+ """Applies Rotary Position Embedding to the query and key tensors.
178
+
179
+ Args:
180
+ q (`torch.Tensor`): The query tensor.
181
+ k (`torch.Tensor`): The key tensor.
182
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
183
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
184
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
185
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
186
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
187
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
188
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
189
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
190
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
191
+ Returns:
192
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
193
+ """
194
+ cos = cos.unsqueeze(unsqueeze_dim)
195
+ sin = sin.unsqueeze(unsqueeze_dim)
196
+ q_embed = (q * cos) + (self.rotate_half(q) * sin)
197
+ k_embed = (k * cos) + (self.rotate_half(k) * sin)
198
+ return q_embed, k_embed
199
+
200
+
201
+ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=2):
202
+ """Applies Rotary Position Embedding to the query and key tensors.
203
+
204
+ Args:
205
+ q (`torch.Tensor`): The query tensor.
206
+ k (`torch.Tensor`): The key tensor.
207
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
208
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
209
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
210
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
211
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
212
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
213
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
214
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
215
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
216
+ Returns:
217
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
218
+ """
219
+ cos = cos.unsqueeze(unsqueeze_dim)
220
+ sin = sin.unsqueeze(unsqueeze_dim)
221
+ q_embed = (q * cos) + (rotate_half(q) * sin)
222
+ k_embed = (k * cos) + (rotate_half(k) * sin)
223
+ return q_embed, k_embed
224
+
225
+
226
+ class GLUActivation(nn.Module):
227
+ def __init__(self, act_fn_name: str):
228
+ super().__init__()
229
+ self.act = ACT2FN[act_fn_name]
230
+
231
+ def forward(self, merged_states: torch.Tensor):
232
+ gate_states, up_states = torch.split(merged_states, merged_states.shape[-1] // 2, dim=-1)
233
+ return self.act(gate_states) * up_states
234
+
235
+
236
+ class MLP(nn.Module):
237
+ def __init__(
238
+ self,
239
+ config: LlamaConfig,
240
+ parallel_config: Optional[ParallelismArgs],
241
+ tp_pg: dist.ProcessGroup,
242
+ ):
243
+ super().__init__()
244
+
245
+ # TODO @thomasw21: refactor so that we store that default in a single place.
246
+ tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE
247
+ tp_linear_async_communication = (
248
+ parallel_config.tp_linear_async_communication if parallel_config is not None else False
249
+ )
250
+
251
+ gate_up_contiguous_chunks = (
252
+ config.intermediate_size, # shape of gate_linear
253
+ config.intermediate_size, # shape of up_linear
254
+ )
255
+ self.gate_up_proj = TensorParallelColumnLinear(
256
+ config.hidden_size,
257
+ 2 * config.intermediate_size,
258
+ pg=tp_pg,
259
+ mode=tp_mode,
260
+ bias=False,
261
+ async_communication=tp_linear_async_communication,
262
+ contiguous_chunks=gate_up_contiguous_chunks,
263
+ tp_recompute_allgather=parallel_config.tp_recompute_allgather,
264
+ )
265
+ self.down_proj = TensorParallelRowLinear(
266
+ config.intermediate_size,
267
+ config.hidden_size,
268
+ pg=tp_pg,
269
+ mode=tp_mode,
270
+ bias=False,
271
+ async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER,
272
+ )
273
+ self.split_silu_mul = torch.compile(GLUActivation(config.hidden_act))
274
+
275
+ def forward(self, hidden_states): # [seq_length, batch_size, hidden_dim]
276
+ merged_states = self.gate_up_proj(hidden_states)
277
+ hidden_states = self.down_proj(self.split_silu_mul(merged_states))
278
+ return {"hidden_states": hidden_states}
279
+
280
+
281
+ class CoreAttention(nn.Module):
282
+ def __init__(self, config: LlamaConfig, parallel_config: Optional[ParallelismArgs], layer_idx: int):
283
+ super().__init__()
284
+ # TODO @thomasw21: GPT has a weird `d_kv` config which I'm guessing is essentically a `d_qkv`
285
+ assert (
286
+ config.hidden_size % config.num_attention_heads == 0
287
+ ), f"Hidden size {config.hidden_size} must be divisible by number of attention heads {config.num_attention_heads}."
288
+ self.d_qk = config.hidden_size // config.num_attention_heads
289
+ self.d_v = config.hidden_size // config.num_attention_heads
290
+ self.is_using_mup = config.is_using_mup
291
+
292
+ self.checkpoint_attention = False # Because flash_attn already does checkpointing
293
+
294
+ @checkpoint_method(attr_name="checkpoint_attention")
295
+ def forward(
296
+ self,
297
+ query_states: torch.Tensor, # [batch_size * q_length, n_local_q_heads, inner_dim]
298
+ key_states: torch.Tensor, # [batch_size * kv_length, n_local_kv_heads, inner_dim]
299
+ value_states: torch.Tensor, # [batch_size * kv_length, n_local_kv_heads, inner_dim]
300
+ q_sequence_mask: torch.Tensor, # torch.BoolTensor [batch_size, q_length] (can be broadcasted to that size)
301
+ kv_sequence_mask: torch.Tensor, # torch.BoolTensor [batch_size, kv_length] (can be broadcasted to that size)
302
+ ):
303
+ from flash_attn.flash_attn_interface import flash_attn_varlen_func
304
+
305
+ # TODO @thomasw21: Compute once, instead of computing for each layers.
306
+ cu_seqlens_q = torch.zeros((q_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device)
307
+ cu_seqlens_k = torch.zeros((kv_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device)
308
+ torch.cumsum(q_sequence_mask.sum(-1, dtype=torch.int32), dim=0, dtype=torch.int32, out=cu_seqlens_q[1:])
309
+ torch.cumsum(kv_sequence_mask.sum(-1, dtype=torch.int32), dim=0, dtype=torch.int32, out=cu_seqlens_k[1:])
310
+
311
+ # TODO(kunhao): flash attn's causal means that the query can only attend to the keys before it. This is not
312
+ # what we want if we are using kv cache. This is a hack as we always have q_length == 1 when using kv cache.
313
+ causal = False if q_sequence_mask.shape[1] == 1 else True
314
+
315
+ # NOTE: this scale is for µTransfer,
316
+ # in SP, we use sqrt(1/d_h)
317
+ softmax_scale = 1 / query_states.shape[-1] if self.is_using_mup else None
318
+ attn_output = flash_attn_varlen_func(
319
+ q=query_states,
320
+ k=key_states,
321
+ v=value_states,
322
+ cu_seqlens_q=cu_seqlens_q,
323
+ cu_seqlens_k=cu_seqlens_k,
324
+ max_seqlen_q=q_sequence_mask.shape[1],
325
+ max_seqlen_k=kv_sequence_mask.shape[1],
326
+ dropout_p=0.0,
327
+ softmax_scale=softmax_scale,
328
+ causal=causal,
329
+ return_attn_probs=False,
330
+ )
331
+
332
+ return attn_output
333
+
334
+
335
+ def pad_to_right(tensor, mask, new_tensor=None):
336
+ """Transform a left-padded tensor into a right-padded tensor. (Useful for prefilling key/value states)
337
+ Args:
338
+ tensor: (batch_size, seqlen, d1, d2)
339
+ mask: (batch_size, seqlen)
340
+ new_tensor: (batch_size, new_tensor_seqlen, d1, d2)
341
+ Returns:
342
+ new_tensor: (batch_size, new_tensor_seqlen, d1, d2)
343
+ right_padded_mask: (batch_size, seqlen)
344
+ """
345
+ # First, we need to find the number of padding for each row
346
+ unpad_seqlens = mask.sum(1)
347
+ # Then, we need to find the maximum length of the tensor
348
+ max_seqlen = mask.shape[1]
349
+ # We can then create the indices to select the padded values
350
+ # The indices are the same for each row
351
+ indices = torch.arange(max_seqlen, device=mask.device)
352
+ # We can then create the mask for the padded values
353
+ right_padded_mask = indices < unpad_seqlens[:, None]
354
+ # We select the useful values
355
+ useful_values = tensor[mask]
356
+ # We create the new tensor (if not provided)
357
+ new_tensor = torch.zeros_like(tensor) if new_tensor is None else new_tensor
358
+ # We fill the new tensor with the useful values
359
+ new_tensor[:, : right_padded_mask.shape[1], :, :][right_padded_mask] = useful_values
360
+ return new_tensor, right_padded_mask
361
+
362
+
363
+ class CausalSelfAttention(nn.Module, AttachableStore):
364
+ def __init__(
365
+ self,
366
+ config: LlamaConfig,
367
+ parallel_config: Optional[ParallelismArgs],
368
+ tp_pg: dist.ProcessGroup,
369
+ layer_idx: int,
370
+ ):
371
+ from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
372
+
373
+ super().__init__()
374
+ # Tensor parallel considerations: We split tensors along head dimension
375
+ assert (
376
+ config.num_attention_heads % tp_pg.size() == 0
377
+ ), f"Number of attention heads ({config.num_attention_heads}) must be divisible by TP size ({tp_pg.size()})."
378
+ try:
379
+ assert (
380
+ config.num_key_value_heads % tp_pg.size() == 0
381
+ ), f"Number of key/value heads ({config.num_key_value_heads}) must be divisible by TP size ({tp_pg.size()})."
382
+ except AttributeError:
383
+ log_rank(
384
+ "WARNING: num_key_value_heads not defined, assuming it is equal to num_attention_heads",
385
+ logger=logger,
386
+ level=logging.WARNING,
387
+ rank=0,
388
+ )
389
+ # If num_key_value_heads is not defined, we assume that it is equal to num_attention_heads
390
+ config.num_key_value_heads = config.num_attention_heads
391
+ assert (
392
+ config.num_attention_heads % config.num_key_value_heads == 0
393
+ ), f"Number of attention heads ({config.num_attention_heads}) must be divisible by number of key/value heads ({config.num_key_value_heads})."
394
+ self.n_local_q_heads = config.num_attention_heads // tp_pg.size()
395
+ self.n_local_kv_heads = config.num_key_value_heads // tp_pg.size()
396
+ self.n_repeats = config.num_attention_heads // config.num_key_value_heads
397
+ self.is_gqa = config.num_attention_heads != config.num_key_value_heads # Whether we are using GQA or not
398
+ self.d_qk = config.hidden_size // config.num_attention_heads
399
+ self.d_v = config.hidden_size // config.num_attention_heads
400
+ self.d_model = config.hidden_size
401
+ self.is_using_mup = config.is_using_mup
402
+
403
+ # TODO @thomasw21: refactor so that we store that default in a single place.
404
+ tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE
405
+ tp_linear_async_communication = (
406
+ parallel_config.tp_linear_async_communication if parallel_config is not None else False
407
+ )
408
+
409
+ # build the slice config for self.qkv for save/load
410
+ # shard are done within the contiguous chunk
411
+ qkv_contiguous_chunks = (
412
+ config.num_attention_heads * self.d_qk, # shape of q
413
+ config.num_key_value_heads * self.d_qk, # shape of k
414
+ config.num_key_value_heads * self.d_qk, # shape of v
415
+ )
416
+ self.qkv_proj = TensorParallelColumnLinear(
417
+ self.d_model,
418
+ config.num_attention_heads * self.d_qk + 2 * config.num_key_value_heads * self.d_qk,
419
+ pg=tp_pg,
420
+ mode=tp_mode,
421
+ bias=False,
422
+ async_communication=tp_linear_async_communication,
423
+ contiguous_chunks=qkv_contiguous_chunks,
424
+ tp_recompute_allgather=parallel_config.tp_recompute_allgather,
425
+ )
426
+ # TODO(kunhao): We want to have only one version per device and not one version per layer.
427
+
428
+ if config.rope_interleaved:
429
+ self.rotary_embedding = RotaryEmbedding(
430
+ dim=self.d_qk,
431
+ end=config.max_position_embeddings,
432
+ theta=config.rope_theta,
433
+ )
434
+ else:
435
+ self.rotary_embedding = LlamaRotaryEmbedding(
436
+ dim=self.d_qk,
437
+ end=config.max_position_embeddings,
438
+ theta=config.rope_theta,
439
+ )
440
+ self.rope_interleaved = config.rope_interleaved
441
+
442
+ # NOTE: Only supported for training (TODO(fmom): position_ids not supported yet)
443
+ self.flash_rotary_embedding = FlashRotaryEmbedding(dim=self.d_qk, base=config.rope_theta, interleaved=config.rope_interleaved)
444
+
445
+ self.o_proj = TensorParallelRowLinear(
446
+ config.num_attention_heads * self.d_qk,
447
+ self.d_model,
448
+ pg=tp_pg,
449
+ mode=tp_mode,
450
+ bias=False,
451
+ async_communication=tp_linear_async_communication,
452
+ )
453
+
454
+ self.attention = CoreAttention(
455
+ config,
456
+ parallel_config=parallel_config,
457
+ layer_idx=layer_idx,
458
+ )
459
+
460
+ self.prefill_kv_len = (
461
+ config.max_position_embeddings
462
+ ) # TODO @nouamane: compute based on free memory, because in rope we can surpass max_position_embeddings
463
+
464
+ def forward(
465
+ self,
466
+ hidden_states, # [seq_length, batch_size, hidden_size]
467
+ sequence_mask, # [batch_size, seq_length]
468
+ ):
469
+ from flash_attn import bert_padding
470
+ from flash_attn.flash_attn_interface import (
471
+ flash_attn_varlen_func,
472
+ flash_attn_with_kvcache,
473
+ )
474
+
475
+ qkv_states = self.qkv_proj(
476
+ hidden_states
477
+ ) # [seq_length, batch_size, n_local_q_heads * d_qk + 2 * n_local_kv_heads * d_qk]
478
+ q_length, batch_size, _ = qkv_states.shape
479
+
480
+ if self.is_gqa:
481
+ query_states, key_states, value_states = torch.split(
482
+ qkv_states,
483
+ [
484
+ self.n_local_q_heads * self.d_qk,
485
+ self.n_local_kv_heads * self.d_qk,
486
+ self.n_local_kv_heads * self.d_qk,
487
+ ],
488
+ dim=-1,
489
+ )
490
+
491
+ query_states = (
492
+ query_states.transpose(0, 1).contiguous().view(batch_size, q_length, self.n_local_q_heads, self.d_qk)
493
+ )
494
+ key_states = (
495
+ key_states.transpose(0, 1).contiguous().view(batch_size, q_length, self.n_local_kv_heads, self.d_qk)
496
+ )
497
+ value_states = (
498
+ value_states.transpose(0, 1).contiguous().view(batch_size, q_length, self.n_local_kv_heads, self.d_qk)
499
+ )
500
+ else:
501
+ query_states, key_states, value_states = (
502
+ qkv_states.view(q_length, batch_size, 3, self.n_local_q_heads, self.d_qk)
503
+ .permute(2, 1, 0, 3, 4)
504
+ .contiguous()
505
+ ) # [3, batch_size, seq_length, n_local_q_heads, d_qk]
506
+
507
+ store = self.get_local_store()
508
+ if store is not None: # Inference case
509
+ # Double check that we use store only at inference time
510
+ assert key_states.requires_grad is False
511
+ assert value_states.requires_grad is False
512
+ if "position_offsets" in store:
513
+ old_position_offsets = store["position_offsets"]
514
+ position_ids = old_position_offsets[:, None] + sequence_mask
515
+ else:
516
+ position_ids = torch.cumsum(sequence_mask, dim=-1, dtype=torch.int32) - 1
517
+ position_offsets = position_ids[:, -1]
518
+
519
+ # Compute rotary embeddings
520
+ # Note: keep track of old rotary embedding end to check if we need to enlarge k_cache and v_cache
521
+ old_rotary_embed_end = self.rotary_embedding.end
522
+ if self.rope_interleaved:
523
+ query_states = self.rotary_embedding(query_states, position_ids=position_ids)
524
+ key_states = self.rotary_embedding(key_states, position_ids=position_ids)
525
+ else:
526
+ cos, sin = self.rotary_embedding(value_states, position_ids)
527
+ query_states, key_states = self.rotary_embedding.apply_rotary_pos_emb(
528
+ query_states, key_states, cos, sin
529
+ )
530
+
531
+ if "key" not in store:
532
+ # First inference iteration (Prefill)
533
+ # TODO @nouamane: support custom masking
534
+ # assert that [ False, False, False, False, True, True, True, True, True, True] is accepted
535
+ # but [ False, False, False, False, True, True, False, False, True, True] is not (can't mask in the middle of sequence)
536
+ assert ~(
537
+ sequence_mask[:, :-1] & (~sequence_mask[:, 1:]) # True is never followed by False
538
+ ).any(), "Can't mask in the middle of sequence, please make sure that pads are at the left of the sequence if existing"
539
+
540
+ # preallocate k_cache, v_cache to self.prefill_kv_len
541
+ k_cache = torch.zeros(
542
+ (
543
+ batch_size,
544
+ self.prefill_kv_len,
545
+ self.n_local_kv_heads,
546
+ self.d_qk,
547
+ ),
548
+ dtype=query_states.dtype,
549
+ device=query_states.device,
550
+ )
551
+ v_cache = torch.zeros(
552
+ (batch_size, self.prefill_kv_len, self.n_local_kv_heads, self.d_v),
553
+ dtype=query_states.dtype,
554
+ device=query_states.device,
555
+ )
556
+ # Remove pad tokens from key_states and concatenate samples in key_unpad
557
+ # cu_seqlens_k is the cumulative sequence lengths of key_states
558
+ (query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input(
559
+ query_states,
560
+ sequence_mask,
561
+ )
562
+ (key_unpad, indices_k, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input(
563
+ key_states, sequence_mask
564
+ )
565
+ (value_unpad, _, _, _) = bert_padding.unpad_input(value_states, sequence_mask)
566
+
567
+ # NOTE: this scale is for µTransfer,
568
+ # in SP, we use sqrt(1/d_h)
569
+ softmax_scale = 1 / query_states.shape[-1] if self.is_using_mup else None
570
+ output_unpad = flash_attn_varlen_func(
571
+ q=query_unpad, # (total_q, n_local_q_heads, d_qk)
572
+ k=key_unpad, # (total_kv, n_local_kv_heads, d_qk)
573
+ v=value_unpad, # (total_kv, n_local_kv_heads, d_v)
574
+ cu_seqlens_q=cu_seqlens_q,
575
+ cu_seqlens_k=cu_seqlens_k,
576
+ max_seqlen_q=max_seqlen_q,
577
+ max_seqlen_k=max_seqlen_k,
578
+ dropout_p=0.0,
579
+ softmax_scale=softmax_scale,
580
+ causal=True, # True in prefill phase, False in subsequent phases
581
+ return_attn_probs=False,
582
+ ) # (total_unpadded, n_local_q_heads, d_v)
583
+
584
+ attention_output = bert_padding.pad_input(
585
+ output_unpad, indices_q, batch_size, q_length
586
+ ) # (batch_size, q_length, n_local_q_heads, d_v)
587
+
588
+ pad_to_right(key_states, sequence_mask, new_tensor=k_cache)
589
+ pad_to_right(value_states, sequence_mask, new_tensor=v_cache)
590
+
591
+ else:
592
+ # Pull pre-computed key/value states
593
+ # Subsequent inference iterations (q_length=1)
594
+ k_cache = store["key"]
595
+ v_cache = store["value"]
596
+
597
+ # NOTE(fmom): According to flash_attn_with_kvcache, "If you pass in k / v, you must make sure that the cache is large enough to hold the new values"
598
+ # Since rotary embedding has changed (to enable larger context), we need to enlarge k_cache and v_cache
599
+ if self.rotary_embedding.end > old_rotary_embed_end:
600
+ k_cache = torch.cat(
601
+ [
602
+ k_cache,
603
+ torch.zeros(
604
+ (
605
+ batch_size,
606
+ self.rotary_embedding.end - old_rotary_embed_end,
607
+ self.n_local_kv_heads,
608
+ self.d_qk,
609
+ ),
610
+ dtype=query_states.dtype,
611
+ device=query_states.device,
612
+ ),
613
+ ],
614
+ dim=1,
615
+ )
616
+
617
+ v_cache = torch.cat(
618
+ [
619
+ v_cache,
620
+ torch.zeros(
621
+ (
622
+ batch_size,
623
+ self.rotary_embedding.end - old_rotary_embed_end,
624
+ self.n_local_kv_heads,
625
+ self.d_v,
626
+ ),
627
+ dtype=query_states.dtype,
628
+ device=query_states.device,
629
+ ),
630
+ ],
631
+ dim=1,
632
+ )
633
+
634
+ assert (
635
+ k_cache.shape[1] == self.rotary_embedding.end
636
+ ), f"Cache size {k_cache.shape[1]} is smaller than rotary embedding end {self.rotary_embedding.end}"
637
+ assert (
638
+ v_cache.shape[1] == self.rotary_embedding.end
639
+ ), f"Cache size {v_cache.shape[1]} is smaller than rotary embedding end {self.rotary_embedding.end}"
640
+
641
+ # [batch_size, seq_length, num_heads, d_qk]
642
+ query_states = query_states.view(
643
+ batch_size, q_length, self.n_local_q_heads, self.d_qk
644
+ ) # [batch_size, q_length, self.n_heads, d_qk]
645
+ kv_length = key_states.shape[1]
646
+ key_states = key_states.view(
647
+ batch_size, kv_length, self.n_local_kv_heads, self.d_qk
648
+ ) # [batch_size, kv_length, self.n_heads, d_qk]
649
+ value_states = value_states.view(
650
+ batch_size, kv_length, self.n_local_kv_heads, self.d_v
651
+ ) # [batch_size, kv_length, self.n_heads, d_v]
652
+
653
+ # NOTE: this scale is for µTransfer,
654
+ # in SP, we use sqrt(1/d_h)
655
+ softmax_scale = 1 / query_states.shape[-1] if self.is_using_mup else None
656
+ attention_output = flash_attn_with_kvcache(
657
+ query_states,
658
+ k_cache,
659
+ v_cache,
660
+ key_states,
661
+ value_states,
662
+ rotary_cos=None,
663
+ rotary_sin=None,
664
+ # TODO @nouamane: seems like this doesn't help to indicate padding in (for first iteration it's just 0)
665
+ cache_seqlens=position_offsets.contiguous(),
666
+ softmax_scale=softmax_scale,
667
+ causal=True,
668
+ rotary_interleaved=False, # GPT-NeoX style
669
+ )
670
+
671
+ store.update(
672
+ {
673
+ "key": k_cache, # flash-attn has updated with new key_states using cache_seqlens
674
+ "value": v_cache,
675
+ "position_offsets": position_offsets,
676
+ }
677
+ )
678
+
679
+ else: # Training case
680
+ # Apply rotary embeddings to query/key states
681
+ # NOTE: The layout is different from models/llama.py which is [batch_size, num_heads, seq_length, d_qk]
682
+ # Here it is, [batch_size, seq_length, num_heads, d_qk]
683
+ # [2, batch_size, seq_length, num_heads, d_qk]
684
+ key_value_states = torch.cat([key_states.unsqueeze(0), value_states.unsqueeze(0)], dim=0)
685
+ # [batch_size, seq_length, 2, num_heads, d_qk]
686
+ key_value_states = key_value_states.permute(1, 2, 0, 3, 4).contiguous()
687
+ query_states, key_value_states = self.flash_rotary_embedding(query_states, kv=key_value_states)
688
+ # [batch_size, seq_length, num_heads, d_qk]
689
+ key_states, value_states = torch.split(key_value_states, 1, dim=2)
690
+
691
+ q_sequence_mask = sequence_mask
692
+ kv_sequence_mask = sequence_mask
693
+
694
+ kv_length = key_states.shape[1]
695
+ # [batch_size, seq_length, num_heads, d_qk]
696
+ # Shaping for use in `flash-attn` version of flash-attn: `flash_attn_unpadded_func`
697
+ query_states = query_states.view(
698
+ batch_size * q_length, self.n_local_q_heads, self.d_qk
699
+ ) # [batch_size * q_length, self.n_heads, d_qk]
700
+
701
+ key_states = key_states.view(
702
+ batch_size * kv_length, self.n_local_kv_heads, self.d_qk
703
+ ) # [batch_size * kv_length, self.n_heads, d_qk]
704
+ value_states = value_states.view(
705
+ batch_size * kv_length, self.n_local_kv_heads, self.d_v
706
+ ) # [batch_size * kv_length, self.n_heads, d_v]
707
+
708
+ attention_output = self.attention(
709
+ query_states=query_states,
710
+ key_states=key_states,
711
+ value_states=value_states,
712
+ q_sequence_mask=q_sequence_mask,
713
+ kv_sequence_mask=kv_sequence_mask,
714
+ )
715
+
716
+ attention_output = (
717
+ attention_output.contiguous().view(batch_size, q_length, self.n_local_q_heads * self.d_v).transpose(0, 1)
718
+ )
719
+ output = self.o_proj(attention_output)
720
+
721
+ return {"hidden_states": output, "sequence_mask": sequence_mask}
722
+
723
+
724
+ class LlamaDecoderLayer(nn.Module):
725
+ def __init__(
726
+ self,
727
+ config: LlamaConfig,
728
+ parallel_config: Optional[ParallelismArgs],
729
+ tp_pg: dist.ProcessGroup,
730
+ layer_idx: int,
731
+ ):
732
+ super().__init__()
733
+ self.input_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
734
+ self.attn = CausalSelfAttention(
735
+ config=config,
736
+ parallel_config=parallel_config,
737
+ tp_pg=tp_pg,
738
+ layer_idx=layer_idx,
739
+ )
740
+
741
+ self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
742
+ self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg)
743
+
744
+ self.recompute_layer = parallel_config.recompute_layer
745
+
746
+ def _core_forward(
747
+ self,
748
+ hidden_states: Union[torch.Tensor, TensorPointer],
749
+ sequence_mask: Union[torch.Tensor, TensorPointer],
750
+ ) -> List[Union[torch.Tensor, TensorPointer]]:
751
+ residual = hidden_states
752
+ hidden_states = self.input_layernorm(hidden_states)
753
+
754
+ output = self.attn(hidden_states=hidden_states, sequence_mask=sequence_mask)
755
+ hidden_states = output["hidden_states"]
756
+ hidden_states = hidden_states + residual
757
+
758
+ residual = hidden_states
759
+ hidden_states = self.post_attention_layernorm(hidden_states)
760
+ hidden_states = self.mlp(hidden_states=hidden_states)["hidden_states"]
761
+ hidden_states = hidden_states + residual
762
+
763
+ return hidden_states, output["sequence_mask"]
764
+
765
+ def _checkpointed_forward(
766
+ self,
767
+ hidden_states: torch.Tensor,
768
+ sequence_mask: torch.Tensor,
769
+ ) -> List[torch.Tensor]:
770
+ return CheckpointFunction.apply(self._core_forward, True, hidden_states, sequence_mask)
771
+
772
+ def forward(
773
+ self,
774
+ hidden_states: Union[torch.Tensor, TensorPointer],
775
+ sequence_mask: Union[torch.Tensor, TensorPointer],
776
+ ) -> Dict[str, Union[torch.Tensor, TensorPointer]]:
777
+
778
+ if self.recompute_layer and not isinstance(hidden_states, TensorPointer):
779
+ hidden_states, sequence_mask = self._checkpointed_forward(hidden_states, sequence_mask)
780
+ else:
781
+ hidden_states, sequence_mask = self._core_forward(hidden_states, sequence_mask)
782
+
783
+ return {
784
+ "hidden_states": hidden_states,
785
+ "sequence_mask": sequence_mask,
786
+ }
787
+
788
+
789
+ class Embedding(nn.Module, AttachableStore):
790
+ def __init__(self, tp_pg: dist.ProcessGroup, config: LlamaConfig, parallel_config: Optional[ParallelismArgs]):
791
+ super().__init__()
792
+ self.token_embedding = TensorParallelEmbedding(
793
+ num_embeddings=config.vocab_size,
794
+ embedding_dim=config.hidden_size,
795
+ padding_idx=config.pad_token_id,
796
+ pg=tp_pg,
797
+ mode=parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE,
798
+ )
799
+ self.pg = tp_pg
800
+
801
+ def forward(self, input_ids: torch.Tensor, input_mask: torch.Tensor): # [batch_size, seq_length]
802
+ store = self.get_local_store()
803
+ if store is not None:
804
+ if "past_length" in store:
805
+ past_length = store["past_length"]
806
+ else:
807
+ past_length = torch.zeros(1, dtype=torch.long, device=input_ids.device).expand(input_ids.shape[0])
808
+
809
+ cumsum_mask = input_mask.cumsum(-1, dtype=torch.long)
810
+ # Store new past_length in store
811
+ store["past_length"] = past_length + cumsum_mask[:, -1]
812
+
813
+ # Format input in `[seq_length, batch_size]` to support high TP with low batch_size
814
+ input_ids = input_ids.transpose(0, 1)
815
+ input_embeds = self.token_embedding(input_ids)
816
+ return {"input_embeds": input_embeds}
817
+
818
+
819
+ class LlamaModel(nn.Module):
820
+ """Build pipeline graph"""
821
+
822
+ def __init__(
823
+ self,
824
+ config: LlamaConfig,
825
+ parallel_context: ParallelContext,
826
+ parallel_config: Optional[ParallelismArgs],
827
+ ):
828
+ super().__init__()
829
+
830
+ # Declare all the nodes
831
+ self.p2p = P2P(parallel_context.pp_pg, device=torch.device("cuda"))
832
+ self.config = config
833
+ self.parallel_config = parallel_config
834
+ self.parallel_context = parallel_context
835
+ self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE
836
+ tp_linear_async_communication = (
837
+ parallel_config.tp_linear_async_communication if parallel_config is not None else False
838
+ )
839
+
840
+ self.token_position_embeddings = PipelineBlock(
841
+ p2p=self.p2p,
842
+ module_builder=Embedding,
843
+ module_kwargs={
844
+ "tp_pg": parallel_context.tp_pg,
845
+ "config": config,
846
+ "parallel_config": parallel_config,
847
+ },
848
+ module_input_keys={"input_ids", "input_mask"},
849
+ module_output_keys={"input_embeds"},
850
+ )
851
+
852
+ log_rank(f"Initialize RoPE Theta = {config.rope_theta}", logger=logger, level=logging.INFO, rank=0)
853
+ if config.rope_interleaved:
854
+ log_rank(
855
+ "The RoPE interleaved version differs from the Transformers implementation. It's better to set rope_interleaved=False if you need to convert the weights to Transformers",
856
+ logger=logger,
857
+ level=logging.INFO,
858
+ rank=0,
859
+ )
860
+
861
+ self.decoder = nn.ModuleList(
862
+ [
863
+ PipelineBlock(
864
+ p2p=self.p2p,
865
+ module_builder=LlamaDecoderLayer,
866
+ module_kwargs={
867
+ "config": config,
868
+ "parallel_config": parallel_config,
869
+ "tp_pg": parallel_context.tp_pg,
870
+ "layer_idx": layer_idx,
871
+ },
872
+ module_input_keys={"hidden_states", "sequence_mask"},
873
+ module_output_keys={"hidden_states", "sequence_mask"},
874
+ )
875
+ for layer_idx in range(config.num_hidden_layers)
876
+ ]
877
+ )
878
+
879
+ self.final_layer_norm = PipelineBlock(
880
+ p2p=self.p2p,
881
+ module_builder=TritonRMSNorm,
882
+ module_kwargs={"hidden_size": config.hidden_size, "eps": config.rms_norm_eps},
883
+ module_input_keys={"input"},
884
+ module_output_keys={"hidden_states"},
885
+ ) # TODO
886
+
887
+ self.lm_head = PipelineBlock(
888
+ p2p=self.p2p,
889
+ # Understand that this means that we return sharded logits that are going to need to be gathered
890
+ module_builder=TensorParallelColumnLinear,
891
+ module_kwargs={
892
+ "in_features": config.hidden_size,
893
+ "out_features": config.vocab_size,
894
+ "pg": parallel_context.tp_pg,
895
+ "bias": False,
896
+ # TODO @thomasw21: refactor so that we store that default in a single place.
897
+ "mode": self.tp_mode,
898
+ "async_communication": tp_linear_async_communication,
899
+ "tp_recompute_allgather": parallel_config.tp_recompute_allgather,
900
+ },
901
+ module_input_keys={"x"},
902
+ module_output_keys={"logits"},
903
+ )
904
+
905
+ self.cast_to_fp32 = PipelineBlock(
906
+ p2p=self.p2p,
907
+ module_builder=lambda: lambda x: x.float(),
908
+ module_kwargs={},
909
+ module_input_keys={"x"},
910
+ module_output_keys={"output"},
911
+ )
912
+
913
+ def forward(
914
+ self,
915
+ input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length]
916
+ input_mask: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length]
917
+ ):
918
+ return self.forward_with_hidden_states(input_ids=input_ids, input_mask=input_mask)[0]
919
+
920
+ def forward_with_hidden_states(
921
+ self,
922
+ input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length]
923
+ input_mask: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length]
924
+ ):
925
+ # all tensors are optional as most ranks don't need anything from the dataloader.
926
+
927
+ output = self.token_position_embeddings(input_ids=input_ids, input_mask=input_mask)
928
+
929
+ hidden_encoder_states = {
930
+ "hidden_states": output["input_embeds"],
931
+ "sequence_mask": input_mask,
932
+ }
933
+ for encoder_block in self.decoder:
934
+ hidden_encoder_states = encoder_block(**hidden_encoder_states)
935
+
936
+ hidden_states = self.final_layer_norm(input=hidden_encoder_states["hidden_states"])["hidden_states"]
937
+
938
+ sharded_logits = self.lm_head(x=hidden_states)["logits"]
939
+
940
+ fp32_sharded_logits = self.cast_to_fp32(x=sharded_logits)["output"]
941
+
942
+ return fp32_sharded_logits, hidden_states
943
+
944
+ def get_block_compute_costs(self):
945
+ """Computes the compute cost of each block in the model so that we can do a better job of load balancing."""
946
+ model_config = self.config
947
+ d_ff = model_config.intermediate_size
948
+ d_qkv = model_config.hidden_size // model_config.num_attention_heads
949
+ block_compute_costs = {
950
+ # CausalSelfAttention (qkv proj + attn out) + MLP
951
+ LlamaDecoderLayer: 4 * model_config.num_attention_heads * d_qkv * model_config.hidden_size
952
+ + 3 * d_ff * model_config.hidden_size,
953
+ # This is the last lm_head
954
+ TensorParallelColumnLinear: model_config.vocab_size * model_config.hidden_size,
955
+ }
956
+ return block_compute_costs
957
+
958
+ def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch_size):
959
+ """Get flops per second for a given model"""
960
+ world_size = self.parallel_context.world_pg.size()
961
+ try:
962
+ num_key_values_heads = self.config.num_key_value_heads
963
+ except AttributeError:
964
+ num_key_values_heads = self.config.num_attention_heads
965
+
966
+ model_flops, hardware_flops = get_flops(
967
+ num_layers=self.config.num_hidden_layers,
968
+ hidden_size=self.config.hidden_size,
969
+ num_heads=self.config.num_attention_heads,
970
+ num_key_value_heads=num_key_values_heads,
971
+ vocab_size=self.config.vocab_size,
972
+ ffn_hidden_size=self.config.intermediate_size,
973
+ seq_len=sequence_length,
974
+ batch_size=global_batch_size,
975
+ )
976
+
977
+ model_flops_per_s = model_flops / (iteration_time_in_sec * world_size * 1e12)
978
+ hardware_flops_per_s = hardware_flops / (iteration_time_in_sec * world_size * 1e12)
979
+ return model_flops_per_s, hardware_flops_per_s
980
+
981
+
982
+ @torch.jit.script
983
+ def masked_mean(loss, label_mask, dtype):
984
+ # type: (Tensor, Tensor, torch.dtype) -> Tensor
985
+ return (loss * label_mask).sum(dtype=dtype) / label_mask.sum()
986
+
987
+
988
+ class Loss(nn.Module):
989
+ def __init__(self, tp_pg: dist.ProcessGroup):
990
+ super().__init__()
991
+ self.tp_pg = tp_pg
992
+
993
+ def forward(
994
+ self,
995
+ sharded_logits: torch.Tensor, # [seq_length, batch_size, logits]
996
+ label_ids: torch.Tensor, # [batch_size, seq_length]
997
+ label_mask: torch.Tensor, # [batch_size, seq_length]
998
+ ) -> Dict[str, torch.Tensor]:
999
+ # Megatron by defaults cast everything in fp32. `--f16-lm-cross-entropy` is an option you can use to keep current precision.
1000
+ # https://github.com/NVIDIA/Megatron-LM/blob/f267e6186eae1d6e2055b412b00e2e545a8e896a/megatron/model/gpt_model.py#L38
1001
+
1002
+ loss = sharded_cross_entropy(
1003
+ sharded_logits, label_ids.transpose(0, 1).contiguous(), group=self.tp_pg, dtype=torch.float
1004
+ ).transpose(0, 1)
1005
+ # TODO @thomasw21: It's unclear what kind of normalization we want to do.
1006
+ loss = masked_mean(loss, label_mask, dtype=torch.float)
1007
+ # I think indexing causes a sync we don't actually want
1008
+ # loss = loss[label_mask].sum()
1009
+ return {"loss": loss}
1010
+
1011
+
1012
+ class LlamaForTrainingFromOurNanotron(NanotronModel):
1013
+ def __init__(
1014
+ self,
1015
+ config: LlamaConfig,
1016
+ parallel_context: ParallelContext,
1017
+ parallel_config: Optional[ParallelismArgs],
1018
+ random_states: Optional[RandomStates] = None,
1019
+ ):
1020
+ super().__init__()
1021
+ self.model = LlamaModel(config=config, parallel_context=parallel_context, parallel_config=parallel_config)
1022
+ self.loss = PipelineBlock(
1023
+ p2p=self.model.p2p,
1024
+ module_builder=Loss,
1025
+ module_kwargs={"tp_pg": parallel_context.tp_pg},
1026
+ module_input_keys={
1027
+ "sharded_logits",
1028
+ "label_ids",
1029
+ "label_mask",
1030
+ },
1031
+ module_output_keys={"loss"},
1032
+ )
1033
+ self.parallel_context = parallel_context
1034
+ self.config = config
1035
+ self.parallel_config = parallel_config
1036
+
1037
+ def forward(
1038
+ self,
1039
+ input_ids: Union[torch.Tensor, TensorPointer],
1040
+ input_mask: Union[torch.Tensor, TensorPointer],
1041
+ label_ids: Union[torch.Tensor, TensorPointer],
1042
+ label_mask: Union[torch.Tensor, TensorPointer],
1043
+ ) -> Dict[str, Union[torch.Tensor, TensorPointer]]:
1044
+ sharded_logits = self.model(
1045
+ input_ids=input_ids,
1046
+ input_mask=input_mask,
1047
+ )
1048
+ loss = self.loss(
1049
+ sharded_logits=sharded_logits,
1050
+ label_ids=label_ids,
1051
+ label_mask=label_mask,
1052
+ )["loss"]
1053
+ return {"loss": loss}
1054
+
1055
+ @torch.no_grad()
1056
+ def init_model_randomly(self, config: Config):
1057
+ """Initialize model parameters randomly.
1058
+ Note:
1059
+ Layernorm weight all 0 or 1 depending on `apply_layernorm_1p`
1060
+ """
1061
+ init_method = config.model.init_method
1062
+ if isinstance(init_method, RandomInit):
1063
+ parametrizator_cls = StandardParametrizator
1064
+ elif isinstance(init_method, SpectralMupInit):
1065
+ parametrizator_cls = SpectralMupParametrizator
1066
+ else:
1067
+ raise ValueError(f"Unknown init method {init_method}")
1068
+
1069
+ parametrizator = parametrizator_cls(config=config.model)
1070
+
1071
+ log_rank(
1072
+ f"Parametrizing model parameters using {parametrizator.__class__.__name__}",
1073
+ logger=logger,
1074
+ level=logging.INFO,
1075
+ rank=0,
1076
+ )
1077
+
1078
+ model = self
1079
+ initialized_parameters = set()
1080
+ # Handle tensor parallelism
1081
+ module_id_to_prefix = {id(module): f"{module_name}." for module_name, module in model.named_modules()}
1082
+ # Fix the root_model
1083
+ module_id_to_prefix[id(model)] = ""
1084
+
1085
+ for param_name, param in model.named_parameters():
1086
+ assert isinstance(param, NanotronParameter)
1087
+
1088
+ module_name, param_name = param_name.rsplit(".", 1)
1089
+
1090
+ if param.is_tied:
1091
+ tied_info = param.get_tied_info()
1092
+ full_param_name = tied_info.get_full_name_from_module_id_to_prefix(
1093
+ module_id_to_prefix=module_id_to_prefix
1094
+ )
1095
+ else:
1096
+ full_param_name = f"{module_name}.{param_name}"
1097
+
1098
+ if full_param_name in initialized_parameters:
1099
+ # Already initialized
1100
+ continue
1101
+
1102
+ module = model.get_submodule(module_name)
1103
+ parametrizator.parametrize(param_name, module)
1104
+
1105
+ assert full_param_name not in initialized_parameters
1106
+ initialized_parameters.add(full_param_name)
1107
+
1108
+ assert initialized_parameters == {
1109
+ param.get_tied_info().get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix)
1110
+ if param.is_tied
1111
+ else name
1112
+ for name, param in model.named_parameters()
1113
+ }, f"Somehow the initialized set of parameters don't match:\n - Expected: { {name for name, _ in model.named_parameters()} }\n - Got: {initialized_parameters}"
1114
+
1115
+ def get_embeddings_lm_head_tied_names(self):
1116
+ """Get the names of the tied embeddings and lm_head weights"""
1117
+ if self.config.tie_word_embeddings is True:
1118
+ return ["model.token_position_embeddings.pp_block.token_embedding.weight", "model.lm_head.pp_block.weight"]
1119
+ else:
1120
+ return []
1121
+
1122
+ def get_block_compute_costs(self):
1123
+ """Computes the compute cost of each block in the model so that we can do a better job of load balancing."""
1124
+ return self.model.get_block_compute_costs()
1125
+
1126
+ def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch_size):
1127
+ """Get flops per second for a given model"""
1128
+ return self.model.get_flops_per_sec(iteration_time_in_sec, sequence_length, global_batch_size)
1129
+
1130
+
1131
+ def get_flops(
1132
+ num_layers,
1133
+ hidden_size,
1134
+ num_heads,
1135
+ num_key_value_heads,
1136
+ vocab_size,
1137
+ seq_len,
1138
+ ffn_hidden_size,
1139
+ batch_size=1,
1140
+ ):
1141
+ """Counts flops in an decoder-only model
1142
+ Args:
1143
+ num_layers: number of decoder layers
1144
+ hidden_size: hidden size of the model
1145
+ num_heads: number of heads in the model
1146
+ num_key_value_heads: number of key/value heads in the model
1147
+ ffn_hidden_size: hidden size of the FFN
1148
+ vocab_size: size of the vocabulary
1149
+ seq_len: sequence length of the decoder
1150
+ batch_size: batch size
1151
+ Returns:
1152
+ model_flops: flops in the model (should be independent of the hardware and model implementation)
1153
+ hardware_flops: flops in the hardware (actual flops performed on the hardware). Check 6.3 in https://arxiv.org/pdf/2205.05198.pdf
1154
+ """
1155
+ if num_key_value_heads is None:
1156
+ num_key_value_heads = num_heads
1157
+ hidden_size_per_head = hidden_size // num_heads
1158
+ # In the following we mark the reduced dimension with parentheses
1159
+ # decoder
1160
+ # self attention
1161
+ ## qkv projection
1162
+ decoder_qkv_proj_flops_fwd = (
1163
+ 2 * num_layers * batch_size * seq_len * (hidden_size) * num_heads * hidden_size_per_head
1164
+ + 2 * num_layers * batch_size * seq_len * (hidden_size) * 2 * num_key_value_heads * hidden_size_per_head
1165
+ )
1166
+ ## qk logits
1167
+ decoder_qk_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (hidden_size_per_head) * seq_len
1168
+ ## v logits
1169
+ decoder_v_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (seq_len) * hidden_size_per_head
1170
+ ## attn out
1171
+ decoder_attn_out_flops_fwd = (
1172
+ 2 * num_layers * batch_size * num_heads * seq_len * (hidden_size_per_head) * hidden_size
1173
+ )
1174
+ # FF
1175
+ ## 1st layer
1176
+ decoder_ffn_1_flops_fwd = 4 * num_layers * batch_size * seq_len * (hidden_size) * ffn_hidden_size
1177
+ ## 2nd layer
1178
+ decoder_ffn_2_flops_fwd = 2 * num_layers * batch_size * seq_len * (ffn_hidden_size) * hidden_size
1179
+
1180
+ decoder_flops_fwd = (
1181
+ decoder_qkv_proj_flops_fwd
1182
+ + decoder_qk_logits_flops_fwd
1183
+ + decoder_v_logits_flops_fwd
1184
+ + decoder_attn_out_flops_fwd
1185
+ + decoder_ffn_1_flops_fwd
1186
+ + decoder_ffn_2_flops_fwd
1187
+ )
1188
+
1189
+ # lm head
1190
+ lm_head_flops_fwd = 2 * batch_size * seq_len * (hidden_size) * vocab_size
1191
+
1192
+ # the bwd pass requires double the flops in case of matmuls to calculate the gradients with respect to
1193
+ # both input and weight tensors
1194
+ model_flops = 3 * (decoder_flops_fwd + lm_head_flops_fwd) # 1 for fwd + 2 for bwd
1195
+
1196
+ hardware_flops = model_flops # TODO: This is a placeholder for now
1197
+
1198
+ return model_flops, hardware_flops